Nikhil Pravin Pise commited on
Commit
150e705
·
1 Parent(s): 9f32f16

Fix: Dynamic API key loading - secrets now work properly on HuggingFace Spaces

Browse files
Files changed (2) hide show
  1. huggingface/app.py +65 -23
  2. src/llm_config.py +34 -4
huggingface/app.py CHANGED
@@ -36,21 +36,35 @@ logger = logging.getLogger("mediguard.huggingface")
36
  # Configuration
37
  # ---------------------------------------------------------------------------
38
 
39
- # Check for required API keys
40
- GROQ_API_KEY = os.getenv("GROQ_API_KEY", "")
41
- GOOGLE_API_KEY = os.getenv("GOOGLE_API_KEY", "")
 
 
42
 
43
- if not GROQ_API_KEY and not GOOGLE_API_KEY:
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
44
  logger.warning(
45
- "No LLM API key found. Set GROQ_API_KEY or GOOGLE_API_KEY environment variable."
46
  )
47
 
48
- # Set default provider based on available keys
49
- if GROQ_API_KEY:
50
- os.environ.setdefault("LLM_PROVIDER", "groq")
51
- elif GOOGLE_API_KEY:
52
- os.environ.setdefault("LLM_PROVIDER", "gemini")
53
-
54
 
55
  # ---------------------------------------------------------------------------
56
  # Guild Initialization (lazy)
@@ -58,24 +72,46 @@ elif GOOGLE_API_KEY:
58
 
59
  _guild = None
60
  _guild_error = None
 
 
 
 
 
 
 
 
 
61
 
62
 
63
  def get_guild():
64
  """Lazy initialization of the Clinical Insight Guild."""
65
- global _guild, _guild_error
 
 
 
 
 
 
66
 
67
  if _guild is not None:
68
  return _guild
69
 
70
  if _guild_error is not None:
71
- raise _guild_error
 
 
72
 
73
  try:
74
  logger.info("Initializing Clinical Insight Guild...")
 
 
 
 
75
  start = time.time()
76
 
77
  from src.workflow import create_guild
78
  _guild = create_guild()
 
79
 
80
  elapsed = time.time() - start
81
  logger.info(f"Guild initialized in {elapsed:.1f}s")
@@ -142,8 +178,10 @@ def analyze_biomarkers(input_text: str, progress=gr.Progress()) -> tuple[str, st
142
  if not input_text.strip():
143
  return "", "", "⚠️ Please enter biomarkers to analyze."
144
 
145
- # Check API key
146
- if not GROQ_API_KEY and not GOOGLE_API_KEY:
 
 
147
  return "", "", (
148
  "❌ **Error**: No LLM API key configured.\n\n"
149
  "Please add your API key in Hugging Face Space Settings → Secrets:\n"
@@ -151,6 +189,10 @@ def analyze_biomarkers(input_text: str, progress=gr.Progress()) -> tuple[str, st
151
  "- or `GOOGLE_API_KEY` (get free at https://aistudio.google.com/app/apikey)"
152
  )
153
 
 
 
 
 
154
  try:
155
  progress(0.1, desc="Parsing biomarkers...")
156
  biomarkers = parse_biomarkers(input_text)
@@ -399,14 +441,14 @@ def create_demo() -> gr.Blocks:
399
  Enter your biomarkers below and get evidence-based insights in seconds.
400
  """)
401
 
402
- # API Key warning (if needed)
403
- if not GROQ_API_KEY and not GOOGLE_API_KEY:
404
- gr.Markdown("""
405
- <div style="background: #ffeeba; padding: 10px; border-radius: 5px; margin: 10px 0;">
406
- ️ <b>API Key Required</b>: Add <code>GROQ_API_KEY</code> or <code>GOOGLE_API_KEY</code>
407
- in Space Settings → Secrets to enable analysis.
408
- </div>
409
- """)
410
 
411
  with gr.Row():
412
  # Input column
 
36
  # Configuration
37
  # ---------------------------------------------------------------------------
38
 
39
+ def get_api_keys():
40
+ """Get API keys dynamically (HuggingFace injects secrets after module load)."""
41
+ groq_key = os.getenv("GROQ_API_KEY", "")
42
+ google_key = os.getenv("GOOGLE_API_KEY", "")
43
+ return groq_key, google_key
44
 
45
+
46
+ def setup_llm_provider():
47
+ """Set LLM provider based on available keys."""
48
+ groq_key, google_key = get_api_keys()
49
+
50
+ if groq_key:
51
+ os.environ["LLM_PROVIDER"] = "groq"
52
+ os.environ["GROQ_API_KEY"] = groq_key # Ensure it's set
53
+ return "groq"
54
+ elif google_key:
55
+ os.environ["LLM_PROVIDER"] = "gemini"
56
+ os.environ["GOOGLE_API_KEY"] = google_key
57
+ return "gemini"
58
+ return None
59
+
60
+
61
+ # Log status at startup (keys may not be available yet)
62
+ _groq, _google = get_api_keys()
63
+ if not _groq and not _google:
64
  logger.warning(
65
+ "No LLM API key found at startup. Will check again when analyzing."
66
  )
67
 
 
 
 
 
 
 
68
 
69
  # ---------------------------------------------------------------------------
70
  # Guild Initialization (lazy)
 
72
 
73
  _guild = None
74
  _guild_error = None
75
+ _guild_provider = None # Track which provider was used
76
+
77
+
78
+ def reset_guild():
79
+ """Reset guild to force re-initialization (e.g., when API key changes)."""
80
+ global _guild, _guild_error, _guild_provider
81
+ _guild = None
82
+ _guild_error = None
83
+ _guild_provider = None
84
 
85
 
86
  def get_guild():
87
  """Lazy initialization of the Clinical Insight Guild."""
88
+ global _guild, _guild_error, _guild_provider
89
+
90
+ # Check if we need to reinitialize (provider changed)
91
+ current_provider = os.getenv("LLM_PROVIDER")
92
+ if _guild_provider and _guild_provider != current_provider:
93
+ logger.info(f"Provider changed from {_guild_provider} to {current_provider}, reinitializing...")
94
+ reset_guild()
95
 
96
  if _guild is not None:
97
  return _guild
98
 
99
  if _guild_error is not None:
100
+ # Don't cache errors forever - allow retry
101
+ logger.warning("Previous initialization failed, retrying...")
102
+ _guild_error = None
103
 
104
  try:
105
  logger.info("Initializing Clinical Insight Guild...")
106
+ logger.info(f"LLM_PROVIDER={os.getenv('LLM_PROVIDER')}")
107
+ logger.info(f"GROQ_API_KEY={'set' if os.getenv('GROQ_API_KEY') else 'NOT SET'}")
108
+ logger.info(f"GOOGLE_API_KEY={'set' if os.getenv('GOOGLE_API_KEY') else 'NOT SET'}")
109
+
110
  start = time.time()
111
 
112
  from src.workflow import create_guild
113
  _guild = create_guild()
114
+ _guild_provider = current_provider
115
 
116
  elapsed = time.time() - start
117
  logger.info(f"Guild initialized in {elapsed:.1f}s")
 
178
  if not input_text.strip():
179
  return "", "", "⚠️ Please enter biomarkers to analyze."
180
 
181
+ # Check API key dynamically (HF injects secrets after startup)
182
+ groq_key, google_key = get_api_keys()
183
+
184
+ if not groq_key and not google_key:
185
  return "", "", (
186
  "❌ **Error**: No LLM API key configured.\n\n"
187
  "Please add your API key in Hugging Face Space Settings → Secrets:\n"
 
189
  "- or `GOOGLE_API_KEY` (get free at https://aistudio.google.com/app/apikey)"
190
  )
191
 
192
+ # Setup provider based on available key
193
+ provider = setup_llm_provider()
194
+ logger.info(f"Using LLM provider: {provider}")
195
+
196
  try:
197
  progress(0.1, desc="Parsing biomarkers...")
198
  biomarkers = parse_biomarkers(input_text)
 
441
  Enter your biomarkers below and get evidence-based insights in seconds.
442
  """)
443
 
444
+ # API Key warning - always show since keys are checked dynamically
445
+ # The actual check happens in analyze_biomarkers()
446
+ gr.Markdown("""
447
+ <div style="background: #d4edda; padding: 10px; border-radius: 5px; margin: 10px 0;">
448
+ ️ <b>Note</b>: Make sure you've added <code>GROQ_API_KEY</code> or <code>GOOGLE_API_KEY</code>
449
+ in Space Settings → Secrets for analysis to work.
450
+ </div>
451
+ """)
452
 
453
  with gr.Row():
454
  # Input column
src/llm_config.py CHANGED
@@ -19,8 +19,14 @@ load_dotenv()
19
  # Configure LangSmith tracing
20
  os.environ["LANGCHAIN_PROJECT"] = os.getenv("LANGCHAIN_PROJECT", "MediGuard_AI_RAG_Helper")
21
 
22
- # Default provider (can be overridden via env)
23
- DEFAULT_LLM_PROVIDER = os.getenv("LLM_PROVIDER", "groq")
 
 
 
 
 
 
24
 
25
 
26
  def get_chat_model(
@@ -41,7 +47,8 @@ def get_chat_model(
41
  Returns:
42
  LangChain chat model instance
43
  """
44
- provider = provider or DEFAULT_LLM_PROVIDER
 
45
 
46
  if provider == "groq":
47
  from langchain_groq import ChatGroq
@@ -164,9 +171,11 @@ class LLMConfig:
164
  provider: LLM provider - "groq" (free), "gemini" (free), or "ollama" (local)
165
  lazy: If True, defer model initialization until first use (avoids API key errors at import)
166
  """
167
- self.provider = provider or DEFAULT_LLM_PROVIDER
 
168
  self._lazy = lazy
169
  self._initialized = False
 
170
  self._lock = threading.Lock()
171
 
172
  # Lazy-initialized model instances
@@ -181,8 +190,28 @@ class LLMConfig:
181
  if not lazy:
182
  self._initialize_models()
183
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
184
  def _initialize_models(self):
185
  """Initialize all model clients (called on first use if lazy)"""
 
 
186
  if self._initialized:
187
  return
188
 
@@ -234,6 +263,7 @@ class LLMConfig:
234
  self._embedding_model = get_embedding_model()
235
 
236
  self._initialized = True
 
237
 
238
  @property
239
  def planner(self):
 
19
  # Configure LangSmith tracing
20
  os.environ["LANGCHAIN_PROJECT"] = os.getenv("LANGCHAIN_PROJECT", "MediGuard_AI_RAG_Helper")
21
 
22
+
23
+ def get_default_llm_provider() -> str:
24
+ """Get default LLM provider dynamically from environment."""
25
+ return os.getenv("LLM_PROVIDER", "groq")
26
+
27
+
28
+ # For backward compatibility (but prefer using get_default_llm_provider())
29
+ DEFAULT_LLM_PROVIDER = get_default_llm_provider()
30
 
31
 
32
  def get_chat_model(
 
47
  Returns:
48
  LangChain chat model instance
49
  """
50
+ # Use dynamic lookup to get current provider from environment
51
+ provider = provider or get_default_llm_provider()
52
 
53
  if provider == "groq":
54
  from langchain_groq import ChatGroq
 
171
  provider: LLM provider - "groq" (free), "gemini" (free), or "ollama" (local)
172
  lazy: If True, defer model initialization until first use (avoids API key errors at import)
173
  """
174
+ # Store explicit provider or None to use dynamic lookup later
175
+ self._explicit_provider = provider
176
  self._lazy = lazy
177
  self._initialized = False
178
+ self._initialized_provider = None # Track which provider was initialized
179
  self._lock = threading.Lock()
180
 
181
  # Lazy-initialized model instances
 
190
  if not lazy:
191
  self._initialize_models()
192
 
193
+ @property
194
+ def provider(self) -> str:
195
+ """Get current provider (dynamic lookup if not explicitly set)."""
196
+ return self._explicit_provider or get_default_llm_provider()
197
+
198
+ def _check_provider_change(self):
199
+ """Check if provider changed and reinitialize if needed."""
200
+ current = self.provider
201
+ if self._initialized and self._initialized_provider != current:
202
+ print(f"Provider changed from {self._initialized_provider} to {current}, reinitializing...")
203
+ self._initialized = False
204
+ self._planner = None
205
+ self._analyzer = None
206
+ self._explainer = None
207
+ self._synthesizer_7b = None
208
+ self._synthesizer_8b = None
209
+ self._director = None
210
+
211
  def _initialize_models(self):
212
  """Initialize all model clients (called on first use if lazy)"""
213
+ self._check_provider_change()
214
+
215
  if self._initialized:
216
  return
217
 
 
263
  self._embedding_model = get_embedding_model()
264
 
265
  self._initialized = True
266
+ self._initialized_provider = self.provider
267
 
268
  @property
269
  def planner(self):