Spaces:
Sleeping
Sleeping
SebAustin commited on
Commit ·
29e2ed7
1
Parent(s): 3265b47
V1.1
Browse files- app.py +5 -0
- src/models/medgemma_client.py +37 -31
app.py
CHANGED
|
@@ -2,9 +2,14 @@
|
|
| 2 |
Hugging Face Spaces entry point.
|
| 3 |
Runs the MedGemma triage Gradio demo. Set HF_TOKEN in Space secrets for gated model access.
|
| 4 |
"""
|
|
|
|
| 5 |
import sys
|
| 6 |
from pathlib import Path
|
| 7 |
|
|
|
|
|
|
|
|
|
|
|
|
|
| 8 |
# Ensure project root is on path (Spaces run from repo root)
|
| 9 |
ROOT = Path(__file__).resolve().parent
|
| 10 |
if str(ROOT) not in sys.path:
|
|
|
|
| 2 |
Hugging Face Spaces entry point.
|
| 3 |
Runs the MedGemma triage Gradio demo. Set HF_TOKEN in Space secrets for gated model access.
|
| 4 |
"""
|
| 5 |
+
import os
|
| 6 |
import sys
|
| 7 |
from pathlib import Path
|
| 8 |
|
| 9 |
+
# Avoid "Invalid value for environment variable OMP_NUM_THREADS" on Spaces (set before any OpenMP use)
|
| 10 |
+
if "OMP_NUM_THREADS" not in os.environ or not str(os.environ.get("OMP_NUM_THREADS", "")).strip().isdigit():
|
| 11 |
+
os.environ["OMP_NUM_THREADS"] = "1"
|
| 12 |
+
|
| 13 |
# Ensure project root is on path (Spaces run from repo root)
|
| 14 |
ROOT = Path(__file__).resolve().parent
|
| 15 |
if str(ROOT) not in sys.path:
|
src/models/medgemma_client.py
CHANGED
|
@@ -147,12 +147,27 @@ class MedGemmaClient:
|
|
| 147 |
# Set default parameters
|
| 148 |
max_length = max_length or ModelConfig.MAX_LENGTH
|
| 149 |
max_new_tokens = max_new_tokens or ModelConfig.MAX_NEW_TOKENS
|
| 150 |
-
temperature = temperature
|
| 151 |
top_p = top_p or ModelConfig.TOP_P
|
| 152 |
top_k = top_k or ModelConfig.TOP_K
|
| 153 |
|
| 154 |
-
|
| 155 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 156 |
inputs = self.tokenizer(
|
| 157 |
prompt,
|
| 158 |
return_tensors="pt",
|
|
@@ -160,39 +175,30 @@ class MedGemmaClient:
|
|
| 160 |
truncation=True,
|
| 161 |
max_length=max_length
|
| 162 |
).to(self.device)
|
| 163 |
-
|
| 164 |
-
# Create generation config
|
| 165 |
-
generation_config = GenerationConfig(
|
| 166 |
-
max_new_tokens=max_new_tokens,
|
| 167 |
-
temperature=temperature,
|
| 168 |
-
top_p=top_p,
|
| 169 |
-
top_k=top_k,
|
| 170 |
-
do_sample=temperature > 0,
|
| 171 |
-
pad_token_id=self.tokenizer.pad_token_id,
|
| 172 |
-
eos_token_id=self.tokenizer.eos_token_id,
|
| 173 |
-
**kwargs
|
| 174 |
-
)
|
| 175 |
-
|
| 176 |
-
# Generate
|
| 177 |
with torch.no_grad():
|
| 178 |
-
outputs = self.model.generate(
|
| 179 |
-
|
| 180 |
-
generation_config=generation_config
|
| 181 |
-
)
|
| 182 |
-
|
| 183 |
-
# Decode output
|
| 184 |
-
generated_text = self.tokenizer.decode(
|
| 185 |
-
outputs[0],
|
| 186 |
-
skip_special_tokens=True
|
| 187 |
-
)
|
| 188 |
-
|
| 189 |
-
# Remove the input prompt from output
|
| 190 |
if generated_text.startswith(prompt):
|
| 191 |
generated_text = generated_text[len(prompt):].strip()
|
| 192 |
-
|
| 193 |
return generated_text
|
| 194 |
-
|
|
|
|
|
|
|
| 195 |
except Exception as e:
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 196 |
logger.error(f"Generation failed: {e}")
|
| 197 |
raise
|
| 198 |
|
|
|
|
| 147 |
# Set default parameters
|
| 148 |
max_length = max_length or ModelConfig.MAX_LENGTH
|
| 149 |
max_new_tokens = max_new_tokens or ModelConfig.MAX_NEW_TOKENS
|
| 150 |
+
temperature = temperature if temperature is not None else ModelConfig.TEMPERATURE
|
| 151 |
top_p = top_p or ModelConfig.TOP_P
|
| 152 |
top_k = top_k or ModelConfig.TOP_K
|
| 153 |
|
| 154 |
+
# Avoid CUDA "probability tensor contains inf/nan or element < 0" by using greedy when temp is low
|
| 155 |
+
# and clamping temperature when sampling (very low temp can make logits explode in float16)
|
| 156 |
+
do_sample = float(temperature) > 0.05
|
| 157 |
+
if do_sample:
|
| 158 |
+
temperature = max(0.05, min(float(temperature), 2.0))
|
| 159 |
+
|
| 160 |
+
def _run_generate(do_samp: bool, temp: float) -> str:
|
| 161 |
+
gen_config = GenerationConfig(
|
| 162 |
+
max_new_tokens=max_new_tokens,
|
| 163 |
+
temperature=temp if do_samp else 1.0,
|
| 164 |
+
top_p=top_p if do_samp else 1.0,
|
| 165 |
+
top_k=top_k if do_samp else 0,
|
| 166 |
+
do_sample=do_samp,
|
| 167 |
+
pad_token_id=self.tokenizer.pad_token_id,
|
| 168 |
+
eos_token_id=self.tokenizer.eos_token_id,
|
| 169 |
+
**kwargs
|
| 170 |
+
)
|
| 171 |
inputs = self.tokenizer(
|
| 172 |
prompt,
|
| 173 |
return_tensors="pt",
|
|
|
|
| 175 |
truncation=True,
|
| 176 |
max_length=max_length
|
| 177 |
).to(self.device)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 178 |
with torch.no_grad():
|
| 179 |
+
outputs = self.model.generate(**inputs, generation_config=gen_config)
|
| 180 |
+
generated_text = self.tokenizer.decode(outputs[0], skip_special_tokens=True)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 181 |
if generated_text.startswith(prompt):
|
| 182 |
generated_text = generated_text[len(prompt):].strip()
|
|
|
|
| 183 |
return generated_text
|
| 184 |
+
|
| 185 |
+
try:
|
| 186 |
+
return _run_generate(do_sample, temperature)
|
| 187 |
except Exception as e:
|
| 188 |
+
err_msg = str(e).lower()
|
| 189 |
+
is_cuda_assert = (
|
| 190 |
+
"cuda" in err_msg
|
| 191 |
+
or "device-side assert" in err_msg
|
| 192 |
+
or "acceleratorerror" in err_msg
|
| 193 |
+
or "probability tensor" in err_msg
|
| 194 |
+
)
|
| 195 |
+
if is_cuda_assert and do_sample:
|
| 196 |
+
logger.warning(f"Generation failed with device error, retrying with greedy decoding: {e}")
|
| 197 |
+
try:
|
| 198 |
+
return _run_generate(do_samp=False, temp=1.0)
|
| 199 |
+
except Exception as retry_e:
|
| 200 |
+
logger.error(f"Generation failed after greedy fallback: {retry_e}")
|
| 201 |
+
raise retry_e
|
| 202 |
logger.error(f"Generation failed: {e}")
|
| 203 |
raise
|
| 204 |
|