SmartHeal commited on
Commit
be9c884
·
verified ·
1 Parent(s): a56a9f6

Update src/ai_processor.py

Browse files
Files changed (1) hide show
  1. src/ai_processor.py +52 -51
src/ai_processor.py CHANGED
@@ -1,11 +1,11 @@
1
  # smartheal_ai_processor.py
2
  # Preserves ALL original class/function names.
3
- # Same logic you confirmed on Colab:
4
- # - Uses segmentation_model.h5 first (fallback to KMeans)
5
- # - Safe overlay (no 'mask' kwarg in addWeighted)
6
- # - CPU-only by default to avoid ZeroGPU cuda probe
7
- # - Registers a harmless @spaces.GPU stub (enable_queue=False) to silence
8
- # "No @spaces.GPU function detected during startup" without starting a GPU worker.
9
 
10
  import os
11
  import time
@@ -13,28 +13,28 @@ import logging
13
  from datetime import datetime
14
  from typing import Optional, Dict, List, Tuple
15
 
16
- # Quieter tokenizer + default CPU
17
  os.environ.setdefault("TOKENIZERS_PARALLELISM", "false")
18
- os.environ.setdefault("CUDA_VISIBLE_DEVICES", "") # keep torch/TF on CPU
19
 
20
  import cv2
21
  import numpy as np
22
  from PIL import Image
23
  from PIL.ExifTags import TAGS
24
 
25
- # --- Register a non-queue GPU stub so Spaces detects @spaces.GPU but doesn't start a worker ---
26
  try:
27
  import spaces as _spaces
28
 
29
- @_spaces.GPU(enable_queue=False) # NOTE: no queue, so ZeroGPU worker is not launched
30
- def _spaces_gpu_stub(ping: int = 0) -> str:
31
- """Harmless stub to satisfy Spaces startup scan without touching CUDA."""
32
  return "ready"
33
 
34
  logging.info("Registered @spaces.GPU stub (enable_queue=False); startup detector satisfied.")
35
  except Exception as _e:
36
- _spaces = None
37
- logging.info("No 'spaces' module or stub registration failed: %s", _e)
38
 
39
  logging.basicConfig(level=logging.INFO, format="%(asctime)s - %(levelname)s - %(message)s")
40
 
@@ -60,7 +60,7 @@ def _import_ultralytics():
60
  def _import_tf_loader():
61
  import tensorflow as tf
62
  try:
63
- tf.config.set_visible_devices([], "GPU") # force TF CPU
64
  except Exception:
65
  pass
66
  from tensorflow.keras.models import load_model
@@ -86,7 +86,7 @@ def _import_hf_hub():
86
  from huggingface_hub import HfApi, HfFolder
87
  return HfApi, HfFolder
88
 
89
- # ---------- LLM report: CPU-only path (safe on ZeroGPU) ----------
90
  def generate_medgemma_report(
91
  patient_info: str,
92
  visual_results: Dict,
@@ -95,16 +95,21 @@ def generate_medgemma_report(
95
  max_new_tokens: Optional[int] = None,
96
  ) -> str:
97
  """
98
- CPU-only MedGemma call (safe on Spaces/ZeroGPU). If it fails, fallback text is provided by caller.
 
99
  """
 
 
 
100
  try:
101
  from transformers import pipeline
102
  pipe = pipeline(
103
- "image-text-to-text",
104
  model="google/medgemma-4b-it",
105
- device_map=None, # CPU
106
  token=HF_TOKEN,
107
- model_kwargs={"low_cpu_mem_usage": True, "use_cache": True},
 
108
  )
109
 
110
  prompt = (
@@ -121,15 +126,7 @@ def generate_medgemma_report(
121
  {"type": "text", "text": prompt},
122
  ]}]
123
 
124
- t0 = time.time()
125
- out = pipe(
126
- text=messages,
127
- max_new_tokens=max_new_tokens or 800,
128
- do_sample=False,
129
- temperature=0.7,
130
- )
131
- logging.info(f"✅ MedGemma finished in {time.time()-t0:.2f}s")
132
-
133
  if out and len(out) > 0:
134
  try:
135
  return out[0]["generated_text"][-1].get("content", "").strip() or "⚠️ Empty response"
@@ -138,7 +135,7 @@ def generate_medgemma_report(
138
  return "⚠️ No output generated"
139
  except Exception as e:
140
  logging.error(f"❌ MedGemma generation error: {e}")
141
- return "⚠️ GPU/LLM worker unavailable"
142
 
143
  # ---------- Initialize CPU models ----------
144
  def load_yolo_model():
@@ -461,14 +458,15 @@ class AIProcessor:
461
  surface_area_cm2 = count_area_cm2(mask01, px_per_cm)
462
  anno_roi = draw_measurement_overlay(roi, mask01, box_pts, length_cm, breadth_cm)
463
  else:
464
- # fallback to detection box
465
  h_px = max(0, y2 - y1); w_px = max(0, x2 - x1)
466
  length_cm = round(h_px / px_per_cm, 2)
467
  breadth_cm = round(w_px / px_per_cm, 2)
468
  surface_area_cm2 = round((h_px * w_px) / (px_per_cm ** 2), 2)
469
  anno_roi = roi.copy()
 
470
 
471
- # --- Save visualizations ---
472
  out_dir = self._ensure_analysis_dir()
473
  ts = datetime.now().strftime("%Y%m%d_%H%M%S")
474
 
@@ -480,29 +478,31 @@ class AIProcessor:
480
  detection_path = os.path.join(out_dir, f"detection_{ts}.png")
481
  cv2.imwrite(detection_path, det_vis)
482
 
483
- segmentation_path = None
484
- annotated_seg_path = None
 
 
 
 
 
485
  if mask01.any():
486
- # Raw mask (ROI size)
487
- mask_path = os.path.join(out_dir, f"segmentation_mask_{ts}.png")
488
- cv2.imwrite(mask_path, (mask01 * 255).astype(np.uint8))
489
-
490
- # Segmentation overlay (paste back to full image)
491
- seg_full = image_cv.copy()
492
- red = np.zeros_like(roi); red[:] = (0, 0, 255)
493
- blended = cv2.addWeighted(roi, 1.0, red, 0.3, 0)
494
  m3 = np.dstack([mask01 * 255] * 3).astype("uint8")
495
  roi_overlay = cv2.add(cv2.bitwise_and(roi, cv2.bitwise_not(m3)),
496
  cv2.bitwise_and(blended, m3))
497
- seg_full[y1:y2, x1:x2] = roi_overlay
498
- segmentation_path = os.path.join(out_dir, f"segmentation_{ts}.png")
499
- cv2.imwrite(segmentation_path, seg_full)
 
 
 
 
500
 
501
- # Annotated (arrows + labels)
502
- anno_full = image_cv.copy()
503
- anno_full[y1:y2, x1:x2] = anno_roi
504
- annotated_seg_path = os.path.join(out_dir, f"segmentation_annotated_{ts}.png")
505
- cv2.imwrite(annotated_seg_path, anno_full)
506
 
507
  # --- Optional classification ---
508
  wound_type = "Unknown"
@@ -525,8 +525,9 @@ class AIProcessor:
525
  "detection_confidence": float(results[0].boxes.conf[0].cpu().item())
526
  if getattr(results[0].boxes, "conf", None) is not None else 0.0,
527
  "detection_image_path": detection_path,
528
- "segmentation_image_path": segmentation_path,
529
  "segmentation_annotated_path": annotated_seg_path,
 
530
  "original_image_path": original_path,
531
  }
532
  except Exception as e:
 
1
  # smartheal_ai_processor.py
2
  # Preserves ALL original class/function names.
3
+ # What you get:
4
+ # - Uses your segmentation_model.h5 first; clean KMeans fallback if it fails/missing
5
+ # - Safe overlay (no 'mask' kwarg with addWeighted)
6
+ # - Always writes a segmentation view (so it never looks like the plain original)
7
+ # - CPU by default; optional VLM (MedGemma) is OFF unless SMARTHEAL_ENABLE_VLM=1
8
+ # - Optional @spaces.GPU **stub** (no queue) to satisfy Spaces startup without touching CUDA
9
 
10
  import os
11
  import time
 
13
  from datetime import datetime
14
  from typing import Optional, Dict, List, Tuple
15
 
16
+ # Quiet tokenizers; default to CPU for safety on ZeroGPU/Spaces
17
  os.environ.setdefault("TOKENIZERS_PARALLELISM", "false")
18
+ os.environ.setdefault("CUDA_VISIBLE_DEVICES", "")
19
 
20
  import cv2
21
  import numpy as np
22
  from PIL import Image
23
  from PIL.ExifTags import TAGS
24
 
25
+ # --- Optional: register a harmless @spaces.GPU-decorated stub to silence startup warning ---
26
  try:
27
  import spaces as _spaces
28
 
29
+ @_spaces.GPU(enable_queue=False) # not queued -> won't start a ZeroGPU worker
30
+ def smartheal_gpu_stub(ping: int = 0) -> str:
31
+ """No-op so Spaces detects at least one @spaces.GPU function without touching CUDA."""
32
  return "ready"
33
 
34
  logging.info("Registered @spaces.GPU stub (enable_queue=False); startup detector satisfied.")
35
  except Exception as _e:
36
+ # It's fine if 'spaces' isn't available locally.
37
+ pass
38
 
39
  logging.basicConfig(level=logging.INFO, format="%(asctime)s - %(levelname)s - %(message)s")
40
 
 
60
  def _import_tf_loader():
61
  import tensorflow as tf
62
  try:
63
+ tf.config.set_visible_devices([], "GPU") # keep TF on CPU
64
  except Exception:
65
  pass
66
  from tensorflow.keras.models import load_model
 
86
  from huggingface_hub import HfApi, HfFolder
87
  return HfApi, HfFolder
88
 
89
+ # ---------- LLM report (OFF by default; enable with SMARTHEAL_ENABLE_VLM=1) ----------
90
  def generate_medgemma_report(
91
  patient_info: str,
92
  visual_results: Dict,
 
95
  max_new_tokens: Optional[int] = None,
96
  ) -> str:
97
  """
98
+ CPU-only MedGemma call (safe). Disabled by default to avoid env mismatches.
99
+ Set SMARTHEAL_ENABLE_VLM=1 to try loading the model.
100
  """
101
+ if os.getenv("SMARTHEAL_ENABLE_VLM", "0") != "1":
102
+ return "⚠️ VLM disabled"
103
+
104
  try:
105
  from transformers import pipeline
106
  pipe = pipeline(
107
+ task="image-text-to-text",
108
  model="google/medgemma-4b-it",
109
+ device_map=None, # CPU
110
  token=HF_TOKEN,
111
+ trust_remote_code=True,
112
+ model_kwargs={"low_cpu_mem_usage": True}, # avoid 'use_cache' arg mismatch
113
  )
114
 
115
  prompt = (
 
126
  {"type": "text", "text": prompt},
127
  ]}]
128
 
129
+ out = pipe(text=messages, max_new_tokens=max_new_tokens or 600, do_sample=False, temperature=0.7)
 
 
 
 
 
 
 
 
130
  if out and len(out) > 0:
131
  try:
132
  return out[0]["generated_text"][-1].get("content", "").strip() or "⚠️ Empty response"
 
135
  return "⚠️ No output generated"
136
  except Exception as e:
137
  logging.error(f"❌ MedGemma generation error: {e}")
138
+ return "⚠️ VLM error"
139
 
140
  # ---------- Initialize CPU models ----------
141
  def load_yolo_model():
 
458
  surface_area_cm2 = count_area_cm2(mask01, px_per_cm)
459
  anno_roi = draw_measurement_overlay(roi, mask01, box_pts, length_cm, breadth_cm)
460
  else:
461
+ # fallback to detection box if segmentation is empty
462
  h_px = max(0, y2 - y1); w_px = max(0, x2 - x1)
463
  length_cm = round(h_px / px_per_cm, 2)
464
  breadth_cm = round(w_px / px_per_cm, 2)
465
  surface_area_cm2 = round((h_px * w_px) / (px_per_cm ** 2), 2)
466
  anno_roi = roi.copy()
467
+ box_pts = None
468
 
469
+ # --- Save visualizations (ALWAYS create a segmentation image) ---
470
  out_dir = self._ensure_analysis_dir()
471
  ts = datetime.now().strftime("%Y%m%d_%H%M%S")
472
 
 
478
  detection_path = os.path.join(out_dir, f"detection_{ts}.png")
479
  cv2.imwrite(detection_path, det_vis)
480
 
481
+ # Save ROI mask image (helps debug)
482
+ roi_mask_path = os.path.join(out_dir, f"roi_mask_{ts}.png")
483
+ cv2.imwrite(roi_mask_path, (mask01 * 255).astype(np.uint8))
484
+
485
+ # Segmentation overlay (paste back to full image). If mask empty, tint ROI red so it's NOT identical to original.
486
+ seg_full = image_cv.copy()
487
+ red = np.zeros_like(roi); red[:] = (0, 0, 255)
488
  if mask01.any():
489
+ blended = cv2.addWeighted(roi, 1.0, red, 0.30, 0)
 
 
 
 
 
 
 
490
  m3 = np.dstack([mask01 * 255] * 3).astype("uint8")
491
  roi_overlay = cv2.add(cv2.bitwise_and(roi, cv2.bitwise_not(m3)),
492
  cv2.bitwise_and(blended, m3))
493
+ else:
494
+ # No mask → light red tint over the ROI to make the "segmentation" view visually distinct.
495
+ roi_overlay = cv2.addWeighted(roi, 0.75, red, 0.25, 0)
496
+
497
+ seg_full[y1:y2, x1:x2] = roi_overlay
498
+ segmentation_path = os.path.join(out_dir, f"segmentation_{ts}.png")
499
+ cv2.imwrite(segmentation_path, seg_full)
500
 
501
+ # Annotated (arrows + labels)
502
+ anno_full = image_cv.copy()
503
+ anno_full[y1:y2, x1:x2] = anno_roi
504
+ annotated_seg_path = os.path.join(out_dir, f"segmentation_annotated_{ts}.png")
505
+ cv2.imwrite(annotated_seg_path, anno_full)
506
 
507
  # --- Optional classification ---
508
  wound_type = "Unknown"
 
525
  "detection_confidence": float(results[0].boxes.conf[0].cpu().item())
526
  if getattr(results[0].boxes, "conf", None) is not None else 0.0,
527
  "detection_image_path": detection_path,
528
+ "segmentation_image_path": segmentation_path, # always present
529
  "segmentation_annotated_path": annotated_seg_path,
530
+ "roi_mask_path": roi_mask_path, # helpful for debugging
531
  "original_image_path": original_path,
532
  }
533
  except Exception as e: