SmartHeal commited on
Commit
abbf692
·
verified ·
1 Parent(s): 1ba8e97

Update src/ai_processor.py

Browse files
Files changed (1) hide show
  1. src/ai_processor.py +182 -131
src/ai_processor.py CHANGED
@@ -1,6 +1,7 @@
1
  # smartheal_ai_processor.py
2
- # Fully functional: auto-calibration from EXIF, mask-based measurements,
3
- # and annotated overlay with arrows+labels.
 
4
 
5
  import os
6
  import time
@@ -13,6 +14,13 @@ import numpy as np
13
  from PIL import Image, ImageOps
14
  from PIL.ExifTags import TAGS
15
 
 
 
 
 
 
 
 
16
  logging.basicConfig(level=logging.INFO, format="%(asctime)s - %(levelname)s - %(message)s")
17
 
18
  UPLOADS_DIR = "uploads"
@@ -60,7 +68,7 @@ def _import_hf_hub():
60
  from huggingface_hub import HfApi, HfFolder
61
  return HfApi, HfFolder
62
 
63
- # ---------- Spaces GPU function (always defined if `spaces` import works) ----------
64
  try:
65
  import spaces
66
 
@@ -91,7 +99,6 @@ try:
91
  "1. Clinical Summary\n2. Treatment Recommendations\n3. Risk Assessment\n4. Monitoring Plan\n"
92
  )
93
 
94
- from transformers import pipeline
95
  pipe = pipeline(
96
  "image-text-to-text",
97
  model="google/medgemma-4b-it",
@@ -134,7 +141,7 @@ except Exception:
134
  ) -> str:
135
  return "⚠️ GPU not available"
136
 
137
- # ---------- Initialize CPU models ----------
138
  def load_yolo_model():
139
  YOLO = _import_ultralytics()
140
  return YOLO(YOLO_MODEL_PATH)
@@ -228,9 +235,8 @@ def setup_knowledge_base() -> None:
228
  initialize_cpu_models()
229
  setup_knowledge_base()
230
 
231
- # ---------- Calibration helpers ----------
232
  def _exif_to_dict(pil_img: Image.Image) -> Dict[str, object]:
233
- """Best-effort EXIF parse from PIL image."""
234
  out = {}
235
  try:
236
  exif = pil_img.getexif()
@@ -255,21 +261,12 @@ def _to_float(val) -> Optional[float]:
255
  return None
256
 
257
  def _estimate_sensor_width_mm(f_mm: Optional[float], f35: Optional[float]) -> Optional[float]:
258
- """
259
- Use 35mm equivalent if present: sensor_width = 36 * f_mm / f35.
260
- """
261
  if f_mm and f35 and f35 > 0:
262
  return 36.0 * f_mm / f35
263
  return None
264
 
265
  def estimate_px_per_cm_from_exif(pil_img: Image.Image, default_px_per_cm: float = DEFAULT_PX_PER_CM) -> Tuple[float, Dict]:
266
- """
267
- Returns (px_per_cm, meta) using EXIF when available.
268
- Formula: field_width_mm = sensor_width_mm * distance_mm / focal_mm
269
- px_per_cm = image_width_px / (field_width_mm / 10)
270
- """
271
  meta = {"used": "default", "f_mm": None, "f35": None, "sensor_w_mm": None, "distance_m": None}
272
-
273
  try:
274
  exif = _exif_to_dict(pil_img)
275
  f_mm = _to_float(exif.get("FocalLength"))
@@ -284,41 +281,30 @@ def estimate_px_per_cm_from_exif(pil_img: Image.Image, default_px_per_cm: float
284
  field_w_mm = sensor_w_mm * (subj_dist_m * 1000.0) / f_mm
285
  field_w_cm = field_w_mm / 10.0
286
  px_per_cm = w_px / max(field_w_cm, 1e-6)
287
-
288
- # sanity clamp
289
  px_per_cm = float(np.clip(px_per_cm, PX_PER_CM_MIN, PX_PER_CM_MAX))
290
  meta["used"] = "exif"
291
  return px_per_cm, meta
292
-
293
- # If EXIF partial but not enough to solve, keep default
294
  return float(default_px_per_cm), meta
295
- except Exception as e:
296
- logging.warning(f"EXIF calibration failed: {e}")
297
  return float(default_px_per_cm), meta
298
 
299
- # ---------- Mask processing + measurement ----------
300
  def largest_component_mask(binary: np.ndarray, min_area_px: int = 50) -> np.ndarray:
301
- """Keep only the largest connected component in a binary mask."""
302
  num, labels, stats, _ = cv2.connectedComponentsWithStats(binary.astype(np.uint8), connectivity=8)
303
  if num <= 1:
304
- return binary
305
- # stats[:, cv2.CC_STAT_AREA]; skip label 0 (background)
306
  areas = stats[1:, cv2.CC_STAT_AREA]
 
 
307
  largest_idx = 1 + int(np.argmax(areas))
308
- if areas.max() < min_area_px:
309
- return binary
310
  return (labels == largest_idx).astype(np.uint8)
311
 
312
  def measure_min_area_rect(mask: np.ndarray, px_per_cm: float) -> Tuple[float, float, Tuple]:
313
- """
314
- Compute oriented min-area rectangle on mask.
315
- Returns (length_cm, breadth_cm, (box_points, center)).
316
- """
317
- contours, _ = cv2.findContours(mask, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE)
318
  if not contours:
319
  return 0.0, 0.0, (None, None)
320
  cnt = max(contours, key=cv2.contourArea)
321
- rect = cv2.minAreaRect(cnt) # (center(x,y), (w,h), angle)
322
  (w_px, h_px) = rect[1]
323
  length_px, breadth_px = (max(w_px, h_px), min(w_px, h_px))
324
  length_cm = round(length_px / max(px_per_cm, 1e-6), 2)
@@ -332,60 +318,54 @@ def count_area_cm2(mask: np.ndarray, px_per_cm: float) -> float:
332
 
333
  def draw_measurement_overlay(
334
  base_bgr: np.ndarray,
335
- mask: np.ndarray,
336
  rect_box: np.ndarray,
337
  length_cm: float,
338
  breadth_cm: float,
339
  thickness: int = 2
340
  ) -> np.ndarray:
341
- """
342
- Draw semi-transparent mask + measurement arrows along the rectangle sides with labels.
343
- """
344
  overlay = base_bgr.copy()
345
- # red mask overlay
346
- colored = np.zeros_like(base_bgr)
347
- colored[:, :] = (0, 0, 255)
348
- mask3 = np.dstack([mask * 255] * 3)
349
- overlay = cv2.addWeighted(overlay, 1.0, (colored & mask3), 0.3, 0)
350
-
351
- # draw rectangle
352
- cv2.polylines(overlay, [rect_box], True, (255, 255, 255), thickness)
353
-
354
- # pick the long side & short side arrows
355
- # box points are in order; connect midpoints of opposite edges
356
- pts = rect_box.reshape(-1, 2)
357
- def midpoint(a, b): return ((a[0] + b[0]) // 2, (a[1] + b[1]) // 2)
358
-
359
- # edges: (0-1,1-2,2-3,3-0)
360
- mids = [midpoint(pts[i], pts[(i+1) % 4]) for i in range(4)]
361
- # vector lengths
362
- e_lens = [np.linalg.norm(pts[i] - pts[(i+1) % 4]) for i in range(4)]
363
- long_pair = (0, 2) if e_lens[0] + e_lens[2] >= e_lens[1] + e_lens[3] else (1, 3)
364
- short_pair = (1, 3) if long_pair == (0, 2) else (0, 2)
365
-
366
- # arrowed lines (white with black shadow)
367
- def draw_arrow(p1, p2):
368
- cv2.arrowedLine(overlay, p1, p2, (0, 0, 0), thickness + 2, tipLength=0.05)
369
- cv2.arrowedLine(overlay, p2, p1, (0, 0, 0), thickness + 2, tipLength=0.05)
370
- cv2.arrowedLine(overlay, p1, p2, (255, 255, 255), thickness, tipLength=0.05)
371
- cv2.arrowedLine(overlay, p2, p1, (255, 255, 255), thickness, tipLength=0.05)
372
-
373
- draw_arrow(mids[long_pair[0]], mids[long_pair[1]])
374
- draw_arrow(mids[short_pair[0]], mids[short_pair[1]])
375
-
376
- # labels near the midpoints
377
- def put_label(text, org):
378
- cv2.putText(overlay, text, (org[0] + 4, org[1] - 4),
379
- cv2.FONT_HERSHEY_SIMPLEX, 0.7, (0, 0, 0), 4, cv2.LINE_AA)
380
- cv2.putText(overlay, text, (org[0] + 4, org[1] - 4),
381
- cv2.FONT_HERSHEY_SIMPLEX, 0.7, (255, 255, 255), 2, cv2.LINE_AA)
382
-
383
- put_label(f"{length_cm:.2f} cm", mids[long_pair[0]])
384
- put_label(f"{breadth_cm:.2f} cm", mids[short_pair[0]])
385
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
386
  return overlay
387
 
388
- # ---------- AI PROCESSOR ----------
389
  class AIProcessor:
390
  def __init__(self):
391
  self.models_cache = models_cache
@@ -399,27 +379,104 @@ class AIProcessor:
399
  os.makedirs(out_dir, exist_ok=True)
400
  return out_dir
401
 
402
- def perform_visual_analysis(self, image_pil: Image.Image) -> Dict:
 
403
  """
404
- YOLO detect segmentation largest-component mask
405
- minAreaRect measurement (cm) using px/cm from EXIF when available →
406
- save original, detection overlay, segmentation overlay, and annotated overlay.
407
  """
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
408
  try:
409
- # --- Auto calibration from EXIF (before any conversion that might drop EXIF) ---
410
  px_per_cm, exif_meta = estimate_px_per_cm_from_exif(image_pil, DEFAULT_PX_PER_CM)
411
 
412
- # Convert PIL to OpenCV BGR
413
  image_cv = cv2.cvtColor(np.array(image_pil.convert("RGB")), cv2.COLOR_RGB2BGR)
414
 
415
- # --- Detection (YOLO) ---
416
  det_model = self.models_cache.get("det")
417
  if det_model is None:
418
  raise RuntimeError("YOLO model not loaded")
419
 
420
  results = det_model.predict(image_cv, verbose=False, device="cpu")
421
  if not results or not getattr(results[0], "boxes", None) or len(results[0].boxes) == 0:
422
- raise ValueError("No wound could be detected.")
423
 
424
  box = results[0].boxes[0].xyxy[0].cpu().numpy().astype(int)
425
  x1, y1, x2, y2 = [int(v) for v in box]
@@ -427,80 +484,74 @@ class AIProcessor:
427
  x2, y2 = min(image_cv.shape[1], x2), min(image_cv.shape[0], y2)
428
  roi = image_cv[y1:y2, x1:x2].copy()
429
  if roi.size == 0:
430
- raise ValueError("Detected ROI is empty.")
431
 
432
- # --- Segmentation (optional but recommended) ---
433
  seg_model = self.models_cache.get("seg")
434
- mask_resized = None
435
- length_cm = breadth_cm = surface_area_cm2 = 0.0
436
-
437
  if seg_model is not None:
438
- try:
439
- H, W = seg_model.input_shape[1:3]
440
- resized = cv2.resize(roi, (W, H))
441
- pred = seg_model.predict(np.expand_dims(resized / 255.0, 0), verbose=0)[0]
442
- raw_mask = pred[:, :, 0]
443
-
444
- # binarize + clean
445
- mask = (raw_mask > 0.5).astype(np.uint8)
446
- mask = cv2.morphologyEx(mask, cv2.MORPH_OPEN, np.ones((3, 3), np.uint8), iterations=1)
447
- mask = cv2.morphologyEx(mask, cv2.MORPH_CLOSE, np.ones((5, 5), np.uint8), iterations=1)
448
- mask = largest_component_mask(mask)
449
-
450
- # bring back to ROI size
451
- mask_resized = cv2.resize(mask * 255, (roi.shape[1], roi.shape[0]), interpolation=cv2.INTER_NEAREST)
452
- bin_mask_roi = (mask_resized > 127).astype(np.uint8)
453
-
454
- # measure with oriented rectangle (in ROI pixels)
455
- length_cm, breadth_cm, (box_pts, _) = measure_min_area_rect(bin_mask_roi, px_per_cm)
456
- surface_area_cm2 = count_area_cm2(bin_mask_roi, px_per_cm)
457
-
458
- # draw overlay with arrows/labels on ROI
459
- anno_roi = draw_measurement_overlay(roi, bin_mask_roi, box_pts, length_cm, breadth_cm)
460
 
461
- except Exception as e:
462
- logging.warning(f"Segmentation failed/partial: {e}")
463
- mask_resized = None
464
- anno_roi = roi.copy()
 
465
  else:
466
- # No segmentation → just draw detection box and keep defaults
 
 
 
 
 
467
  anno_roi = roi.copy()
468
 
469
- # --- Save all visualizations ---
470
  out_dir = self._ensure_analysis_dir()
471
  ts = datetime.now().strftime("%Y%m%d_%H%M%S")
472
 
473
- # Original
474
  original_path = os.path.join(out_dir, f"original_{ts}.png")
475
  cv2.imwrite(original_path, image_cv)
476
 
477
- # Detection overlay (rectangle on full image)
478
  det_vis = image_cv.copy()
479
  cv2.rectangle(det_vis, (x1, y1), (x2, y2), (0, 255, 0), 2)
480
  detection_path = os.path.join(out_dir, f"detection_{ts}.png")
481
  cv2.imwrite(detection_path, det_vis)
482
 
483
- # Segmentation overlay (ROI pasted back into full frame for consistent display)
484
  segmentation_path = None
485
  annotated_seg_path = None
486
- if mask_resized is not None:
487
- # compose overlay on full image for "segmentation" view
488
  seg_full = image_cv.copy()
489
  roi_overlay = roi.copy()
490
  red = np.zeros_like(roi_overlay); red[:] = (0, 0, 255)
491
- alpha = 0.3
492
- roi_overlay = cv2.addWeighted(roi_overlay, 1.0, red, alpha, 0, mask=mask_resized)
 
 
 
 
 
493
  seg_full[y1:y2, x1:x2] = roi_overlay
494
  segmentation_path = os.path.join(out_dir, f"segmentation_{ts}.png")
495
  cv2.imwrite(segmentation_path, seg_full)
496
 
497
- # annotated overlay (arrows+labels) placed back into full image
498
  anno_full = image_cv.copy()
499
  anno_full[y1:y2, x1:x2] = anno_roi
500
  annotated_seg_path = os.path.join(out_dir, f"segmentation_annotated_{ts}.png")
501
  cv2.imwrite(annotated_seg_path, anno_full)
502
 
503
- # --- Optional classification ---
504
  wound_type = "Unknown"
505
  cls_pipe = self.models_cache.get("cls")
506
  if cls_pipe is not None:
@@ -517,7 +568,7 @@ class AIProcessor:
517
  "breadth_cm": breadth_cm,
518
  "surface_area_cm2": surface_area_cm2,
519
  "px_per_cm": round(px_per_cm, 2),
520
- "calibration_meta": exif_meta, # for debugging/auditing
521
  "detection_confidence": float(results[0].boxes.conf[0].cpu().item())
522
  if getattr(results[0].boxes, "conf", None) is not None else 0.0,
523
  "detection_image_path": detection_path,
@@ -529,7 +580,7 @@ class AIProcessor:
529
  logging.error(f"Visual analysis failed: {e}", exc_info=True)
530
  raise
531
 
532
- # ---------- Knowledge base and reporting stay unchanged ----------
533
  def query_guidelines(self, query: str) -> str:
534
  try:
535
  vs = self.knowledge_base_cache.get("vector_store")
 
1
  # smartheal_ai_processor.py
2
+ # Fully functional: "segment like snippet" while preserving ALL original names.
3
+ # You can keep using AIProcessor.perform_visual_analysis / analyze_wound / full_analysis_pipeline
4
+ # exactly as before. A convenience AIProcessor.segment_like_snippet(...) is added.
5
 
6
  import os
7
  import time
 
14
  from PIL import Image, ImageOps
15
  from PIL.ExifTags import TAGS
16
 
17
+ try:
18
+ import gradio as gr
19
+ except Exception:
20
+ class _GrErr(RuntimeError): ...
21
+ class gr: # shim so `gr.Error` won’t crash if Gradio isn’t present
22
+ Error = _GrErr
23
+
24
  logging.basicConfig(level=logging.INFO, format="%(asctime)s - %(levelname)s - %(message)s")
25
 
26
  UPLOADS_DIR = "uploads"
 
68
  from huggingface_hub import HfApi, HfFolder
69
  return HfApi, HfFolder
70
 
71
+ # ---------- Spaces GPU function (kept name/behavior) ----------
72
  try:
73
  import spaces
74
 
 
99
  "1. Clinical Summary\n2. Treatment Recommendations\n3. Risk Assessment\n4. Monitoring Plan\n"
100
  )
101
 
 
102
  pipe = pipeline(
103
  "image-text-to-text",
104
  model="google/medgemma-4b-it",
 
141
  ) -> str:
142
  return "⚠️ GPU not available"
143
 
144
+ # ---------- Initialize CPU models (same function names/behavior) ----------
145
  def load_yolo_model():
146
  YOLO = _import_ultralytics()
147
  return YOLO(YOLO_MODEL_PATH)
 
235
  initialize_cpu_models()
236
  setup_knowledge_base()
237
 
238
+ # ---------- Calibration helpers (added, names unchanged elsewhere) ----------
239
  def _exif_to_dict(pil_img: Image.Image) -> Dict[str, object]:
 
240
  out = {}
241
  try:
242
  exif = pil_img.getexif()
 
261
  return None
262
 
263
  def _estimate_sensor_width_mm(f_mm: Optional[float], f35: Optional[float]) -> Optional[float]:
 
 
 
264
  if f_mm and f35 and f35 > 0:
265
  return 36.0 * f_mm / f35
266
  return None
267
 
268
  def estimate_px_per_cm_from_exif(pil_img: Image.Image, default_px_per_cm: float = DEFAULT_PX_PER_CM) -> Tuple[float, Dict]:
 
 
 
 
 
269
  meta = {"used": "default", "f_mm": None, "f35": None, "sensor_w_mm": None, "distance_m": None}
 
270
  try:
271
  exif = _exif_to_dict(pil_img)
272
  f_mm = _to_float(exif.get("FocalLength"))
 
281
  field_w_mm = sensor_w_mm * (subj_dist_m * 1000.0) / f_mm
282
  field_w_cm = field_w_mm / 10.0
283
  px_per_cm = w_px / max(field_w_cm, 1e-6)
 
 
284
  px_per_cm = float(np.clip(px_per_cm, PX_PER_CM_MIN, PX_PER_CM_MAX))
285
  meta["used"] = "exif"
286
  return px_per_cm, meta
 
 
287
  return float(default_px_per_cm), meta
288
+ except Exception:
 
289
  return float(default_px_per_cm), meta
290
 
291
+ # ---------- Mask processing + measurement (helpers added) ----------
292
  def largest_component_mask(binary: np.ndarray, min_area_px: int = 50) -> np.ndarray:
 
293
  num, labels, stats, _ = cv2.connectedComponentsWithStats(binary.astype(np.uint8), connectivity=8)
294
  if num <= 1:
295
+ return binary.astype(np.uint8)
 
296
  areas = stats[1:, cv2.CC_STAT_AREA]
297
+ if areas.size == 0 or areas.max() < min_area_px:
298
+ return binary.astype(np.uint8)
299
  largest_idx = 1 + int(np.argmax(areas))
 
 
300
  return (labels == largest_idx).astype(np.uint8)
301
 
302
  def measure_min_area_rect(mask: np.ndarray, px_per_cm: float) -> Tuple[float, float, Tuple]:
303
+ contours, _ = cv2.findContours(mask.astype(np.uint8), cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE)
 
 
 
 
304
  if not contours:
305
  return 0.0, 0.0, (None, None)
306
  cnt = max(contours, key=cv2.contourArea)
307
+ rect = cv2.minAreaRect(cnt)
308
  (w_px, h_px) = rect[1]
309
  length_px, breadth_px = (max(w_px, h_px), min(w_px, h_px))
310
  length_cm = round(length_px / max(px_per_cm, 1e-6), 2)
 
318
 
319
  def draw_measurement_overlay(
320
  base_bgr: np.ndarray,
321
+ mask01: np.ndarray,
322
  rect_box: np.ndarray,
323
  length_cm: float,
324
  breadth_cm: float,
325
  thickness: int = 2
326
  ) -> np.ndarray:
327
+ """Safe overlay (no mask arg to addWeighted)."""
 
 
328
  overlay = base_bgr.copy()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
329
 
330
+ # red mask overlay only where mask==1
331
+ colored = np.zeros_like(base_bgr); colored[:] = (0, 0, 255)
332
+ mask3 = np.dstack([mask01 * 255] * 3).astype(np.uint8)
333
+ blended = cv2.addWeighted(overlay, 1.0, colored, 0.3, 0)
334
+ # keep blended only on mask
335
+ blended_masked = cv2.bitwise_and(blended, mask3)
336
+ bg = cv2.bitwise_and(overlay, cv2.bitwise_not(mask3))
337
+ overlay = cv2.add(bg, blended_masked)
338
+
339
+ if rect_box is not None:
340
+ cv2.polylines(overlay, [rect_box], True, (255, 255, 255), thickness)
341
+
342
+ pts = rect_box.reshape(-1, 2)
343
+ def midpoint(a, b): return ((a[0] + b[0]) // 2, (a[1] + b[1]) // 2)
344
+ mids = [midpoint(pts[i], pts[(i+1) % 4]) for i in range(4)]
345
+ e_lens = [np.linalg.norm(pts[i] - pts[(i+1) % 4]) for i in range(4)]
346
+ long_pair = (0, 2) if e_lens[0] + e_lens[2] >= e_lens[1] + e_lens[3] else (1, 3)
347
+ short_pair = (1, 3) if long_pair == (0, 2) else (0, 2)
348
+
349
+ def draw_arrow(img, p1, p2):
350
+ cv2.arrowedLine(img, p1, p2, (0, 0, 0), thickness + 2, tipLength=0.05)
351
+ cv2.arrowedLine(img, p2, p1, (0, 0, 0), thickness + 2, tipLength=0.05)
352
+ cv2.arrowedLine(img, p1, p2, (255, 255, 255), thickness, tipLength=0.05)
353
+ cv2.arrowedLine(img, p2, p1, (255, 255, 255), thickness, tipLength=0.05)
354
+
355
+ draw_arrow(overlay, mids[long_pair[0]], mids[long_pair[1]])
356
+ draw_arrow(overlay, mids[short_pair[0]], mids[short_pair[1]])
357
+
358
+ def put_label(text, org):
359
+ cv2.putText(overlay, text, (org[0] + 4, org[1] - 4),
360
+ cv2.FONT_HERSHEY_SIMPLEX, 0.7, (0, 0, 0), 4, cv2.LINE_AA)
361
+ cv2.putText(overlay, text, (org[0] + 4, org[1] - 4),
362
+ cv2.FONT_HERSHEY_SIMPLEX, 0.7, (255, 255, 255), 2, cv2.LINE_AA)
363
+
364
+ put_label(f"{length_cm:.2f} cm", mids[long_pair[0]])
365
+ put_label(f"{breadth_cm:.2f} cm", mids[short_pair[0]])
366
  return overlay
367
 
368
+ # ---------- AI PROCESSOR (ALL names preserved) ----------
369
  class AIProcessor:
370
  def __init__(self):
371
  self.models_cache = models_cache
 
379
  os.makedirs(out_dir, exist_ok=True)
380
  return out_dir
381
 
382
+ # NEW helper that mirrors your short snippet exactly (you can call or ignore)
383
+ def segment_like_snippet(self, image_pil: Image.Image) -> Tuple[Dict, Image.Image, Image.Image]:
384
  """
385
+ Returns (visual_results, detected_image_pil, mask_pil) exactly like your snippet.
386
+ Uses EXIF-calibrated px/cm if available; otherwise DEFAULT_PX_PER_CM.
 
387
  """
388
+ if image_pil is None:
389
+ raise gr.Error("No image provided.")
390
+
391
+ px_per_cm, _ = estimate_px_per_cm_from_exif(image_pil, DEFAULT_PX_PER_CM)
392
+
393
+ # Convert image
394
+ image_cv = cv2.cvtColor(np.array(image_pil.convert("RGB")), cv2.COLOR_RGB2BGR)
395
+
396
+ # Detection
397
+ det_model = self.models_cache.get("det")
398
+ if det_model is None:
399
+ raise gr.Error("Detection model not loaded.")
400
+ results = det_model.predict(image_cv, verbose=False, device="cpu")
401
+ if not results or not getattr(results[0], "boxes", None) or len(results[0].boxes) == 0:
402
+ raise gr.Error("No wound could be detected.")
403
+ box = results[0].boxes[0].xyxy[0].cpu().numpy().astype(int)
404
+ x1, y1, x2, y2 = [int(v) for v in box]
405
+ x1, y1 = max(0, x1), max(0, y1)
406
+ x2, y2 = min(image_cv.shape[1], x2), min(image_cv.shape[0], y2)
407
+ detected_region_cv = image_cv[y1:y2, x1:x2]
408
+ if detected_region_cv.size == 0:
409
+ raise gr.Error("Detected ROI is empty.")
410
+
411
+ # Segmentation
412
+ seg_model = self.models_cache.get("seg")
413
+ mask_roi_01 = None
414
+ if seg_model is not None:
415
+ H, W = seg_model.input_shape[1:3]
416
+ resized = cv2.resize(detected_region_cv, (W, H))
417
+ pred = seg_model.predict(np.expand_dims(resized / 255.0, 0), verbose=0)[0]
418
+ raw = pred[:, :, 0]
419
+ mask = (raw > 0.5).astype(np.uint8)
420
+ mask = cv2.morphologyEx(mask, cv2.MORPH_OPEN, np.ones((3, 3), np.uint8), iterations=1)
421
+ mask = cv2.morphologyEx(mask, cv2.MORPH_CLOSE, np.ones((5, 5), np.uint8), iterations=1)
422
+ mask = largest_component_mask(mask, min_area_px=50)
423
+ mask_roi_01 = cv2.resize(mask, (detected_region_cv.shape[1], detected_region_cv.shape[0]),
424
+ interpolation=cv2.INTER_NEAREST).astype(np.uint8)
425
+ else:
426
+ mask_roi_01 = np.zeros(detected_region_cv.shape[:2], dtype=np.uint8)
427
+
428
+ # Measurement (oriented rect)
429
+ if mask_roi_01.any():
430
+ length_cm, breadth_cm, _ = measure_min_area_rect(mask_roi_01, px_per_cm)
431
+ area_cm2 = count_area_cm2(mask_roi_01, px_per_cm)
432
+ else:
433
+ # fall back to detection box
434
+ h_px = max(0, y2 - y1)
435
+ w_px = max(0, x2 - x1)
436
+ length_cm, breadth_cm = round(h_px / px_per_cm, 2), round(w_px / px_per_cm, 2)
437
+ area_cm2 = round((h_px * w_px) / (px_per_cm ** 2), 2)
438
+
439
+ # Classification (optional)
440
+ wound_type = "Unknown"
441
+ cls_pipe = self.models_cache.get("cls")
442
+ if cls_pipe is not None:
443
+ try:
444
+ detected_image_pil = Image.fromarray(cv2.cvtColor(detected_region_cv, cv2.COLOR_BGR2RGB))
445
+ preds = cls_pipe(detected_image_pil)
446
+ if preds:
447
+ wound_type = max(preds, key=lambda x: x.get("score", 0)).get("label", "Unknown")
448
+ except Exception as e:
449
+ logging.warning(f"Classification failed: {e}")
450
+ detected_image_pil = Image.fromarray(cv2.cvtColor(detected_region_cv, cv2.COLOR_BGR2RGB))
451
+ else:
452
+ detected_image_pil = Image.fromarray(cv2.cvtColor(detected_region_cv, cv2.COLOR_BGR2RGB))
453
+
454
+ visual_results = {
455
+ "wound_type": wound_type,
456
+ "length_cm": length_cm,
457
+ "breadth_cm": breadth_cm,
458
+ "surface_area_cm2": area_cm2
459
+ }
460
+ mask_pil = Image.fromarray((mask_roi_01 * 255).astype(np.uint8))
461
+ return visual_results, detected_image_pil, mask_pil
462
+
463
+ # ORIGINAL NAME preserved; inside it we follow the snippet-style flow and also save overlays
464
+ def perform_visual_analysis(self, image_pil: Image.Image) -> Dict:
465
  try:
466
+ # --- Auto calibration from EXIF ---
467
  px_per_cm, exif_meta = estimate_px_per_cm_from_exif(image_pil, DEFAULT_PX_PER_CM)
468
 
469
+ # Convert image
470
  image_cv = cv2.cvtColor(np.array(image_pil.convert("RGB")), cv2.COLOR_RGB2BGR)
471
 
472
+ # --- Detection ---
473
  det_model = self.models_cache.get("det")
474
  if det_model is None:
475
  raise RuntimeError("YOLO model not loaded")
476
 
477
  results = det_model.predict(image_cv, verbose=False, device="cpu")
478
  if not results or not getattr(results[0], "boxes", None) or len(results[0].boxes) == 0:
479
+ raise gr.Error("No wound could be detected.")
480
 
481
  box = results[0].boxes[0].xyxy[0].cpu().numpy().astype(int)
482
  x1, y1, x2, y2 = [int(v) for v in box]
 
484
  x2, y2 = min(image_cv.shape[1], x2), min(image_cv.shape[0], y2)
485
  roi = image_cv[y1:y2, x1:x2].copy()
486
  if roi.size == 0:
487
+ raise gr.Error("Detected ROI is empty.")
488
 
489
+ # --- Segmentation (snippet style) ---
490
  seg_model = self.models_cache.get("seg")
491
+ mask_roi_01 = None
 
 
492
  if seg_model is not None:
493
+ H, W = seg_model.input_shape[1:3]
494
+ resized = cv2.resize(roi, (W, H))
495
+ pred = seg_model.predict(np.expand_dims(resized / 255.0, 0), verbose=0)[0]
496
+ raw_mask = pred[:, :, 0]
497
+ mask = (raw_mask > 0.5).astype(np.uint8)
498
+ mask = cv2.morphologyEx(mask, cv2.MORPH_OPEN, np.ones((3, 3), np.uint8), iterations=1)
499
+ mask = cv2.morphologyEx(mask, cv2.MORPH_CLOSE, np.ones((5, 5), np.uint8), iterations=1)
500
+ mask = largest_component_mask(mask)
501
+ mask_roi_01 = cv2.resize(mask, (roi.shape[1], roi.shape[0]), interpolation=cv2.INTER_NEAREST)
502
+ else:
503
+ mask_roi_01 = np.zeros(roi.shape[:2], dtype=np.uint8)
 
 
 
 
 
 
 
 
 
 
 
504
 
505
+ # --- Measurement with oriented rect (better than boundingRect) ---
506
+ if mask_roi_01.any():
507
+ length_cm, breadth_cm, (box_pts, _) = measure_min_area_rect(mask_roi_01, px_per_cm)
508
+ surface_area_cm2 = count_area_cm2(mask_roi_01, px_per_cm)
509
+ anno_roi = draw_measurement_overlay(roi, mask_roi_01, box_pts, length_cm, breadth_cm)
510
  else:
511
+ # fallback to detection box if segmentation missing/empty
512
+ h_px = max(0, y2 - y1)
513
+ w_px = max(0, x2 - x1)
514
+ length_cm = round(h_px / px_per_cm, 2)
515
+ breadth_cm = round(w_px / px_per_cm, 2)
516
+ surface_area_cm2 = round((h_px * w_px) / (px_per_cm ** 2), 2)
517
  anno_roi = roi.copy()
518
 
519
+ # --- Save visuals ---
520
  out_dir = self._ensure_analysis_dir()
521
  ts = datetime.now().strftime("%Y%m%d_%H%M%S")
522
 
 
523
  original_path = os.path.join(out_dir, f"original_{ts}.png")
524
  cv2.imwrite(original_path, image_cv)
525
 
 
526
  det_vis = image_cv.copy()
527
  cv2.rectangle(det_vis, (x1, y1), (x2, y2), (0, 255, 0), 2)
528
  detection_path = os.path.join(out_dir, f"detection_{ts}.png")
529
  cv2.imwrite(detection_path, det_vis)
530
 
 
531
  segmentation_path = None
532
  annotated_seg_path = None
533
+ if mask_roi_01 is not None and mask_roi_01.any():
534
+ # Safe blending: blend once, then gate by mask
535
  seg_full = image_cv.copy()
536
  roi_overlay = roi.copy()
537
  red = np.zeros_like(roi_overlay); red[:] = (0, 0, 255)
538
+ blended = cv2.addWeighted(roi_overlay, 1.0, red, 0.3, 0)
539
+ mask_u8 = (mask_roi_01.astype(np.uint8) * 255)
540
+ mask3 = cv2.merge([mask_u8, mask_u8, mask_u8])
541
+ blended_masked = cv2.bitwise_and(blended, mask3)
542
+ roi_bg = cv2.bitwise_and(roi_overlay, cv2.bitwise_not(mask3))
543
+ roi_overlay = cv2.add(roi_bg, blended_masked)
544
+
545
  seg_full[y1:y2, x1:x2] = roi_overlay
546
  segmentation_path = os.path.join(out_dir, f"segmentation_{ts}.png")
547
  cv2.imwrite(segmentation_path, seg_full)
548
 
 
549
  anno_full = image_cv.copy()
550
  anno_full[y1:y2, x1:x2] = anno_roi
551
  annotated_seg_path = os.path.join(out_dir, f"segmentation_annotated_{ts}.png")
552
  cv2.imwrite(annotated_seg_path, anno_full)
553
 
554
+ # --- Classification (optional) ---
555
  wound_type = "Unknown"
556
  cls_pipe = self.models_cache.get("cls")
557
  if cls_pipe is not None:
 
568
  "breadth_cm": breadth_cm,
569
  "surface_area_cm2": surface_area_cm2,
570
  "px_per_cm": round(px_per_cm, 2),
571
+ "calibration_meta": exif_meta,
572
  "detection_confidence": float(results[0].boxes.conf[0].cpu().item())
573
  if getattr(results[0].boxes, "conf", None) is not None else 0.0,
574
  "detection_image_path": detection_path,
 
580
  logging.error(f"Visual analysis failed: {e}", exc_info=True)
581
  raise
582
 
583
+ # ---------- Knowledge base and reporting (names preserved) ----------
584
  def query_guidelines(self, query: str) -> str:
585
  try:
586
  vs = self.knowledge_base_cache.get("vector_store")