SmartHeal commited on
Commit
68e317b
·
verified ·
1 Parent(s): 990c773

Update src/ai_processor.py

Browse files
Files changed (1) hide show
  1. src/ai_processor.py +164 -19
src/ai_processor.py CHANGED
@@ -1,16 +1,17 @@
1
  # smartheal_ai_processor.py
2
  # Full, functional module with an always-present @spaces.GPU function (if `spaces` is importable)
3
  # and robust CPU fallbacks to avoid crashes when GPU isn't actually available yet.
 
4
 
5
  import os
6
  import time
7
  import logging
8
  from datetime import datetime
9
- from typing import Optional, Dict, List
10
 
11
  import cv2
12
  import numpy as np
13
- from PIL import Image
14
 
15
  # =============== LOGGING ===============
16
  logging.basicConfig(level=logging.INFO, format="%(asctime)s - %(levelname)s - %(message)s")
@@ -24,7 +25,8 @@ YOLO_MODEL_PATH = "src/best.pt"
24
  SEG_MODEL_PATH = "src/segmentation_model.h5" # optional
25
  GUIDELINE_PDFS = ["src/eHealth in Wound Care.pdf", "src/IWGDF Guideline.pdf", "src/evaluation.pdf"]
26
  DATASET_ID = "SmartHeal/wound-image-uploads" # optional (requires HF_TOKEN)
27
- PIXELS_PER_CM = 38
 
28
 
29
  # =============== CACHES ===============
30
  models_cache: Dict[str, object] = {}
@@ -243,12 +245,119 @@ def setup_knowledge_base() -> None:
243
  initialize_cpu_models()
244
  setup_knowledge_base()
245
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
246
  # =============== AI PROCESSOR ===============
247
  class AIProcessor:
248
  def __init__(self):
249
  self.models_cache = models_cache
250
  self.knowledge_base_cache = knowledge_base_cache
251
- self.px_per_cm = PIXELS_PER_CM
252
  self.uploads_dir = UPLOADS_DIR
253
  self.dataset_id = DATASET_ID
254
  self.hf_token = HF_TOKEN
@@ -259,14 +368,22 @@ class AIProcessor:
259
  return out_dir
260
 
261
  def perform_visual_analysis(self, image_pil: Image.Image) -> Dict:
262
- """YOLO detect → (optional) Keras seg → (optional) HF classify → save visuals."""
263
  try:
264
- image_cv = cv2.cvtColor(np.array(image_pil.convert("RGB")), cv2.COLOR_RGB2BGR)
 
265
 
266
  det = self.models_cache.get("det")
267
  if det is None:
268
  raise RuntimeError("YOLO model not loaded")
269
 
 
 
 
 
 
 
 
270
  # YOLO on CPU
271
  results = det.predict(image_cv, verbose=False, device="cpu")
272
  if not results or not getattr(results[0], "boxes", None) or len(results[0].boxes) == 0:
@@ -280,8 +397,11 @@ class AIProcessor:
280
 
281
  # Optional segmentation
282
  seg_model = self.models_cache.get("seg")
283
- length = breadth = area = 0.0
284
  seg_path = None
 
 
 
285
  if seg_model is not None and detected_region_cv.size > 0:
286
  try:
287
  input_size = seg_model.input_shape[1:3]
@@ -293,24 +413,46 @@ class AIProcessor:
293
  if contours:
294
  cnt = max(contours, key=cv2.contourArea)
295
  x, y, w, h = cv2.boundingRect(cnt)
296
- length = round(h / self.px_per_cm, 2)
297
- breadth = round(w / self.px_per_cm, 2)
298
- area = round(cv2.contourArea(cnt) / (self.px_per_cm ** 2), 2)
299
 
300
- # overlay visualization
 
 
 
 
 
301
  mask_resized = cv2.resize(
302
  mask_np * 255,
303
  (detected_region_cv.shape[1], detected_region_cv.shape[0]),
304
  interpolation=cv2.INTER_NEAREST,
305
  )
306
  overlay = detected_region_cv.copy()
307
- overlay[mask_resized > 127] = [0, 0, 255]
308
  seg_vis = cv2.addWeighted(detected_region_cv, 0.7, overlay, 0.3, 0)
309
 
 
 
 
 
 
 
 
 
 
 
 
 
310
  ts = datetime.now().strftime("%Y%m%d_%H%M%S")
311
  out_dir = self._ensure_analysis_dir()
312
  seg_path = os.path.join(out_dir, f"segmentation_{ts}.png")
313
- cv2.imwrite(seg_path, seg_vis)
 
 
 
 
 
 
 
 
314
  except Exception as e:
315
  logging.warning(f"Segmentation skipped: {e}")
316
 
@@ -339,14 +481,16 @@ class AIProcessor:
339
 
340
  return {
341
  "wound_type": wound_type,
342
- "length_cm": length,
343
- "breadth_cm": breadth,
344
- "surface_area_cm2": area,
 
 
345
  "detection_confidence": float(results[0].boxes.conf[0].cpu().item())
346
  if getattr(results[0].boxes, "conf", None) is not None
347
  else 0.0,
348
  "detection_image_path": det_path,
349
- "segmentation_image_path": seg_path,
350
  "original_image_path": original_path,
351
  }
352
  except Exception as e:
@@ -386,11 +530,12 @@ class AIProcessor:
386
  - **Dimensions**: {visual_results.get('length_cm', 0)} cm × {visual_results.get('breadth_cm', 0)} cm
387
  - **Surface Area**: {visual_results.get('surface_area_cm2', 0)} cm²
388
  - **Detection Confidence**: {visual_results.get('detection_confidence', 0):.1%}
 
389
 
390
  ## 📊 Analysis Images
391
  - **Original**: {visual_results.get('original_image_path', 'N/A')}
392
  - **Detection**: {visual_results.get('detection_image_path', 'N/A')}
393
- - **Segmentation**: {visual_results.get('segmentation_image_path', 'N/A')}
394
 
395
  ## 🎯 Clinical Summary
396
  Automated analysis provides quantitative measurements; verify via clinical examination.
@@ -534,4 +679,4 @@ Automated analysis provides quantitative measurements; verify via clinical exami
534
  "report": f"Analysis initialization failed: {str(e)}",
535
  "saved_image_path": None,
536
  "guideline_context": "",
537
- }
 
1
  # smartheal_ai_processor.py
2
  # Full, functional module with an always-present @spaces.GPU function (if `spaces` is importable)
3
  # and robust CPU fallbacks to avoid crashes when GPU isn't actually available yet.
4
+ # + Automatic calibration (px/cm) and measurement overlay on segmentation.
5
 
6
  import os
7
  import time
8
  import logging
9
  from datetime import datetime
10
+ from typing import Optional, Dict, List, Tuple, Union
11
 
12
  import cv2
13
  import numpy as np
14
+ from PIL import Image, TiffImagePlugin
15
 
16
  # =============== LOGGING ===============
17
  logging.basicConfig(level=logging.INFO, format="%(asctime)s - %(levelname)s - %(message)s")
 
25
  SEG_MODEL_PATH = "src/segmentation_model.h5" # optional
26
  GUIDELINE_PDFS = ["src/eHealth in Wound Care.pdf", "src/IWGDF Guideline.pdf", "src/evaluation.pdf"]
27
  DATASET_ID = "SmartHeal/wound-image-uploads" # optional (requires HF_TOKEN)
28
+ # Fallback px/cm if we cannot calibrate from EXIF
29
+ DEFAULT_PIXELS_PER_CM = 38.0
30
 
31
  # =============== CACHES ===============
32
  models_cache: Dict[str, object] = {}
 
245
  initialize_cpu_models()
246
  setup_knowledge_base()
247
 
248
+ # =============== Utility: EXIF-based auto calibration ===============
249
+ def _rational_to_float(val) -> Optional[float]:
250
+ try:
251
+ if isinstance(val, TiffImagePlugin.IFDRational):
252
+ return float(val.numerator) / float(val.denominator or 1)
253
+ if isinstance(val, tuple) and len(val) == 2 and all(isinstance(x, (int, float)) for x in val):
254
+ # (num, den)
255
+ den = val[1] if val[1] else 1.0
256
+ return float(val[0]) / float(den)
257
+ return float(val)
258
+ except Exception:
259
+ return None
260
+
261
+ def _auto_pixels_per_cm_from_exif(image_pil: Image.Image) -> Tuple[float, str]:
262
+ """
263
+ Try several EXIF / info sources to estimate pixels-per-cm.
264
+ Return (px_per_cm, source_str).
265
+ NOTE: Many phones set DPI metadata arbitrarily; we clamp to a sensible range and
266
+ fall back to DEFAULT_PIXELS_PER_CM if values look bogus.
267
+ """
268
+ # 1) PIL .info["dpi"]
269
+ try:
270
+ dpi_info = image_pil.info.get("dpi")
271
+ if isinstance(dpi_info, (tuple, list)) and len(dpi_info) >= 1:
272
+ xdpi = float(dpi_info[0]) if dpi_info[0] else None
273
+ if xdpi and 40 <= xdpi <= 1200:
274
+ ppcm = xdpi / 2.54
275
+ if 5 <= ppcm <= 500:
276
+ return ppcm, "dpi_info"
277
+ except Exception:
278
+ pass
279
+
280
+ # 2) EXIF XResolution (282), YResolution (283), ResolutionUnit (296) [2 = inch, 3 = cm]
281
+ try:
282
+ exif = image_pil.getexif()
283
+ if exif:
284
+ xres = _rational_to_float(exif.get(282)) # XResolution
285
+ unit = int(exif.get(296) or 2) # default to inches
286
+ if xres:
287
+ if unit == 3: # per cm
288
+ if 5 <= xres <= 500:
289
+ return xres, "EXIF_XRes_cm"
290
+ else: # per inch
291
+ ppcm = xres / 2.54
292
+ if 5 <= ppcm <= 500:
293
+ return ppcm, "EXIF_XRes_in"
294
+ except Exception:
295
+ pass
296
+
297
+ # 3) Heuristic fallback
298
+ return DEFAULT_PIXELS_PER_CM, "default"
299
+
300
+ # =============== Drawing helpers ===============
301
+ def _draw_measurement_overlay(
302
+ base_bgr: np.ndarray,
303
+ rect_xywh: Tuple[int, int, int, int],
304
+ length_cm: float,
305
+ breadth_cm: float,
306
+ ) -> np.ndarray:
307
+ """
308
+ Draw arrows for vertical (length) and horizontal (breadth) on top of base image.
309
+ rect_xywh is relative to base_bgr.
310
+ """
311
+ x, y, w, h = rect_xywh
312
+ img = base_bgr.copy()
313
+
314
+ # Colors (BGR) and styling
315
+ color = (255, 255, 255) # white
316
+ shadow = (0, 0, 0) # black outline
317
+ thickness = 2
318
+ font = cv2.FONT_HERSHEY_SIMPLEX
319
+
320
+ # --- Horizontal arrow (breadth) ---
321
+ y_mid = y + h // 2
322
+ x_left = x
323
+ x_right = x + w
324
+ # shadow line
325
+ cv2.arrowedLine(img, (x_left, y_mid+1), (x_right, y_mid+1), shadow, thickness+2, cv2.LINE_AA, tipLength=0.02)
326
+ # main line
327
+ cv2.arrowedLine(img, (x_left, y_mid), (x_right, y_mid), color, thickness, cv2.LINE_AA, tipLength=0.02)
328
+
329
+ # breadth label
330
+ label_b = f"{breadth_cm:.2f} cm"
331
+ (tw, th), _ = cv2.getTextSize(label_b, font, 0.7, 2)
332
+ tx = x + (w - tw) // 2
333
+ ty = y_mid - 8
334
+ cv2.putText(img, label_b, (tx+1, ty+1), font, 0.7, shadow, 3, cv2.LINE_AA)
335
+ cv2.putText(img, label_b, (tx, ty), font, 0.7, color, 2, cv2.LINE_AA)
336
+
337
+ # --- Vertical arrow (length) ---
338
+ x_mid = x + w // 2
339
+ y_top = y
340
+ y_bottom = y + h
341
+ # shadow line
342
+ cv2.arrowedLine(img, (x_mid+1, y_top), (x_mid+1, y_bottom), shadow, thickness+2, cv2.LINE_AA, tipLength=0.02)
343
+ # main line
344
+ cv2.arrowedLine(img, (x_mid, y_top), (x_mid, y_bottom), color, thickness, cv2.LINE_AA, tipLength=0.02)
345
+
346
+ # length label
347
+ label_l = f"{length_cm:.2f} cm"
348
+ (tw2, th2), _ = cv2.getTextSize(label_l, font, 0.7, 2)
349
+ tx2 = x_mid - (tw2 // 2)
350
+ ty2 = y + th2 + 8
351
+ cv2.putText(img, label_l, (tx2+1, ty2+1), font, 0.7, shadow, 3, cv2.LINE_AA)
352
+ cv2.putText(img, label_l, (tx2, ty2), font, 0.7, color, 2, cv2.LINE_AA)
353
+
354
+ return img
355
+
356
  # =============== AI PROCESSOR ===============
357
  class AIProcessor:
358
  def __init__(self):
359
  self.models_cache = models_cache
360
  self.knowledge_base_cache = knowledge_base_cache
 
361
  self.uploads_dir = UPLOADS_DIR
362
  self.dataset_id = DATASET_ID
363
  self.hf_token = HF_TOKEN
 
368
  return out_dir
369
 
370
  def perform_visual_analysis(self, image_pil: Image.Image) -> Dict:
371
+ """YOLO detect → (optional) Keras seg → (optional) HF classify → save visuals with measurement overlay."""
372
  try:
373
+ image_rgb = image_pil.convert("RGB")
374
+ image_cv = cv2.cvtColor(np.array(image_rgb), cv2.COLOR_RGB2BGR)
375
 
376
  det = self.models_cache.get("det")
377
  if det is None:
378
  raise RuntimeError("YOLO model not loaded")
379
 
380
+ # ---------- Automatic calibration (px/cm) ----------
381
+ px_per_cm, calib_src = _auto_pixels_per_cm_from_exif(image_rgb)
382
+ # keep within reasonable range
383
+ if not (5.0 <= px_per_cm <= 500.0):
384
+ px_per_cm, calib_src = DEFAULT_PIXELS_PER_CM, "default"
385
+ logging.info(f"Calibration: {px_per_cm:.2f} px/cm (source={calib_src})")
386
+
387
  # YOLO on CPU
388
  results = det.predict(image_cv, verbose=False, device="cpu")
389
  if not results or not getattr(results[0], "boxes", None) or len(results[0].boxes) == 0:
 
397
 
398
  # Optional segmentation
399
  seg_model = self.models_cache.get("seg")
400
+ length_cm = breadth_cm = surface_area_cm2 = 0.0
401
  seg_path = None
402
+
403
+ rect_xywh_global = None # for overlay on full image if seg missing
404
+
405
  if seg_model is not None and detected_region_cv.size > 0:
406
  try:
407
  input_size = seg_model.input_shape[1:3]
 
413
  if contours:
414
  cnt = max(contours, key=cv2.contourArea)
415
  x, y, w, h = cv2.boundingRect(cnt)
 
 
 
416
 
417
+ # Measurements using calibration
418
+ length_cm = round(h / px_per_cm, 2)
419
+ breadth_cm = round(w / px_per_cm, 2)
420
+ surface_area_cm2 = round(cv2.contourArea(cnt) / (px_per_cm ** 2), 2)
421
+
422
+ # Create segmentation overlay in the cropped region
423
  mask_resized = cv2.resize(
424
  mask_np * 255,
425
  (detected_region_cv.shape[1], detected_region_cv.shape[0]),
426
  interpolation=cv2.INTER_NEAREST,
427
  )
428
  overlay = detected_region_cv.copy()
429
+ overlay[mask_resized > 127] = [0, 0, 255] # red overlay
430
  seg_vis = cv2.addWeighted(detected_region_cv, 0.7, overlay, 0.3, 0)
431
 
432
+ # Draw measurement arrows on seg_vis
433
+ # Map rect from mask space -> cropped image space
434
+ scale_x = detected_region_cv.shape[1] / float(input_size[1])
435
+ scale_y = detected_region_cv.shape[0] / float(input_size[0])
436
+ rect_xywh_cropped = (
437
+ int(x * scale_x),
438
+ int(y * scale_y),
439
+ int(w * scale_x),
440
+ int(h * scale_y),
441
+ )
442
+ seg_vis_meas = _draw_measurement_overlay(seg_vis, rect_xywh_cropped, length_cm, breadth_cm)
443
+
444
  ts = datetime.now().strftime("%Y%m%d_%H%M%S")
445
  out_dir = self._ensure_analysis_dir()
446
  seg_path = os.path.join(out_dir, f"segmentation_{ts}.png")
447
+ cv2.imwrite(seg_path, seg_vis_meas)
448
+
449
+ # Also store rect in full-image coordinates (if ever needed)
450
+ rect_xywh_global = (
451
+ x1 + rect_xywh_cropped[0],
452
+ y1 + rect_xywh_cropped[1],
453
+ rect_xywh_cropped[2],
454
+ rect_xywh_cropped[3],
455
+ )
456
  except Exception as e:
457
  logging.warning(f"Segmentation skipped: {e}")
458
 
 
481
 
482
  return {
483
  "wound_type": wound_type,
484
+ "length_cm": float(length_cm),
485
+ "breadth_cm": float(breadth_cm),
486
+ "surface_area_cm2": float(surface_area_cm2),
487
+ "calibration_px_per_cm": float(px_per_cm),
488
+ "calibration_source": calib_src,
489
  "detection_confidence": float(results[0].boxes.conf[0].cpu().item())
490
  if getattr(results[0].boxes, "conf", None) is not None
491
  else 0.0,
492
  "detection_image_path": det_path,
493
+ "segmentation_image_path": seg_path, # <-- now includes arrow overlay if seg succeeded
494
  "original_image_path": original_path,
495
  }
496
  except Exception as e:
 
530
  - **Dimensions**: {visual_results.get('length_cm', 0)} cm × {visual_results.get('breadth_cm', 0)} cm
531
  - **Surface Area**: {visual_results.get('surface_area_cm2', 0)} cm²
532
  - **Detection Confidence**: {visual_results.get('detection_confidence', 0):.1%}
533
+ - **Calibration**: {visual_results.get('calibration_px_per_cm', 0)} px/cm (source: {visual_results.get('calibration_source','n/a')})
534
 
535
  ## 📊 Analysis Images
536
  - **Original**: {visual_results.get('original_image_path', 'N/A')}
537
  - **Detection**: {visual_results.get('detection_image_path', 'N/A')}
538
+ - **Segmentation (with measurements)**: {visual_results.get('segmentation_image_path', 'N/A')}
539
 
540
  ## 🎯 Clinical Summary
541
  Automated analysis provides quantitative measurements; verify via clinical examination.
 
679
  "report": f"Analysis initialization failed: {str(e)}",
680
  "saved_image_path": None,
681
  "guideline_context": "",
682
+ }