SmartHeal commited on
Commit
860a048
·
verified ·
1 Parent(s): be9c884

Update src/ai_processor.py

Browse files
Files changed (1) hide show
  1. src/ai_processor.py +204 -88
src/ai_processor.py CHANGED
@@ -1,11 +1,6 @@
1
  # smartheal_ai_processor.py
2
- # Preserves ALL original class/function names.
3
- # What you get:
4
- # - Uses your segmentation_model.h5 first; clean KMeans fallback if it fails/missing
5
- # - Safe overlay (no 'mask' kwarg with addWeighted)
6
- # - Always writes a segmentation view (so it never looks like the plain original)
7
- # - CPU by default; optional VLM (MedGemma) is OFF unless SMARTHEAL_ENABLE_VLM=1
8
- # - Optional @spaces.GPU **stub** (no queue) to satisfy Spaces startup without touching CUDA
9
 
10
  import os
11
  import time
@@ -13,31 +8,36 @@ import logging
13
  from datetime import datetime
14
  from typing import Optional, Dict, List, Tuple
15
 
16
- # Quiet tokenizers; default to CPU for safety on ZeroGPU/Spaces
17
  os.environ.setdefault("TOKENIZERS_PARALLELISM", "false")
18
  os.environ.setdefault("CUDA_VISIBLE_DEVICES", "")
 
 
19
 
20
  import cv2
21
  import numpy as np
22
  from PIL import Image
23
  from PIL.ExifTags import TAGS
24
 
25
- # --- Optional: register a harmless @spaces.GPU-decorated stub to silence startup warning ---
 
 
 
 
 
 
 
 
 
26
  try:
27
  import spaces as _spaces
28
-
29
- @_spaces.GPU(enable_queue=False) # not queued -> won't start a ZeroGPU worker
30
  def smartheal_gpu_stub(ping: int = 0) -> str:
31
- """No-op so Spaces detects at least one @spaces.GPU function without touching CUDA."""
32
  return "ready"
33
-
34
- logging.info("Registered @spaces.GPU stub (enable_queue=False); startup detector satisfied.")
35
- except Exception as _e:
36
- # It's fine if 'spaces' isn't available locally.
37
  pass
38
 
39
- logging.basicConfig(level=logging.INFO, format="%(asctime)s - %(levelname)s - %(message)s")
40
-
41
  UPLOADS_DIR = "uploads"
42
  os.makedirs(UPLOADS_DIR, exist_ok=True)
43
 
@@ -49,6 +49,11 @@ DATASET_ID = "SmartHeal/wound-image-uploads"
49
  DEFAULT_PX_PER_CM = 38.0
50
  PX_PER_CM_MIN, PX_PER_CM_MAX = 5.0, 1200.0
51
 
 
 
 
 
 
52
  models_cache: Dict[str, object] = {}
53
  knowledge_base_cache: Dict[str, object] = {}
54
 
@@ -86,7 +91,7 @@ def _import_hf_hub():
86
  from huggingface_hub import HfApi, HfFolder
87
  return HfApi, HfFolder
88
 
89
- # ---------- LLM report (OFF by default; enable with SMARTHEAL_ENABLE_VLM=1) ----------
90
  def generate_medgemma_report(
91
  patient_info: str,
92
  visual_results: Dict,
@@ -94,24 +99,18 @@ def generate_medgemma_report(
94
  image_pil: Image.Image,
95
  max_new_tokens: Optional[int] = None,
96
  ) -> str:
97
- """
98
- CPU-only MedGemma call (safe). Disabled by default to avoid env mismatches.
99
- Set SMARTHEAL_ENABLE_VLM=1 to try loading the model.
100
- """
101
  if os.getenv("SMARTHEAL_ENABLE_VLM", "0") != "1":
102
  return "⚠️ VLM disabled"
103
-
104
  try:
105
  from transformers import pipeline
106
  pipe = pipeline(
107
  task="image-text-to-text",
108
  model="google/medgemma-4b-it",
109
- device_map=None, # CPU
110
  token=HF_TOKEN,
111
  trust_remote_code=True,
112
- model_kwargs={"low_cpu_mem_usage": True}, # avoid 'use_cache' arg mismatch
113
  )
114
-
115
  prompt = (
116
  "You are a medical AI assistant. Analyze this wound image and patient data.\n\n"
117
  f"Patient: {patient_info}\n"
@@ -120,12 +119,10 @@ def generate_medgemma_report(
120
  "Provide a structured report with:\n"
121
  "1. Clinical Summary\n2. Treatment Recommendations\n3. Risk Assessment\n4. Monitoring Plan\n"
122
  )
123
-
124
  messages = [{"role": "user", "content": [
125
  {"type": "image", "image": image_pil},
126
  {"type": "text", "text": prompt},
127
  ]}]
128
-
129
  out = pipe(text=messages, max_new_tokens=max_new_tokens or 600, do_sample=False, temperature=0.7)
130
  if out and len(out) > 0:
131
  try:
@@ -174,7 +171,10 @@ def initialize_cpu_models() -> None:
174
  try:
175
  if os.path.exists(SEG_MODEL_PATH):
176
  models_cache["seg"] = load_segmentation_model()
177
- logging.info("✅ Segmentation model loaded (CPU)")
 
 
 
178
  else:
179
  models_cache["seg"] = None
180
  logging.warning("Segmentation model file missing; skipping.")
@@ -283,47 +283,133 @@ def estimate_px_per_cm_from_exif(pil_img: Image.Image, default_px_per_cm: float
283
  except Exception:
284
  return float(default_px_per_cm), meta
285
 
286
- # ---------- Segmentation (model-first, KMeans fallback) ----------
287
- def segment_wound(image: np.ndarray) -> np.ndarray:
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
288
  """
289
- Segments wound from a preprocessed ROI image, with a fallback to KMeans if the model fails.
290
- Returns a mask in 0..255 (uint8), same HxW as input image.
291
  """
292
- segmentation_model = models_cache.get("seg", None)
 
 
 
 
 
 
 
293
 
294
- if segmentation_model is not None:
295
  try:
296
- input_shape = getattr(segmentation_model, "input_shape", None)
297
- if input_shape is None or len(input_shape) < 3:
298
- raise ValueError(f"Bad seg input_shape: {input_shape}")
299
- H, W = int(input_shape[1]), int(input_shape[2]) # (None,H,W,C)
300
-
301
- resized = cv2.resize(image, (W, H)) # (W,H)
302
- norm = np.expand_dims(resized / 255.0, axis=0) # (1,H,W,3)
303
- prediction = segmentation_model.predict(norm, verbose=0)
304
-
305
- # Handle models with multiple outputs
306
- if isinstance(prediction, (list, tuple)):
307
- prediction = prediction[0]
308
- # squeeze batch dim if present
309
- prediction = prediction[0] if getattr(prediction, "ndim", 0) >= 3 else prediction
310
-
311
- pred2d = np.squeeze(prediction) # (H,W) or (H,W,1)->(H,W)
312
- mask_prob = cv2.resize(pred2d, (image.shape[1], image.shape[0]))
313
- mask = (mask_prob >= 0.5).astype(np.uint8) * 255
314
- return mask.astype(np.uint8)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
315
  except Exception as e:
316
- logging.warning(f"⚠️ Segmentation model prediction failed: {e}. Falling back to KMeans.")
 
317
 
318
- # --- Fallback: color clustering (KMeans, k=2), pick 'reddest' cluster in Lab a* ---
319
- Z = image.reshape((-1, 3)).astype(np.float32)
320
  criteria = (cv2.TERM_CRITERIA_EPS + cv2.TERM_CRITERIA_MAX_ITER, 10, 1.0)
321
  _, labels, centers = cv2.kmeans(Z, 2, None, criteria, 5, cv2.KMEANS_PP_CENTERS)
322
  centers_u8 = centers.astype(np.uint8).reshape(1, 2, 3)
323
  centers_lab = cv2.cvtColor(centers_u8, cv2.COLOR_BGR2LAB)[0]
324
- wound_idx = int(np.argmax(centers_lab[:, 1])) # a* channel (redness)
325
- mask = (labels.reshape(image.shape[:2]) == wound_idx).astype(np.uint8) * 255
326
- return mask.astype(np.uint8)
 
 
 
 
 
 
 
 
 
 
 
 
 
327
 
328
  # ---------- Measurement + overlay helpers ----------
329
  def largest_component_mask(binary01: np.ndarray, min_area_px: int = 50) -> np.ndarray:
@@ -362,16 +448,21 @@ def draw_measurement_overlay(
362
  thickness: int = 2
363
  ) -> np.ndarray:
364
  overlay = base_bgr.copy()
365
- # Safe masked blend (OpenCV addWeighted has no 'mask' kwarg)
 
366
  red = np.zeros_like(overlay); red[:] = (0, 0, 255)
367
- blended = cv2.addWeighted(overlay, 1.0, red, 0.3, 0)
368
- m3 = np.dstack([mask01 * 255] * 3).astype("uint8")
369
- overlay = cv2.add(cv2.bitwise_and(overlay, cv2.bitwise_not(m3)),
370
- cv2.bitwise_and(blended, m3))
 
 
 
 
 
371
 
372
  if rect_box is not None:
373
  cv2.polylines(overlay, [rect_box], True, (255, 255, 255), thickness)
374
-
375
  pts = rect_box.reshape(-1, 2)
376
  def midpoint(a, b): return ((a[0] + b[0]) // 2, (a[1] + b[1]) // 2)
377
  mids = [midpoint(pts[i], pts[(i+1) % 4]) for i in range(4)]
@@ -444,32 +535,40 @@ class AIProcessor:
444
  except Exception:
445
  raise RuntimeError("Detected ROI is empty.")
446
 
 
 
 
447
  # --- Segmentation (model-first + KMeans fallback) ---
448
- mask_u8_255 = segment_wound(roi) # 0..255
449
  mask01 = (mask_u8_255 > 127).astype(np.uint8)
 
 
450
  if mask01.any():
 
451
  mask01 = cv2.morphologyEx(mask01, cv2.MORPH_OPEN, np.ones((3,3), np.uint8), iterations=1)
452
  mask01 = cv2.morphologyEx(mask01, cv2.MORPH_CLOSE, np.ones((3,3), np.uint8), iterations=1)
453
  mask01 = largest_component_mask(mask01, min_area_px=30)
 
454
 
455
  # --- Measurement ---
456
  if mask01.any():
457
  length_cm, breadth_cm, (box_pts, _) = measure_min_area_rect(mask01, px_per_cm)
458
  surface_area_cm2 = count_area_cm2(mask01, px_per_cm)
459
  anno_roi = draw_measurement_overlay(roi, mask01, box_pts, length_cm, breadth_cm)
 
460
  else:
461
- # fallback to detection box if segmentation is empty
462
  h_px = max(0, y2 - y1); w_px = max(0, x2 - x1)
463
  length_cm = round(h_px / px_per_cm, 2)
464
  breadth_cm = round(w_px / px_per_cm, 2)
465
  surface_area_cm2 = round((h_px * w_px) / (px_per_cm ** 2), 2)
466
  anno_roi = roi.copy()
 
 
 
467
  box_pts = None
 
468
 
469
- # --- Save visualizations (ALWAYS create a segmentation image) ---
470
- out_dir = self._ensure_analysis_dir()
471
- ts = datetime.now().strftime("%Y%m%d_%H%M%S")
472
-
473
  original_path = os.path.join(out_dir, f"original_{ts}.png")
474
  cv2.imwrite(original_path, image_cv)
475
 
@@ -478,27 +577,30 @@ class AIProcessor:
478
  detection_path = os.path.join(out_dir, f"detection_{ts}.png")
479
  cv2.imwrite(detection_path, det_vis)
480
 
481
- # Save ROI mask image (helps debug)
482
  roi_mask_path = os.path.join(out_dir, f"roi_mask_{ts}.png")
483
  cv2.imwrite(roi_mask_path, (mask01 * 255).astype(np.uint8))
484
 
485
- # Segmentation overlay (paste back to full image). If mask empty, tint ROI red so it's NOT identical to original.
486
- seg_full = image_cv.copy()
487
- red = np.zeros_like(roi); red[:] = (0, 0, 255)
488
- if mask01.any():
489
- blended = cv2.addWeighted(roi, 1.0, red, 0.30, 0)
490
- m3 = np.dstack([mask01 * 255] * 3).astype("uint8")
491
- roi_overlay = cv2.add(cv2.bitwise_and(roi, cv2.bitwise_not(m3)),
492
- cv2.bitwise_and(blended, m3))
 
 
493
  else:
494
- # No mask light red tint over the ROI to make the "segmentation" view visually distinct.
495
- roi_overlay = cv2.addWeighted(roi, 0.75, red, 0.25, 0)
496
 
 
497
  seg_full[y1:y2, x1:x2] = roi_overlay
498
  segmentation_path = os.path.join(out_dir, f"segmentation_{ts}.png")
499
  cv2.imwrite(segmentation_path, seg_full)
500
 
501
- # Annotated (arrows + labels)
 
 
502
  anno_full = image_cv.copy()
503
  anno_full[y1:y2, x1:x2] = anno_roi
504
  annotated_seg_path = os.path.join(out_dir, f"segmentation_annotated_{ts}.png")
@@ -515,6 +617,17 @@ class AIProcessor:
515
  except Exception as e:
516
  logging.warning(f"Classification failed: {e}")
517
 
 
 
 
 
 
 
 
 
 
 
 
518
  return {
519
  "wound_type": wound_type,
520
  "length_cm": length_cm,
@@ -525,16 +638,19 @@ class AIProcessor:
525
  "detection_confidence": float(results[0].boxes.conf[0].cpu().item())
526
  if getattr(results[0].boxes, "conf", None) is not None else 0.0,
527
  "detection_image_path": detection_path,
528
- "segmentation_image_path": segmentation_path, # always present
529
  "segmentation_annotated_path": annotated_seg_path,
530
- "roi_mask_path": roi_mask_path, # helpful for debugging
 
 
 
531
  "original_image_path": original_path,
532
  }
533
  except Exception as e:
534
  logging.error(f"Visual analysis failed: {e}", exc_info=True)
535
  raise
536
 
537
- # ---------- Knowledge base + reporting (unchanged names) ----------
538
  def query_guidelines(self, query: str) -> str:
539
  try:
540
  vs = self.knowledge_base_cache.get("vector_store")
 
1
  # smartheal_ai_processor.py
2
+ # Verbose, instrumented version — preserves public class/function names
3
+ # Turn on deep logging: export LOGLEVEL=DEBUG SMARTHEAL_DEBUG=1
 
 
 
 
 
4
 
5
  import os
6
  import time
 
8
  from datetime import datetime
9
  from typing import Optional, Dict, List, Tuple
10
 
11
+ # ---- Environment defaults ----
12
  os.environ.setdefault("TOKENIZERS_PARALLELISM", "false")
13
  os.environ.setdefault("CUDA_VISIBLE_DEVICES", "")
14
+ LOGLEVEL = os.getenv("LOGLEVEL", "INFO").upper()
15
+ SMARTHEAL_DEBUG = os.getenv("SMARTHEAL_DEBUG", "0") == "1"
16
 
17
  import cv2
18
  import numpy as np
19
  from PIL import Image
20
  from PIL.ExifTags import TAGS
21
 
22
+ # --- Logging config ---
23
+ logging.basicConfig(
24
+ level=getattr(logging, LOGLEVEL, logging.INFO),
25
+ format="%(asctime)s - %(levelname)s - %(message)s",
26
+ )
27
+
28
+ def _log_kv(prefix: str, kv: Dict):
29
+ logging.debug(prefix + " | " + " | ".join(f"{k}={v}" for k, v in kv.items()))
30
+
31
+ # --- Optional Spaces GPU stub (harmless) ---
32
  try:
33
  import spaces as _spaces
34
+ @_spaces.GPU(enable_queue=False)
 
35
  def smartheal_gpu_stub(ping: int = 0) -> str:
 
36
  return "ready"
37
+ logging.info("Registered @spaces.GPU stub (enable_queue=False).")
38
+ except Exception:
 
 
39
  pass
40
 
 
 
41
  UPLOADS_DIR = "uploads"
42
  os.makedirs(UPLOADS_DIR, exist_ok=True)
43
 
 
49
  DEFAULT_PX_PER_CM = 38.0
50
  PX_PER_CM_MIN, PX_PER_CM_MAX = 5.0, 1200.0
51
 
52
+ # Segmentation preprocessing knobs
53
+ SEG_EXPECTS_RGB = os.getenv("SEG_EXPECTS_RGB", "1") == "1" # most TF models trained on RGB
54
+ SEG_NORM = os.getenv("SEG_NORM", "0to1") # "0to1" | "imagenet"
55
+ SEG_THRESH = float(os.getenv("SEG_THRESH", "0.5"))
56
+
57
  models_cache: Dict[str, object] = {}
58
  knowledge_base_cache: Dict[str, object] = {}
59
 
 
91
  from huggingface_hub import HfApi, HfFolder
92
  return HfApi, HfFolder
93
 
94
+ # ---------- VLM (disabled by default) ----------
95
  def generate_medgemma_report(
96
  patient_info: str,
97
  visual_results: Dict,
 
99
  image_pil: Image.Image,
100
  max_new_tokens: Optional[int] = None,
101
  ) -> str:
 
 
 
 
102
  if os.getenv("SMARTHEAL_ENABLE_VLM", "0") != "1":
103
  return "⚠️ VLM disabled"
 
104
  try:
105
  from transformers import pipeline
106
  pipe = pipeline(
107
  task="image-text-to-text",
108
  model="google/medgemma-4b-it",
109
+ device_map=None,
110
  token=HF_TOKEN,
111
  trust_remote_code=True,
112
+ model_kwargs={"low_cpu_mem_usage": True},
113
  )
 
114
  prompt = (
115
  "You are a medical AI assistant. Analyze this wound image and patient data.\n\n"
116
  f"Patient: {patient_info}\n"
 
119
  "Provide a structured report with:\n"
120
  "1. Clinical Summary\n2. Treatment Recommendations\n3. Risk Assessment\n4. Monitoring Plan\n"
121
  )
 
122
  messages = [{"role": "user", "content": [
123
  {"type": "image", "image": image_pil},
124
  {"type": "text", "text": prompt},
125
  ]}]
 
126
  out = pipe(text=messages, max_new_tokens=max_new_tokens or 600, do_sample=False, temperature=0.7)
127
  if out and len(out) > 0:
128
  try:
 
171
  try:
172
  if os.path.exists(SEG_MODEL_PATH):
173
  models_cache["seg"] = load_segmentation_model()
174
+ m = models_cache["seg"]
175
+ ishape = getattr(m, "input_shape", None)
176
+ oshape = getattr(m, "output_shape", None)
177
+ logging.info(f"✅ Segmentation model loaded (CPU) | input_shape={ishape} output_shape={oshape}")
178
  else:
179
  models_cache["seg"] = None
180
  logging.warning("Segmentation model file missing; skipping.")
 
283
  except Exception:
284
  return float(default_px_per_cm), meta
285
 
286
+ # ---------- Segmentation helpers ----------
287
+ def _imagenet_norm(arr: np.ndarray) -> np.ndarray:
288
+ # expects RGB 0..255 -> float
289
+ mean = np.array([123.675, 116.28, 103.53], dtype=np.float32)
290
+ std = np.array([58.395, 57.12, 57.375], dtype=np.float32)
291
+ return (arr.astype(np.float32) - mean) / std
292
+
293
+ def _preprocess_for_seg(bgr_roi: np.ndarray, target_hw: Tuple[int, int]) -> np.ndarray:
294
+ H, W = target_hw
295
+ # Resize first
296
+ resized = cv2.resize(bgr_roi, (W, H), interpolation=cv2.INTER_LINEAR)
297
+ # Convert to RGB if required
298
+ if SEG_EXPECTS_RGB:
299
+ resized = cv2.cvtColor(resized, cv2.COLOR_BGR2RGB)
300
+ # Normalize
301
+ if SEG_NORM.lower() == "imagenet":
302
+ x = _imagenet_norm(resized)
303
+ else:
304
+ x = resized.astype(np.float32) / 255.0
305
+ # Add batch dim
306
+ x = np.expand_dims(x, axis=0) # (1,H,W,3)
307
+ return x
308
+
309
+ def _to_prob(pred: np.ndarray) -> np.ndarray:
310
+ # Pred could be (1,H,W,1), (H,W,1), (1,H,W), (H,W), or logits
311
+ p = np.squeeze(pred)
312
+ # If values look like logits, apply sigmoid
313
+ pmin, pmax = float(p.min()), float(p.max())
314
+ if pmax > 1.0 or pmin < 0.0:
315
+ p = 1.0 / (1.0 + np.exp(-p))
316
+ return p.astype(np.float32)
317
+
318
+ # Global last debug dict (per-process) to attach into results
319
+ _last_seg_debug: Dict[str, object] = {}
320
+
321
+ def segment_wound(image_bgr: np.ndarray, ts: str, out_dir: str) -> Tuple[np.ndarray, Dict[str, object]]:
322
  """
323
+ Attempts TF segmentation first; falls back to KMeans if needed.
324
+ Returns (mask_uint8_0_255, debug_dict)
325
  """
326
+ global _last_seg_debug
327
+ _last_seg_debug = {}
328
+
329
+ seg_model = models_cache.get("seg", None)
330
+ used = "fallback_kmeans"
331
+ reason = "no_model"
332
+ heatmap_path = None
333
+ saw_roi_path = None
334
 
335
+ if seg_model is not None:
336
  try:
337
+ ishape = getattr(seg_model, "input_shape", None)
338
+ if not ishape or len(ishape) < 4:
339
+ raise ValueError(f"Bad seg input_shape: {ishape}")
340
+ th, tw = int(ishape[1]), int(ishape[2])
341
+ x = _preprocess_for_seg(image_bgr, (th, tw))
342
+ saw_roi = (cv2.cvtColor(image_bgr, cv2.COLOR_BGR2RGB) if SEG_EXPECTS_RGB else image_bgr)
343
+ if SMARTHEAL_DEBUG:
344
+ saw_roi_path = os.path.join(out_dir, f"roi_for_seg_{ts}.png")
345
+ cv2.imwrite(saw_roi_path, (cv2.cvtColor(saw_roi, cv2.COLOR_RGB2BGR) if SEG_EXPECTS_RGB else saw_roi))
346
+
347
+ # Inference
348
+ pred = seg_model.predict(x, verbose=0)
349
+ if isinstance(pred, (list, tuple)):
350
+ pred = pred[0]
351
+ p = _to_prob(pred) # HxW
352
+ p = cv2.resize(p, (image_bgr.shape[1], image_bgr.shape[0])) # back to ROI size
353
+
354
+ # Debug stats
355
+ pmin, pmax, pmean = float(p.min()), float(p.max()), float(p.mean())
356
+ _log_kv("SEG_PROB_STATS", {"min": pmin, "max": pmax, "mean": pmean})
357
+
358
+ if SMARTHEAL_DEBUG:
359
+ # save heatmap (0..255)
360
+ hm = (np.clip(p, 0, 1) * 255).astype(np.uint8)
361
+ heat = cv2.applyColorMap(hm, cv2.COLORMAP_JET)
362
+ heatmap_path = os.path.join(out_dir, f"seg_pred_heatmap_{ts}.png")
363
+ cv2.imwrite(heatmap_path, heat)
364
+
365
+ # Threshold
366
+ thr = SEG_THRESH
367
+ mask = (p >= thr).astype(np.uint8) * 255
368
+ pos = int((mask > 0).sum())
369
+ frac = pos / float(mask.size)
370
+ logging.info(f"SegModel USED | thr={thr} pos_px={pos} pos_frac={frac:.4f} ex_rgb={SEG_EXPECTS_RGB} norm={SEG_NORM}")
371
+
372
+ used = "tf_model"
373
+ reason = "ok"
374
+
375
+ _last_seg_debug = {
376
+ "used": used,
377
+ "reason": reason,
378
+ "input_shape": ishape,
379
+ "prob_min": pmin, "prob_max": pmax, "prob_mean": pmean,
380
+ "threshold": thr,
381
+ "positive_fraction": frac,
382
+ "heatmap_path": heatmap_path,
383
+ "roi_seen_by_model": saw_roi_path,
384
+ }
385
+ return mask.astype(np.uint8), _last_seg_debug
386
+
387
  except Exception as e:
388
+ reason = f"model_failed: {e}"
389
+ logging.warning(f"⚠️ Segmentation model prediction failed → fallback. Reason: {e}")
390
 
391
+ # --- Fallback: KMeans (k=2), pick 'reddest' cluster in Lab a* ---
392
+ Z = image_bgr.reshape((-1, 3)).astype(np.float32)
393
  criteria = (cv2.TERM_CRITERIA_EPS + cv2.TERM_CRITERIA_MAX_ITER, 10, 1.0)
394
  _, labels, centers = cv2.kmeans(Z, 2, None, criteria, 5, cv2.KMEANS_PP_CENTERS)
395
  centers_u8 = centers.astype(np.uint8).reshape(1, 2, 3)
396
  centers_lab = cv2.cvtColor(centers_u8, cv2.COLOR_BGR2LAB)[0]
397
+ wound_idx = int(np.argmax(centers_lab[:, 1])) # maximize a* (redness)
398
+ mask = (labels.reshape(image_bgr.shape[:2]) == wound_idx).astype(np.uint8) * 255
399
+
400
+ pos = int((mask > 0).sum()); frac = pos / float(mask.size)
401
+ logging.info(f"KMeans USED | pos_px={pos} pos_frac={frac:.4f}")
402
+
403
+ _last_seg_debug = {
404
+ "used": used,
405
+ "reason": reason,
406
+ "kmeans_centers_bgr": centers.tolist(),
407
+ "kmeans_centers_lab": centers_lab.astype(float).tolist(),
408
+ "positive_fraction": frac,
409
+ "heatmap_path": heatmap_path,
410
+ "roi_seen_by_model": saw_roi_path,
411
+ }
412
+ return mask.astype(np.uint8), _last_seg_debug
413
 
414
  # ---------- Measurement + overlay helpers ----------
415
  def largest_component_mask(binary01: np.ndarray, min_area_px: int = 50) -> np.ndarray:
 
448
  thickness: int = 2
449
  ) -> np.ndarray:
450
  overlay = base_bgr.copy()
451
+
452
+ # Strong overlay + contour
453
  red = np.zeros_like(overlay); red[:] = (0, 0, 255)
454
+ alpha = 0.55
455
+ tinted = cv2.addWeighted(overlay, 1 - alpha, red, alpha, 0)
456
+ m3 = cv2.merge([mask01 * 255] * 3).astype("uint8")
457
+ overlay = np.where(m3 > 0, tinted, overlay)
458
+
459
+ # Draw contour
460
+ cnts, _ = cv2.findContours((mask01 * 255).astype(np.uint8), cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE)
461
+ if cnts:
462
+ cv2.drawContours(overlay, cnts, -1, (255, 255, 255), 2)
463
 
464
  if rect_box is not None:
465
  cv2.polylines(overlay, [rect_box], True, (255, 255, 255), thickness)
 
466
  pts = rect_box.reshape(-1, 2)
467
  def midpoint(a, b): return ((a[0] + b[0]) // 2, (a[1] + b[1]) // 2)
468
  mids = [midpoint(pts[i], pts[(i+1) % 4]) for i in range(4)]
 
535
  except Exception:
536
  raise RuntimeError("Detected ROI is empty.")
537
 
538
+ out_dir = self._ensure_analysis_dir()
539
+ ts = datetime.now().strftime("%Y%m%d_%H%M%S")
540
+
541
  # --- Segmentation (model-first + KMeans fallback) ---
542
+ mask_u8_255, seg_debug = segment_wound(roi, ts, out_dir)
543
  mask01 = (mask_u8_255 > 127).astype(np.uint8)
544
+
545
+ # Post-processing + metrics
546
  if mask01.any():
547
+ mask_before = mask01.sum()
548
  mask01 = cv2.morphologyEx(mask01, cv2.MORPH_OPEN, np.ones((3,3), np.uint8), iterations=1)
549
  mask01 = cv2.morphologyEx(mask01, cv2.MORPH_CLOSE, np.ones((3,3), np.uint8), iterations=1)
550
  mask01 = largest_component_mask(mask01, min_area_px=30)
551
+ logging.debug(f"Mask postproc: px_before={mask_before} px_after={int(mask01.sum())}")
552
 
553
  # --- Measurement ---
554
  if mask01.any():
555
  length_cm, breadth_cm, (box_pts, _) = measure_min_area_rect(mask01, px_per_cm)
556
  surface_area_cm2 = count_area_cm2(mask01, px_per_cm)
557
  anno_roi = draw_measurement_overlay(roi, mask01, box_pts, length_cm, breadth_cm)
558
+ segmentation_empty = False
559
  else:
 
560
  h_px = max(0, y2 - y1); w_px = max(0, x2 - x1)
561
  length_cm = round(h_px / px_per_cm, 2)
562
  breadth_cm = round(w_px / px_per_cm, 2)
563
  surface_area_cm2 = round((h_px * w_px) / (px_per_cm ** 2), 2)
564
  anno_roi = roi.copy()
565
+ cv2.rectangle(anno_roi, (2, 2), (anno_roi.shape[1]-3, anno_roi.shape[0]-3), (0, 0, 255), 3)
566
+ cv2.line(anno_roi, (0, 0), (anno_roi.shape[1]-1, anno_roi.shape[0]-1), (0, 0, 255), 2)
567
+ cv2.line(anno_roi, (anno_roi.shape[1]-1, 0), (0, anno_roi.shape[0]-1), (0, 0, 255), 2)
568
  box_pts = None
569
+ segmentation_empty = True
570
 
571
+ # --- Save visualizations ---
 
 
 
572
  original_path = os.path.join(out_dir, f"original_{ts}.png")
573
  cv2.imwrite(original_path, image_cv)
574
 
 
577
  detection_path = os.path.join(out_dir, f"detection_{ts}.png")
578
  cv2.imwrite(detection_path, det_vis)
579
 
 
580
  roi_mask_path = os.path.join(out_dir, f"roi_mask_{ts}.png")
581
  cv2.imwrite(roi_mask_path, (mask01 * 255).astype(np.uint8))
582
 
583
+ # ROI overlay (very clear)
584
+ mask255 = (mask01 * 255).astype(np.uint8)
585
+ mask3 = cv2.merge([mask255, mask255, mask255])
586
+ red = np.zeros_like(roi); red[:] = (0, 0, 255)
587
+ alpha = 0.55
588
+ tinted = cv2.addWeighted(roi, 1 - alpha, red, alpha, 0)
589
+ if mask255.any():
590
+ roi_overlay = np.where(mask3 > 0, tinted, roi)
591
+ cnts, _ = cv2.findContours(mask255, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE)
592
+ cv2.drawContours(roi_overlay, cnts, -1, (255, 255, 255), 2)
593
  else:
594
+ roi_overlay = anno_roi # already marked X
 
595
 
596
+ seg_full = image_cv.copy()
597
  seg_full[y1:y2, x1:x2] = roi_overlay
598
  segmentation_path = os.path.join(out_dir, f"segmentation_{ts}.png")
599
  cv2.imwrite(segmentation_path, seg_full)
600
 
601
+ segmentation_roi_path = os.path.join(out_dir, f"segmentation_roi_{ts}.png")
602
+ cv2.imwrite(segmentation_roi_path, roi_overlay)
603
+
604
  anno_full = image_cv.copy()
605
  anno_full[y1:y2, x1:x2] = anno_roi
606
  annotated_seg_path = os.path.join(out_dir, f"segmentation_annotated_{ts}.png")
 
617
  except Exception as e:
618
  logging.warning(f"Classification failed: {e}")
619
 
620
+ # Log end-of-seg summary
621
+ seg_summary = {
622
+ "seg_used": seg_debug.get("used"),
623
+ "seg_reason": seg_debug.get("reason"),
624
+ "positive_fraction": round(float(seg_debug.get("positive_fraction", 0.0)), 6),
625
+ "threshold": seg_debug.get("threshold", SEG_THRESH),
626
+ "segmentation_empty": segmentation_empty,
627
+ "exif_px_per_cm": round(px_per_cm, 3),
628
+ }
629
+ _log_kv("SEG_SUMMARY", seg_summary)
630
+
631
  return {
632
  "wound_type": wound_type,
633
  "length_cm": length_cm,
 
638
  "detection_confidence": float(results[0].boxes.conf[0].cpu().item())
639
  if getattr(results[0].boxes, "conf", None) is not None else 0.0,
640
  "detection_image_path": detection_path,
641
+ "segmentation_image_path": segmentation_path,
642
  "segmentation_annotated_path": annotated_seg_path,
643
+ "segmentation_roi_path": segmentation_roi_path,
644
+ "roi_mask_path": roi_mask_path,
645
+ "segmentation_empty": segmentation_empty,
646
+ "segmentation_debug": seg_debug,
647
  "original_image_path": original_path,
648
  }
649
  except Exception as e:
650
  logging.error(f"Visual analysis failed: {e}", exc_info=True)
651
  raise
652
 
653
+ # ---------- Knowledge base + reporting ----------
654
  def query_guidelines(self, query: str) -> str:
655
  try:
656
  vs = self.knowledge_base_cache.get("vector_store")