SmartHeal commited on
Commit
1ba8e97
·
verified ·
1 Parent(s): 68e317b

Update src/ai_processor.py

Browse files
Files changed (1) hide show
  1. src/ai_processor.py +264 -235
src/ai_processor.py CHANGED
@@ -1,22 +1,20 @@
1
  # smartheal_ai_processor.py
2
- # Full, functional module with an always-present @spaces.GPU function (if `spaces` is importable)
3
- # and robust CPU fallbacks to avoid crashes when GPU isn't actually available yet.
4
- # + Automatic calibration (px/cm) and measurement overlay on segmentation.
5
 
6
  import os
7
  import time
8
  import logging
9
  from datetime import datetime
10
- from typing import Optional, Dict, List, Tuple, Union
11
 
12
  import cv2
13
  import numpy as np
14
- from PIL import Image, TiffImagePlugin
 
15
 
16
- # =============== LOGGING ===============
17
  logging.basicConfig(level=logging.INFO, format="%(asctime)s - %(levelname)s - %(message)s")
18
 
19
- # =============== CONFIG ===============
20
  UPLOADS_DIR = "uploads"
21
  os.makedirs(UPLOADS_DIR, exist_ok=True)
22
 
@@ -24,15 +22,14 @@ HF_TOKEN = os.getenv("HF_TOKEN", None)
24
  YOLO_MODEL_PATH = "src/best.pt"
25
  SEG_MODEL_PATH = "src/segmentation_model.h5" # optional
26
  GUIDELINE_PDFS = ["src/eHealth in Wound Care.pdf", "src/IWGDF Guideline.pdf", "src/evaluation.pdf"]
27
- DATASET_ID = "SmartHeal/wound-image-uploads" # optional (requires HF_TOKEN)
28
- # Fallback px/cm if we cannot calibrate from EXIF
29
- DEFAULT_PIXELS_PER_CM = 38.0
30
 
31
- # =============== CACHES ===============
32
  models_cache: Dict[str, object] = {}
33
  knowledge_base_cache: Dict[str, object] = {}
34
 
35
- # =============== Optional imports (lazy) ===============
36
  def _import_ultralytics():
37
  from ultralytics import YOLO
38
  return YOLO
@@ -63,7 +60,7 @@ def _import_hf_hub():
63
  from huggingface_hub import HfApi, HfFolder
64
  return HfApi, HfFolder
65
 
66
- # =============== Spaces GPU function (always defined if `spaces` import works) ===============
67
  try:
68
  import spaces
69
 
@@ -75,39 +72,29 @@ try:
75
  image_pil: Image.Image,
76
  max_new_tokens: Optional[int] = None,
77
  ) -> str:
78
- """
79
- This function MUST exist at import time so Spaces Zero detects it.
80
- It is guarded internally so if anything fails (no GPU yet, model load error),
81
- it returns a warning and your pipeline will use the fallback report.
82
- """
83
  try:
84
  import torch
85
  from transformers import pipeline
86
 
87
- # Try to free cache; if no CUDA, this will raise and we return a warning.
88
  try:
89
  if hasattr(torch, "cuda") and torch.cuda.is_available():
90
  torch.cuda.empty_cache()
91
  except Exception:
92
  pass
93
 
94
- prompt = f"""
95
- You are a medical AI assistant. Analyze this wound image and patient data.
96
-
97
- Patient: {patient_info}
98
- Wound: {visual_results.get('wound_type', 'Unknown')} - {visual_results.get('length_cm', 0)}×{visual_results.get('breadth_cm', 0)} cm
99
-
100
- Provide a structured report with:
101
- 1. Clinical Summary
102
- 2. Treatment Recommendations
103
- 3. Risk Assessment
104
- 4. Monitoring Plan
105
- """.strip()
106
 
 
107
  pipe = pipeline(
108
  "image-text-to-text",
109
  model="google/medgemma-4b-it",
110
- torch_dtype=getattr(torch, "bfloat16", None),
111
  device_map="auto",
112
  token=HF_TOKEN,
113
  model_kwargs={"low_cpu_mem_usage": True, "use_cache": True},
@@ -129,7 +116,6 @@ Provide a structured report with:
129
  logging.info(f"✅ MedGemma finished in {time.time()-t0:.2f}s")
130
 
131
  if out and len(out) > 0:
132
- # Defensive extraction (different transformers versions)
133
  try:
134
  return out[0]["generated_text"][-1].get("content", "").strip() or "⚠️ Empty response"
135
  except Exception:
@@ -139,7 +125,6 @@ Provide a structured report with:
139
  logging.error(f"❌ MedGemma generation error: {e}")
140
  return "⚠️ GPU worker unavailable"
141
  except Exception:
142
- # If `spaces` cannot be imported locally, expose a CPU-safe stub with same signature.
143
  def generate_medgemma_report(
144
  patient_info: str,
145
  visual_results: Dict,
@@ -149,7 +134,7 @@ except Exception:
149
  ) -> str:
150
  return "⚠️ GPU not available"
151
 
152
- # =============== Model init (CPU-safe) ===============
153
  def load_yolo_model():
154
  YOLO = _import_ultralytics()
155
  return YOLO(YOLO_MODEL_PATH)
@@ -213,7 +198,6 @@ def initialize_cpu_models() -> None:
213
  def setup_knowledge_base() -> None:
214
  if "vector_store" in knowledge_base_cache:
215
  return
216
-
217
  docs: List = []
218
  try:
219
  PyPDFLoader = _import_langchain_pdf()
@@ -241,119 +225,167 @@ def setup_knowledge_base() -> None:
241
  knowledge_base_cache["vector_store"] = None
242
  logging.warning("KB disabled (no docs or embeddings).")
243
 
244
- # Initialize on import so app is ready
245
  initialize_cpu_models()
246
  setup_knowledge_base()
247
 
248
- # =============== Utility: EXIF-based auto calibration ===============
249
- def _rational_to_float(val) -> Optional[float]:
 
 
250
  try:
251
- if isinstance(val, TiffImagePlugin.IFDRational):
252
- return float(val.numerator) / float(val.denominator or 1)
253
- if isinstance(val, tuple) and len(val) == 2 and all(isinstance(x, (int, float)) for x in val):
254
- # (num, den)
255
- den = val[1] if val[1] else 1.0
256
- return float(val[0]) / float(den)
 
 
 
 
 
 
 
 
 
 
 
257
  return float(val)
258
  except Exception:
259
  return None
260
 
261
- def _auto_pixels_per_cm_from_exif(image_pil: Image.Image) -> Tuple[float, str]:
262
  """
263
- Try several EXIF / info sources to estimate pixels-per-cm.
264
- Return (px_per_cm, source_str).
265
- NOTE: Many phones set DPI metadata arbitrarily; we clamp to a sensible range and
266
- fall back to DEFAULT_PIXELS_PER_CM if values look bogus.
267
  """
268
- # 1) PIL .info["dpi"]
269
- try:
270
- dpi_info = image_pil.info.get("dpi")
271
- if isinstance(dpi_info, (tuple, list)) and len(dpi_info) >= 1:
272
- xdpi = float(dpi_info[0]) if dpi_info[0] else None
273
- if xdpi and 40 <= xdpi <= 1200:
274
- ppcm = xdpi / 2.54
275
- if 5 <= ppcm <= 500:
276
- return ppcm, "dpi_info"
277
- except Exception:
278
- pass
279
 
280
- # 2) EXIF XResolution (282), YResolution (283), ResolutionUnit (296) [2 = inch, 3 = cm]
281
- try:
282
- exif = image_pil.getexif()
283
- if exif:
284
- xres = _rational_to_float(exif.get(282)) # XResolution
285
- unit = int(exif.get(296) or 2) # default to inches
286
- if xres:
287
- if unit == 3: # per cm
288
- if 5 <= xres <= 500:
289
- return xres, "EXIF_XRes_cm"
290
- else: # per inch
291
- ppcm = xres / 2.54
292
- if 5 <= ppcm <= 500:
293
- return ppcm, "EXIF_XRes_in"
294
- except Exception:
295
- pass
296
-
297
- # 3) Heuristic fallback
298
- return DEFAULT_PIXELS_PER_CM, "default"
299
 
300
- # =============== Drawing helpers ===============
301
- def _draw_measurement_overlay(
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
302
  base_bgr: np.ndarray,
303
- rect_xywh: Tuple[int, int, int, int],
 
304
  length_cm: float,
305
  breadth_cm: float,
 
306
  ) -> np.ndarray:
307
  """
308
- Draw arrows for vertical (length) and horizontal (breadth) on top of base image.
309
- rect_xywh is relative to base_bgr.
310
  """
311
- x, y, w, h = rect_xywh
312
- img = base_bgr.copy()
313
-
314
- # Colors (BGR) and styling
315
- color = (255, 255, 255) # white
316
- shadow = (0, 0, 0) # black outline
317
- thickness = 2
318
- font = cv2.FONT_HERSHEY_SIMPLEX
319
-
320
- # --- Horizontal arrow (breadth) ---
321
- y_mid = y + h // 2
322
- x_left = x
323
- x_right = x + w
324
- # shadow line
325
- cv2.arrowedLine(img, (x_left, y_mid+1), (x_right, y_mid+1), shadow, thickness+2, cv2.LINE_AA, tipLength=0.02)
326
- # main line
327
- cv2.arrowedLine(img, (x_left, y_mid), (x_right, y_mid), color, thickness, cv2.LINE_AA, tipLength=0.02)
328
-
329
- # breadth label
330
- label_b = f"{breadth_cm:.2f} cm"
331
- (tw, th), _ = cv2.getTextSize(label_b, font, 0.7, 2)
332
- tx = x + (w - tw) // 2
333
- ty = y_mid - 8
334
- cv2.putText(img, label_b, (tx+1, ty+1), font, 0.7, shadow, 3, cv2.LINE_AA)
335
- cv2.putText(img, label_b, (tx, ty), font, 0.7, color, 2, cv2.LINE_AA)
336
-
337
- # --- Vertical arrow (length) ---
338
- x_mid = x + w // 2
339
- y_top = y
340
- y_bottom = y + h
341
- # shadow line
342
- cv2.arrowedLine(img, (x_mid+1, y_top), (x_mid+1, y_bottom), shadow, thickness+2, cv2.LINE_AA, tipLength=0.02)
343
- # main line
344
- cv2.arrowedLine(img, (x_mid, y_top), (x_mid, y_bottom), color, thickness, cv2.LINE_AA, tipLength=0.02)
345
-
346
- # length label
347
- label_l = f"{length_cm:.2f} cm"
348
- (tw2, th2), _ = cv2.getTextSize(label_l, font, 0.7, 2)
349
- tx2 = x_mid - (tw2 // 2)
350
- ty2 = y + th2 + 8
351
- cv2.putText(img, label_l, (tx2+1, ty2+1), font, 0.7, shadow, 3, cv2.LINE_AA)
352
- cv2.putText(img, label_l, (tx2, ty2), font, 0.7, color, 2, cv2.LINE_AA)
353
-
354
- return img
355
-
356
- # =============== AI PROCESSOR ===============
357
  class AIProcessor:
358
  def __init__(self):
359
  self.models_cache = models_cache
@@ -368,24 +400,24 @@ class AIProcessor:
368
  return out_dir
369
 
370
  def perform_visual_analysis(self, image_pil: Image.Image) -> Dict:
371
- """YOLO detect → (optional) Keras seg → (optional) HF classify → save visuals with measurement overlay."""
 
 
 
 
372
  try:
373
- image_rgb = image_pil.convert("RGB")
374
- image_cv = cv2.cvtColor(np.array(image_rgb), cv2.COLOR_RGB2BGR)
375
 
376
- det = self.models_cache.get("det")
377
- if det is None:
378
- raise RuntimeError("YOLO model not loaded")
379
 
380
- # ---------- Automatic calibration (px/cm) ----------
381
- px_per_cm, calib_src = _auto_pixels_per_cm_from_exif(image_rgb)
382
- # keep within reasonable range
383
- if not (5.0 <= px_per_cm <= 500.0):
384
- px_per_cm, calib_src = DEFAULT_PIXELS_PER_CM, "default"
385
- logging.info(f"Calibration: {px_per_cm:.2f} px/cm (source={calib_src})")
386
 
387
- # YOLO on CPU
388
- results = det.predict(image_cv, verbose=False, device="cpu")
389
  if not results or not getattr(results[0], "boxes", None) or len(results[0].boxes) == 0:
390
  raise ValueError("No wound could be detected.")
391
 
@@ -393,122 +425,122 @@ class AIProcessor:
393
  x1, y1, x2, y2 = [int(v) for v in box]
394
  x1, y1 = max(0, x1), max(0, y1)
395
  x2, y2 = min(image_cv.shape[1], x2), min(image_cv.shape[0], y2)
396
- detected_region_cv = image_cv[y1:y2, x1:x2]
 
 
397
 
398
- # Optional segmentation
399
  seg_model = self.models_cache.get("seg")
 
400
  length_cm = breadth_cm = surface_area_cm2 = 0.0
401
- seg_path = None
402
 
403
- rect_xywh_global = None # for overlay on full image if seg missing
404
-
405
- if seg_model is not None and detected_region_cv.size > 0:
406
  try:
407
- input_size = seg_model.input_shape[1:3]
408
- resized = cv2.resize(detected_region_cv, (input_size[1], input_size[0]))
409
- mask_pred = seg_model.predict(np.expand_dims(resized / 255.0, 0), verbose=0)[0]
410
- mask_np = (mask_pred[:, :, 0] > 0.5).astype(np.uint8)
411
-
412
- contours, _ = cv2.findContours(mask_np, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE)
413
- if contours:
414
- cnt = max(contours, key=cv2.contourArea)
415
- x, y, w, h = cv2.boundingRect(cnt)
416
-
417
- # Measurements using calibration
418
- length_cm = round(h / px_per_cm, 2)
419
- breadth_cm = round(w / px_per_cm, 2)
420
- surface_area_cm2 = round(cv2.contourArea(cnt) / (px_per_cm ** 2), 2)
421
-
422
- # Create segmentation overlay in the cropped region
423
- mask_resized = cv2.resize(
424
- mask_np * 255,
425
- (detected_region_cv.shape[1], detected_region_cv.shape[0]),
426
- interpolation=cv2.INTER_NEAREST,
427
- )
428
- overlay = detected_region_cv.copy()
429
- overlay[mask_resized > 127] = [0, 0, 255] # red overlay
430
- seg_vis = cv2.addWeighted(detected_region_cv, 0.7, overlay, 0.3, 0)
431
-
432
- # Draw measurement arrows on seg_vis
433
- # Map rect from mask space -> cropped image space
434
- scale_x = detected_region_cv.shape[1] / float(input_size[1])
435
- scale_y = detected_region_cv.shape[0] / float(input_size[0])
436
- rect_xywh_cropped = (
437
- int(x * scale_x),
438
- int(y * scale_y),
439
- int(w * scale_x),
440
- int(h * scale_y),
441
- )
442
- seg_vis_meas = _draw_measurement_overlay(seg_vis, rect_xywh_cropped, length_cm, breadth_cm)
443
-
444
- ts = datetime.now().strftime("%Y%m%d_%H%M%S")
445
- out_dir = self._ensure_analysis_dir()
446
- seg_path = os.path.join(out_dir, f"segmentation_{ts}.png")
447
- cv2.imwrite(seg_path, seg_vis_meas)
448
-
449
- # Also store rect in full-image coordinates (if ever needed)
450
- rect_xywh_global = (
451
- x1 + rect_xywh_cropped[0],
452
- y1 + rect_xywh_cropped[1],
453
- rect_xywh_cropped[2],
454
- rect_xywh_cropped[3],
455
- )
456
  except Exception as e:
457
- logging.warning(f"Segmentation skipped: {e}")
 
 
 
 
 
 
 
 
 
 
 
 
 
458
 
459
- # Optional classification
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
460
  wound_type = "Unknown"
461
  cls_pipe = self.models_cache.get("cls")
462
  if cls_pipe is not None:
463
  try:
464
- detected_image_pil = Image.fromarray(cv2.cvtColor(detected_region_cv, cv2.COLOR_BGR2RGB))
465
- preds = cls_pipe(detected_image_pil)
466
  if preds:
467
  wound_type = max(preds, key=lambda x: x.get("score", 0)).get("label", "Unknown")
468
  except Exception as e:
469
  logging.warning(f"Classification failed: {e}")
470
 
471
- # Save detection & original
472
- out_dir = self._ensure_analysis_dir()
473
- ts = datetime.now().strftime("%Y%m%d_%H%M%S")
474
- det_vis = image_cv.copy()
475
- cv2.rectangle(det_vis, (x1, y1), (x2, y2), (0, 255, 0), 2)
476
- det_path = os.path.join(out_dir, f"detection_{ts}.png")
477
- cv2.imwrite(det_path, det_vis)
478
-
479
- original_path = os.path.join(out_dir, f"original_{ts}.png")
480
- cv2.imwrite(original_path, image_cv)
481
-
482
  return {
483
  "wound_type": wound_type,
484
- "length_cm": float(length_cm),
485
- "breadth_cm": float(breadth_cm),
486
- "surface_area_cm2": float(surface_area_cm2),
487
- "calibration_px_per_cm": float(px_per_cm),
488
- "calibration_source": calib_src,
489
  "detection_confidence": float(results[0].boxes.conf[0].cpu().item())
490
- if getattr(results[0].boxes, "conf", None) is not None
491
- else 0.0,
492
- "detection_image_path": det_path,
493
- "segmentation_image_path": seg_path, # <-- now includes arrow overlay if seg succeeded
494
  "original_image_path": original_path,
495
  }
496
  except Exception as e:
497
- logging.error(f"Visual analysis failed: {e}")
498
  raise
499
 
 
500
  def query_guidelines(self, query: str) -> str:
501
- """Query the (optional) guideline knowledge base."""
502
  try:
503
  vs = self.knowledge_base_cache.get("vector_store")
504
  if not vs:
505
  return "Knowledge base is not available."
506
  try:
507
  retriever = vs.as_retriever(search_kwargs={"k": 5})
508
- docs = retriever.get_relevant_documents(query) # LC >= 0.2
509
  except Exception:
510
  retriever = vs.as_retriever(search_kwargs={"k": 5})
511
- docs = retriever.invoke(query) # older LC
512
  lines: List[str] = []
513
  for d in docs:
514
  src = (d.metadata or {}).get("source", "N/A")
@@ -530,12 +562,13 @@ class AIProcessor:
530
  - **Dimensions**: {visual_results.get('length_cm', 0)} cm × {visual_results.get('breadth_cm', 0)} cm
531
  - **Surface Area**: {visual_results.get('surface_area_cm2', 0)} cm²
532
  - **Detection Confidence**: {visual_results.get('detection_confidence', 0):.1%}
533
- - **Calibration**: {visual_results.get('calibration_px_per_cm', 0)} px/cm (source: {visual_results.get('calibration_source','n/a')})
534
 
535
  ## 📊 Analysis Images
536
  - **Original**: {visual_results.get('original_image_path', 'N/A')}
537
  - **Detection**: {visual_results.get('detection_image_path', 'N/A')}
538
- - **Segmentation (with measurements)**: {visual_results.get('segmentation_image_path', 'N/A')}
 
539
 
540
  ## 🎯 Clinical Summary
541
  Automated analysis provides quantitative measurements; verify via clinical examination.
@@ -563,7 +596,6 @@ Automated analysis provides quantitative measurements; verify via clinical exami
563
  image_pil: Image.Image,
564
  max_new_tokens: Optional[int] = None,
565
  ) -> str:
566
- """Use GPU path when available, fallback otherwise."""
567
  try:
568
  report = generate_medgemma_report(
569
  patient_info, visual_results, guideline_context, image_pil, max_new_tokens
@@ -577,7 +609,6 @@ Automated analysis provides quantitative measurements; verify via clinical exami
577
  return self._generate_fallback_report(patient_info, visual_results, guideline_context)
578
 
579
  def save_and_commit_image(self, image_pil: Image.Image) -> str:
580
- """Save locally and (optionally) upload to HF dataset."""
581
  try:
582
  os.makedirs(self.uploads_dir, exist_ok=True)
583
  ts = datetime.now().strftime("%Y%m%d_%H%M%S")
@@ -609,7 +640,6 @@ Automated analysis provides quantitative measurements; verify via clinical exami
609
  return ""
610
 
611
  def full_analysis_pipeline(self, image_pil: Image.Image, questionnaire_data: Dict) -> Dict:
612
- """End-to-end analysis."""
613
  try:
614
  saved_path = self.save_and_commit_image(image_pil)
615
  visual_results = self.perform_visual_analysis(image_pil)
@@ -656,7 +686,6 @@ Automated analysis provides quantitative measurements; verify via clinical exami
656
  }
657
 
658
  def analyze_wound(self, image, questionnaire_data: Dict) -> Dict:
659
- """Public entrypoint used by UI."""
660
  try:
661
  if isinstance(image, str):
662
  if not os.path.exists(image):
@@ -679,4 +708,4 @@ Automated analysis provides quantitative measurements; verify via clinical exami
679
  "report": f"Analysis initialization failed: {str(e)}",
680
  "saved_image_path": None,
681
  "guideline_context": "",
682
- }
 
1
  # smartheal_ai_processor.py
2
+ # Fully functional: auto-calibration from EXIF, mask-based measurements,
3
+ # and annotated overlay with arrows+labels.
 
4
 
5
  import os
6
  import time
7
  import logging
8
  from datetime import datetime
9
+ from typing import Optional, Dict, List, Tuple
10
 
11
  import cv2
12
  import numpy as np
13
+ from PIL import Image, ImageOps
14
+ from PIL.ExifTags import TAGS
15
 
 
16
  logging.basicConfig(level=logging.INFO, format="%(asctime)s - %(levelname)s - %(message)s")
17
 
 
18
  UPLOADS_DIR = "uploads"
19
  os.makedirs(UPLOADS_DIR, exist_ok=True)
20
 
 
22
  YOLO_MODEL_PATH = "src/best.pt"
23
  SEG_MODEL_PATH = "src/segmentation_model.h5" # optional
24
  GUIDELINE_PDFS = ["src/eHealth in Wound Care.pdf", "src/IWGDF Guideline.pdf", "src/evaluation.pdf"]
25
+ DATASET_ID = "SmartHeal/wound-image-uploads"
26
+ DEFAULT_PX_PER_CM = 38.0 # fallback when we cannot calibrate
27
+ PX_PER_CM_MIN, PX_PER_CM_MAX = 5.0, 1200.0 # sanity bounds
28
 
 
29
  models_cache: Dict[str, object] = {}
30
  knowledge_base_cache: Dict[str, object] = {}
31
 
32
+ # ---------- Lazy imports ----------
33
  def _import_ultralytics():
34
  from ultralytics import YOLO
35
  return YOLO
 
60
  from huggingface_hub import HfApi, HfFolder
61
  return HfApi, HfFolder
62
 
63
+ # ---------- Spaces GPU function (always defined if `spaces` import works) ----------
64
  try:
65
  import spaces
66
 
 
72
  image_pil: Image.Image,
73
  max_new_tokens: Optional[int] = None,
74
  ) -> str:
 
 
 
 
 
75
  try:
76
  import torch
77
  from transformers import pipeline
78
 
 
79
  try:
80
  if hasattr(torch, "cuda") and torch.cuda.is_available():
81
  torch.cuda.empty_cache()
82
  except Exception:
83
  pass
84
 
85
+ prompt = (
86
+ "You are a medical AI assistant. Analyze this wound image and patient data.\n\n"
87
+ f"Patient: {patient_info}\n"
88
+ f"Wound: {visual_results.get('wound_type', 'Unknown')} - "
89
+ f"{visual_results.get('length_cm', 0)}×{visual_results.get('breadth_cm', 0)} cm\n\n"
90
+ "Provide a structured report with:\n"
91
+ "1. Clinical Summary\n2. Treatment Recommendations\n3. Risk Assessment\n4. Monitoring Plan\n"
92
+ )
 
 
 
 
93
 
94
+ from transformers import pipeline
95
  pipe = pipeline(
96
  "image-text-to-text",
97
  model="google/medgemma-4b-it",
 
98
  device_map="auto",
99
  token=HF_TOKEN,
100
  model_kwargs={"low_cpu_mem_usage": True, "use_cache": True},
 
116
  logging.info(f"✅ MedGemma finished in {time.time()-t0:.2f}s")
117
 
118
  if out and len(out) > 0:
 
119
  try:
120
  return out[0]["generated_text"][-1].get("content", "").strip() or "⚠️ Empty response"
121
  except Exception:
 
125
  logging.error(f"❌ MedGemma generation error: {e}")
126
  return "⚠️ GPU worker unavailable"
127
  except Exception:
 
128
  def generate_medgemma_report(
129
  patient_info: str,
130
  visual_results: Dict,
 
134
  ) -> str:
135
  return "⚠️ GPU not available"
136
 
137
+ # ---------- Initialize CPU models ----------
138
  def load_yolo_model():
139
  YOLO = _import_ultralytics()
140
  return YOLO(YOLO_MODEL_PATH)
 
198
  def setup_knowledge_base() -> None:
199
  if "vector_store" in knowledge_base_cache:
200
  return
 
201
  docs: List = []
202
  try:
203
  PyPDFLoader = _import_langchain_pdf()
 
225
  knowledge_base_cache["vector_store"] = None
226
  logging.warning("KB disabled (no docs or embeddings).")
227
 
 
228
  initialize_cpu_models()
229
  setup_knowledge_base()
230
 
231
+ # ---------- Calibration helpers ----------
232
+ def _exif_to_dict(pil_img: Image.Image) -> Dict[str, object]:
233
+ """Best-effort EXIF parse from PIL image."""
234
+ out = {}
235
  try:
236
+ exif = pil_img.getexif()
237
+ if not exif:
238
+ return out
239
+ for k, v in exif.items():
240
+ tag = TAGS.get(k, k)
241
+ out[tag] = v
242
+ except Exception:
243
+ pass
244
+ return out
245
+
246
+ def _to_float(val) -> Optional[float]:
247
+ try:
248
+ if val is None:
249
+ return None
250
+ if isinstance(val, tuple) and len(val) == 2:
251
+ num, den = float(val[0]), float(val[1]) if float(val[1]) != 0 else 1.0
252
+ return num / den
253
  return float(val)
254
  except Exception:
255
  return None
256
 
257
+ def _estimate_sensor_width_mm(f_mm: Optional[float], f35: Optional[float]) -> Optional[float]:
258
  """
259
+ Use 35mm equivalent if present: sensor_width = 36 * f_mm / f35.
 
 
 
260
  """
261
+ if f_mm and f35 and f35 > 0:
262
+ return 36.0 * f_mm / f35
263
+ return None
 
 
 
 
 
 
 
 
264
 
265
+ def estimate_px_per_cm_from_exif(pil_img: Image.Image, default_px_per_cm: float = DEFAULT_PX_PER_CM) -> Tuple[float, Dict]:
266
+ """
267
+ Returns (px_per_cm, meta) using EXIF when available.
268
+ Formula: field_width_mm = sensor_width_mm * distance_mm / focal_mm
269
+ px_per_cm = image_width_px / (field_width_mm / 10)
270
+ """
271
+ meta = {"used": "default", "f_mm": None, "f35": None, "sensor_w_mm": None, "distance_m": None}
 
 
 
 
 
 
 
 
 
 
 
 
272
 
273
+ try:
274
+ exif = _exif_to_dict(pil_img)
275
+ f_mm = _to_float(exif.get("FocalLength"))
276
+ f35 = _to_float(exif.get("FocalLengthIn35mmFilm") or exif.get("FocalLengthIn35mm"))
277
+ subj_dist_m = _to_float(exif.get("SubjectDistance"))
278
+ sensor_w_mm = _estimate_sensor_width_mm(f_mm, f35)
279
+
280
+ meta.update({"f_mm": f_mm, "f35": f35, "sensor_w_mm": sensor_w_mm, "distance_m": subj_dist_m})
281
+
282
+ if f_mm and sensor_w_mm and subj_dist_m and subj_dist_m > 0:
283
+ w_px = pil_img.width
284
+ field_w_mm = sensor_w_mm * (subj_dist_m * 1000.0) / f_mm
285
+ field_w_cm = field_w_mm / 10.0
286
+ px_per_cm = w_px / max(field_w_cm, 1e-6)
287
+
288
+ # sanity clamp
289
+ px_per_cm = float(np.clip(px_per_cm, PX_PER_CM_MIN, PX_PER_CM_MAX))
290
+ meta["used"] = "exif"
291
+ return px_per_cm, meta
292
+
293
+ # If EXIF partial but not enough to solve, keep default
294
+ return float(default_px_per_cm), meta
295
+ except Exception as e:
296
+ logging.warning(f"EXIF calibration failed: {e}")
297
+ return float(default_px_per_cm), meta
298
+
299
+ # ---------- Mask processing + measurement ----------
300
+ def largest_component_mask(binary: np.ndarray, min_area_px: int = 50) -> np.ndarray:
301
+ """Keep only the largest connected component in a binary mask."""
302
+ num, labels, stats, _ = cv2.connectedComponentsWithStats(binary.astype(np.uint8), connectivity=8)
303
+ if num <= 1:
304
+ return binary
305
+ # stats[:, cv2.CC_STAT_AREA]; skip label 0 (background)
306
+ areas = stats[1:, cv2.CC_STAT_AREA]
307
+ largest_idx = 1 + int(np.argmax(areas))
308
+ if areas.max() < min_area_px:
309
+ return binary
310
+ return (labels == largest_idx).astype(np.uint8)
311
+
312
+ def measure_min_area_rect(mask: np.ndarray, px_per_cm: float) -> Tuple[float, float, Tuple]:
313
+ """
314
+ Compute oriented min-area rectangle on mask.
315
+ Returns (length_cm, breadth_cm, (box_points, center)).
316
+ """
317
+ contours, _ = cv2.findContours(mask, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE)
318
+ if not contours:
319
+ return 0.0, 0.0, (None, None)
320
+ cnt = max(contours, key=cv2.contourArea)
321
+ rect = cv2.minAreaRect(cnt) # (center(x,y), (w,h), angle)
322
+ (w_px, h_px) = rect[1]
323
+ length_px, breadth_px = (max(w_px, h_px), min(w_px, h_px))
324
+ length_cm = round(length_px / max(px_per_cm, 1e-6), 2)
325
+ breadth_cm = round(breadth_px / max(px_per_cm, 1e-6), 2)
326
+ box = cv2.boxPoints(rect).astype(int)
327
+ return length_cm, breadth_cm, (box, rect[0])
328
+
329
+ def count_area_cm2(mask: np.ndarray, px_per_cm: float) -> float:
330
+ px_count = float(mask.astype(bool).sum())
331
+ return round(px_count / (max(px_per_cm, 1e-6) ** 2), 2)
332
+
333
+ def draw_measurement_overlay(
334
  base_bgr: np.ndarray,
335
+ mask: np.ndarray,
336
+ rect_box: np.ndarray,
337
  length_cm: float,
338
  breadth_cm: float,
339
+ thickness: int = 2
340
  ) -> np.ndarray:
341
  """
342
+ Draw semi-transparent mask + measurement arrows along the rectangle sides with labels.
 
343
  """
344
+ overlay = base_bgr.copy()
345
+ # red mask overlay
346
+ colored = np.zeros_like(base_bgr)
347
+ colored[:, :] = (0, 0, 255)
348
+ mask3 = np.dstack([mask * 255] * 3)
349
+ overlay = cv2.addWeighted(overlay, 1.0, (colored & mask3), 0.3, 0)
350
+
351
+ # draw rectangle
352
+ cv2.polylines(overlay, [rect_box], True, (255, 255, 255), thickness)
353
+
354
+ # pick the long side & short side arrows
355
+ # box points are in order; connect midpoints of opposite edges
356
+ pts = rect_box.reshape(-1, 2)
357
+ def midpoint(a, b): return ((a[0] + b[0]) // 2, (a[1] + b[1]) // 2)
358
+
359
+ # edges: (0-1,1-2,2-3,3-0)
360
+ mids = [midpoint(pts[i], pts[(i+1) % 4]) for i in range(4)]
361
+ # vector lengths
362
+ e_lens = [np.linalg.norm(pts[i] - pts[(i+1) % 4]) for i in range(4)]
363
+ long_pair = (0, 2) if e_lens[0] + e_lens[2] >= e_lens[1] + e_lens[3] else (1, 3)
364
+ short_pair = (1, 3) if long_pair == (0, 2) else (0, 2)
365
+
366
+ # arrowed lines (white with black shadow)
367
+ def draw_arrow(p1, p2):
368
+ cv2.arrowedLine(overlay, p1, p2, (0, 0, 0), thickness + 2, tipLength=0.05)
369
+ cv2.arrowedLine(overlay, p2, p1, (0, 0, 0), thickness + 2, tipLength=0.05)
370
+ cv2.arrowedLine(overlay, p1, p2, (255, 255, 255), thickness, tipLength=0.05)
371
+ cv2.arrowedLine(overlay, p2, p1, (255, 255, 255), thickness, tipLength=0.05)
372
+
373
+ draw_arrow(mids[long_pair[0]], mids[long_pair[1]])
374
+ draw_arrow(mids[short_pair[0]], mids[short_pair[1]])
375
+
376
+ # labels near the midpoints
377
+ def put_label(text, org):
378
+ cv2.putText(overlay, text, (org[0] + 4, org[1] - 4),
379
+ cv2.FONT_HERSHEY_SIMPLEX, 0.7, (0, 0, 0), 4, cv2.LINE_AA)
380
+ cv2.putText(overlay, text, (org[0] + 4, org[1] - 4),
381
+ cv2.FONT_HERSHEY_SIMPLEX, 0.7, (255, 255, 255), 2, cv2.LINE_AA)
382
+
383
+ put_label(f"{length_cm:.2f} cm", mids[long_pair[0]])
384
+ put_label(f"{breadth_cm:.2f} cm", mids[short_pair[0]])
385
+
386
+ return overlay
387
+
388
+ # ---------- AI PROCESSOR ----------
 
389
  class AIProcessor:
390
  def __init__(self):
391
  self.models_cache = models_cache
 
400
  return out_dir
401
 
402
  def perform_visual_analysis(self, image_pil: Image.Image) -> Dict:
403
+ """
404
+ YOLO detect → segmentation → largest-component mask →
405
+ minAreaRect measurement (cm) using px/cm from EXIF when available →
406
+ save original, detection overlay, segmentation overlay, and annotated overlay.
407
+ """
408
  try:
409
+ # --- Auto calibration from EXIF (before any conversion that might drop EXIF) ---
410
+ px_per_cm, exif_meta = estimate_px_per_cm_from_exif(image_pil, DEFAULT_PX_PER_CM)
411
 
412
+ # Convert PIL to OpenCV BGR
413
+ image_cv = cv2.cvtColor(np.array(image_pil.convert("RGB")), cv2.COLOR_RGB2BGR)
 
414
 
415
+ # --- Detection (YOLO) ---
416
+ det_model = self.models_cache.get("det")
417
+ if det_model is None:
418
+ raise RuntimeError("YOLO model not loaded")
 
 
419
 
420
+ results = det_model.predict(image_cv, verbose=False, device="cpu")
 
421
  if not results or not getattr(results[0], "boxes", None) or len(results[0].boxes) == 0:
422
  raise ValueError("No wound could be detected.")
423
 
 
425
  x1, y1, x2, y2 = [int(v) for v in box]
426
  x1, y1 = max(0, x1), max(0, y1)
427
  x2, y2 = min(image_cv.shape[1], x2), min(image_cv.shape[0], y2)
428
+ roi = image_cv[y1:y2, x1:x2].copy()
429
+ if roi.size == 0:
430
+ raise ValueError("Detected ROI is empty.")
431
 
432
+ # --- Segmentation (optional but recommended) ---
433
  seg_model = self.models_cache.get("seg")
434
+ mask_resized = None
435
  length_cm = breadth_cm = surface_area_cm2 = 0.0
 
436
 
437
+ if seg_model is not None:
 
 
438
  try:
439
+ H, W = seg_model.input_shape[1:3]
440
+ resized = cv2.resize(roi, (W, H))
441
+ pred = seg_model.predict(np.expand_dims(resized / 255.0, 0), verbose=0)[0]
442
+ raw_mask = pred[:, :, 0]
443
+
444
+ # binarize + clean
445
+ mask = (raw_mask > 0.5).astype(np.uint8)
446
+ mask = cv2.morphologyEx(mask, cv2.MORPH_OPEN, np.ones((3, 3), np.uint8), iterations=1)
447
+ mask = cv2.morphologyEx(mask, cv2.MORPH_CLOSE, np.ones((5, 5), np.uint8), iterations=1)
448
+ mask = largest_component_mask(mask)
449
+
450
+ # bring back to ROI size
451
+ mask_resized = cv2.resize(mask * 255, (roi.shape[1], roi.shape[0]), interpolation=cv2.INTER_NEAREST)
452
+ bin_mask_roi = (mask_resized > 127).astype(np.uint8)
453
+
454
+ # measure with oriented rectangle (in ROI pixels)
455
+ length_cm, breadth_cm, (box_pts, _) = measure_min_area_rect(bin_mask_roi, px_per_cm)
456
+ surface_area_cm2 = count_area_cm2(bin_mask_roi, px_per_cm)
457
+
458
+ # draw overlay with arrows/labels on ROI
459
+ anno_roi = draw_measurement_overlay(roi, bin_mask_roi, box_pts, length_cm, breadth_cm)
460
+
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
461
  except Exception as e:
462
+ logging.warning(f"Segmentation failed/partial: {e}")
463
+ mask_resized = None
464
+ anno_roi = roi.copy()
465
+ else:
466
+ # No segmentation → just draw detection box and keep defaults
467
+ anno_roi = roi.copy()
468
+
469
+ # --- Save all visualizations ---
470
+ out_dir = self._ensure_analysis_dir()
471
+ ts = datetime.now().strftime("%Y%m%d_%H%M%S")
472
+
473
+ # Original
474
+ original_path = os.path.join(out_dir, f"original_{ts}.png")
475
+ cv2.imwrite(original_path, image_cv)
476
 
477
+ # Detection overlay (rectangle on full image)
478
+ det_vis = image_cv.copy()
479
+ cv2.rectangle(det_vis, (x1, y1), (x2, y2), (0, 255, 0), 2)
480
+ detection_path = os.path.join(out_dir, f"detection_{ts}.png")
481
+ cv2.imwrite(detection_path, det_vis)
482
+
483
+ # Segmentation overlay (ROI pasted back into full frame for consistent display)
484
+ segmentation_path = None
485
+ annotated_seg_path = None
486
+ if mask_resized is not None:
487
+ # compose overlay on full image for "segmentation" view
488
+ seg_full = image_cv.copy()
489
+ roi_overlay = roi.copy()
490
+ red = np.zeros_like(roi_overlay); red[:] = (0, 0, 255)
491
+ alpha = 0.3
492
+ roi_overlay = cv2.addWeighted(roi_overlay, 1.0, red, alpha, 0, mask=mask_resized)
493
+ seg_full[y1:y2, x1:x2] = roi_overlay
494
+ segmentation_path = os.path.join(out_dir, f"segmentation_{ts}.png")
495
+ cv2.imwrite(segmentation_path, seg_full)
496
+
497
+ # annotated overlay (arrows+labels) placed back into full image
498
+ anno_full = image_cv.copy()
499
+ anno_full[y1:y2, x1:x2] = anno_roi
500
+ annotated_seg_path = os.path.join(out_dir, f"segmentation_annotated_{ts}.png")
501
+ cv2.imwrite(annotated_seg_path, anno_full)
502
+
503
+ # --- Optional classification ---
504
  wound_type = "Unknown"
505
  cls_pipe = self.models_cache.get("cls")
506
  if cls_pipe is not None:
507
  try:
508
+ preds = cls_pipe(Image.fromarray(cv2.cvtColor(roi, cv2.COLOR_BGR2RGB)))
 
509
  if preds:
510
  wound_type = max(preds, key=lambda x: x.get("score", 0)).get("label", "Unknown")
511
  except Exception as e:
512
  logging.warning(f"Classification failed: {e}")
513
 
 
 
 
 
 
 
 
 
 
 
 
514
  return {
515
  "wound_type": wound_type,
516
+ "length_cm": length_cm,
517
+ "breadth_cm": breadth_cm,
518
+ "surface_area_cm2": surface_area_cm2,
519
+ "px_per_cm": round(px_per_cm, 2),
520
+ "calibration_meta": exif_meta, # for debugging/auditing
521
  "detection_confidence": float(results[0].boxes.conf[0].cpu().item())
522
+ if getattr(results[0].boxes, "conf", None) is not None else 0.0,
523
+ "detection_image_path": detection_path,
524
+ "segmentation_image_path": segmentation_path,
525
+ "segmentation_annotated_path": annotated_seg_path,
526
  "original_image_path": original_path,
527
  }
528
  except Exception as e:
529
+ logging.error(f"Visual analysis failed: {e}", exc_info=True)
530
  raise
531
 
532
+ # ---------- Knowledge base and reporting stay unchanged ----------
533
  def query_guidelines(self, query: str) -> str:
 
534
  try:
535
  vs = self.knowledge_base_cache.get("vector_store")
536
  if not vs:
537
  return "Knowledge base is not available."
538
  try:
539
  retriever = vs.as_retriever(search_kwargs={"k": 5})
540
+ docs = retriever.get_relevant_documents(query)
541
  except Exception:
542
  retriever = vs.as_retriever(search_kwargs={"k": 5})
543
+ docs = retriever.invoke(query)
544
  lines: List[str] = []
545
  for d in docs:
546
  src = (d.metadata or {}).get("source", "N/A")
 
562
  - **Dimensions**: {visual_results.get('length_cm', 0)} cm × {visual_results.get('breadth_cm', 0)} cm
563
  - **Surface Area**: {visual_results.get('surface_area_cm2', 0)} cm²
564
  - **Detection Confidence**: {visual_results.get('detection_confidence', 0):.1%}
565
+ - **Calibration**: {visual_results.get('px_per_cm','?')} px/cm ({(visual_results.get('calibration_meta') or {}).get('used','default')})
566
 
567
  ## 📊 Analysis Images
568
  - **Original**: {visual_results.get('original_image_path', 'N/A')}
569
  - **Detection**: {visual_results.get('detection_image_path', 'N/A')}
570
+ - **Segmentation**: {visual_results.get('segmentation_image_path', 'N/A')}
571
+ - **Annotated**: {visual_results.get('segmentation_annotated_path', 'N/A')}
572
 
573
  ## 🎯 Clinical Summary
574
  Automated analysis provides quantitative measurements; verify via clinical examination.
 
596
  image_pil: Image.Image,
597
  max_new_tokens: Optional[int] = None,
598
  ) -> str:
 
599
  try:
600
  report = generate_medgemma_report(
601
  patient_info, visual_results, guideline_context, image_pil, max_new_tokens
 
609
  return self._generate_fallback_report(patient_info, visual_results, guideline_context)
610
 
611
  def save_and_commit_image(self, image_pil: Image.Image) -> str:
 
612
  try:
613
  os.makedirs(self.uploads_dir, exist_ok=True)
614
  ts = datetime.now().strftime("%Y%m%d_%H%M%S")
 
640
  return ""
641
 
642
  def full_analysis_pipeline(self, image_pil: Image.Image, questionnaire_data: Dict) -> Dict:
 
643
  try:
644
  saved_path = self.save_and_commit_image(image_pil)
645
  visual_results = self.perform_visual_analysis(image_pil)
 
686
  }
687
 
688
  def analyze_wound(self, image, questionnaire_data: Dict) -> Dict:
 
689
  try:
690
  if isinstance(image, str):
691
  if not os.path.exists(image):
 
708
  "report": f"Analysis initialization failed: {str(e)}",
709
  "saved_image_path": None,
710
  "guideline_context": "",
711
+ }