SmartHeal commited on
Commit
83e490e
·
verified ·
1 Parent(s): f85c4fc

Update src/ai_processor.py

Browse files
Files changed (1) hide show
  1. src/ai_processor.py +98 -33
src/ai_processor.py CHANGED
@@ -140,50 +140,101 @@ Keep to 220–300 words. Do NOT provide diagnosis. Avoid contraindicated advice.
140
 
141
  # ---------- MedGemma-only text generator ----------
142
  @_SPACES_GPU(enable_queue=True)
143
- def _medgemma_generate_gpu(prompt: str, model_id: str, max_new_tokens: int, token: Optional[str]):
 
 
 
 
 
 
144
  """
145
- Runs entirely inside a Spaces GPU worker. Uses Med-Gemma (text-only) to draft the report.
 
 
146
  """
147
- import torch
148
  from transformers import pipeline
149
-
150
- pipe = pipeline(
151
- "image-text-to-text",
152
- model="unsloth/medgemma-4b-it-unsloth-bnb-4bit",
153
- torch_dtype=torch.bfloat16,
154
- device="cuda",
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
155
  )
156
- out = pipe(
157
- prompt,
158
- max_new_tokens=max_new_tokens,
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
159
  do_sample=False,
160
  temperature=0.2,
161
- return_full_text=True,
162
  )
163
- text = (out[0].get("generated_text") if isinstance(out, list) else out).strip()
164
- # Remove the prompt echo if present
165
- if text.startswith(prompt):
166
- text = text[len(prompt):].lstrip()
167
- return text or "⚠️ Empty response"
168
 
169
- def generate_medgemma_report( # kept name so callers don't change
 
 
 
 
 
 
 
 
 
 
 
170
  patient_info: str,
171
  visual_results: Dict,
172
  guideline_context: str,
173
- image_pil: Image.Image, # kept for signature compatibility; not used by MedGemma
174
- max_new_tokens: Optional[int] = None,
175
  ) -> str:
176
- """
177
- MedGemma (text-only) report generation.
178
- The image is analyzed by the vision pipeline; MedGemma formats clinical guidance text.
179
- """
180
  if os.getenv("SMARTHEAL_ENABLE_VLM", "1") != "1":
181
  return "⚠️ VLM disabled"
182
 
183
- # Default to a public Med-Gemma instruction-tuned model (update via env if you have access to another).
184
- model_id = os.getenv("SMARTHEAL_MEDGEMMA_MODEL", "google/med-gemma-2-2b-it")
185
- max_new_tokens = max_new_tokens or int(os.getenv("SMARTHEAL_VLM_MAX_TOKENS", "600"))
186
-
187
  uprompt = SMARTHEAL_USER_PREFIX.format(
188
  patient_info=patient_info,
189
  wound_type=visual_results.get("wound_type", "Unknown"),
@@ -194,16 +245,30 @@ def generate_medgemma_report( # kept name so callers don't change
194
  px_per_cm=visual_results.get("px_per_cm", "?"),
195
  guideline_context=(guideline_context or "")[:900],
196
  )
197
-
198
- # Compose a single text prompt
199
  prompt = f"{SMARTHEAL_SYSTEM_PROMPT}\n\n{uprompt}\n\nAnswer:"
200
 
 
 
 
201
  try:
202
- return _medgemma_generate_gpu(prompt, model_id, max_new_tokens, HF_TOKEN)
203
  except Exception as e:
204
- logging.error(f"MedGemma call failed: {e}")
 
 
 
 
 
 
 
 
 
 
 
 
205
  return "⚠️ VLM error"
206
 
 
207
  # ---------- Input-shape helpers (avoid `.as_list()` on strings) ----------
208
  def _shape_to_hw(shape) -> Tuple[Optional[int], Optional[int]]:
209
  try:
 
140
 
141
  # ---------- MedGemma-only text generator ----------
142
  @_SPACES_GPU(enable_queue=True)
143
+ def _medgemma_generate_gpu_with_pipeline(
144
+ prompt: str,
145
+ image_pil, # PIL.Image (the wound image)
146
+ model_id: str | None = None, # e.g. "unsloth/medgemma-4b-it-bnb-4bit"
147
+ max_new_tokens: int = 256,
148
+ token: str | None = None,
149
+ ) -> str:
150
  """
151
+ Vision LLM via Transformers pipeline using the "messages" format:
152
+ [{"role":"user","content":[{"type":"image","image": PIL}, {"type":"text","text": "..."}]}]
153
+ Returns a generated string.
154
  """
155
+ import os, torch
156
  from transformers import pipeline
157
+ try:
158
+ from transformers import BitsAndBytesConfig # only needed for 4-bit
159
+ except Exception:
160
+ BitsAndBytesConfig = None
161
+
162
+ hf_token = token or os.getenv("HF_TOKEN")
163
+ mid = model_id or "unsloth/medgemma-4b-it-bnb-4bit"
164
+
165
+ # device / dtype
166
+ use_cuda = torch.cuda.is_available()
167
+ device = 0 if use_cuda else -1
168
+ dtype = torch.bfloat16 if use_cuda else torch.float32
169
+
170
+ # Build messages in the doc format
171
+ messages = [{
172
+ "role": "user",
173
+ "content": [
174
+ {"type": "image", "image": image_pil}, # local PIL image
175
+ {"type": "text", "text": prompt},
176
+ ],
177
+ }]
178
+
179
+ pipe_kwargs = dict(
180
+ task="image-text-to-text",
181
+ model=mid,
182
+ torch_dtype=dtype,
183
+ device=device, # GPU=0 or CPU=-1
184
+ trust_remote_code=True,
185
  )
186
+
187
+ # Pass HF token (newer Transformers uses `token`; older uses `use_auth_token`)
188
+ if hf_token:
189
+ try:
190
+ pipe_kwargs["token"] = hf_token
191
+ except TypeError:
192
+ pipe_kwargs["use_auth_token"] = hf_token
193
+
194
+ # If this is the 4-bit Unsloth build, attach quantization (requires CUDA + bitsandbytes)
195
+ if "bnb-4bit" in mid.lower():
196
+ if not use_cuda or BitsAndBytesConfig is None:
197
+ raise RuntimeError("Unsloth 4-bit requires CUDA + bitsandbytes; no GPU available.")
198
+ bnb = BitsAndBytesConfig(
199
+ load_in_4bit=True,
200
+ bnb_4bit_quant_type="nf4",
201
+ bnb_4bit_use_double_quant=True,
202
+ bnb_4bit_compute_dtype=torch.bfloat16,
203
+ )
204
+ pipe_kwargs["model_kwargs"] = {"quantization_config": bnb}
205
+
206
+ # Create pipeline and run with messages
207
+ p = pipeline(**pipe_kwargs)
208
+ out = p(
209
+ text=messages,
210
+ max_new_tokens=int(max_new_tokens or 256),
211
  do_sample=False,
212
  temperature=0.2,
213
+ return_full_text=False, # we just want the answer
214
  )
 
 
 
 
 
215
 
216
+ # Normalize output to a string
217
+ if isinstance(out, list):
218
+ # pipelines often return a list of strings or dicts; handle both
219
+ first = out[0]
220
+ text = first.get("generated_text") if isinstance(first, dict) else str(first)
221
+ else:
222
+ text = str(out)
223
+
224
+ return (text or "").strip() or "⚠️ Empty response"
225
+
226
+
227
+ def generate_medgemma_report(
228
  patient_info: str,
229
  visual_results: Dict,
230
  guideline_context: str,
231
+ image_pil, # keep passing the PIL image
232
+ max_new_tokens: int | None = None,
233
  ) -> str:
 
 
 
 
234
  if os.getenv("SMARTHEAL_ENABLE_VLM", "1") != "1":
235
  return "⚠️ VLM disabled"
236
 
237
+ # Build your prompt as before
 
 
 
238
  uprompt = SMARTHEAL_USER_PREFIX.format(
239
  patient_info=patient_info,
240
  wound_type=visual_results.get("wound_type", "Unknown"),
 
245
  px_per_cm=visual_results.get("px_per_cm", "?"),
246
  guideline_context=(guideline_context or "")[:900],
247
  )
 
 
248
  prompt = f"{SMARTHEAL_SYSTEM_PROMPT}\n\n{uprompt}\n\nAnswer:"
249
 
250
+ model_id = os.getenv("SMARTHEAL_MEDGEMMA_MODEL", "unsloth/medgemma-4b-it-bnb-4bit")
251
+ max_new_tokens = max_new_tokens or int(os.getenv("SMARTHEAL_VLM_MAX_TOKENS", "600"))
252
+
253
  try:
254
+ return _medgemma_generate_gpu_with_pipeline(prompt, image_pil, model_id, max_new_tokens, HF_TOKEN)
255
  except Exception as e:
256
+ # Optional: automatic tiny fallback if CUDA/bnb/space issues show up
257
+ err = str(e)
258
+ if any(s in err for s in ("No space left", "bitsandbytes", "CUDA", "requires CUDA")):
259
+ try:
260
+ return _medgemma_generate_gpu_with_pipeline(
261
+ prompt, image_pil,
262
+ model_id="bczhou/tiny-llava-v1-hf", # ~1GB; CPU OK
263
+ max_new_tokens=max_new_tokens,
264
+ token=HF_TOKEN,
265
+ )
266
+ except Exception:
267
+ pass
268
+ logging.error(f"MedGemma pipeline failed: {e}", exc_info=True)
269
  return "⚠️ VLM error"
270
 
271
+
272
  # ---------- Input-shape helpers (avoid `.as_list()` on strings) ----------
273
  def _shape_to_hw(shape) -> Tuple[Optional[int], Optional[int]]:
274
  try: