Spaces:
Runtime error
Runtime error
Upload agents/medgemma_engine.py with huggingface_hub
Browse files- agents/medgemma_engine.py +240 -0
agents/medgemma_engine.py
ADDED
|
@@ -0,0 +1,240 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
MedGemma Engine - Unified interface for MedGemma inference
|
| 3 |
+
Supports both MLX (local Mac) and Transformers (GPU/CPU)
|
| 4 |
+
"""
|
| 5 |
+
|
| 6 |
+
import os
|
| 7 |
+
import time
|
| 8 |
+
from typing import Optional, Dict, Any
|
| 9 |
+
|
| 10 |
+
# Detect available backends
|
| 11 |
+
MLX_AVAILABLE = False
|
| 12 |
+
TRANSFORMERS_AVAILABLE = False
|
| 13 |
+
|
| 14 |
+
try:
|
| 15 |
+
from mlx_lm import load, generate
|
| 16 |
+
import mlx.core as mx
|
| 17 |
+
MLX_AVAILABLE = True
|
| 18 |
+
except ImportError:
|
| 19 |
+
pass
|
| 20 |
+
|
| 21 |
+
try:
|
| 22 |
+
import torch
|
| 23 |
+
from transformers import AutoTokenizer, AutoModelForCausalLM
|
| 24 |
+
TRANSFORMERS_AVAILABLE = True
|
| 25 |
+
except ImportError:
|
| 26 |
+
pass
|
| 27 |
+
|
| 28 |
+
|
| 29 |
+
class MedGemmaEngine:
|
| 30 |
+
"""
|
| 31 |
+
Unified MedGemma inference engine.
|
| 32 |
+
Automatically selects the best available backend:
|
| 33 |
+
- MLX for Apple Silicon (M1/M2/M3/M4) - preferred locally
|
| 34 |
+
- Transformers + CUDA for NVIDIA GPUs (HuggingFace Spaces)
|
| 35 |
+
- Transformers + CPU as fallback
|
| 36 |
+
"""
|
| 37 |
+
|
| 38 |
+
# Model configurations
|
| 39 |
+
MLX_MODEL = "mlx-community/medgemma-4b-it-4bit"
|
| 40 |
+
HF_MODEL = "google/medgemma-4b-it"
|
| 41 |
+
|
| 42 |
+
def __init__(self, prefer_mlx: bool = None, force_demo: bool = False):
|
| 43 |
+
# Auto-detect best backend preference
|
| 44 |
+
# On HuggingFace Spaces, prefer transformers (MLX won't work)
|
| 45 |
+
import os
|
| 46 |
+
is_spaces = os.environ.get("SPACE_ID") is not None
|
| 47 |
+
|
| 48 |
+
if prefer_mlx is None:
|
| 49 |
+
prefer_mlx = not is_spaces # Prefer MLX locally, transformers on Spaces
|
| 50 |
+
self.model = None
|
| 51 |
+
self.tokenizer = None
|
| 52 |
+
self.backend = None
|
| 53 |
+
self.is_loaded = False
|
| 54 |
+
self.force_demo = force_demo
|
| 55 |
+
self.prefer_mlx = prefer_mlx
|
| 56 |
+
|
| 57 |
+
if force_demo:
|
| 58 |
+
self.backend = "demo"
|
| 59 |
+
self.is_loaded = True
|
| 60 |
+
print("⚠️ MedGemma running in DEMO mode (no real inference)")
|
| 61 |
+
|
| 62 |
+
def load(self) -> bool:
|
| 63 |
+
"""Load the model using the best available backend."""
|
| 64 |
+
if self.force_demo:
|
| 65 |
+
return True
|
| 66 |
+
|
| 67 |
+
if self.is_loaded:
|
| 68 |
+
return True
|
| 69 |
+
|
| 70 |
+
# Try MLX first (best for Mac)
|
| 71 |
+
if self.prefer_mlx and MLX_AVAILABLE:
|
| 72 |
+
try:
|
| 73 |
+
print(f"🔄 Loading MedGemma with MLX ({self.MLX_MODEL})...")
|
| 74 |
+
start = time.time()
|
| 75 |
+
self.model, self.tokenizer = load(self.MLX_MODEL)
|
| 76 |
+
self.backend = "mlx"
|
| 77 |
+
self.is_loaded = True
|
| 78 |
+
print(f"✅ MedGemma loaded with MLX in {time.time()-start:.1f}s")
|
| 79 |
+
return True
|
| 80 |
+
except Exception as e:
|
| 81 |
+
print(f"⚠️ MLX loading failed: {e}")
|
| 82 |
+
|
| 83 |
+
# Try Transformers with GPU
|
| 84 |
+
if TRANSFORMERS_AVAILABLE:
|
| 85 |
+
try:
|
| 86 |
+
import torch
|
| 87 |
+
device = "cuda" if torch.cuda.is_available() else "cpu"
|
| 88 |
+
print(f"🔄 Loading MedGemma with Transformers on {device}...")
|
| 89 |
+
|
| 90 |
+
self.tokenizer = AutoTokenizer.from_pretrained(
|
| 91 |
+
self.HF_MODEL,
|
| 92 |
+
trust_remote_code=True
|
| 93 |
+
)
|
| 94 |
+
|
| 95 |
+
if device == "cuda":
|
| 96 |
+
from transformers import BitsAndBytesConfig
|
| 97 |
+
quantization_config = BitsAndBytesConfig(
|
| 98 |
+
load_in_4bit=True,
|
| 99 |
+
bnb_4bit_compute_dtype=torch.float16,
|
| 100 |
+
)
|
| 101 |
+
self.model = AutoModelForCausalLM.from_pretrained(
|
| 102 |
+
self.HF_MODEL,
|
| 103 |
+
quantization_config=quantization_config,
|
| 104 |
+
device_map="auto",
|
| 105 |
+
trust_remote_code=True,
|
| 106 |
+
)
|
| 107 |
+
else:
|
| 108 |
+
self.model = AutoModelForCausalLM.from_pretrained(
|
| 109 |
+
self.HF_MODEL,
|
| 110 |
+
trust_remote_code=True,
|
| 111 |
+
torch_dtype=torch.float32,
|
| 112 |
+
)
|
| 113 |
+
|
| 114 |
+
self.backend = f"transformers-{device}"
|
| 115 |
+
self.is_loaded = True
|
| 116 |
+
print(f"✅ MedGemma loaded with Transformers ({device})")
|
| 117 |
+
return True
|
| 118 |
+
|
| 119 |
+
except Exception as e:
|
| 120 |
+
print(f"⚠️ Transformers loading failed: {e}")
|
| 121 |
+
|
| 122 |
+
# Fallback to demo mode
|
| 123 |
+
print("⚠️ No model backend available - using demo mode")
|
| 124 |
+
self.backend = "demo"
|
| 125 |
+
self.is_loaded = True
|
| 126 |
+
return True
|
| 127 |
+
|
| 128 |
+
def generate(self, prompt: str, max_tokens: int = 256) -> str:
|
| 129 |
+
"""Generate a response from MedGemma."""
|
| 130 |
+
if not self.is_loaded:
|
| 131 |
+
self.load()
|
| 132 |
+
|
| 133 |
+
if self.backend == "demo":
|
| 134 |
+
return self._demo_response(prompt)
|
| 135 |
+
|
| 136 |
+
try:
|
| 137 |
+
if self.backend == "mlx":
|
| 138 |
+
return self._generate_mlx(prompt, max_tokens)
|
| 139 |
+
else:
|
| 140 |
+
return self._generate_transformers(prompt, max_tokens)
|
| 141 |
+
except Exception as e:
|
| 142 |
+
print(f"⚠️ Generation error: {e}")
|
| 143 |
+
return self._demo_response(prompt)
|
| 144 |
+
|
| 145 |
+
def _generate_mlx(self, prompt: str, max_tokens: int) -> str:
|
| 146 |
+
"""Generate using MLX backend."""
|
| 147 |
+
response = generate(
|
| 148 |
+
self.model,
|
| 149 |
+
self.tokenizer,
|
| 150 |
+
prompt=prompt,
|
| 151 |
+
max_tokens=max_tokens,
|
| 152 |
+
verbose=False
|
| 153 |
+
)
|
| 154 |
+
# Clean up the response (remove the prompt if echoed)
|
| 155 |
+
if response.startswith(prompt):
|
| 156 |
+
response = response[len(prompt):].strip()
|
| 157 |
+
return response
|
| 158 |
+
|
| 159 |
+
def _generate_transformers(self, prompt: str, max_tokens: int) -> str:
|
| 160 |
+
"""Generate using Transformers backend."""
|
| 161 |
+
import torch
|
| 162 |
+
|
| 163 |
+
messages = [{"role": "user", "content": prompt}]
|
| 164 |
+
inputs = self.tokenizer.apply_chat_template(
|
| 165 |
+
messages, return_tensors="pt", add_generation_prompt=True
|
| 166 |
+
)
|
| 167 |
+
|
| 168 |
+
attention_mask = torch.ones_like(inputs)
|
| 169 |
+
|
| 170 |
+
if hasattr(self.model, 'device'):
|
| 171 |
+
inputs = inputs.to(self.model.device)
|
| 172 |
+
attention_mask = attention_mask.to(self.model.device)
|
| 173 |
+
|
| 174 |
+
with torch.no_grad():
|
| 175 |
+
outputs = self.model.generate(
|
| 176 |
+
inputs,
|
| 177 |
+
attention_mask=attention_mask,
|
| 178 |
+
max_new_tokens=max_tokens,
|
| 179 |
+
do_sample=False,
|
| 180 |
+
pad_token_id=self.tokenizer.eos_token_id,
|
| 181 |
+
)
|
| 182 |
+
|
| 183 |
+
response = self.tokenizer.decode(
|
| 184 |
+
outputs[0][inputs.shape[1]:],
|
| 185 |
+
skip_special_tokens=True
|
| 186 |
+
)
|
| 187 |
+
return response.strip()
|
| 188 |
+
|
| 189 |
+
def _demo_response(self, prompt: str) -> str:
|
| 190 |
+
"""Fallback demo responses when no model is available."""
|
| 191 |
+
prompt_lower = prompt.lower()
|
| 192 |
+
|
| 193 |
+
if "interpret" in prompt_lower or "finding" in prompt_lower:
|
| 194 |
+
return "Based on the imaging findings, clinical correlation is recommended. The described abnormality may represent an infectious, inflammatory, or neoplastic process. Further workup including laboratory studies and clinical examination would be beneficial for definitive diagnosis."
|
| 195 |
+
|
| 196 |
+
elif "report" in prompt_lower or "generate" in prompt_lower:
|
| 197 |
+
return """FINDINGS:
|
| 198 |
+
The visualized structures are assessed. Any noted abnormalities are described with their location, size, and characteristics.
|
| 199 |
+
|
| 200 |
+
IMPRESSION:
|
| 201 |
+
1. Findings as described above.
|
| 202 |
+
2. Clinical correlation recommended.
|
| 203 |
+
|
| 204 |
+
RECOMMENDATIONS:
|
| 205 |
+
Follow-up imaging as clinically indicated."""
|
| 206 |
+
|
| 207 |
+
elif "priority" in prompt_lower or "urgent" in prompt_lower:
|
| 208 |
+
return "PRIORITY LEVEL: ROUTINE. Based on the findings, this case does not require immediate attention but should be reviewed in standard workflow timeframe. Clinical correlation with patient symptoms is recommended."
|
| 209 |
+
|
| 210 |
+
else:
|
| 211 |
+
return "Clinical correlation recommended. Please consult with a radiologist for definitive interpretation."
|
| 212 |
+
|
| 213 |
+
def get_status(self) -> Dict[str, Any]:
|
| 214 |
+
"""Get engine status."""
|
| 215 |
+
return {
|
| 216 |
+
"is_loaded": self.is_loaded,
|
| 217 |
+
"backend": self.backend,
|
| 218 |
+
"mlx_available": MLX_AVAILABLE,
|
| 219 |
+
"transformers_available": TRANSFORMERS_AVAILABLE,
|
| 220 |
+
"model_name": self.MLX_MODEL if self.backend == "mlx" else self.HF_MODEL
|
| 221 |
+
}
|
| 222 |
+
|
| 223 |
+
|
| 224 |
+
# Global engine instance
|
| 225 |
+
_engine: Optional[MedGemmaEngine] = None
|
| 226 |
+
|
| 227 |
+
|
| 228 |
+
def get_engine(force_demo: bool = False) -> MedGemmaEngine:
|
| 229 |
+
"""Get or create the global MedGemma engine."""
|
| 230 |
+
global _engine
|
| 231 |
+
if _engine is None:
|
| 232 |
+
_engine = MedGemmaEngine(force_demo=force_demo)
|
| 233 |
+
_engine.load()
|
| 234 |
+
return _engine
|
| 235 |
+
|
| 236 |
+
|
| 237 |
+
def generate_response(prompt: str, max_tokens: int = 256) -> str:
|
| 238 |
+
"""Convenience function to generate a response."""
|
| 239 |
+
engine = get_engine()
|
| 240 |
+
return engine.generate(prompt, max_tokens)
|