SamarpeetGarad commited on
Commit
6992528
·
verified ·
1 Parent(s): 20d7590

Upload agents/medgemma_engine.py with huggingface_hub

Browse files
Files changed (1) hide show
  1. agents/medgemma_engine.py +240 -0
agents/medgemma_engine.py ADDED
@@ -0,0 +1,240 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ MedGemma Engine - Unified interface for MedGemma inference
3
+ Supports both MLX (local Mac) and Transformers (GPU/CPU)
4
+ """
5
+
6
+ import os
7
+ import time
8
+ from typing import Optional, Dict, Any
9
+
10
+ # Detect available backends
11
+ MLX_AVAILABLE = False
12
+ TRANSFORMERS_AVAILABLE = False
13
+
14
+ try:
15
+ from mlx_lm import load, generate
16
+ import mlx.core as mx
17
+ MLX_AVAILABLE = True
18
+ except ImportError:
19
+ pass
20
+
21
+ try:
22
+ import torch
23
+ from transformers import AutoTokenizer, AutoModelForCausalLM
24
+ TRANSFORMERS_AVAILABLE = True
25
+ except ImportError:
26
+ pass
27
+
28
+
29
+ class MedGemmaEngine:
30
+ """
31
+ Unified MedGemma inference engine.
32
+ Automatically selects the best available backend:
33
+ - MLX for Apple Silicon (M1/M2/M3/M4) - preferred locally
34
+ - Transformers + CUDA for NVIDIA GPUs (HuggingFace Spaces)
35
+ - Transformers + CPU as fallback
36
+ """
37
+
38
+ # Model configurations
39
+ MLX_MODEL = "mlx-community/medgemma-4b-it-4bit"
40
+ HF_MODEL = "google/medgemma-4b-it"
41
+
42
+ def __init__(self, prefer_mlx: bool = None, force_demo: bool = False):
43
+ # Auto-detect best backend preference
44
+ # On HuggingFace Spaces, prefer transformers (MLX won't work)
45
+ import os
46
+ is_spaces = os.environ.get("SPACE_ID") is not None
47
+
48
+ if prefer_mlx is None:
49
+ prefer_mlx = not is_spaces # Prefer MLX locally, transformers on Spaces
50
+ self.model = None
51
+ self.tokenizer = None
52
+ self.backend = None
53
+ self.is_loaded = False
54
+ self.force_demo = force_demo
55
+ self.prefer_mlx = prefer_mlx
56
+
57
+ if force_demo:
58
+ self.backend = "demo"
59
+ self.is_loaded = True
60
+ print("⚠️ MedGemma running in DEMO mode (no real inference)")
61
+
62
+ def load(self) -> bool:
63
+ """Load the model using the best available backend."""
64
+ if self.force_demo:
65
+ return True
66
+
67
+ if self.is_loaded:
68
+ return True
69
+
70
+ # Try MLX first (best for Mac)
71
+ if self.prefer_mlx and MLX_AVAILABLE:
72
+ try:
73
+ print(f"🔄 Loading MedGemma with MLX ({self.MLX_MODEL})...")
74
+ start = time.time()
75
+ self.model, self.tokenizer = load(self.MLX_MODEL)
76
+ self.backend = "mlx"
77
+ self.is_loaded = True
78
+ print(f"✅ MedGemma loaded with MLX in {time.time()-start:.1f}s")
79
+ return True
80
+ except Exception as e:
81
+ print(f"⚠️ MLX loading failed: {e}")
82
+
83
+ # Try Transformers with GPU
84
+ if TRANSFORMERS_AVAILABLE:
85
+ try:
86
+ import torch
87
+ device = "cuda" if torch.cuda.is_available() else "cpu"
88
+ print(f"🔄 Loading MedGemma with Transformers on {device}...")
89
+
90
+ self.tokenizer = AutoTokenizer.from_pretrained(
91
+ self.HF_MODEL,
92
+ trust_remote_code=True
93
+ )
94
+
95
+ if device == "cuda":
96
+ from transformers import BitsAndBytesConfig
97
+ quantization_config = BitsAndBytesConfig(
98
+ load_in_4bit=True,
99
+ bnb_4bit_compute_dtype=torch.float16,
100
+ )
101
+ self.model = AutoModelForCausalLM.from_pretrained(
102
+ self.HF_MODEL,
103
+ quantization_config=quantization_config,
104
+ device_map="auto",
105
+ trust_remote_code=True,
106
+ )
107
+ else:
108
+ self.model = AutoModelForCausalLM.from_pretrained(
109
+ self.HF_MODEL,
110
+ trust_remote_code=True,
111
+ torch_dtype=torch.float32,
112
+ )
113
+
114
+ self.backend = f"transformers-{device}"
115
+ self.is_loaded = True
116
+ print(f"✅ MedGemma loaded with Transformers ({device})")
117
+ return True
118
+
119
+ except Exception as e:
120
+ print(f"⚠️ Transformers loading failed: {e}")
121
+
122
+ # Fallback to demo mode
123
+ print("⚠️ No model backend available - using demo mode")
124
+ self.backend = "demo"
125
+ self.is_loaded = True
126
+ return True
127
+
128
+ def generate(self, prompt: str, max_tokens: int = 256) -> str:
129
+ """Generate a response from MedGemma."""
130
+ if not self.is_loaded:
131
+ self.load()
132
+
133
+ if self.backend == "demo":
134
+ return self._demo_response(prompt)
135
+
136
+ try:
137
+ if self.backend == "mlx":
138
+ return self._generate_mlx(prompt, max_tokens)
139
+ else:
140
+ return self._generate_transformers(prompt, max_tokens)
141
+ except Exception as e:
142
+ print(f"⚠️ Generation error: {e}")
143
+ return self._demo_response(prompt)
144
+
145
+ def _generate_mlx(self, prompt: str, max_tokens: int) -> str:
146
+ """Generate using MLX backend."""
147
+ response = generate(
148
+ self.model,
149
+ self.tokenizer,
150
+ prompt=prompt,
151
+ max_tokens=max_tokens,
152
+ verbose=False
153
+ )
154
+ # Clean up the response (remove the prompt if echoed)
155
+ if response.startswith(prompt):
156
+ response = response[len(prompt):].strip()
157
+ return response
158
+
159
+ def _generate_transformers(self, prompt: str, max_tokens: int) -> str:
160
+ """Generate using Transformers backend."""
161
+ import torch
162
+
163
+ messages = [{"role": "user", "content": prompt}]
164
+ inputs = self.tokenizer.apply_chat_template(
165
+ messages, return_tensors="pt", add_generation_prompt=True
166
+ )
167
+
168
+ attention_mask = torch.ones_like(inputs)
169
+
170
+ if hasattr(self.model, 'device'):
171
+ inputs = inputs.to(self.model.device)
172
+ attention_mask = attention_mask.to(self.model.device)
173
+
174
+ with torch.no_grad():
175
+ outputs = self.model.generate(
176
+ inputs,
177
+ attention_mask=attention_mask,
178
+ max_new_tokens=max_tokens,
179
+ do_sample=False,
180
+ pad_token_id=self.tokenizer.eos_token_id,
181
+ )
182
+
183
+ response = self.tokenizer.decode(
184
+ outputs[0][inputs.shape[1]:],
185
+ skip_special_tokens=True
186
+ )
187
+ return response.strip()
188
+
189
+ def _demo_response(self, prompt: str) -> str:
190
+ """Fallback demo responses when no model is available."""
191
+ prompt_lower = prompt.lower()
192
+
193
+ if "interpret" in prompt_lower or "finding" in prompt_lower:
194
+ return "Based on the imaging findings, clinical correlation is recommended. The described abnormality may represent an infectious, inflammatory, or neoplastic process. Further workup including laboratory studies and clinical examination would be beneficial for definitive diagnosis."
195
+
196
+ elif "report" in prompt_lower or "generate" in prompt_lower:
197
+ return """FINDINGS:
198
+ The visualized structures are assessed. Any noted abnormalities are described with their location, size, and characteristics.
199
+
200
+ IMPRESSION:
201
+ 1. Findings as described above.
202
+ 2. Clinical correlation recommended.
203
+
204
+ RECOMMENDATIONS:
205
+ Follow-up imaging as clinically indicated."""
206
+
207
+ elif "priority" in prompt_lower or "urgent" in prompt_lower:
208
+ return "PRIORITY LEVEL: ROUTINE. Based on the findings, this case does not require immediate attention but should be reviewed in standard workflow timeframe. Clinical correlation with patient symptoms is recommended."
209
+
210
+ else:
211
+ return "Clinical correlation recommended. Please consult with a radiologist for definitive interpretation."
212
+
213
+ def get_status(self) -> Dict[str, Any]:
214
+ """Get engine status."""
215
+ return {
216
+ "is_loaded": self.is_loaded,
217
+ "backend": self.backend,
218
+ "mlx_available": MLX_AVAILABLE,
219
+ "transformers_available": TRANSFORMERS_AVAILABLE,
220
+ "model_name": self.MLX_MODEL if self.backend == "mlx" else self.HF_MODEL
221
+ }
222
+
223
+
224
+ # Global engine instance
225
+ _engine: Optional[MedGemmaEngine] = None
226
+
227
+
228
+ def get_engine(force_demo: bool = False) -> MedGemmaEngine:
229
+ """Get or create the global MedGemma engine."""
230
+ global _engine
231
+ if _engine is None:
232
+ _engine = MedGemmaEngine(force_demo=force_demo)
233
+ _engine.load()
234
+ return _engine
235
+
236
+
237
+ def generate_response(prompt: str, max_tokens: int = 256) -> str:
238
+ """Convenience function to generate a response."""
239
+ engine = get_engine()
240
+ return engine.generate(prompt, max_tokens)