SmartHeal commited on
Commit
bc7b1d8
·
verified ·
1 Parent(s): a4ec7ff

Update src/ai_processor.py

Browse files
Files changed (1) hide show
  1. src/ai_processor.py +119 -149
src/ai_processor.py CHANGED
@@ -7,8 +7,9 @@ import logging
7
  from datetime import datetime
8
  from typing import Optional, Dict, List, Tuple
9
 
10
- # ---- Environment defaults (do NOT globally hint CUDA here) ----
11
  os.environ.setdefault("TOKENIZERS_PARALLELISM", "false")
 
12
  LOGLEVEL = os.getenv("LOGLEVEL", "INFO").upper()
13
  SMARTHEAL_DEBUG = os.getenv("SMARTHEAL_DEBUG", "0") == "1"
14
 
@@ -26,14 +27,9 @@ logging.basicConfig(
26
  def _log_kv(prefix: str, kv: Dict):
27
  logging.debug(prefix + " | " + " | ".join(f"{k}={v}" for k, v in kv.items()))
28
 
29
- # --- Spaces GPU decorator (REQUIRED) ---
30
- from spaces import GPU as _SPACES_GPU
31
 
32
- @_SPACES_GPU(enable_queue=True) # enable_queue ignored by ZeroGPU but explicit is fine
33
- def smartheal_gpu_stub(ping: int = 0) -> str:
34
- return "ready"
35
-
36
- # ---- Paths / constants ----
37
  UPLOADS_DIR = "uploads"
38
  os.makedirs(UPLOADS_DIR, exist_ok=True)
39
 
@@ -53,37 +49,15 @@ SEG_THRESH = float(os.getenv("SEG_THRESH", "0.5"))
53
  models_cache: Dict[str, object] = {}
54
  knowledge_base_cache: Dict[str, object] = {}
55
 
56
- # ---------- Utilities to prevent CUDA in main process ----------
57
- from contextlib import contextmanager
58
-
59
- @contextmanager
60
- def _no_cuda_env():
61
- """
62
- Mask GPUs so any library imported/constructed in the main process
63
- cannot see CUDA (required for Spaces Stateless GPU).
64
- """
65
- prev = os.environ.get("CUDA_VISIBLE_DEVICES")
66
- os.environ["CUDA_VISIBLE_DEVICES"] = "-1"
67
- try:
68
- yield
69
- finally:
70
- if prev is None:
71
- os.environ.pop("CUDA_VISIBLE_DEVICES", None)
72
- else:
73
- os.environ["CUDA_VISIBLE_DEVICES"] = prev
74
-
75
- # ---------- Lazy imports (wrapped where needed) ----------
76
  def _import_ultralytics():
77
- # Prevent Ultralytics from probing CUDA on import
78
- with _no_cuda_env():
79
- from ultralytics import YOLO
80
  return YOLO
81
 
82
  def _import_tf_loader():
83
  import tensorflow as tf
84
  try:
85
- # Keep TF on CPU only
86
- tf.config.set_visible_devices([], "GPU")
87
  except Exception:
88
  pass
89
  from tensorflow.keras.models import load_model
@@ -94,11 +68,8 @@ def _import_hf_cls():
94
  return pipeline
95
 
96
  def _import_embeddings():
97
- # Prefer the new package if available, fallback to community to avoid deprecation warnings
98
- try:
99
- from langchain_huggingface import HuggingFaceEmbeddings # type: ignore
100
- except Exception:
101
- from langchain_community.embeddings import HuggingFaceEmbeddings # type: ignore
102
  return HuggingFaceEmbeddings
103
 
104
  def _import_langchain_pdf():
@@ -113,85 +84,107 @@ def _import_hf_hub():
113
  from huggingface_hub import HfApi, HfFolder
114
  return HfApi, HfFolder
115
 
116
- # ---------- SmartHeal prompts (system + user prefix) ----------
117
- SMARTHEAL_SYSTEM_PROMPT = """\
118
- You are SmartHeal Clinical Assistant, a wound-care decision-support system.
119
- You analyze wound photographs and brief patient context to produce careful,
120
- specific, guideline-informed recommendations WITHOUT diagnosing. You always:
121
- - Use the measurements calculated by the vision pipeline as ground truth.
122
- - Prefer concise, actionable steps tailored to exudate level, infection risk, and pain.
123
- - Flag uncertainties and red flags that need escalation to a clinician.
124
- - Avoid contraindicated advice; do not infer unseen comorbidities.
125
- - Keep under 300 words and use the requested headings exactly.
126
- - Tone: professional, clear, and conservative; no definitive medical claims.
127
- - Safety: remind the user to seek clinician review for changes or red flags.
128
- """
129
-
130
- SMARTHEAL_USER_PREFIX = """\
131
- Patient: {patient_info}
132
- Visual findings: type={wound_type}, size={length_cm}x{breadth_cm} cm, area={area_cm2} cm^2,
133
- detection_conf={det_conf:.2f}, calibration={px_per_cm} px/cm.
134
-
135
- Guideline context (snippets you can draw principles from; do not quote at length):
136
- {guideline_context}
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
137
 
138
- Write a structured answer with these headings exactly:
139
- 1. Clinical Summary (max 4 bullet points)
140
- 2. Likely Stage/Type (if uncertain, say 'uncertain')
141
- 3. Treatment Plan (specific dressing choices and frequency based on exudate/infection risk)
142
- 4. Red Flags (what to escalate and when)
143
- 5. Follow-up Cadence (days)
144
- 6. Notes (assumptions/uncertainties)
 
 
 
 
145
 
146
- Keep to 220–300 words. Do NOT provide diagnosis. Avoid contraindicated advice.
147
- """
148
 
149
- # ---------- VLM (MedGemma replaced with Qwen2-VL) ----------
150
- @_SPACES_GPU(enable_queue=True)
151
- def _vlm_infer_gpu(messages, model_id: str, max_new_tokens: int, token: Optional[str]):
152
- """
153
- Runs entirely inside a Spaces GPU worker. It's the ONLY place we allow CUDA init.
154
- """
155
- import torch
156
- if not torch.cuda.is_available():
157
- raise RuntimeError("CUDA not available in worker (check ZeroGPU torch version).")
158
- from transformers import pipeline
159
- pipe = pipeline(
160
- task="image-text-to-text",
161
- model=model_id,
162
- device_map={"": 0}, # be explicit: put everything on cuda:0
163
- token=token,
164
  trust_remote_code=True,
165
- model_kwargs={"low_cpu_mem_usage": True},
 
 
 
 
 
 
 
 
166
  )
167
- out = pipe(text=messages, max_new_tokens=max_new_tokens, do_sample=False, temperature=0.2)
168
- try:
169
- txt = out[0]["generated_text"][-1].get("content", "")
170
- except Exception:
171
- txt = out[0].get("generated_text", "")
172
- return (txt or "").strip() or "⚠️ Empty response"
173
 
174
- def _vlm_infer_cpu(messages, model_id: str, max_new_tokens: int, token: Optional[str]) -> str:
175
- """
176
- CPU fallback path when ZeroGPU grant fails or CUDA wheel is unavailable.
177
- """
 
 
 
 
178
  from transformers import pipeline
179
  pipe = pipeline(
180
  task="image-text-to-text",
181
  model=model_id,
182
- device_map="cpu",
183
- token=token,
184
  trust_remote_code=True,
185
- model_kwargs={"low_cpu_mem_usage": True},
 
 
 
 
 
 
186
  )
187
- out = pipe(text=messages, max_new_tokens=max_new_tokens, do_sample=False, temperature=0.2)
188
  try:
189
- txt = out[0]["generated_text"][-1].get("content", "")
190
  except Exception:
191
- txt = out[0].get("generated_text", "")
192
- return (txt or "").strip() or "⚠️ Empty response"
193
 
194
- def generate_medgemma_report( # kept name so callers don't change
195
  patient_info: str,
196
  visual_results: Dict,
197
  guideline_context: str,
@@ -199,53 +192,34 @@ def generate_medgemma_report( # kept name so callers don't change
199
  max_new_tokens: Optional[int] = None,
200
  ) -> str:
201
  """
202
- MedGemma replacement using Qwen/Qwen2-VL-2B-Instruct via image-text-to-text.
203
- Loads & runs ONLY inside a GPU worker to satisfy Stateless GPU constraints.
204
- Falls back to CPU pipeline if a GPU grant/initialization fails.
205
  """
206
- if os.getenv("SMARTHEAL_ENABLE_VLM", "1") != "1":
207
- return "⚠️ VLM disabled"
208
-
209
- model_id = os.getenv("SMARTHEAL_VLM_MODEL", "Qwen/Qwen2-VL-2B-Instruct")
210
- max_new_tokens = max_new_tokens or int(os.getenv("SMARTHEAL_VLM_MAX_TOKENS", "600"))
211
-
212
- uprompt = SMARTHEAL_USER_PREFIX.format(
213
- patient_info=patient_info,
214
- wound_type=visual_results.get("wound_type", "Unknown"),
215
- length_cm=visual_results.get("length_cm", 0),
216
- breadth_cm=visual_results.get("breadth_cm", 0),
217
- area_cm2=visual_results.get("surface_area_cm2", 0),
218
- det_conf=float(visual_results.get("detection_confidence", 0.0)),
219
- px_per_cm=visual_results.get("px_per_cm", "?"),
220
- guideline_context=(guideline_context or "")[:900],
221
- )
222
-
223
- messages = [
224
- {"role": "system", "content": [{"type": "text", "text": SMARTHEAL_SYSTEM_PROMPT}]},
225
- {"role": "user", "content": [
226
- {"type": "image", "image": image_pil},
227
- {"type": "text", "text": uprompt},
228
- ]},
229
- ]
230
-
231
- # Try GPU worker first, then CPU fallback
232
  try:
233
- return _vlm_infer_gpu(messages, model_id, max_new_tokens, HF_TOKEN)
 
 
 
 
 
 
 
234
  except Exception as e:
235
- logging.warning(f"GPU VLM failed; falling back to CPU: {e}")
236
- try:
237
- return _vlm_infer_cpu(messages, model_id, max_new_tokens, HF_TOKEN)
238
- except Exception as e2:
239
- logging.error(f"CPU VLM also failed: {e2}")
240
- return "⚠️ VLM error"
 
 
 
241
 
242
  # ---------- Initialize CPU models ----------
243
  def load_yolo_model():
244
  YOLO = _import_ultralytics()
245
- # Construct model with CUDA masked to avoid auto-selecting cuda:0
246
- with _no_cuda_env():
247
- model = YOLO(YOLO_MODEL_PATH)
248
- return model
249
 
250
  def load_segmentation_model():
251
  load_model = _import_tf_loader()
@@ -287,7 +261,6 @@ def initialize_cpu_models() -> None:
287
  models_cache["seg"] = None
288
  logging.warning("Segmentation model file missing; skipping.")
289
  except Exception as e:
290
- # Typical with Keras/TF version mismatch; pin TF/Keras 2.15 in requirements.
291
  models_cache["seg"] = None
292
  logging.warning(f"Segmentation unavailable: {e}")
293
 
@@ -452,7 +425,6 @@ def _grabcut_refine(bgr: np.ndarray, seed01: np.ndarray, iters: int = 3) -> np.n
452
  seed_dil = cv2.dilate(seed01, k, iterations=1)
453
  gc[seed01.astype(bool)] = cv2.GC_PR_FGD
454
  gc[seed_dil.astype(bool)] = cv2.GC_FGD
455
- # force borders to background
456
  gc[0, :], gc[-1, :], gc[:, 0], gc[:, -1] = cv2.GC_BGD, cv2.GC_BGD, cv2.GC_BGD, cv2.GC_BGD
457
  bgdModel = np.zeros((1, 65), np.float64)
458
  fgdModel = np.zeros((1, 65), np.float64)
@@ -485,9 +457,7 @@ def _clean_mask(mask01: np.ndarray) -> np.ndarray:
485
  mask01 = (labels == largest_idx).astype(np.uint8)
486
  return (mask01 > 0).astype(np.uint8)
487
 
488
- # Global last debug dict (per-process)
489
- _last_seg_debug: Dict[str, object] = {}
490
-
491
  def segment_wound(image_bgr: np.ndarray, ts: str, out_dir: str) -> Tuple[np.ndarray, Dict[str, object]]:
492
  """
493
  TF model → adaptive threshold on prob → GrabCut grow → cleanup.
@@ -710,7 +680,6 @@ class AIProcessor:
710
  det_model = self.models_cache.get("det")
711
  if det_model is None:
712
  raise RuntimeError("YOLO model not loaded")
713
- # Force CPU inference and avoid CUDA touch
714
  results = det_model.predict(image_cv, verbose=False, device="cpu")
715
  if (not results) or (not getattr(results[0], "boxes", None)) or (len(results[0].boxes) == 0):
716
  try:
@@ -856,7 +825,7 @@ class AIProcessor:
856
  if not vs:
857
  return "Knowledge base is not available."
858
  retriever = vs.as_retriever(search_kwargs={"k": 5})
859
- # Modern API (avoid get_relevant_documents deprecation)
860
  docs = retriever.invoke(query)
861
  lines: List[str] = []
862
  for d in docs:
@@ -914,6 +883,7 @@ Automated analysis provides quantitative measurements; verify via clinical exami
914
  max_new_tokens: Optional[int] = None,
915
  ) -> str:
916
  try:
 
917
  report = generate_medgemma_report(
918
  patient_info, visual_results, guideline_context, image_pil, max_new_tokens
919
  )
 
7
  from datetime import datetime
8
  from typing import Optional, Dict, List, Tuple
9
 
10
+ # ---- Environment defaults (mask CUDA in main process) ----
11
  os.environ.setdefault("TOKENIZERS_PARALLELISM", "false")
12
+ os.environ.setdefault("CUDA_VISIBLE_DEVICES", "") # ensure main never touches CUDA
13
  LOGLEVEL = os.getenv("LOGLEVEL", "INFO").upper()
14
  SMARTHEAL_DEBUG = os.getenv("SMARTHEAL_DEBUG", "0") == "1"
15
 
 
27
  def _log_kv(prefix: str, kv: Dict):
28
  logging.debug(prefix + " | " + " | ".join(f"{k}={v}" for k, v in kv.items()))
29
 
30
+ # --- Spaces GPU (non-optional) ---
31
+ import spaces # required; do not stub/optionalize
32
 
 
 
 
 
 
33
  UPLOADS_DIR = "uploads"
34
  os.makedirs(UPLOADS_DIR, exist_ok=True)
35
 
 
49
  models_cache: Dict[str, object] = {}
50
  knowledge_base_cache: Dict[str, object] = {}
51
 
52
+ # ---------- Lazy imports ----------
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
53
  def _import_ultralytics():
54
+ from ultralytics import YOLO
 
 
55
  return YOLO
56
 
57
  def _import_tf_loader():
58
  import tensorflow as tf
59
  try:
60
+ tf.config.set_visible_devices([], "GPU") # keep TF on CPU
 
61
  except Exception:
62
  pass
63
  from tensorflow.keras.models import load_model
 
68
  return pipeline
69
 
70
  def _import_embeddings():
71
+ # updated per LangChain deprecations
72
+ from langchain_huggingface import HuggingFaceEmbeddings
 
 
 
73
  return HuggingFaceEmbeddings
74
 
75
  def _import_langchain_pdf():
 
84
  from huggingface_hub import HfApi, HfFolder
85
  return HfApi, HfFolder
86
 
87
+ # ---------- VLM (MedGemma replacement under the same public function name) ----------
88
+ SMARTHEAL_VLM_ID = os.getenv("SMARTHEAL_VLM_ID", "Qwen/Qwen2-VL-2B-Instruct")
89
+ SMARTHEAL_VLM_MAX_NEW_TOKENS = int(os.getenv("SMARTHEAL_VLM_MAX_NEW_TOKENS", "600"))
90
+ SMARTHEAL_VLM_TEMPERATURE = float(os.getenv("SMARTHEAL_VLM_TEMPERATURE", "0.2"))
91
+
92
+ SMARTHEAL_SYSTEM_PROMPT = """You are SmartHeal, a medical decision-support assistant specialized in wound assessment.
93
+ You are given: (1) a wound photograph, (2) basic patient context, and (3) visual measurements (length, width, area)
94
+ estimated from computer vision. You must:
95
+
96
+ - Summarize clinically-relevant visual cues (tissue type, exudate amount, slough/necrosis, peri-wound condition).
97
+ - Interpret in context of diabetes/infection/moisture/bleeding risks.
98
+ - Provide clear next-step care: cleansing, debridement criteria, dressing selection, offloading, escalation triggers.
99
+ - Include risk flags (ischemia, cellulitis, osteomyelitis suspicion) and monitoring frequency.
100
+ - Be concise, structured, and avoid speculation beyond the image and given data.
101
+ - Always add a short disclaimer: “Decision-support only; verify clinically.”"""
102
+
103
+ def _build_vlm_messages(patient_info: str, visual_results: Dict, guideline_context: str) -> list:
104
+ wt = visual_results.get("wound_type", "Unknown")
105
+ L = visual_results.get("length_cm", 0)
106
+ W = visual_results.get("breadth_cm", 0)
107
+ A = visual_results.get("surface_area_cm2", 0)
108
+ ppcm = visual_results.get("px_per_cm", "?")
109
+
110
+ ctx = (guideline_context or "")
111
+ if ctx:
112
+ ctx = f"\n\nRelevant guideline snippets:\n{ctx[:1200]}{'...' if len(ctx)>1200 else ''}"
113
+
114
+ text = (
115
+ f"{SMARTHEAL_SYSTEM_PROMPT}\n\n"
116
+ f"Patient: {patient_info}\n"
117
+ f"Wound visual summary (from CV): type={wt}, length={L} cm, width={W} cm, "
118
+ f"area={A} cm² (calibration {ppcm} px/cm)."
119
+ f"{ctx}\n\n"
120
+ "Analyze the image and provide:\n"
121
+ "1) Clinical Summary\n2) Dressing & Treatment Plan\n"
122
+ "3) Risk/Red Flags\n4) Monitoring Plan\n"
123
+ "Format with short headings and bullets.\n"
124
+ )
125
+ return [{"role": "user", "content": [{"type": "text", "text": text}]}]
126
 
127
+ @spaces.GPU # non-optional: ensure CUDA work happens only inside the ZeroGPU worker
128
+ def _vlm_infer_gpu(
129
+ image_pil: Image.Image,
130
+ messages: list,
131
+ max_new_tokens: int,
132
+ temperature: float,
133
+ model_id: str,
134
+ token: Optional[str],
135
+ ) -> str:
136
+ import torch
137
+ from transformers import AutoProcessor, AutoModelForCausalLM
138
 
139
+ torch.backends.cuda.matmul.allow_tf32 = True
140
+ device = "cuda"
141
 
142
+ processor = AutoProcessor.from_pretrained(model_id, trust_remote_code=True, use_fast=True, token=token)
143
+ model = AutoModelForCausalLM.from_pretrained(
144
+ model_id,
145
+ torch_dtype=torch.float16,
 
 
 
 
 
 
 
 
 
 
 
146
  trust_remote_code=True,
147
+ token=token,
148
+ ).to(device)
149
+
150
+ inputs = processor(messages=messages, images=[image_pil], return_tensors="pt").to(device)
151
+ gen_ids = model.generate(
152
+ **inputs,
153
+ max_new_tokens=max_new_tokens,
154
+ do_sample=False,
155
+ temperature=temperature,
156
  )
157
+ out = processor.batch_decode(gen_ids, skip_special_tokens=True)[0]
158
+ return out.strip()
 
 
 
 
159
 
160
+ def _vlm_infer_cpu(
161
+ image_pil: Image.Image,
162
+ messages: list,
163
+ max_new_tokens: int,
164
+ temperature: float,
165
+ model_id: str,
166
+ token: Optional[str],
167
+ ) -> str:
168
  from transformers import pipeline
169
  pipe = pipeline(
170
  task="image-text-to-text",
171
  model=model_id,
172
+ device="cpu",
 
173
  trust_remote_code=True,
174
+ token=token,
175
+ )
176
+ out = pipe(
177
+ text=[{"role": "user", "content": [{"type": "image", "image": image_pil}, *messages[0]["content"]]}],
178
+ max_new_tokens=max_new_tokens,
179
+ do_sample=False,
180
+ temperature=temperature,
181
  )
 
182
  try:
183
+ return (out[0]["generated_text"][-1].get("content", "") or "").strip()
184
  except Exception:
185
+ return (out[0].get("generated_text", "") or "").strip()
 
186
 
187
+ def generate_medgemma_report( # <-- keep the original PUBLIC NAME
188
  patient_info: str,
189
  visual_results: Dict,
190
  guideline_context: str,
 
192
  max_new_tokens: Optional[int] = None,
193
  ) -> str:
194
  """
195
+ Re-implemented to use Qwen/Qwen2-VL-* via ZeroGPU (@spaces.GPU) with CPU fallback.
196
+ Name preserved for compatibility with existing callers.
 
197
  """
198
+ msgs = _build_vlm_messages(patient_info, visual_results, guideline_context)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
199
  try:
200
+ return _vlm_infer_gpu(
201
+ image_pil=image_pil,
202
+ messages=msgs,
203
+ max_new_tokens=max_new_tokens or SMARTHEAL_VLM_MAX_NEW_TOKENS,
204
+ temperature=SMARTHEAL_VLM_TEMPERATURE,
205
+ model_id=SMARTHEAL_VLM_ID,
206
+ token=HF_TOKEN,
207
+ )
208
  except Exception as e:
209
+ logging.warning(f"GPU VLM failed; falling back to CPU: {e!r}")
210
+ return _vlm_infer_cpu(
211
+ image_pil=image_pil,
212
+ messages=msgs,
213
+ max_new_tokens=max_new_tokens or SMARTHEAL_VLM_MAX_NEW_TOKENS,
214
+ temperature=SMARTHEAL_VLM_TEMPERATURE,
215
+ model_id=SMARTHEAL_VLM_ID,
216
+ token=HF_TOKEN,
217
+ ) or "⚠️ VLM returned empty output"
218
 
219
  # ---------- Initialize CPU models ----------
220
  def load_yolo_model():
221
  YOLO = _import_ultralytics()
222
+ return YOLO(YOLO_MODEL_PATH)
 
 
 
223
 
224
  def load_segmentation_model():
225
  load_model = _import_tf_loader()
 
261
  models_cache["seg"] = None
262
  logging.warning("Segmentation model file missing; skipping.")
263
  except Exception as e:
 
264
  models_cache["seg"] = None
265
  logging.warning(f"Segmentation unavailable: {e}")
266
 
 
425
  seed_dil = cv2.dilate(seed01, k, iterations=1)
426
  gc[seed01.astype(bool)] = cv2.GC_PR_FGD
427
  gc[seed_dil.astype(bool)] = cv2.GC_FGD
 
428
  gc[0, :], gc[-1, :], gc[:, 0], gc[:, -1] = cv2.GC_BGD, cv2.GC_BGD, cv2.GC_BGD, cv2.GC_BGD
429
  bgdModel = np.zeros((1, 65), np.float64)
430
  fgdModel = np.zeros((1, 65), np.float64)
 
457
  mask01 = (labels == largest_idx).astype(np.uint8)
458
  return (mask01 > 0).astype(np.uint8)
459
 
460
+ # ---------- Segmentation pipeline ----------
 
 
461
  def segment_wound(image_bgr: np.ndarray, ts: str, out_dir: str) -> Tuple[np.ndarray, Dict[str, object]]:
462
  """
463
  TF model → adaptive threshold on prob → GrabCut grow → cleanup.
 
680
  det_model = self.models_cache.get("det")
681
  if det_model is None:
682
  raise RuntimeError("YOLO model not loaded")
 
683
  results = det_model.predict(image_cv, verbose=False, device="cpu")
684
  if (not results) or (not getattr(results[0], "boxes", None)) or (len(results[0].boxes) == 0):
685
  try:
 
825
  if not vs:
826
  return "Knowledge base is not available."
827
  retriever = vs.as_retriever(search_kwargs={"k": 5})
828
+ # LangChain deprecation fix: use invoke()
829
  docs = retriever.invoke(query)
830
  lines: List[str] = []
831
  for d in docs:
 
883
  max_new_tokens: Optional[int] = None,
884
  ) -> str:
885
  try:
886
+ # call the preserved function name (now backed by Qwen2-VL)
887
  report = generate_medgemma_report(
888
  patient_info, visual_results, guideline_context, image_pil, max_new_tokens
889
  )