SebAustin commited on
Commit
29e2ed7
·
1 Parent(s): 3265b47
Files changed (2) hide show
  1. app.py +5 -0
  2. src/models/medgemma_client.py +37 -31
app.py CHANGED
@@ -2,9 +2,14 @@
2
  Hugging Face Spaces entry point.
3
  Runs the MedGemma triage Gradio demo. Set HF_TOKEN in Space secrets for gated model access.
4
  """
 
5
  import sys
6
  from pathlib import Path
7
 
 
 
 
 
8
  # Ensure project root is on path (Spaces run from repo root)
9
  ROOT = Path(__file__).resolve().parent
10
  if str(ROOT) not in sys.path:
 
2
  Hugging Face Spaces entry point.
3
  Runs the MedGemma triage Gradio demo. Set HF_TOKEN in Space secrets for gated model access.
4
  """
5
+ import os
6
  import sys
7
  from pathlib import Path
8
 
9
+ # Avoid "Invalid value for environment variable OMP_NUM_THREADS" on Spaces (set before any OpenMP use)
10
+ if "OMP_NUM_THREADS" not in os.environ or not str(os.environ.get("OMP_NUM_THREADS", "")).strip().isdigit():
11
+ os.environ["OMP_NUM_THREADS"] = "1"
12
+
13
  # Ensure project root is on path (Spaces run from repo root)
14
  ROOT = Path(__file__).resolve().parent
15
  if str(ROOT) not in sys.path:
src/models/medgemma_client.py CHANGED
@@ -147,12 +147,27 @@ class MedGemmaClient:
147
  # Set default parameters
148
  max_length = max_length or ModelConfig.MAX_LENGTH
149
  max_new_tokens = max_new_tokens or ModelConfig.MAX_NEW_TOKENS
150
- temperature = temperature or ModelConfig.TEMPERATURE
151
  top_p = top_p or ModelConfig.TOP_P
152
  top_k = top_k or ModelConfig.TOP_K
153
 
154
- try:
155
- # Tokenize input
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
156
  inputs = self.tokenizer(
157
  prompt,
158
  return_tensors="pt",
@@ -160,39 +175,30 @@ class MedGemmaClient:
160
  truncation=True,
161
  max_length=max_length
162
  ).to(self.device)
163
-
164
- # Create generation config
165
- generation_config = GenerationConfig(
166
- max_new_tokens=max_new_tokens,
167
- temperature=temperature,
168
- top_p=top_p,
169
- top_k=top_k,
170
- do_sample=temperature > 0,
171
- pad_token_id=self.tokenizer.pad_token_id,
172
- eos_token_id=self.tokenizer.eos_token_id,
173
- **kwargs
174
- )
175
-
176
- # Generate
177
  with torch.no_grad():
178
- outputs = self.model.generate(
179
- **inputs,
180
- generation_config=generation_config
181
- )
182
-
183
- # Decode output
184
- generated_text = self.tokenizer.decode(
185
- outputs[0],
186
- skip_special_tokens=True
187
- )
188
-
189
- # Remove the input prompt from output
190
  if generated_text.startswith(prompt):
191
  generated_text = generated_text[len(prompt):].strip()
192
-
193
  return generated_text
194
-
 
 
195
  except Exception as e:
 
 
 
 
 
 
 
 
 
 
 
 
 
 
196
  logger.error(f"Generation failed: {e}")
197
  raise
198
 
 
147
  # Set default parameters
148
  max_length = max_length or ModelConfig.MAX_LENGTH
149
  max_new_tokens = max_new_tokens or ModelConfig.MAX_NEW_TOKENS
150
+ temperature = temperature if temperature is not None else ModelConfig.TEMPERATURE
151
  top_p = top_p or ModelConfig.TOP_P
152
  top_k = top_k or ModelConfig.TOP_K
153
 
154
+ # Avoid CUDA "probability tensor contains inf/nan or element < 0" by using greedy when temp is low
155
+ # and clamping temperature when sampling (very low temp can make logits explode in float16)
156
+ do_sample = float(temperature) > 0.05
157
+ if do_sample:
158
+ temperature = max(0.05, min(float(temperature), 2.0))
159
+
160
+ def _run_generate(do_samp: bool, temp: float) -> str:
161
+ gen_config = GenerationConfig(
162
+ max_new_tokens=max_new_tokens,
163
+ temperature=temp if do_samp else 1.0,
164
+ top_p=top_p if do_samp else 1.0,
165
+ top_k=top_k if do_samp else 0,
166
+ do_sample=do_samp,
167
+ pad_token_id=self.tokenizer.pad_token_id,
168
+ eos_token_id=self.tokenizer.eos_token_id,
169
+ **kwargs
170
+ )
171
  inputs = self.tokenizer(
172
  prompt,
173
  return_tensors="pt",
 
175
  truncation=True,
176
  max_length=max_length
177
  ).to(self.device)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
178
  with torch.no_grad():
179
+ outputs = self.model.generate(**inputs, generation_config=gen_config)
180
+ generated_text = self.tokenizer.decode(outputs[0], skip_special_tokens=True)
 
 
 
 
 
 
 
 
 
 
181
  if generated_text.startswith(prompt):
182
  generated_text = generated_text[len(prompt):].strip()
 
183
  return generated_text
184
+
185
+ try:
186
+ return _run_generate(do_sample, temperature)
187
  except Exception as e:
188
+ err_msg = str(e).lower()
189
+ is_cuda_assert = (
190
+ "cuda" in err_msg
191
+ or "device-side assert" in err_msg
192
+ or "acceleratorerror" in err_msg
193
+ or "probability tensor" in err_msg
194
+ )
195
+ if is_cuda_assert and do_sample:
196
+ logger.warning(f"Generation failed with device error, retrying with greedy decoding: {e}")
197
+ try:
198
+ return _run_generate(do_samp=False, temp=1.0)
199
+ except Exception as retry_e:
200
+ logger.error(f"Generation failed after greedy fallback: {retry_e}")
201
+ raise retry_e
202
  logger.error(f"Generation failed: {e}")
203
  raise
204