SmartHeal commited on
Commit
ef69ec1
·
verified ·
1 Parent(s): abbf692

Update src/ai_processor.py

Browse files
Files changed (1) hide show
  1. src/ai_processor.py +174 -186
src/ai_processor.py CHANGED
@@ -1,7 +1,6 @@
1
  # smartheal_ai_processor.py
2
- # Fully functional: "segment like snippet" while preserving ALL original names.
3
- # You can keep using AIProcessor.perform_visual_analysis / analyze_wound / full_analysis_pipeline
4
- # exactly as before. A convenience AIProcessor.segment_like_snippet(...) is added.
5
 
6
  import os
7
  import time
@@ -9,18 +8,14 @@ import logging
9
  from datetime import datetime
10
  from typing import Optional, Dict, List, Tuple
11
 
 
 
 
12
  import cv2
13
  import numpy as np
14
  from PIL import Image, ImageOps
15
  from PIL.ExifTags import TAGS
16
 
17
- try:
18
- import gradio as gr
19
- except Exception:
20
- class _GrErr(RuntimeError): ...
21
- class gr: # shim so `gr.Error` won’t crash if Gradio isn’t present
22
- Error = _GrErr
23
-
24
  logging.basicConfig(level=logging.INFO, format="%(asctime)s - %(levelname)s - %(message)s")
25
 
26
  UPLOADS_DIR = "uploads"
@@ -44,7 +39,7 @@ def _import_ultralytics():
44
 
45
  def _import_tf_loader():
46
  import tensorflow as tf
47
- tf.config.set_visible_devices([], "GPU") # force CPU for TF
48
  from tensorflow.keras.models import load_model
49
  return load_model
50
 
@@ -68,69 +63,90 @@ def _import_hf_hub():
68
  from huggingface_hub import HfApi, HfFolder
69
  return HfApi, HfFolder
70
 
71
- # ---------- Spaces GPU function (kept name/behavior) ----------
72
- try:
73
- import spaces
74
-
75
- @spaces.GPU(enable_queue=True, duration=90)
76
- def generate_medgemma_report(
77
- patient_info: str,
78
- visual_results: Dict,
79
- guideline_context: str,
80
- image_pil: Image.Image,
81
- max_new_tokens: Optional[int] = None,
82
- ) -> str:
83
- try:
84
- import torch
85
- from transformers import pipeline
86
-
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
87
  try:
88
- if hasattr(torch, "cuda") and torch.cuda.is_available():
89
- torch.cuda.empty_cache()
90
  except Exception:
91
- pass
92
-
93
- prompt = (
94
- "You are a medical AI assistant. Analyze this wound image and patient data.\n\n"
95
- f"Patient: {patient_info}\n"
96
- f"Wound: {visual_results.get('wound_type', 'Unknown')} - "
97
- f"{visual_results.get('length_cm', 0)}×{visual_results.get('breadth_cm', 0)} cm\n\n"
98
- "Provide a structured report with:\n"
99
- "1. Clinical Summary\n2. Treatment Recommendations\n3. Risk Assessment\n4. Monitoring Plan\n"
100
- )
101
-
102
- pipe = pipeline(
103
- "image-text-to-text",
104
- model="google/medgemma-4b-it",
105
- device_map="auto",
106
- token=HF_TOKEN,
107
- model_kwargs={"low_cpu_mem_usage": True, "use_cache": True},
108
- )
109
-
110
- messages = [{"role": "user", "content": [
111
- {"type": "image", "image": image_pil},
112
- {"type": "text", "text": prompt},
113
- ]}]
114
-
115
- t0 = time.time()
116
- out = pipe(
117
- text=messages,
118
- max_new_tokens=max_new_tokens or 800,
119
- do_sample=False,
120
- temperature=0.7,
121
- pad_token_id=pipe.tokenizer.eos_token_id,
122
- )
123
- logging.info(f"✅ MedGemma finished in {time.time()-t0:.2f}s")
124
 
125
- if out and len(out) > 0:
126
- try:
127
- return out[0]["generated_text"][-1].get("content", "").strip() or "⚠️ Empty response"
128
- except Exception:
129
- return (out[0].get("generated_text", "") or "").strip() or "⚠️ Empty response"
130
- return "⚠️ No output generated"
131
- except Exception as e:
132
- logging.error(f"❌ MedGemma generation error: {e}")
133
- return "⚠️ GPU worker unavailable"
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
134
  except Exception:
135
  def generate_medgemma_report(
136
  patient_info: str,
@@ -139,9 +155,9 @@ except Exception:
139
  image_pil: Image.Image,
140
  max_new_tokens: Optional[int] = None,
141
  ) -> str:
142
- return "⚠️ GPU not available"
143
 
144
- # ---------- Initialize CPU models (same function names/behavior) ----------
145
  def load_yolo_model():
146
  YOLO = _import_ultralytics()
147
  return YOLO(YOLO_MODEL_PATH)
@@ -235,7 +251,7 @@ def setup_knowledge_base() -> None:
235
  initialize_cpu_models()
236
  setup_knowledge_base()
237
 
238
- # ---------- Calibration helpers (added, names unchanged elsewhere) ----------
239
  def _exif_to_dict(pil_img: Image.Image) -> Dict[str, object]:
240
  out = {}
241
  try:
@@ -288,7 +304,52 @@ def estimate_px_per_cm_from_exif(pil_img: Image.Image, default_px_per_cm: float
288
  except Exception:
289
  return float(default_px_per_cm), meta
290
 
291
- # ---------- Mask processing + measurement (helpers added) ----------
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
292
  def largest_component_mask(binary: np.ndarray, min_area_px: int = 50) -> np.ndarray:
293
  num, labels, stats, _ = cv2.connectedComponentsWithStats(binary.astype(np.uint8), connectivity=8)
294
  if num <= 1:
@@ -324,16 +385,13 @@ def draw_measurement_overlay(
324
  breadth_cm: float,
325
  thickness: int = 2
326
  ) -> np.ndarray:
327
- """Safe overlay (no mask arg to addWeighted)."""
328
  overlay = base_bgr.copy()
329
-
330
- # red mask overlay only where mask==1
331
  colored = np.zeros_like(base_bgr); colored[:] = (0, 0, 255)
332
- mask3 = np.dstack([mask01 * 255] * 3).astype(np.uint8)
333
  blended = cv2.addWeighted(overlay, 1.0, colored, 0.3, 0)
334
- # keep blended only on mask
335
- blended_masked = cv2.bitwise_and(blended, mask3)
336
- bg = cv2.bitwise_and(overlay, cv2.bitwise_not(mask3))
337
  overlay = cv2.add(bg, blended_masked)
338
 
339
  if rect_box is not None:
@@ -365,7 +423,7 @@ def draw_measurement_overlay(
365
  put_label(f"{breadth_cm:.2f} cm", mids[short_pair[0]])
366
  return overlay
367
 
368
- # ---------- AI PROCESSOR (ALL names preserved) ----------
369
  class AIProcessor:
370
  def __init__(self):
371
  self.models_cache = models_cache
@@ -379,103 +437,26 @@ class AIProcessor:
379
  os.makedirs(out_dir, exist_ok=True)
380
  return out_dir
381
 
382
- # NEW helper that mirrors your short snippet exactly (you can call or ignore)
383
- def segment_like_snippet(self, image_pil: Image.Image) -> Tuple[Dict, Image.Image, Image.Image]:
384
  """
385
- Returns (visual_results, detected_image_pil, mask_pil) exactly like your snippet.
386
- Uses EXIF-calibrated px/cm if available; otherwise DEFAULT_PX_PER_CM.
387
  """
388
- if image_pil is None:
389
- raise gr.Error("No image provided.")
390
-
391
- px_per_cm, _ = estimate_px_per_cm_from_exif(image_pil, DEFAULT_PX_PER_CM)
392
-
393
- # Convert image
394
- image_cv = cv2.cvtColor(np.array(image_pil.convert("RGB")), cv2.COLOR_RGB2BGR)
395
-
396
- # Detection
397
- det_model = self.models_cache.get("det")
398
- if det_model is None:
399
- raise gr.Error("Detection model not loaded.")
400
- results = det_model.predict(image_cv, verbose=False, device="cpu")
401
- if not results or not getattr(results[0], "boxes", None) or len(results[0].boxes) == 0:
402
- raise gr.Error("No wound could be detected.")
403
- box = results[0].boxes[0].xyxy[0].cpu().numpy().astype(int)
404
- x1, y1, x2, y2 = [int(v) for v in box]
405
- x1, y1 = max(0, x1), max(0, y1)
406
- x2, y2 = min(image_cv.shape[1], x2), min(image_cv.shape[0], y2)
407
- detected_region_cv = image_cv[y1:y2, x1:x2]
408
- if detected_region_cv.size == 0:
409
- raise gr.Error("Detected ROI is empty.")
410
-
411
- # Segmentation
412
- seg_model = self.models_cache.get("seg")
413
- mask_roi_01 = None
414
- if seg_model is not None:
415
- H, W = seg_model.input_shape[1:3]
416
- resized = cv2.resize(detected_region_cv, (W, H))
417
- pred = seg_model.predict(np.expand_dims(resized / 255.0, 0), verbose=0)[0]
418
- raw = pred[:, :, 0]
419
- mask = (raw > 0.5).astype(np.uint8)
420
- mask = cv2.morphologyEx(mask, cv2.MORPH_OPEN, np.ones((3, 3), np.uint8), iterations=1)
421
- mask = cv2.morphologyEx(mask, cv2.MORPH_CLOSE, np.ones((5, 5), np.uint8), iterations=1)
422
- mask = largest_component_mask(mask, min_area_px=50)
423
- mask_roi_01 = cv2.resize(mask, (detected_region_cv.shape[1], detected_region_cv.shape[0]),
424
- interpolation=cv2.INTER_NEAREST).astype(np.uint8)
425
- else:
426
- mask_roi_01 = np.zeros(detected_region_cv.shape[:2], dtype=np.uint8)
427
-
428
- # Measurement (oriented rect)
429
- if mask_roi_01.any():
430
- length_cm, breadth_cm, _ = measure_min_area_rect(mask_roi_01, px_per_cm)
431
- area_cm2 = count_area_cm2(mask_roi_01, px_per_cm)
432
- else:
433
- # fall back to detection box
434
- h_px = max(0, y2 - y1)
435
- w_px = max(0, x2 - x1)
436
- length_cm, breadth_cm = round(h_px / px_per_cm, 2), round(w_px / px_per_cm, 2)
437
- area_cm2 = round((h_px * w_px) / (px_per_cm ** 2), 2)
438
-
439
- # Classification (optional)
440
- wound_type = "Unknown"
441
- cls_pipe = self.models_cache.get("cls")
442
- if cls_pipe is not None:
443
- try:
444
- detected_image_pil = Image.fromarray(cv2.cvtColor(detected_region_cv, cv2.COLOR_BGR2RGB))
445
- preds = cls_pipe(detected_image_pil)
446
- if preds:
447
- wound_type = max(preds, key=lambda x: x.get("score", 0)).get("label", "Unknown")
448
- except Exception as e:
449
- logging.warning(f"Classification failed: {e}")
450
- detected_image_pil = Image.fromarray(cv2.cvtColor(detected_region_cv, cv2.COLOR_BGR2RGB))
451
- else:
452
- detected_image_pil = Image.fromarray(cv2.cvtColor(detected_region_cv, cv2.COLOR_BGR2RGB))
453
-
454
- visual_results = {
455
- "wound_type": wound_type,
456
- "length_cm": length_cm,
457
- "breadth_cm": breadth_cm,
458
- "surface_area_cm2": area_cm2
459
- }
460
- mask_pil = Image.fromarray((mask_roi_01 * 255).astype(np.uint8))
461
- return visual_results, detected_image_pil, mask_pil
462
-
463
- # ORIGINAL NAME preserved; inside it we follow the snippet-style flow and also save overlays
464
- def perform_visual_analysis(self, image_pil: Image.Image) -> Dict:
465
  try:
466
  # --- Auto calibration from EXIF ---
467
  px_per_cm, exif_meta = estimate_px_per_cm_from_exif(image_pil, DEFAULT_PX_PER_CM)
468
 
469
- # Convert image
470
  image_cv = cv2.cvtColor(np.array(image_pil.convert("RGB")), cv2.COLOR_RGB2BGR)
471
 
472
- # --- Detection ---
473
  det_model = self.models_cache.get("det")
474
  if det_model is None:
475
  raise RuntimeError("YOLO model not loaded")
476
 
477
  results = det_model.predict(image_cv, verbose=False, device="cpu")
478
  if not results or not getattr(results[0], "boxes", None) or len(results[0].boxes) == 0:
 
479
  raise gr.Error("No wound could be detected.")
480
 
481
  box = results[0].boxes[0].xyxy[0].cpu().numpy().astype(int)
@@ -484,39 +465,46 @@ class AIProcessor:
484
  x2, y2 = min(image_cv.shape[1], x2), min(image_cv.shape[0], y2)
485
  roi = image_cv[y1:y2, x1:x2].copy()
486
  if roi.size == 0:
 
487
  raise gr.Error("Detected ROI is empty.")
488
 
489
- # --- Segmentation (snippet style) ---
490
  seg_model = self.models_cache.get("seg")
491
  mask_roi_01 = None
492
  if seg_model is not None:
493
- H, W = seg_model.input_shape[1:3]
494
- resized = cv2.resize(roi, (W, H))
495
- pred = seg_model.predict(np.expand_dims(resized / 255.0, 0), verbose=0)[0]
496
- raw_mask = pred[:, :, 0]
497
- mask = (raw_mask > 0.5).astype(np.uint8)
498
- mask = cv2.morphologyEx(mask, cv2.MORPH_OPEN, np.ones((3, 3), np.uint8), iterations=1)
499
- mask = cv2.morphologyEx(mask, cv2.MORPH_CLOSE, np.ones((5, 5), np.uint8), iterations=1)
500
- mask = largest_component_mask(mask)
501
- mask_roi_01 = cv2.resize(mask, (roi.shape[1], roi.shape[0]), interpolation=cv2.INTER_NEAREST)
 
 
 
 
 
 
 
502
  else:
503
- mask_roi_01 = np.zeros(roi.shape[:2], dtype=np.uint8)
504
 
505
- # --- Measurement with oriented rect (better than boundingRect) ---
506
- if mask_roi_01.any():
507
  length_cm, breadth_cm, (box_pts, _) = measure_min_area_rect(mask_roi_01, px_per_cm)
508
  surface_area_cm2 = count_area_cm2(mask_roi_01, px_per_cm)
509
  anno_roi = draw_measurement_overlay(roi, mask_roi_01, box_pts, length_cm, breadth_cm)
510
  else:
511
- # fallback to detection box if segmentation missing/empty
512
- h_px = max(0, y2 - y1)
513
- w_px = max(0, x2 - x1)
514
  length_cm = round(h_px / px_per_cm, 2)
515
  breadth_cm = round(w_px / px_per_cm, 2)
516
  surface_area_cm2 = round((h_px * w_px) / (px_per_cm ** 2), 2)
517
  anno_roi = roi.copy()
518
 
519
- # --- Save visuals ---
520
  out_dir = self._ensure_analysis_dir()
521
  ts = datetime.now().strftime("%Y%m%d_%H%M%S")
522
 
@@ -531,7 +519,7 @@ class AIProcessor:
531
  segmentation_path = None
532
  annotated_seg_path = None
533
  if mask_roi_01 is not None and mask_roi_01.any():
534
- # Safe blending: blend once, then gate by mask
535
  seg_full = image_cv.copy()
536
  roi_overlay = roi.copy()
537
  red = np.zeros_like(roi_overlay); red[:] = (0, 0, 255)
@@ -551,7 +539,7 @@ class AIProcessor:
551
  annotated_seg_path = os.path.join(out_dir, f"segmentation_annotated_{ts}.png")
552
  cv2.imwrite(annotated_seg_path, anno_full)
553
 
554
- # --- Classification (optional) ---
555
  wound_type = "Unknown"
556
  cls_pipe = self.models_cache.get("cls")
557
  if cls_pipe is not None:
@@ -580,7 +568,7 @@ class AIProcessor:
580
  logging.error(f"Visual analysis failed: {e}", exc_info=True)
581
  raise
582
 
583
- # ---------- Knowledge base and reporting (names preserved) ----------
584
  def query_guidelines(self, query: str) -> str:
585
  try:
586
  vs = self.knowledge_base_cache.get("vector_store")
 
1
  # smartheal_ai_processor.py
2
+ # Fully functional: robust segmentation + safe overlays + conditional GPU wrapper.
3
+ # All original class/function names preserved. New helpers are additive.
 
4
 
5
  import os
6
  import time
 
8
  from datetime import datetime
9
  from typing import Optional, Dict, List, Tuple
10
 
11
+ # --- quiet tokenizers fork warning (HF) ---
12
+ os.environ.setdefault("TOKENIZERS_PARALLELISM", "false")
13
+
14
  import cv2
15
  import numpy as np
16
  from PIL import Image, ImageOps
17
  from PIL.ExifTags import TAGS
18
 
 
 
 
 
 
 
 
19
  logging.basicConfig(level=logging.INFO, format="%(asctime)s - %(levelname)s - %(message)s")
20
 
21
  UPLOADS_DIR = "uploads"
 
39
 
40
  def _import_tf_loader():
41
  import tensorflow as tf
42
+ tf.config.set_visible_devices([], "GPU") # force CPU for TF to avoid CUDA contention
43
  from tensorflow.keras.models import load_model
44
  return load_model
45
 
 
63
  from huggingface_hub import HfApi, HfFolder
64
  return HfApi, HfFolder
65
 
66
+ # ---------- Conditional Spaces GPU function ----------
67
+ # Avoid scheduling a GPU worker when CUDA is not available (prevents cudaGetDeviceCount crash)
68
+ def _cuda_available() -> bool:
69
+ try:
70
+ import torch
71
+ return bool(getattr(torch, "cuda", None)) and torch.cuda.is_available()
72
+ except Exception:
73
+ return False
74
+
75
+ def _generate_medgemma_report_core(
76
+ patient_info: str,
77
+ visual_results: Dict,
78
+ guideline_context: str,
79
+ image_pil: Image.Image,
80
+ max_new_tokens: Optional[int] = None,
81
+ ) -> str:
82
+ try:
83
+ from transformers import pipeline
84
+ # Use CPU by default; if CUDA truly available, pipeline can still map automatically
85
+ pipe = pipeline(
86
+ "image-text-to-text",
87
+ model="google/medgemma-4b-it",
88
+ device_map="auto" if _cuda_available() else None,
89
+ token=HF_TOKEN,
90
+ model_kwargs={"low_cpu_mem_usage": True, "use_cache": True},
91
+ )
92
+
93
+ prompt = (
94
+ "You are a medical AI assistant. Analyze this wound image and patient data.\n\n"
95
+ f"Patient: {patient_info}\n"
96
+ f"Wound: {visual_results.get('wound_type', 'Unknown')} - "
97
+ f"{visual_results.get('length_cm', 0)}×{visual_results.get('breadth_cm', 0)} cm\n\n"
98
+ "Provide a structured report with:\n"
99
+ "1. Clinical Summary\n2. Treatment Recommendations\n3. Risk Assessment\n4. Monitoring Plan\n"
100
+ )
101
+
102
+ messages = [{"role": "user", "content": [
103
+ {"type": "image", "image": image_pil},
104
+ {"type": "text", "text": prompt},
105
+ ]}]
106
+
107
+ t0 = time.time()
108
+ out = pipe(
109
+ text=messages,
110
+ max_new_tokens=max_new_tokens or 800,
111
+ do_sample=False,
112
+ temperature=0.7,
113
+ )
114
+ logging.info(f"✅ MedGemma finished in {time.time()-t0:.2f}s")
115
+
116
+ if out and len(out) > 0:
117
  try:
118
+ return out[0]["generated_text"][-1].get("content", "").strip() or "⚠️ Empty response"
 
119
  except Exception:
120
+ return (out[0].get("generated_text", "") or "").strip() or "⚠️ Empty response"
121
+ return "⚠️ No output generated"
122
+ except Exception as e:
123
+ logging.error(f" MedGemma generation error: {e}")
124
+ return "⚠️ GPU/LLM worker unavailable"
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
125
 
126
+ # Preserve the SAME public function name.
127
+ # Only decorate with @spaces.GPU if CUDA is truly available.
128
+ try:
129
+ import spaces
130
+ if _cuda_available():
131
+ @spaces.GPU(enable_queue=True, duration=90)
132
+ def generate_medgemma_report(
133
+ patient_info: str,
134
+ visual_results: Dict,
135
+ guideline_context: str,
136
+ image_pil: Image.Image,
137
+ max_new_tokens: Optional[int] = None,
138
+ ) -> str:
139
+ return _generate_medgemma_report_core(patient_info, visual_results, guideline_context, image_pil, max_new_tokens)
140
+ else:
141
+ def generate_medgemma_report(
142
+ patient_info: str,
143
+ visual_results: Dict,
144
+ guideline_context: str,
145
+ image_pil: Image.Image,
146
+ max_new_tokens: Optional[int] = None,
147
+ ) -> str:
148
+ # no decorator -> no GPU worker init -> no cudaGetDeviceCount crash
149
+ return _generate_medgemma_report_core(patient_info, visual_results, guideline_context, image_pil, max_new_tokens)
150
  except Exception:
151
  def generate_medgemma_report(
152
  patient_info: str,
 
155
  image_pil: Image.Image,
156
  max_new_tokens: Optional[int] = None,
157
  ) -> str:
158
+ return _generate_medgemma_report_core(patient_info, visual_results, guideline_context, image_pil, max_new_tokens)
159
 
160
+ # ---------- Initialize CPU models ----------
161
  def load_yolo_model():
162
  YOLO = _import_ultralytics()
163
  return YOLO(YOLO_MODEL_PATH)
 
251
  initialize_cpu_models()
252
  setup_knowledge_base()
253
 
254
+ # ---------- Calibration helpers ----------
255
  def _exif_to_dict(pil_img: Image.Image) -> Dict[str, object]:
256
  out = {}
257
  try:
 
304
  except Exception:
305
  return float(default_px_per_cm), meta
306
 
307
+ # ---------- Segmentation helpers (additive; names preserved elsewhere) ----------
308
+ def _get_seg_hw(seg_model) -> Tuple[int, int]:
309
+ shp = getattr(seg_model, "input_shape", None)
310
+ if shp and len(shp) >= 4:
311
+ return int(shp[1]), int(shp[2])
312
+ # try Keras .inputs shape
313
+ try:
314
+ shp = seg_model.inputs[0].shape
315
+ return int(shp[1]), int(shp[2])
316
+ except Exception:
317
+ pass
318
+ raise ValueError(f"Cannot infer (H,W) from segmentation model input shape: {shp}")
319
+
320
+ def _to_prob(mask_pred: np.ndarray) -> np.ndarray:
321
+ m = np.array(mask_pred)
322
+ # squeeze batch/channel dims
323
+ while m.ndim > 2:
324
+ if m.shape[0] == 1:
325
+ m = np.squeeze(m, axis=0)
326
+ if m.ndim > 2 and m.shape[-1] == 1:
327
+ m = np.squeeze(m, axis=-1)
328
+ if m.ndim == 3 and m.shape[-1] > 1:
329
+ # pick the most active channel
330
+ ch = np.argmax(m.reshape(-1, m.shape[-1]).mean(0))
331
+ m = m[..., ch]
332
+ if m.ndim <= 2:
333
+ break
334
+ m = m.astype("float32")
335
+ # if looks like logits -> sigmoid
336
+ if m.max() > 1.5 or m.min() < -0.5:
337
+ m = 1.0 / (1.0 + np.exp(-m))
338
+ return np.clip(m, 0.0, 1.0)
339
+
340
+ def _adaptive_threshold(prob: np.ndarray, hard: float = 0.5) -> np.ndarray:
341
+ if (prob >= hard).sum() > 0:
342
+ return (prob >= hard).astype("uint8")
343
+ # try Otsu
344
+ m8 = (np.clip(prob, 0, 1) * 255).astype("uint8")
345
+ try:
346
+ # we only need the threshold value _
347
+ _, _ = cv2.threshold(m8, 0, 255, cv2.THRESH_BINARY + cv2.THRESH_OTSU)
348
+ return (m8 >= _).astype("uint8")
349
+ except Exception:
350
+ p = float(np.percentile(prob, 99.0))
351
+ return (prob >= max(0.2, min(0.9, p))).astype("uint8")
352
+
353
  def largest_component_mask(binary: np.ndarray, min_area_px: int = 50) -> np.ndarray:
354
  num, labels, stats, _ = cv2.connectedComponentsWithStats(binary.astype(np.uint8), connectivity=8)
355
  if num <= 1:
 
385
  breadth_cm: float,
386
  thickness: int = 2
387
  ) -> np.ndarray:
 
388
  overlay = base_bgr.copy()
389
+ # safe blend: blend once, then gate with mask (no mask kwarg!)
 
390
  colored = np.zeros_like(base_bgr); colored[:] = (0, 0, 255)
 
391
  blended = cv2.addWeighted(overlay, 1.0, colored, 0.3, 0)
392
+ m3 = np.dstack([mask01 * 255] * 3).astype("uint8")
393
+ blended_masked = cv2.bitwise_and(blended, m3)
394
+ bg = cv2.bitwise_and(overlay, cv2.bitwise_not(m3))
395
  overlay = cv2.add(bg, blended_masked)
396
 
397
  if rect_box is not None:
 
423
  put_label(f"{breadth_cm:.2f} cm", mids[short_pair[0]])
424
  return overlay
425
 
426
+ # ---------- AI PROCESSOR ----------
427
  class AIProcessor:
428
  def __init__(self):
429
  self.models_cache = models_cache
 
437
  os.makedirs(out_dir, exist_ok=True)
438
  return out_dir
439
 
440
+ def perform_visual_analysis(self, image_pil: Image.Image) -> Dict:
 
441
  """
442
+ Detect crop ROI → (optional) segment cleanup → largest component →
443
+ oriented minAreaRect in cm (EXIF-calibrated) save original/detect/seg/annotated.
444
  """
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
445
  try:
446
  # --- Auto calibration from EXIF ---
447
  px_per_cm, exif_meta = estimate_px_per_cm_from_exif(image_pil, DEFAULT_PX_PER_CM)
448
 
449
+ # Convert PIL to OpenCV BGR
450
  image_cv = cv2.cvtColor(np.array(image_pil.convert("RGB")), cv2.COLOR_RGB2BGR)
451
 
452
+ # --- Detection (YOLO) ---
453
  det_model = self.models_cache.get("det")
454
  if det_model is None:
455
  raise RuntimeError("YOLO model not loaded")
456
 
457
  results = det_model.predict(image_cv, verbose=False, device="cpu")
458
  if not results or not getattr(results[0], "boxes", None) or len(results[0].boxes) == 0:
459
+ import gradio as gr # local import to keep class name intact if gradio missing
460
  raise gr.Error("No wound could be detected.")
461
 
462
  box = results[0].boxes[0].xyxy[0].cpu().numpy().astype(int)
 
465
  x2, y2 = min(image_cv.shape[1], x2), min(image_cv.shape[0], y2)
466
  roi = image_cv[y1:y2, x1:x2].copy()
467
  if roi.size == 0:
468
+ import gradio as gr
469
  raise gr.Error("Detected ROI is empty.")
470
 
471
+ # --- Segmentation (robust) ---
472
  seg_model = self.models_cache.get("seg")
473
  mask_roi_01 = None
474
  if seg_model is not None:
475
+ try:
476
+ H, W = _get_seg_hw(seg_model) # robust (H,W)
477
+ resized = cv2.resize(roi, (W, H)) # cv2.resize expects (W,H)
478
+ pred = seg_model.predict(np.expand_dims(resized / 255.0, 0), verbose=0)
479
+ prob = _to_prob(pred) # (H,W) in [0,1]
480
+ binmask = _adaptive_threshold(prob, hard=0.5)
481
+ # gentle cleanup + largest component
482
+ binmask = cv2.morphologyEx(binmask, cv2.MORPH_OPEN, np.ones((3,3), np.uint8), iterations=1)
483
+ binmask = cv2.morphologyEx(binmask, cv2.MORPH_CLOSE, np.ones((3,3), np.uint8), iterations=1)
484
+ binmask = largest_component_mask(binmask, min_area_px=30)
485
+ # back to ROI size {0,1}
486
+ mask_roi_01 = cv2.resize(binmask, (roi.shape[1], roi.shape[0]), interpolation=cv2.INTER_NEAREST).astype(np.uint8)
487
+ logging.info(f"seg prob stats: min={prob.min():.4f}, max={prob.max():.4f}, mean={prob.mean():.4f}; on={(mask_roi_01==1).sum()}")
488
+ except Exception as e:
489
+ logging.warning(f"Segmentation failed: {e}")
490
+ mask_roi_01 = None
491
  else:
492
+ logging.info("Skipping segmentation (no model).")
493
 
494
+ # --- Measurement ---
495
+ if mask_roi_01 is not None and mask_roi_01.any():
496
  length_cm, breadth_cm, (box_pts, _) = measure_min_area_rect(mask_roi_01, px_per_cm)
497
  surface_area_cm2 = count_area_cm2(mask_roi_01, px_per_cm)
498
  anno_roi = draw_measurement_overlay(roi, mask_roi_01, box_pts, length_cm, breadth_cm)
499
  else:
500
+ # fallback to detection-box cm
501
+ h_px = max(0, y2 - y1); w_px = max(0, x2 - x1)
 
502
  length_cm = round(h_px / px_per_cm, 2)
503
  breadth_cm = round(w_px / px_per_cm, 2)
504
  surface_area_cm2 = round((h_px * w_px) / (px_per_cm ** 2), 2)
505
  anno_roi = roi.copy()
506
 
507
+ # --- Save visualizations ---
508
  out_dir = self._ensure_analysis_dir()
509
  ts = datetime.now().strftime("%Y%m%d_%H%M%S")
510
 
 
519
  segmentation_path = None
520
  annotated_seg_path = None
521
  if mask_roi_01 is not None and mask_roi_01.any():
522
+ # safe masked blend (no mask kwarg to addWeighted)
523
  seg_full = image_cv.copy()
524
  roi_overlay = roi.copy()
525
  red = np.zeros_like(roi_overlay); red[:] = (0, 0, 255)
 
539
  annotated_seg_path = os.path.join(out_dir, f"segmentation_annotated_{ts}.png")
540
  cv2.imwrite(annotated_seg_path, anno_full)
541
 
542
+ # --- Optional classification ---
543
  wound_type = "Unknown"
544
  cls_pipe = self.models_cache.get("cls")
545
  if cls_pipe is not None:
 
568
  logging.error(f"Visual analysis failed: {e}", exc_info=True)
569
  raise
570
 
571
+ # ---------- Knowledge base and reporting stay unchanged ----------
572
  def query_guidelines(self, query: str) -> str:
573
  try:
574
  vs = self.knowledge_base_cache.get("vector_store")