SmartHeal-Agentic-AI / src /ai_processor.py
SmartHeal's picture
Update src/ai_processor.py
ef69ec1 verified
raw
history blame
31.7 kB
# smartheal_ai_processor.py
# Fully functional: robust segmentation + safe overlays + conditional GPU wrapper.
# All original class/function names preserved. New helpers are additive.
import os
import time
import logging
from datetime import datetime
from typing import Optional, Dict, List, Tuple
# --- quiet tokenizers fork warning (HF) ---
os.environ.setdefault("TOKENIZERS_PARALLELISM", "false")
import cv2
import numpy as np
from PIL import Image, ImageOps
from PIL.ExifTags import TAGS
logging.basicConfig(level=logging.INFO, format="%(asctime)s - %(levelname)s - %(message)s")
UPLOADS_DIR = "uploads"
os.makedirs(UPLOADS_DIR, exist_ok=True)
HF_TOKEN = os.getenv("HF_TOKEN", None)
YOLO_MODEL_PATH = "src/best.pt"
SEG_MODEL_PATH = "src/segmentation_model.h5" # optional
GUIDELINE_PDFS = ["src/eHealth in Wound Care.pdf", "src/IWGDF Guideline.pdf", "src/evaluation.pdf"]
DATASET_ID = "SmartHeal/wound-image-uploads"
DEFAULT_PX_PER_CM = 38.0 # fallback when we cannot calibrate
PX_PER_CM_MIN, PX_PER_CM_MAX = 5.0, 1200.0 # sanity bounds
models_cache: Dict[str, object] = {}
knowledge_base_cache: Dict[str, object] = {}
# ---------- Lazy imports ----------
def _import_ultralytics():
from ultralytics import YOLO
return YOLO
def _import_tf_loader():
import tensorflow as tf
tf.config.set_visible_devices([], "GPU") # force CPU for TF to avoid CUDA contention
from tensorflow.keras.models import load_model
return load_model
def _import_hf_cls():
from transformers import pipeline
return pipeline
def _import_embeddings():
from langchain_community.embeddings import HuggingFaceEmbeddings
return HuggingFaceEmbeddings
def _import_langchain_pdf():
from langchain_community.document_loaders import PyPDFLoader
return PyPDFLoader
def _import_langchain_faiss():
from langchain_community.vectorstores import FAISS
return FAISS
def _import_hf_hub():
from huggingface_hub import HfApi, HfFolder
return HfApi, HfFolder
# ---------- Conditional Spaces GPU function ----------
# Avoid scheduling a GPU worker when CUDA is not available (prevents cudaGetDeviceCount crash)
def _cuda_available() -> bool:
try:
import torch
return bool(getattr(torch, "cuda", None)) and torch.cuda.is_available()
except Exception:
return False
def _generate_medgemma_report_core(
patient_info: str,
visual_results: Dict,
guideline_context: str,
image_pil: Image.Image,
max_new_tokens: Optional[int] = None,
) -> str:
try:
from transformers import pipeline
# Use CPU by default; if CUDA truly available, pipeline can still map automatically
pipe = pipeline(
"image-text-to-text",
model="google/medgemma-4b-it",
device_map="auto" if _cuda_available() else None,
token=HF_TOKEN,
model_kwargs={"low_cpu_mem_usage": True, "use_cache": True},
)
prompt = (
"You are a medical AI assistant. Analyze this wound image and patient data.\n\n"
f"Patient: {patient_info}\n"
f"Wound: {visual_results.get('wound_type', 'Unknown')} - "
f"{visual_results.get('length_cm', 0)}×{visual_results.get('breadth_cm', 0)} cm\n\n"
"Provide a structured report with:\n"
"1. Clinical Summary\n2. Treatment Recommendations\n3. Risk Assessment\n4. Monitoring Plan\n"
)
messages = [{"role": "user", "content": [
{"type": "image", "image": image_pil},
{"type": "text", "text": prompt},
]}]
t0 = time.time()
out = pipe(
text=messages,
max_new_tokens=max_new_tokens or 800,
do_sample=False,
temperature=0.7,
)
logging.info(f"✅ MedGemma finished in {time.time()-t0:.2f}s")
if out and len(out) > 0:
try:
return out[0]["generated_text"][-1].get("content", "").strip() or "⚠️ Empty response"
except Exception:
return (out[0].get("generated_text", "") or "").strip() or "⚠️ Empty response"
return "⚠️ No output generated"
except Exception as e:
logging.error(f"❌ MedGemma generation error: {e}")
return "⚠️ GPU/LLM worker unavailable"
# Preserve the SAME public function name.
# Only decorate with @spaces.GPU if CUDA is truly available.
try:
import spaces
if _cuda_available():
@spaces.GPU(enable_queue=True, duration=90)
def generate_medgemma_report(
patient_info: str,
visual_results: Dict,
guideline_context: str,
image_pil: Image.Image,
max_new_tokens: Optional[int] = None,
) -> str:
return _generate_medgemma_report_core(patient_info, visual_results, guideline_context, image_pil, max_new_tokens)
else:
def generate_medgemma_report(
patient_info: str,
visual_results: Dict,
guideline_context: str,
image_pil: Image.Image,
max_new_tokens: Optional[int] = None,
) -> str:
# no decorator -> no GPU worker init -> no cudaGetDeviceCount crash
return _generate_medgemma_report_core(patient_info, visual_results, guideline_context, image_pil, max_new_tokens)
except Exception:
def generate_medgemma_report(
patient_info: str,
visual_results: Dict,
guideline_context: str,
image_pil: Image.Image,
max_new_tokens: Optional[int] = None,
) -> str:
return _generate_medgemma_report_core(patient_info, visual_results, guideline_context, image_pil, max_new_tokens)
# ---------- Initialize CPU models ----------
def load_yolo_model():
YOLO = _import_ultralytics()
return YOLO(YOLO_MODEL_PATH)
def load_segmentation_model():
load_model = _import_tf_loader()
return load_model(SEG_MODEL_PATH, compile=False)
def load_classification_pipeline():
pipe = _import_hf_cls()
return pipe("image-classification", model="Hemg/Wound-classification", token=HF_TOKEN, device="cpu")
def load_embedding_model():
Emb = _import_embeddings()
return Emb(model_name="sentence-transformers/all-MiniLM-L6-v2", model_kwargs={"device": "cpu"})
def initialize_cpu_models() -> None:
if HF_TOKEN:
try:
HfApi, HfFolder = _import_hf_hub()
HfFolder.save_token(HF_TOKEN)
logging.info("✅ HF token set")
except Exception as e:
logging.warning(f"HF token save failed: {e}")
if "det" not in models_cache:
try:
models_cache["det"] = load_yolo_model()
logging.info("✅ YOLO loaded (CPU)")
except Exception as e:
logging.error(f"YOLO load failed: {e}")
if "seg" not in models_cache:
try:
if os.path.exists(SEG_MODEL_PATH):
models_cache["seg"] = load_segmentation_model()
logging.info("✅ Segmentation model loaded (CPU)")
else:
models_cache["seg"] = None
logging.warning("Segmentation model file missing; skipping.")
except Exception as e:
models_cache["seg"] = None
logging.warning(f"Segmentation unavailable: {e}")
if "cls" not in models_cache:
try:
models_cache["cls"] = load_classification_pipeline()
logging.info("✅ Classifier loaded (CPU)")
except Exception as e:
models_cache["cls"] = None
logging.warning(f"Classifier unavailable: {e}")
if "embedding_model" not in models_cache:
try:
models_cache["embedding_model"] = load_embedding_model()
logging.info("✅ Embeddings loaded (CPU)")
except Exception as e:
models_cache["embedding_model"] = None
logging.warning(f"Embeddings unavailable: {e}")
def setup_knowledge_base() -> None:
if "vector_store" in knowledge_base_cache:
return
docs: List = []
try:
PyPDFLoader = _import_langchain_pdf()
for pdf in GUIDELINE_PDFS:
if os.path.exists(pdf):
try:
docs.extend(PyPDFLoader(pdf).load())
logging.info(f"Loaded PDF: {pdf}")
except Exception as e:
logging.warning(f"PDF load failed ({pdf}): {e}")
except Exception as e:
logging.warning(f"LangChain PDF loader unavailable: {e}")
if docs and models_cache.get("embedding_model"):
try:
from langchain.text_splitter import RecursiveCharacterTextSplitter
FAISS = _import_langchain_faiss()
chunks = RecursiveCharacterTextSplitter(chunk_size=1000, chunk_overlap=100).split_documents(docs)
knowledge_base_cache["vector_store"] = FAISS.from_documents(chunks, models_cache["embedding_model"])
logging.info(f"✅ Knowledge base ready ({len(chunks)} chunks)")
except Exception as e:
knowledge_base_cache["vector_store"] = None
logging.warning(f"KB build failed: {e}")
else:
knowledge_base_cache["vector_store"] = None
logging.warning("KB disabled (no docs or embeddings).")
initialize_cpu_models()
setup_knowledge_base()
# ---------- Calibration helpers ----------
def _exif_to_dict(pil_img: Image.Image) -> Dict[str, object]:
out = {}
try:
exif = pil_img.getexif()
if not exif:
return out
for k, v in exif.items():
tag = TAGS.get(k, k)
out[tag] = v
except Exception:
pass
return out
def _to_float(val) -> Optional[float]:
try:
if val is None:
return None
if isinstance(val, tuple) and len(val) == 2:
num, den = float(val[0]), float(val[1]) if float(val[1]) != 0 else 1.0
return num / den
return float(val)
except Exception:
return None
def _estimate_sensor_width_mm(f_mm: Optional[float], f35: Optional[float]) -> Optional[float]:
if f_mm and f35 and f35 > 0:
return 36.0 * f_mm / f35
return None
def estimate_px_per_cm_from_exif(pil_img: Image.Image, default_px_per_cm: float = DEFAULT_PX_PER_CM) -> Tuple[float, Dict]:
meta = {"used": "default", "f_mm": None, "f35": None, "sensor_w_mm": None, "distance_m": None}
try:
exif = _exif_to_dict(pil_img)
f_mm = _to_float(exif.get("FocalLength"))
f35 = _to_float(exif.get("FocalLengthIn35mmFilm") or exif.get("FocalLengthIn35mm"))
subj_dist_m = _to_float(exif.get("SubjectDistance"))
sensor_w_mm = _estimate_sensor_width_mm(f_mm, f35)
meta.update({"f_mm": f_mm, "f35": f35, "sensor_w_mm": sensor_w_mm, "distance_m": subj_dist_m})
if f_mm and sensor_w_mm and subj_dist_m and subj_dist_m > 0:
w_px = pil_img.width
field_w_mm = sensor_w_mm * (subj_dist_m * 1000.0) / f_mm
field_w_cm = field_w_mm / 10.0
px_per_cm = w_px / max(field_w_cm, 1e-6)
px_per_cm = float(np.clip(px_per_cm, PX_PER_CM_MIN, PX_PER_CM_MAX))
meta["used"] = "exif"
return px_per_cm, meta
return float(default_px_per_cm), meta
except Exception:
return float(default_px_per_cm), meta
# ---------- Segmentation helpers (additive; names preserved elsewhere) ----------
def _get_seg_hw(seg_model) -> Tuple[int, int]:
shp = getattr(seg_model, "input_shape", None)
if shp and len(shp) >= 4:
return int(shp[1]), int(shp[2])
# try Keras .inputs shape
try:
shp = seg_model.inputs[0].shape
return int(shp[1]), int(shp[2])
except Exception:
pass
raise ValueError(f"Cannot infer (H,W) from segmentation model input shape: {shp}")
def _to_prob(mask_pred: np.ndarray) -> np.ndarray:
m = np.array(mask_pred)
# squeeze batch/channel dims
while m.ndim > 2:
if m.shape[0] == 1:
m = np.squeeze(m, axis=0)
if m.ndim > 2 and m.shape[-1] == 1:
m = np.squeeze(m, axis=-1)
if m.ndim == 3 and m.shape[-1] > 1:
# pick the most active channel
ch = np.argmax(m.reshape(-1, m.shape[-1]).mean(0))
m = m[..., ch]
if m.ndim <= 2:
break
m = m.astype("float32")
# if looks like logits -> sigmoid
if m.max() > 1.5 or m.min() < -0.5:
m = 1.0 / (1.0 + np.exp(-m))
return np.clip(m, 0.0, 1.0)
def _adaptive_threshold(prob: np.ndarray, hard: float = 0.5) -> np.ndarray:
if (prob >= hard).sum() > 0:
return (prob >= hard).astype("uint8")
# try Otsu
m8 = (np.clip(prob, 0, 1) * 255).astype("uint8")
try:
# we only need the threshold value _
_, _ = cv2.threshold(m8, 0, 255, cv2.THRESH_BINARY + cv2.THRESH_OTSU)
return (m8 >= _).astype("uint8")
except Exception:
p = float(np.percentile(prob, 99.0))
return (prob >= max(0.2, min(0.9, p))).astype("uint8")
def largest_component_mask(binary: np.ndarray, min_area_px: int = 50) -> np.ndarray:
num, labels, stats, _ = cv2.connectedComponentsWithStats(binary.astype(np.uint8), connectivity=8)
if num <= 1:
return binary.astype(np.uint8)
areas = stats[1:, cv2.CC_STAT_AREA]
if areas.size == 0 or areas.max() < min_area_px:
return binary.astype(np.uint8)
largest_idx = 1 + int(np.argmax(areas))
return (labels == largest_idx).astype(np.uint8)
def measure_min_area_rect(mask: np.ndarray, px_per_cm: float) -> Tuple[float, float, Tuple]:
contours, _ = cv2.findContours(mask.astype(np.uint8), cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE)
if not contours:
return 0.0, 0.0, (None, None)
cnt = max(contours, key=cv2.contourArea)
rect = cv2.minAreaRect(cnt)
(w_px, h_px) = rect[1]
length_px, breadth_px = (max(w_px, h_px), min(w_px, h_px))
length_cm = round(length_px / max(px_per_cm, 1e-6), 2)
breadth_cm = round(breadth_px / max(px_per_cm, 1e-6), 2)
box = cv2.boxPoints(rect).astype(int)
return length_cm, breadth_cm, (box, rect[0])
def count_area_cm2(mask: np.ndarray, px_per_cm: float) -> float:
px_count = float(mask.astype(bool).sum())
return round(px_count / (max(px_per_cm, 1e-6) ** 2), 2)
def draw_measurement_overlay(
base_bgr: np.ndarray,
mask01: np.ndarray,
rect_box: np.ndarray,
length_cm: float,
breadth_cm: float,
thickness: int = 2
) -> np.ndarray:
overlay = base_bgr.copy()
# safe blend: blend once, then gate with mask (no mask kwarg!)
colored = np.zeros_like(base_bgr); colored[:] = (0, 0, 255)
blended = cv2.addWeighted(overlay, 1.0, colored, 0.3, 0)
m3 = np.dstack([mask01 * 255] * 3).astype("uint8")
blended_masked = cv2.bitwise_and(blended, m3)
bg = cv2.bitwise_and(overlay, cv2.bitwise_not(m3))
overlay = cv2.add(bg, blended_masked)
if rect_box is not None:
cv2.polylines(overlay, [rect_box], True, (255, 255, 255), thickness)
pts = rect_box.reshape(-1, 2)
def midpoint(a, b): return ((a[0] + b[0]) // 2, (a[1] + b[1]) // 2)
mids = [midpoint(pts[i], pts[(i+1) % 4]) for i in range(4)]
e_lens = [np.linalg.norm(pts[i] - pts[(i+1) % 4]) for i in range(4)]
long_pair = (0, 2) if e_lens[0] + e_lens[2] >= e_lens[1] + e_lens[3] else (1, 3)
short_pair = (1, 3) if long_pair == (0, 2) else (0, 2)
def draw_arrow(img, p1, p2):
cv2.arrowedLine(img, p1, p2, (0, 0, 0), thickness + 2, tipLength=0.05)
cv2.arrowedLine(img, p2, p1, (0, 0, 0), thickness + 2, tipLength=0.05)
cv2.arrowedLine(img, p1, p2, (255, 255, 255), thickness, tipLength=0.05)
cv2.arrowedLine(img, p2, p1, (255, 255, 255), thickness, tipLength=0.05)
draw_arrow(overlay, mids[long_pair[0]], mids[long_pair[1]])
draw_arrow(overlay, mids[short_pair[0]], mids[short_pair[1]])
def put_label(text, org):
cv2.putText(overlay, text, (org[0] + 4, org[1] - 4),
cv2.FONT_HERSHEY_SIMPLEX, 0.7, (0, 0, 0), 4, cv2.LINE_AA)
cv2.putText(overlay, text, (org[0] + 4, org[1] - 4),
cv2.FONT_HERSHEY_SIMPLEX, 0.7, (255, 255, 255), 2, cv2.LINE_AA)
put_label(f"{length_cm:.2f} cm", mids[long_pair[0]])
put_label(f"{breadth_cm:.2f} cm", mids[short_pair[0]])
return overlay
# ---------- AI PROCESSOR ----------
class AIProcessor:
def __init__(self):
self.models_cache = models_cache
self.knowledge_base_cache = knowledge_base_cache
self.uploads_dir = UPLOADS_DIR
self.dataset_id = DATASET_ID
self.hf_token = HF_TOKEN
def _ensure_analysis_dir(self) -> str:
out_dir = os.path.join(self.uploads_dir, "analysis")
os.makedirs(out_dir, exist_ok=True)
return out_dir
def perform_visual_analysis(self, image_pil: Image.Image) -> Dict:
"""
Detect → crop ROI → (optional) segment → cleanup → largest component →
oriented minAreaRect in cm (EXIF-calibrated) → save original/detect/seg/annotated.
"""
try:
# --- Auto calibration from EXIF ---
px_per_cm, exif_meta = estimate_px_per_cm_from_exif(image_pil, DEFAULT_PX_PER_CM)
# Convert PIL to OpenCV BGR
image_cv = cv2.cvtColor(np.array(image_pil.convert("RGB")), cv2.COLOR_RGB2BGR)
# --- Detection (YOLO) ---
det_model = self.models_cache.get("det")
if det_model is None:
raise RuntimeError("YOLO model not loaded")
results = det_model.predict(image_cv, verbose=False, device="cpu")
if not results or not getattr(results[0], "boxes", None) or len(results[0].boxes) == 0:
import gradio as gr # local import to keep class name intact if gradio missing
raise gr.Error("No wound could be detected.")
box = results[0].boxes[0].xyxy[0].cpu().numpy().astype(int)
x1, y1, x2, y2 = [int(v) for v in box]
x1, y1 = max(0, x1), max(0, y1)
x2, y2 = min(image_cv.shape[1], x2), min(image_cv.shape[0], y2)
roi = image_cv[y1:y2, x1:x2].copy()
if roi.size == 0:
import gradio as gr
raise gr.Error("Detected ROI is empty.")
# --- Segmentation (robust) ---
seg_model = self.models_cache.get("seg")
mask_roi_01 = None
if seg_model is not None:
try:
H, W = _get_seg_hw(seg_model) # robust (H,W)
resized = cv2.resize(roi, (W, H)) # cv2.resize expects (W,H)
pred = seg_model.predict(np.expand_dims(resized / 255.0, 0), verbose=0)
prob = _to_prob(pred) # (H,W) in [0,1]
binmask = _adaptive_threshold(prob, hard=0.5)
# gentle cleanup + largest component
binmask = cv2.morphologyEx(binmask, cv2.MORPH_OPEN, np.ones((3,3), np.uint8), iterations=1)
binmask = cv2.morphologyEx(binmask, cv2.MORPH_CLOSE, np.ones((3,3), np.uint8), iterations=1)
binmask = largest_component_mask(binmask, min_area_px=30)
# back to ROI size {0,1}
mask_roi_01 = cv2.resize(binmask, (roi.shape[1], roi.shape[0]), interpolation=cv2.INTER_NEAREST).astype(np.uint8)
logging.info(f"seg prob stats: min={prob.min():.4f}, max={prob.max():.4f}, mean={prob.mean():.4f}; on={(mask_roi_01==1).sum()}")
except Exception as e:
logging.warning(f"Segmentation failed: {e}")
mask_roi_01 = None
else:
logging.info("Skipping segmentation (no model).")
# --- Measurement ---
if mask_roi_01 is not None and mask_roi_01.any():
length_cm, breadth_cm, (box_pts, _) = measure_min_area_rect(mask_roi_01, px_per_cm)
surface_area_cm2 = count_area_cm2(mask_roi_01, px_per_cm)
anno_roi = draw_measurement_overlay(roi, mask_roi_01, box_pts, length_cm, breadth_cm)
else:
# fallback to detection-box cm
h_px = max(0, y2 - y1); w_px = max(0, x2 - x1)
length_cm = round(h_px / px_per_cm, 2)
breadth_cm = round(w_px / px_per_cm, 2)
surface_area_cm2 = round((h_px * w_px) / (px_per_cm ** 2), 2)
anno_roi = roi.copy()
# --- Save visualizations ---
out_dir = self._ensure_analysis_dir()
ts = datetime.now().strftime("%Y%m%d_%H%M%S")
original_path = os.path.join(out_dir, f"original_{ts}.png")
cv2.imwrite(original_path, image_cv)
det_vis = image_cv.copy()
cv2.rectangle(det_vis, (x1, y1), (x2, y2), (0, 255, 0), 2)
detection_path = os.path.join(out_dir, f"detection_{ts}.png")
cv2.imwrite(detection_path, det_vis)
segmentation_path = None
annotated_seg_path = None
if mask_roi_01 is not None and mask_roi_01.any():
# safe masked blend (no mask kwarg to addWeighted)
seg_full = image_cv.copy()
roi_overlay = roi.copy()
red = np.zeros_like(roi_overlay); red[:] = (0, 0, 255)
blended = cv2.addWeighted(roi_overlay, 1.0, red, 0.3, 0)
mask_u8 = (mask_roi_01.astype(np.uint8) * 255)
mask3 = cv2.merge([mask_u8, mask_u8, mask_u8])
blended_masked = cv2.bitwise_and(blended, mask3)
roi_bg = cv2.bitwise_and(roi_overlay, cv2.bitwise_not(mask3))
roi_overlay = cv2.add(roi_bg, blended_masked)
seg_full[y1:y2, x1:x2] = roi_overlay
segmentation_path = os.path.join(out_dir, f"segmentation_{ts}.png")
cv2.imwrite(segmentation_path, seg_full)
anno_full = image_cv.copy()
anno_full[y1:y2, x1:x2] = anno_roi
annotated_seg_path = os.path.join(out_dir, f"segmentation_annotated_{ts}.png")
cv2.imwrite(annotated_seg_path, anno_full)
# --- Optional classification ---
wound_type = "Unknown"
cls_pipe = self.models_cache.get("cls")
if cls_pipe is not None:
try:
preds = cls_pipe(Image.fromarray(cv2.cvtColor(roi, cv2.COLOR_BGR2RGB)))
if preds:
wound_type = max(preds, key=lambda x: x.get("score", 0)).get("label", "Unknown")
except Exception as e:
logging.warning(f"Classification failed: {e}")
return {
"wound_type": wound_type,
"length_cm": length_cm,
"breadth_cm": breadth_cm,
"surface_area_cm2": surface_area_cm2,
"px_per_cm": round(px_per_cm, 2),
"calibration_meta": exif_meta,
"detection_confidence": float(results[0].boxes.conf[0].cpu().item())
if getattr(results[0].boxes, "conf", None) is not None else 0.0,
"detection_image_path": detection_path,
"segmentation_image_path": segmentation_path,
"segmentation_annotated_path": annotated_seg_path,
"original_image_path": original_path,
}
except Exception as e:
logging.error(f"Visual analysis failed: {e}", exc_info=True)
raise
# ---------- Knowledge base and reporting stay unchanged ----------
def query_guidelines(self, query: str) -> str:
try:
vs = self.knowledge_base_cache.get("vector_store")
if not vs:
return "Knowledge base is not available."
try:
retriever = vs.as_retriever(search_kwargs={"k": 5})
docs = retriever.get_relevant_documents(query)
except Exception:
retriever = vs.as_retriever(search_kwargs={"k": 5})
docs = retriever.invoke(query)
lines: List[str] = []
for d in docs:
src = (d.metadata or {}).get("source", "N/A")
txt = (d.page_content or "")[:300]
lines.append(f"Source: {src}\nContent: {txt}...")
return "\n\n".join(lines) if lines else "No relevant guideline snippets found."
except Exception as e:
logging.warning(f"Guidelines query failed: {e}")
return f"Guidelines query failed: {str(e)}"
def _generate_fallback_report(self, patient_info: str, visual_results: Dict, guideline_context: str) -> str:
return f"""# 🩺 SmartHeal AI - Comprehensive Wound Analysis Report
## 📋 Patient Information
{patient_info}
## 🔍 Visual Analysis Results
- **Wound Type**: {visual_results.get('wound_type', 'Unknown')}
- **Dimensions**: {visual_results.get('length_cm', 0)} cm × {visual_results.get('breadth_cm', 0)} cm
- **Surface Area**: {visual_results.get('surface_area_cm2', 0)} cm²
- **Detection Confidence**: {visual_results.get('detection_confidence', 0):.1%}
- **Calibration**: {visual_results.get('px_per_cm','?')} px/cm ({(visual_results.get('calibration_meta') or {}).get('used','default')})
## 📊 Analysis Images
- **Original**: {visual_results.get('original_image_path', 'N/A')}
- **Detection**: {visual_results.get('detection_image_path', 'N/A')}
- **Segmentation**: {visual_results.get('segmentation_image_path', 'N/A')}
- **Annotated**: {visual_results.get('segmentation_annotated_path', 'N/A')}
## 🎯 Clinical Summary
Automated analysis provides quantitative measurements; verify via clinical examination.
## 💊 Recommendations
- Cleanse wound gently; select dressing per exudate/infection risk
- Debride necrotic tissue if indicated (clinical decision)
- Document with serial photos and measurements
## 📅 Monitoring
- Daily in week 1, then every 2–3 days (or as indicated)
- Weekly progress review
## 📚 Guideline Context
{(guideline_context or '')[:800]}{"..." if guideline_context and len(guideline_context) > 800 else ''}
**Disclaimer:** Automated, for decision support only. Verify clinically.
"""
def generate_final_report(
self,
patient_info: str,
visual_results: Dict,
guideline_context: str,
image_pil: Image.Image,
max_new_tokens: Optional[int] = None,
) -> str:
try:
report = generate_medgemma_report(
patient_info, visual_results, guideline_context, image_pil, max_new_tokens
)
if report and report.strip() and not report.startswith(("⚠️", "❌")):
return report
logging.warning("MedGemma unavailable/invalid; using fallback.")
return self._generate_fallback_report(patient_info, visual_results, guideline_context)
except Exception as e:
logging.error(f"Report generation failed: {e}")
return self._generate_fallback_report(patient_info, visual_results, guideline_context)
def save_and_commit_image(self, image_pil: Image.Image) -> str:
try:
os.makedirs(self.uploads_dir, exist_ok=True)
ts = datetime.now().strftime("%Y%m%d_%H%M%S")
filename = f"{ts}.png"
path = os.path.join(self.uploads_dir, filename)
image_pil.convert("RGB").save(path)
logging.info(f"✅ Image saved locally: {path}")
if HF_TOKEN and DATASET_ID:
try:
HfApi, HfFolder = _import_hf_hub()
HfFolder.save_token(HF_TOKEN)
api = HfApi()
api.upload_file(
path_or_fileobj=path,
path_in_repo=f"images/{filename}",
repo_id=DATASET_ID,
repo_type="dataset",
token=HF_TOKEN,
commit_message=f"Upload wound image: {filename}",
)
logging.info("✅ Image committed to HF dataset")
except Exception as e:
logging.warning(f"HF upload failed: {e}")
return path
except Exception as e:
logging.error(f"Failed to save/commit image: {e}")
return ""
def full_analysis_pipeline(self, image_pil: Image.Image, questionnaire_data: Dict) -> Dict:
try:
saved_path = self.save_and_commit_image(image_pil)
visual_results = self.perform_visual_analysis(image_pil)
pi = questionnaire_data or {}
patient_info = (
f"Age: {pi.get('age','N/A')}, "
f"Diabetic: {pi.get('diabetic','N/A')}, "
f"Allergies: {pi.get('allergies','N/A')}, "
f"Date of Wound: {pi.get('date_of_injury','N/A')}, "
f"Professional Care: {pi.get('professional_care','N/A')}, "
f"Oozing/Bleeding: {pi.get('oozing_bleeding','N/A')}, "
f"Infection: {pi.get('infection','N/A')}, "
f"Moisture: {pi.get('moisture','N/A')}"
)
query = (
f"best practices for managing a {visual_results.get('wound_type','Unknown')} "
f"with moisture '{pi.get('moisture','unknown')}' and infection '{pi.get('infection','unknown')}' "
f"in a diabetic status '{pi.get('diabetic','unknown')}'"
)
guideline_context = self.query_guidelines(query)
report = self.generate_final_report(patient_info, visual_results, guideline_context, image_pil)
return {
"success": True,
"visual_analysis": visual_results,
"report": report,
"saved_image_path": saved_path,
"guideline_context": (guideline_context or "")[:500] + (
"..." if guideline_context and len(guideline_context) > 500 else ""
),
}
except Exception as e:
logging.error(f"Pipeline error: {e}")
return {
"success": False,
"error": str(e),
"visual_analysis": {},
"report": f"Analysis failed: {str(e)}",
"saved_image_path": None,
"guideline_context": "",
}
def analyze_wound(self, image, questionnaire_data: Dict) -> Dict:
try:
if isinstance(image, str):
if not os.path.exists(image):
raise ValueError(f"Image file not found: {image}")
image_pil = Image.open(image)
elif isinstance(image, Image.Image):
image_pil = image
elif isinstance(image, np.ndarray):
image_pil = Image.fromarray(image)
else:
raise ValueError(f"Unsupported image type: {type(image)}")
return self.full_analysis_pipeline(image_pil, questionnaire_data or {})
except Exception as e:
logging.error(f"Wound analysis error: {e}")
return {
"success": False,
"error": str(e),
"visual_analysis": {},
"report": f"Analysis initialization failed: {str(e)}",
"saved_image_path": None,
"guideline_context": "",
}