SmartHeal commited on
Commit
8599b0e
·
verified ·
1 Parent(s): b51469c

Update src/ai_processor.py

Browse files
Files changed (1) hide show
  1. src/ai_processor.py +123 -74
src/ai_processor.py CHANGED
@@ -1,10 +1,9 @@
1
  # smartheal_ai_processor.py
2
  # Preserves ALL original class/function names.
3
- # Changes:
4
- # - Adds segment_wound(image) with your logic (+ KMeans fallback)
5
- # - perform_visual_analysis() now calls segment_wound() for mask
6
- # - Safe overlay (no mask kwarg in addWeighted)
7
- # - Conditional @spaces.GPU to avoid cudaGetDeviceCount crash
8
 
9
  import os
10
  import time
@@ -14,12 +13,31 @@ from typing import Optional, Dict, List, Tuple
14
 
15
  # Quiet HF tokenizers fork warning
16
  os.environ.setdefault("TOKENIZERS_PARALLELISM", "false")
 
 
17
 
18
  import cv2
19
  import numpy as np
20
  from PIL import Image
21
  from PIL.ExifTags import TAGS
22
- import spaces
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
23
 
24
  logging.basicConfig(level=logging.INFO, format="%(asctime)s - %(levelname)s - %(message)s")
25
 
@@ -44,7 +62,10 @@ def _import_ultralytics():
44
 
45
  def _import_tf_loader():
46
  import tensorflow as tf
47
- tf.config.set_visible_devices([], "GPU") # force TF CPU
 
 
 
48
  from tensorflow.keras.models import load_model
49
  return load_model
50
 
@@ -68,16 +89,8 @@ def _import_hf_hub():
68
  from huggingface_hub import HfApi, HfFolder
69
  return HfApi, HfFolder
70
 
71
- # ---------- Conditional Spaces GPU wrapper ----------
72
- def _cuda_available() -> bool:
73
- try:
74
- import torch
75
- return bool(getattr(torch, "cuda", None)) and torch.cuda.is_available()
76
- except Exception:
77
- return False
78
-
79
- @spaces.GPU(enable_queue=True, duration=90)
80
- def _generate_medgemma_report_core(
81
  patient_info: str,
82
  visual_results: Dict,
83
  guideline_context: str,
@@ -89,7 +102,7 @@ def _generate_medgemma_report_core(
89
  pipe = pipeline(
90
  "image-text-to-text",
91
  model="google/medgemma-4b-it",
92
- device_map="auto" if _cuda_available() else None,
93
  token=HF_TOKEN,
94
  model_kwargs={"low_cpu_mem_usage": True, "use_cache": True},
95
  )
@@ -127,28 +140,9 @@ def _generate_medgemma_report_core(
127
  logging.error(f"❌ MedGemma generation error: {e}")
128
  return "⚠️ GPU/LLM worker unavailable"
129
 
130
- try:
131
- import spaces
132
- if _cuda_available():
133
- @spaces.GPU(enable_queue=True, duration=90)
134
- def generate_medgemma_report(
135
- patient_info: str,
136
- visual_results: Dict,
137
- guideline_context: str,
138
- image_pil: Image.Image,
139
- max_new_tokens: Optional[int] = None,
140
- ) -> str:
141
- return _generate_medgemma_report_core(patient_info, visual_results, guideline_context, image_pil, max_new_tokens)
142
- else:
143
- def generate_medgemma_report(
144
- patient_info: str,
145
- visual_results: Dict,
146
- guideline_context: str,
147
- image_pil: Image.Image,
148
- max_new_tokens: Optional[int] = None,
149
- ) -> str:
150
- return _generate_medgemma_report_core(patient_info, visual_results, guideline_context, image_pil, max_new_tokens)
151
- except Exception:
152
  def generate_medgemma_report(
153
  patient_info: str,
154
  visual_results: Dict,
@@ -156,7 +150,53 @@ except Exception:
156
  image_pil: Image.Image,
157
  max_new_tokens: Optional[int] = None,
158
  ) -> str:
159
- return _generate_medgemma_report_core(patient_info, visual_results, guideline_context, image_pil, max_new_tokens)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
160
 
161
  # ---------- Initialize CPU models ----------
162
  def load_yolo_model():
@@ -304,7 +344,7 @@ def estimate_px_per_cm_from_exif(pil_img: Image.Image, default_px_per_cm: float
304
  except Exception:
305
  return float(default_px_per_cm), meta
306
 
307
- # ---------- Your requested segmentation logic ----------
308
  def segment_wound(image: np.ndarray) -> np.ndarray:
309
  """
310
  Segments wound from a preprocessed ROI image, with a fallback to KMeans if the model fails.
@@ -314,39 +354,35 @@ def segment_wound(image: np.ndarray) -> np.ndarray:
314
 
315
  if segmentation_model is not None:
316
  try:
317
- input_size = getattr(segmentation_model, "input_shape", None)
318
- if input_size is None or len(input_size) < 3:
319
- raise ValueError(f"Bad seg input_shape: {input_size}")
320
- H, W = int(input_size[1]), int(input_size[2]) # (None,H,W,C)
321
 
322
- resized = cv2.resize(image, (W, H)) # cv2 takes (W,H)
323
  norm = np.expand_dims(resized / 255.0, axis=0) # (1,H,W,3)
324
  prediction = segmentation_model.predict(norm, verbose=0)
325
 
326
  # Handle models with multiple outputs
327
- if isinstance(prediction, list):
328
  prediction = prediction[0]
329
  # squeeze batch dim if present
330
- prediction = prediction[0] if prediction.ndim >= 3 else prediction
331
 
332
- # prediction can be (H,W,1) or (H,W)
333
- pred2d = prediction.squeeze()
334
- mask_prob = cv2.resize(pred2d, (image.shape[1], image.shape[0])) # back to ROI size
335
  mask = (mask_prob >= 0.5).astype(np.uint8) * 255
336
- if mask.max() == 0:
337
- logging.info("Seg model returned empty mask at 0.5 — keeping as-is (KMeans fallback will handle if needed).")
338
  return mask.astype(np.uint8)
339
  except Exception as e:
340
  logging.warning(f"⚠️ Segmentation model prediction failed: {e}. Falling back to KMeans.")
341
 
342
- # --- Fallback: color clustering (KMeans, k=2) ---
343
  Z = image.reshape((-1, 3)).astype(np.float32)
344
  criteria = (cv2.TERM_CRITERIA_EPS + cv2.TERM_CRITERIA_MAX_ITER, 10, 1.0)
345
- _K = 2
346
- _, labels, centers = cv2.kmeans(Z, _K, None, criteria, 5, cv2.KMEANS_PP_CENTERS)
347
- centers = centers.astype(np.uint8).reshape(1, _K, 3)
348
- centers_lab = cv2.cvtColor(centers, cv2.COLOR_BGR2LAB)[0]
349
- wound_idx = int(np.argmax(centers_lab[:, 1])) # reddest cluster (a* channel)
350
  mask = (labels.reshape(image.shape[:2]) == wound_idx).astype(np.uint8) * 255
351
  return mask.astype(np.uint8)
352
 
@@ -387,6 +423,7 @@ def draw_measurement_overlay(
387
  thickness: int = 2
388
  ) -> np.ndarray:
389
  overlay = base_bgr.copy()
 
390
  red = np.zeros_like(overlay); red[:] = (0, 0, 255)
391
  blended = cv2.addWeighted(overlay, 1.0, red, 0.3, 0)
392
  m3 = np.dstack([mask01 * 255] * 3).astype("uint8")
@@ -417,8 +454,8 @@ def draw_measurement_overlay(
417
 
418
  draw_arrow(overlay, mids[long_pair[0]], mids[long_pair[1]])
419
  draw_arrow(overlay, mids[short_pair[0]], mids[short_pair[1]])
420
- put_label(f"{length_cm:.2f} cm", mids[long_pair[0]])
421
- put_label(f"{breadth_cm:.2f} cm", mids[short_pair[0]])
422
  return overlay
423
 
424
  # ---------- AI PROCESSOR ----------
@@ -449,9 +486,12 @@ class AIProcessor:
449
  if det_model is None:
450
  raise RuntimeError("YOLO model not loaded")
451
  results = det_model.predict(image_cv, verbose=False, device="cpu")
452
- if not results or not getattr(results[0], "boxes", None) or len(results[0].boxes) == 0:
453
- import gradio as gr
454
- raise gr.Error("No wound could be detected.")
 
 
 
455
 
456
  box = results[0].boxes[0].xyxy[0].cpu().numpy().astype(int)
457
  x1, y1, x2, y2 = [int(v) for v in box]
@@ -459,16 +499,19 @@ class AIProcessor:
459
  x2, y2 = min(image_cv.shape[1], x2), min(image_cv.shape[0], y2)
460
  roi = image_cv[y1:y2, x1:x2].copy()
461
  if roi.size == 0:
462
- import gradio as gr
463
- raise gr.Error("Detected ROI is empty.")
 
 
 
464
 
465
- # --- Segmentation (your logic + fallback) ---
466
  mask_u8_255 = segment_wound(roi) # 0..255
467
- # Clean up & keep largest component (in 0/1)
468
  mask01 = (mask_u8_255 > 127).astype(np.uint8)
469
- mask01 = cv2.morphologyEx(mask01, cv2.MORPH_OPEN, np.ones((3,3), np.uint8), iterations=1)
470
- mask01 = cv2.morphologyEx(mask01, cv2.MORPH_CLOSE, np.ones((3,3), np.uint8), iterations=1)
471
- mask01 = largest_component_mask(mask01, min_area_px=30)
 
472
 
473
  # --- Measurement ---
474
  if mask01.any():
@@ -482,6 +525,7 @@ class AIProcessor:
482
  breadth_cm = round(w_px / px_per_cm, 2)
483
  surface_area_cm2 = round((h_px * w_px) / (px_per_cm ** 2), 2)
484
  anno_roi = roi.copy()
 
485
 
486
  # --- Save visualizations ---
487
  out_dir = self._ensure_analysis_dir()
@@ -498,8 +542,12 @@ class AIProcessor:
498
  segmentation_path = None
499
  annotated_seg_path = None
500
  if mask01.any():
 
 
 
 
 
501
  seg_full = image_cv.copy()
502
- # safe masked blend (no mask kwarg)
503
  red = np.zeros_like(roi); red[:] = (0, 0, 255)
504
  blended = cv2.addWeighted(roi, 1.0, red, 0.3, 0)
505
  m3 = np.dstack([mask01 * 255] * 3).astype("uint8")
@@ -509,6 +557,7 @@ class AIProcessor:
509
  segmentation_path = os.path.join(out_dir, f"segmentation_{ts}.png")
510
  cv2.imwrite(segmentation_path, seg_full)
511
 
 
512
  anno_full = image_cv.copy()
513
  anno_full[y1:y2, x1:x2] = anno_roi
514
  annotated_seg_path = os.path.join(out_dir, f"segmentation_annotated_{ts}.png")
@@ -722,4 +771,4 @@ Automated analysis provides quantitative measurements; verify via clinical exami
722
  "report": f"Analysis initialization failed: {str(e)}",
723
  "saved_image_path": None,
724
  "guideline_context": "",
725
- }
 
1
  # smartheal_ai_processor.py
2
  # Preserves ALL original class/function names.
3
+ # Same logic as your Colab run:
4
+ # - Uses segmentation_model.h5 if present (fallback to KMeans)
5
+ # - Safe overlay (no 'mask' kwarg in addWeighted)
6
+ # - CPU-only by default (no CUDA probe). Optional Spaces GPU is opt-in.
 
7
 
8
  import os
9
  import time
 
13
 
14
  # Quiet HF tokenizers fork warning
15
  os.environ.setdefault("TOKENIZERS_PARALLELISM", "false")
16
+ # Default to CPU-only to match Colab logic
17
+ os.environ.setdefault("CUDA_VISIBLE_DEVICES", "")
18
 
19
  import cv2
20
  import numpy as np
21
  from PIL import Image
22
  from PIL.ExifTags import TAGS
23
+
24
+ # --- Optional Spaces GPU (explicit opt-in) ---
25
+ ENABLE_SPACES_GPU = os.getenv("ENABLE_SPACES_GPU", "0") == "1"
26
+ ALLOW_CUDA_PROBE = os.getenv("ALLOW_CUDA_PROBE", "0") == "1" # leave "0" for ZeroGPU safety
27
+
28
+ try:
29
+ import spaces as _spaces
30
+ except Exception:
31
+ _spaces = None
32
+
33
+ def _cuda_available() -> bool:
34
+ if not ALLOW_CUDA_PROBE:
35
+ return False
36
+ try:
37
+ import torch
38
+ return bool(getattr(torch, "cuda", None)) and torch.cuda.is_available()
39
+ except Exception:
40
+ return False
41
 
42
  logging.basicConfig(level=logging.INFO, format="%(asctime)s - %(levelname)s - %(message)s")
43
 
 
62
 
63
  def _import_tf_loader():
64
  import tensorflow as tf
65
+ try:
66
+ tf.config.set_visible_devices([], "GPU") # force TF CPU
67
+ except Exception:
68
+ pass
69
  from tensorflow.keras.models import load_model
70
  return load_model
71
 
 
89
  from huggingface_hub import HfApi, HfFolder
90
  return HfApi, HfFolder
91
 
92
+ # ---------- LLM report: CPU by default; optional Spaces GPU if enabled ----------
93
+ def _generate_medgemma_report_cpu(
 
 
 
 
 
 
 
 
94
  patient_info: str,
95
  visual_results: Dict,
96
  guideline_context: str,
 
102
  pipe = pipeline(
103
  "image-text-to-text",
104
  model="google/medgemma-4b-it",
105
+ device_map=None, # CPU
106
  token=HF_TOKEN,
107
  model_kwargs={"low_cpu_mem_usage": True, "use_cache": True},
108
  )
 
140
  logging.error(f"❌ MedGemma generation error: {e}")
141
  return "⚠️ GPU/LLM worker unavailable"
142
 
143
+ # Optional GPU path if you *explicitly* enable it and the env supports it
144
+ if ENABLE_SPACES_GPU and _spaces is not None:
145
+ @_spaces.GPU(enable_queue=True, duration=90)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
146
  def generate_medgemma_report(
147
  patient_info: str,
148
  visual_results: Dict,
 
150
  image_pil: Image.Image,
151
  max_new_tokens: Optional[int] = None,
152
  ) -> str:
153
+ # Even here, avoid probing CUDA unless allowed; device_map="auto" if we trust the env
154
+ try:
155
+ from transformers import pipeline
156
+ pipe = pipeline(
157
+ "image-text-to-text",
158
+ model="google/medgemma-4b-it",
159
+ device_map="auto" if _cuda_available() else None,
160
+ token=HF_TOKEN,
161
+ model_kwargs={"low_cpu_mem_usage": True, "use_cache": True},
162
+ )
163
+ prompt = (
164
+ "You are a medical AI assistant. Analyze this wound image and patient data.\n\n"
165
+ f"Patient: {patient_info}\n"
166
+ f"Wound: {visual_results.get('wound_type', 'Unknown')} - "
167
+ f"{visual_results.get('length_cm', 0)}×{visual_results.get('breadth_cm', 0)} cm\n\n"
168
+ "Provide a structured report with:\n"
169
+ "1. Clinical Summary\n2. Treatment Recommendations\n3. Risk Assessment\n4. Monitoring Plan\n"
170
+ )
171
+ messages = [{"role": "user", "content": [
172
+ {"type": "image", "image": image_pil},
173
+ {"type": "text", "text": prompt},
174
+ ]}]
175
+ out = pipe(
176
+ text=messages,
177
+ max_new_tokens=max_new_tokens or 800,
178
+ do_sample=False,
179
+ temperature=0.7,
180
+ )
181
+ if out and len(out) > 0:
182
+ try:
183
+ return out[0]["generated_text"][-1].get("content", "").strip() or "⚠️ Empty response"
184
+ except Exception:
185
+ return (out[0].get("generated_text", "") or "").strip() or "⚠️ Empty response"
186
+ return "⚠️ No output generated"
187
+ except Exception as e:
188
+ logging.error(f"❌ MedGemma (GPU path) error: {e}")
189
+ return _generate_medgemma_report_cpu(patient_info, visual_results, guideline_context, image_pil, max_new_tokens)
190
+ else:
191
+ # CPU default (Colab-like behavior)
192
+ def generate_medgemma_report(
193
+ patient_info: str,
194
+ visual_results: Dict,
195
+ guideline_context: str,
196
+ image_pil: Image.Image,
197
+ max_new_tokens: Optional[int] = None,
198
+ ) -> str:
199
+ return _generate_medgemma_report_cpu(patient_info, visual_results, guideline_context, image_pil, max_new_tokens)
200
 
201
  # ---------- Initialize CPU models ----------
202
  def load_yolo_model():
 
344
  except Exception:
345
  return float(default_px_per_cm), meta
346
 
347
+ # ---------- Segmentation (model-first, KMeans fallback) ----------
348
  def segment_wound(image: np.ndarray) -> np.ndarray:
349
  """
350
  Segments wound from a preprocessed ROI image, with a fallback to KMeans if the model fails.
 
354
 
355
  if segmentation_model is not None:
356
  try:
357
+ input_shape = getattr(segmentation_model, "input_shape", None)
358
+ if input_shape is None or len(input_shape) < 3:
359
+ raise ValueError(f"Bad seg input_shape: {input_shape}")
360
+ H, W = int(input_shape[1]), int(input_shape[2]) # (None,H,W,C)
361
 
362
+ resized = cv2.resize(image, (W, H)) # (W,H)
363
  norm = np.expand_dims(resized / 255.0, axis=0) # (1,H,W,3)
364
  prediction = segmentation_model.predict(norm, verbose=0)
365
 
366
  # Handle models with multiple outputs
367
+ if isinstance(prediction, (list, tuple)):
368
  prediction = prediction[0]
369
  # squeeze batch dim if present
370
+ prediction = prediction[0] if getattr(prediction, "ndim", 0) >= 3 else prediction
371
 
372
+ pred2d = np.squeeze(prediction) # (H,W) or (H,W,1)->(H,W)
373
+ mask_prob = cv2.resize(pred2d, (image.shape[1], image.shape[0]))
 
374
  mask = (mask_prob >= 0.5).astype(np.uint8) * 255
 
 
375
  return mask.astype(np.uint8)
376
  except Exception as e:
377
  logging.warning(f"⚠️ Segmentation model prediction failed: {e}. Falling back to KMeans.")
378
 
379
+ # --- Fallback: color clustering (KMeans, k=2), pick 'reddest' cluster in Lab a* ---
380
  Z = image.reshape((-1, 3)).astype(np.float32)
381
  criteria = (cv2.TERM_CRITERIA_EPS + cv2.TERM_CRITERIA_MAX_ITER, 10, 1.0)
382
+ _, labels, centers = cv2.kmeans(Z, 2, None, criteria, 5, cv2.KMEANS_PP_CENTERS)
383
+ centers_u8 = centers.astype(np.uint8).reshape(1, 2, 3)
384
+ centers_lab = cv2.cvtColor(centers_u8, cv2.COLOR_BGR2LAB)[0]
385
+ wound_idx = int(np.argmax(centers_lab[:, 1])) # a* channel (redness)
 
386
  mask = (labels.reshape(image.shape[:2]) == wound_idx).astype(np.uint8) * 255
387
  return mask.astype(np.uint8)
388
 
 
423
  thickness: int = 2
424
  ) -> np.ndarray:
425
  overlay = base_bgr.copy()
426
+ # Safe masked blend (OpenCV addWeighted has no 'mask' kwarg)
427
  red = np.zeros_like(overlay); red[:] = (0, 0, 255)
428
  blended = cv2.addWeighted(overlay, 1.0, red, 0.3, 0)
429
  m3 = np.dstack([mask01 * 255] * 3).astype("uint8")
 
454
 
455
  draw_arrow(overlay, mids[long_pair[0]], mids[long_pair[1]])
456
  draw_arrow(overlay, mids[short_pair[0]], mids[short_pair[1]])
457
+ put_label(f"Length: {length_cm:.2f} cm", mids[long_pair[0]])
458
+ put_label(f"Breadth: {breadth_cm:.2f} cm", mids[short_pair[0]])
459
  return overlay
460
 
461
  # ---------- AI PROCESSOR ----------
 
486
  if det_model is None:
487
  raise RuntimeError("YOLO model not loaded")
488
  results = det_model.predict(image_cv, verbose=False, device="cpu")
489
+ if (not results) or (not getattr(results[0], "boxes", None)) or (len(results[0].boxes) == 0):
490
+ try:
491
+ import gradio as gr
492
+ raise gr.Error("No wound could be detected.")
493
+ except Exception:
494
+ raise RuntimeError("No wound could be detected.")
495
 
496
  box = results[0].boxes[0].xyxy[0].cpu().numpy().astype(int)
497
  x1, y1, x2, y2 = [int(v) for v in box]
 
499
  x2, y2 = min(image_cv.shape[1], x2), min(image_cv.shape[0], y2)
500
  roi = image_cv[y1:y2, x1:x2].copy()
501
  if roi.size == 0:
502
+ try:
503
+ import gradio as gr
504
+ raise gr.Error("Detected ROI is empty.")
505
+ except Exception:
506
+ raise RuntimeError("Detected ROI is empty.")
507
 
508
+ # --- Segmentation (model-first + KMeans fallback) ---
509
  mask_u8_255 = segment_wound(roi) # 0..255
 
510
  mask01 = (mask_u8_255 > 127).astype(np.uint8)
511
+ if mask01.any():
512
+ mask01 = cv2.morphologyEx(mask01, cv2.MORPH_OPEN, np.ones((3,3), np.uint8), iterations=1)
513
+ mask01 = cv2.morphologyEx(mask01, cv2.MORPH_CLOSE, np.ones((3,3), np.uint8), iterations=1)
514
+ mask01 = largest_component_mask(mask01, min_area_px=30)
515
 
516
  # --- Measurement ---
517
  if mask01.any():
 
525
  breadth_cm = round(w_px / px_per_cm, 2)
526
  surface_area_cm2 = round((h_px * w_px) / (px_per_cm ** 2), 2)
527
  anno_roi = roi.copy()
528
+ box_pts = None
529
 
530
  # --- Save visualizations ---
531
  out_dir = self._ensure_analysis_dir()
 
542
  segmentation_path = None
543
  annotated_seg_path = None
544
  if mask01.any():
545
+ # Raw mask (ROI size)
546
+ mask_path = os.path.join(out_dir, f"segmentation_mask_{ts}.png")
547
+ cv2.imwrite(mask_path, (mask01 * 255).astype(np.uint8))
548
+
549
+ # Segmentation overlay (paste back to full image)
550
  seg_full = image_cv.copy()
 
551
  red = np.zeros_like(roi); red[:] = (0, 0, 255)
552
  blended = cv2.addWeighted(roi, 1.0, red, 0.3, 0)
553
  m3 = np.dstack([mask01 * 255] * 3).astype("uint8")
 
557
  segmentation_path = os.path.join(out_dir, f"segmentation_{ts}.png")
558
  cv2.imwrite(segmentation_path, seg_full)
559
 
560
+ # Annotated (arrows + labels)
561
  anno_full = image_cv.copy()
562
  anno_full[y1:y2, x1:x2] = anno_roi
563
  annotated_seg_path = os.path.join(out_dir, f"segmentation_annotated_{ts}.png")
 
771
  "report": f"Analysis initialization failed: {str(e)}",
772
  "saved_image_path": None,
773
  "guideline_context": "",
774
+ }