Spaces:
Running
Running
| # smartheal_ai_processor.py | |
| # Fully functional: robust segmentation + safe overlays + conditional GPU wrapper. | |
| # All original class/function names preserved. New helpers are additive. | |
| import os | |
| import time | |
| import logging | |
| from datetime import datetime | |
| from typing import Optional, Dict, List, Tuple | |
| # --- quiet tokenizers fork warning (HF) --- | |
| os.environ.setdefault("TOKENIZERS_PARALLELISM", "false") | |
| import cv2 | |
| import numpy as np | |
| from PIL import Image, ImageOps | |
| from PIL.ExifTags import TAGS | |
| logging.basicConfig(level=logging.INFO, format="%(asctime)s - %(levelname)s - %(message)s") | |
| UPLOADS_DIR = "uploads" | |
| os.makedirs(UPLOADS_DIR, exist_ok=True) | |
| HF_TOKEN = os.getenv("HF_TOKEN", None) | |
| YOLO_MODEL_PATH = "src/best.pt" | |
| SEG_MODEL_PATH = "src/segmentation_model.h5" # optional | |
| GUIDELINE_PDFS = ["src/eHealth in Wound Care.pdf", "src/IWGDF Guideline.pdf", "src/evaluation.pdf"] | |
| DATASET_ID = "SmartHeal/wound-image-uploads" | |
| DEFAULT_PX_PER_CM = 38.0 # fallback when we cannot calibrate | |
| PX_PER_CM_MIN, PX_PER_CM_MAX = 5.0, 1200.0 # sanity bounds | |
| models_cache: Dict[str, object] = {} | |
| knowledge_base_cache: Dict[str, object] = {} | |
| # ---------- Lazy imports ---------- | |
| def _import_ultralytics(): | |
| from ultralytics import YOLO | |
| return YOLO | |
| def _import_tf_loader(): | |
| import tensorflow as tf | |
| tf.config.set_visible_devices([], "GPU") # force CPU for TF to avoid CUDA contention | |
| from tensorflow.keras.models import load_model | |
| return load_model | |
| def _import_hf_cls(): | |
| from transformers import pipeline | |
| return pipeline | |
| def _import_embeddings(): | |
| from langchain_community.embeddings import HuggingFaceEmbeddings | |
| return HuggingFaceEmbeddings | |
| def _import_langchain_pdf(): | |
| from langchain_community.document_loaders import PyPDFLoader | |
| return PyPDFLoader | |
| def _import_langchain_faiss(): | |
| from langchain_community.vectorstores import FAISS | |
| return FAISS | |
| def _import_hf_hub(): | |
| from huggingface_hub import HfApi, HfFolder | |
| return HfApi, HfFolder | |
| # ---------- Conditional Spaces GPU function ---------- | |
| # Avoid scheduling a GPU worker when CUDA is not available (prevents cudaGetDeviceCount crash) | |
| def _cuda_available() -> bool: | |
| try: | |
| import torch | |
| return bool(getattr(torch, "cuda", None)) and torch.cuda.is_available() | |
| except Exception: | |
| return False | |
| def _generate_medgemma_report_core( | |
| patient_info: str, | |
| visual_results: Dict, | |
| guideline_context: str, | |
| image_pil: Image.Image, | |
| max_new_tokens: Optional[int] = None, | |
| ) -> str: | |
| try: | |
| from transformers import pipeline | |
| # Use CPU by default; if CUDA truly available, pipeline can still map automatically | |
| pipe = pipeline( | |
| "image-text-to-text", | |
| model="google/medgemma-4b-it", | |
| device_map="auto" if _cuda_available() else None, | |
| token=HF_TOKEN, | |
| model_kwargs={"low_cpu_mem_usage": True, "use_cache": True}, | |
| ) | |
| prompt = ( | |
| "You are a medical AI assistant. Analyze this wound image and patient data.\n\n" | |
| f"Patient: {patient_info}\n" | |
| f"Wound: {visual_results.get('wound_type', 'Unknown')} - " | |
| f"{visual_results.get('length_cm', 0)}×{visual_results.get('breadth_cm', 0)} cm\n\n" | |
| "Provide a structured report with:\n" | |
| "1. Clinical Summary\n2. Treatment Recommendations\n3. Risk Assessment\n4. Monitoring Plan\n" | |
| ) | |
| messages = [{"role": "user", "content": [ | |
| {"type": "image", "image": image_pil}, | |
| {"type": "text", "text": prompt}, | |
| ]}] | |
| t0 = time.time() | |
| out = pipe( | |
| text=messages, | |
| max_new_tokens=max_new_tokens or 800, | |
| do_sample=False, | |
| temperature=0.7, | |
| ) | |
| logging.info(f"✅ MedGemma finished in {time.time()-t0:.2f}s") | |
| if out and len(out) > 0: | |
| try: | |
| return out[0]["generated_text"][-1].get("content", "").strip() or "⚠️ Empty response" | |
| except Exception: | |
| return (out[0].get("generated_text", "") or "").strip() or "⚠️ Empty response" | |
| return "⚠️ No output generated" | |
| except Exception as e: | |
| logging.error(f"❌ MedGemma generation error: {e}") | |
| return "⚠️ GPU/LLM worker unavailable" | |
| # Preserve the SAME public function name. | |
| # Only decorate with @spaces.GPU if CUDA is truly available. | |
| try: | |
| import spaces | |
| if _cuda_available(): | |
| def generate_medgemma_report( | |
| patient_info: str, | |
| visual_results: Dict, | |
| guideline_context: str, | |
| image_pil: Image.Image, | |
| max_new_tokens: Optional[int] = None, | |
| ) -> str: | |
| return _generate_medgemma_report_core(patient_info, visual_results, guideline_context, image_pil, max_new_tokens) | |
| else: | |
| def generate_medgemma_report( | |
| patient_info: str, | |
| visual_results: Dict, | |
| guideline_context: str, | |
| image_pil: Image.Image, | |
| max_new_tokens: Optional[int] = None, | |
| ) -> str: | |
| # no decorator -> no GPU worker init -> no cudaGetDeviceCount crash | |
| return _generate_medgemma_report_core(patient_info, visual_results, guideline_context, image_pil, max_new_tokens) | |
| except Exception: | |
| def generate_medgemma_report( | |
| patient_info: str, | |
| visual_results: Dict, | |
| guideline_context: str, | |
| image_pil: Image.Image, | |
| max_new_tokens: Optional[int] = None, | |
| ) -> str: | |
| return _generate_medgemma_report_core(patient_info, visual_results, guideline_context, image_pil, max_new_tokens) | |
| # ---------- Initialize CPU models ---------- | |
| def load_yolo_model(): | |
| YOLO = _import_ultralytics() | |
| return YOLO(YOLO_MODEL_PATH) | |
| def load_segmentation_model(): | |
| load_model = _import_tf_loader() | |
| return load_model(SEG_MODEL_PATH, compile=False) | |
| def load_classification_pipeline(): | |
| pipe = _import_hf_cls() | |
| return pipe("image-classification", model="Hemg/Wound-classification", token=HF_TOKEN, device="cpu") | |
| def load_embedding_model(): | |
| Emb = _import_embeddings() | |
| return Emb(model_name="sentence-transformers/all-MiniLM-L6-v2", model_kwargs={"device": "cpu"}) | |
| def initialize_cpu_models() -> None: | |
| if HF_TOKEN: | |
| try: | |
| HfApi, HfFolder = _import_hf_hub() | |
| HfFolder.save_token(HF_TOKEN) | |
| logging.info("✅ HF token set") | |
| except Exception as e: | |
| logging.warning(f"HF token save failed: {e}") | |
| if "det" not in models_cache: | |
| try: | |
| models_cache["det"] = load_yolo_model() | |
| logging.info("✅ YOLO loaded (CPU)") | |
| except Exception as e: | |
| logging.error(f"YOLO load failed: {e}") | |
| if "seg" not in models_cache: | |
| try: | |
| if os.path.exists(SEG_MODEL_PATH): | |
| models_cache["seg"] = load_segmentation_model() | |
| logging.info("✅ Segmentation model loaded (CPU)") | |
| else: | |
| models_cache["seg"] = None | |
| logging.warning("Segmentation model file missing; skipping.") | |
| except Exception as e: | |
| models_cache["seg"] = None | |
| logging.warning(f"Segmentation unavailable: {e}") | |
| if "cls" not in models_cache: | |
| try: | |
| models_cache["cls"] = load_classification_pipeline() | |
| logging.info("✅ Classifier loaded (CPU)") | |
| except Exception as e: | |
| models_cache["cls"] = None | |
| logging.warning(f"Classifier unavailable: {e}") | |
| if "embedding_model" not in models_cache: | |
| try: | |
| models_cache["embedding_model"] = load_embedding_model() | |
| logging.info("✅ Embeddings loaded (CPU)") | |
| except Exception as e: | |
| models_cache["embedding_model"] = None | |
| logging.warning(f"Embeddings unavailable: {e}") | |
| def setup_knowledge_base() -> None: | |
| if "vector_store" in knowledge_base_cache: | |
| return | |
| docs: List = [] | |
| try: | |
| PyPDFLoader = _import_langchain_pdf() | |
| for pdf in GUIDELINE_PDFS: | |
| if os.path.exists(pdf): | |
| try: | |
| docs.extend(PyPDFLoader(pdf).load()) | |
| logging.info(f"Loaded PDF: {pdf}") | |
| except Exception as e: | |
| logging.warning(f"PDF load failed ({pdf}): {e}") | |
| except Exception as e: | |
| logging.warning(f"LangChain PDF loader unavailable: {e}") | |
| if docs and models_cache.get("embedding_model"): | |
| try: | |
| from langchain.text_splitter import RecursiveCharacterTextSplitter | |
| FAISS = _import_langchain_faiss() | |
| chunks = RecursiveCharacterTextSplitter(chunk_size=1000, chunk_overlap=100).split_documents(docs) | |
| knowledge_base_cache["vector_store"] = FAISS.from_documents(chunks, models_cache["embedding_model"]) | |
| logging.info(f"✅ Knowledge base ready ({len(chunks)} chunks)") | |
| except Exception as e: | |
| knowledge_base_cache["vector_store"] = None | |
| logging.warning(f"KB build failed: {e}") | |
| else: | |
| knowledge_base_cache["vector_store"] = None | |
| logging.warning("KB disabled (no docs or embeddings).") | |
| initialize_cpu_models() | |
| setup_knowledge_base() | |
| # ---------- Calibration helpers ---------- | |
| def _exif_to_dict(pil_img: Image.Image) -> Dict[str, object]: | |
| out = {} | |
| try: | |
| exif = pil_img.getexif() | |
| if not exif: | |
| return out | |
| for k, v in exif.items(): | |
| tag = TAGS.get(k, k) | |
| out[tag] = v | |
| except Exception: | |
| pass | |
| return out | |
| def _to_float(val) -> Optional[float]: | |
| try: | |
| if val is None: | |
| return None | |
| if isinstance(val, tuple) and len(val) == 2: | |
| num, den = float(val[0]), float(val[1]) if float(val[1]) != 0 else 1.0 | |
| return num / den | |
| return float(val) | |
| except Exception: | |
| return None | |
| def _estimate_sensor_width_mm(f_mm: Optional[float], f35: Optional[float]) -> Optional[float]: | |
| if f_mm and f35 and f35 > 0: | |
| return 36.0 * f_mm / f35 | |
| return None | |
| def estimate_px_per_cm_from_exif(pil_img: Image.Image, default_px_per_cm: float = DEFAULT_PX_PER_CM) -> Tuple[float, Dict]: | |
| meta = {"used": "default", "f_mm": None, "f35": None, "sensor_w_mm": None, "distance_m": None} | |
| try: | |
| exif = _exif_to_dict(pil_img) | |
| f_mm = _to_float(exif.get("FocalLength")) | |
| f35 = _to_float(exif.get("FocalLengthIn35mmFilm") or exif.get("FocalLengthIn35mm")) | |
| subj_dist_m = _to_float(exif.get("SubjectDistance")) | |
| sensor_w_mm = _estimate_sensor_width_mm(f_mm, f35) | |
| meta.update({"f_mm": f_mm, "f35": f35, "sensor_w_mm": sensor_w_mm, "distance_m": subj_dist_m}) | |
| if f_mm and sensor_w_mm and subj_dist_m and subj_dist_m > 0: | |
| w_px = pil_img.width | |
| field_w_mm = sensor_w_mm * (subj_dist_m * 1000.0) / f_mm | |
| field_w_cm = field_w_mm / 10.0 | |
| px_per_cm = w_px / max(field_w_cm, 1e-6) | |
| px_per_cm = float(np.clip(px_per_cm, PX_PER_CM_MIN, PX_PER_CM_MAX)) | |
| meta["used"] = "exif" | |
| return px_per_cm, meta | |
| return float(default_px_per_cm), meta | |
| except Exception: | |
| return float(default_px_per_cm), meta | |
| # ---------- Segmentation helpers (additive; names preserved elsewhere) ---------- | |
| def _get_seg_hw(seg_model) -> Tuple[int, int]: | |
| shp = getattr(seg_model, "input_shape", None) | |
| if shp and len(shp) >= 4: | |
| return int(shp[1]), int(shp[2]) | |
| # try Keras .inputs shape | |
| try: | |
| shp = seg_model.inputs[0].shape | |
| return int(shp[1]), int(shp[2]) | |
| except Exception: | |
| pass | |
| raise ValueError(f"Cannot infer (H,W) from segmentation model input shape: {shp}") | |
| def _to_prob(mask_pred: np.ndarray) -> np.ndarray: | |
| m = np.array(mask_pred) | |
| # squeeze batch/channel dims | |
| while m.ndim > 2: | |
| if m.shape[0] == 1: | |
| m = np.squeeze(m, axis=0) | |
| if m.ndim > 2 and m.shape[-1] == 1: | |
| m = np.squeeze(m, axis=-1) | |
| if m.ndim == 3 and m.shape[-1] > 1: | |
| # pick the most active channel | |
| ch = np.argmax(m.reshape(-1, m.shape[-1]).mean(0)) | |
| m = m[..., ch] | |
| if m.ndim <= 2: | |
| break | |
| m = m.astype("float32") | |
| # if looks like logits -> sigmoid | |
| if m.max() > 1.5 or m.min() < -0.5: | |
| m = 1.0 / (1.0 + np.exp(-m)) | |
| return np.clip(m, 0.0, 1.0) | |
| def _adaptive_threshold(prob: np.ndarray, hard: float = 0.5) -> np.ndarray: | |
| if (prob >= hard).sum() > 0: | |
| return (prob >= hard).astype("uint8") | |
| # try Otsu | |
| m8 = (np.clip(prob, 0, 1) * 255).astype("uint8") | |
| try: | |
| # we only need the threshold value _ | |
| _, _ = cv2.threshold(m8, 0, 255, cv2.THRESH_BINARY + cv2.THRESH_OTSU) | |
| return (m8 >= _).astype("uint8") | |
| except Exception: | |
| p = float(np.percentile(prob, 99.0)) | |
| return (prob >= max(0.2, min(0.9, p))).astype("uint8") | |
| def largest_component_mask(binary: np.ndarray, min_area_px: int = 50) -> np.ndarray: | |
| num, labels, stats, _ = cv2.connectedComponentsWithStats(binary.astype(np.uint8), connectivity=8) | |
| if num <= 1: | |
| return binary.astype(np.uint8) | |
| areas = stats[1:, cv2.CC_STAT_AREA] | |
| if areas.size == 0 or areas.max() < min_area_px: | |
| return binary.astype(np.uint8) | |
| largest_idx = 1 + int(np.argmax(areas)) | |
| return (labels == largest_idx).astype(np.uint8) | |
| def measure_min_area_rect(mask: np.ndarray, px_per_cm: float) -> Tuple[float, float, Tuple]: | |
| contours, _ = cv2.findContours(mask.astype(np.uint8), cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE) | |
| if not contours: | |
| return 0.0, 0.0, (None, None) | |
| cnt = max(contours, key=cv2.contourArea) | |
| rect = cv2.minAreaRect(cnt) | |
| (w_px, h_px) = rect[1] | |
| length_px, breadth_px = (max(w_px, h_px), min(w_px, h_px)) | |
| length_cm = round(length_px / max(px_per_cm, 1e-6), 2) | |
| breadth_cm = round(breadth_px / max(px_per_cm, 1e-6), 2) | |
| box = cv2.boxPoints(rect).astype(int) | |
| return length_cm, breadth_cm, (box, rect[0]) | |
| def count_area_cm2(mask: np.ndarray, px_per_cm: float) -> float: | |
| px_count = float(mask.astype(bool).sum()) | |
| return round(px_count / (max(px_per_cm, 1e-6) ** 2), 2) | |
| def draw_measurement_overlay( | |
| base_bgr: np.ndarray, | |
| mask01: np.ndarray, | |
| rect_box: np.ndarray, | |
| length_cm: float, | |
| breadth_cm: float, | |
| thickness: int = 2 | |
| ) -> np.ndarray: | |
| overlay = base_bgr.copy() | |
| # safe blend: blend once, then gate with mask (no mask kwarg!) | |
| colored = np.zeros_like(base_bgr); colored[:] = (0, 0, 255) | |
| blended = cv2.addWeighted(overlay, 1.0, colored, 0.3, 0) | |
| m3 = np.dstack([mask01 * 255] * 3).astype("uint8") | |
| blended_masked = cv2.bitwise_and(blended, m3) | |
| bg = cv2.bitwise_and(overlay, cv2.bitwise_not(m3)) | |
| overlay = cv2.add(bg, blended_masked) | |
| if rect_box is not None: | |
| cv2.polylines(overlay, [rect_box], True, (255, 255, 255), thickness) | |
| pts = rect_box.reshape(-1, 2) | |
| def midpoint(a, b): return ((a[0] + b[0]) // 2, (a[1] + b[1]) // 2) | |
| mids = [midpoint(pts[i], pts[(i+1) % 4]) for i in range(4)] | |
| e_lens = [np.linalg.norm(pts[i] - pts[(i+1) % 4]) for i in range(4)] | |
| long_pair = (0, 2) if e_lens[0] + e_lens[2] >= e_lens[1] + e_lens[3] else (1, 3) | |
| short_pair = (1, 3) if long_pair == (0, 2) else (0, 2) | |
| def draw_arrow(img, p1, p2): | |
| cv2.arrowedLine(img, p1, p2, (0, 0, 0), thickness + 2, tipLength=0.05) | |
| cv2.arrowedLine(img, p2, p1, (0, 0, 0), thickness + 2, tipLength=0.05) | |
| cv2.arrowedLine(img, p1, p2, (255, 255, 255), thickness, tipLength=0.05) | |
| cv2.arrowedLine(img, p2, p1, (255, 255, 255), thickness, tipLength=0.05) | |
| draw_arrow(overlay, mids[long_pair[0]], mids[long_pair[1]]) | |
| draw_arrow(overlay, mids[short_pair[0]], mids[short_pair[1]]) | |
| def put_label(text, org): | |
| cv2.putText(overlay, text, (org[0] + 4, org[1] - 4), | |
| cv2.FONT_HERSHEY_SIMPLEX, 0.7, (0, 0, 0), 4, cv2.LINE_AA) | |
| cv2.putText(overlay, text, (org[0] + 4, org[1] - 4), | |
| cv2.FONT_HERSHEY_SIMPLEX, 0.7, (255, 255, 255), 2, cv2.LINE_AA) | |
| put_label(f"{length_cm:.2f} cm", mids[long_pair[0]]) | |
| put_label(f"{breadth_cm:.2f} cm", mids[short_pair[0]]) | |
| return overlay | |
| # ---------- AI PROCESSOR ---------- | |
| class AIProcessor: | |
| def __init__(self): | |
| self.models_cache = models_cache | |
| self.knowledge_base_cache = knowledge_base_cache | |
| self.uploads_dir = UPLOADS_DIR | |
| self.dataset_id = DATASET_ID | |
| self.hf_token = HF_TOKEN | |
| def _ensure_analysis_dir(self) -> str: | |
| out_dir = os.path.join(self.uploads_dir, "analysis") | |
| os.makedirs(out_dir, exist_ok=True) | |
| return out_dir | |
| def perform_visual_analysis(self, image_pil: Image.Image) -> Dict: | |
| """ | |
| Detect → crop ROI → (optional) segment → cleanup → largest component → | |
| oriented minAreaRect in cm (EXIF-calibrated) → save original/detect/seg/annotated. | |
| """ | |
| try: | |
| # --- Auto calibration from EXIF --- | |
| px_per_cm, exif_meta = estimate_px_per_cm_from_exif(image_pil, DEFAULT_PX_PER_CM) | |
| # Convert PIL to OpenCV BGR | |
| image_cv = cv2.cvtColor(np.array(image_pil.convert("RGB")), cv2.COLOR_RGB2BGR) | |
| # --- Detection (YOLO) --- | |
| det_model = self.models_cache.get("det") | |
| if det_model is None: | |
| raise RuntimeError("YOLO model not loaded") | |
| results = det_model.predict(image_cv, verbose=False, device="cpu") | |
| if not results or not getattr(results[0], "boxes", None) or len(results[0].boxes) == 0: | |
| import gradio as gr # local import to keep class name intact if gradio missing | |
| raise gr.Error("No wound could be detected.") | |
| box = results[0].boxes[0].xyxy[0].cpu().numpy().astype(int) | |
| x1, y1, x2, y2 = [int(v) for v in box] | |
| x1, y1 = max(0, x1), max(0, y1) | |
| x2, y2 = min(image_cv.shape[1], x2), min(image_cv.shape[0], y2) | |
| roi = image_cv[y1:y2, x1:x2].copy() | |
| if roi.size == 0: | |
| import gradio as gr | |
| raise gr.Error("Detected ROI is empty.") | |
| # --- Segmentation (robust) --- | |
| seg_model = self.models_cache.get("seg") | |
| mask_roi_01 = None | |
| if seg_model is not None: | |
| try: | |
| H, W = _get_seg_hw(seg_model) # robust (H,W) | |
| resized = cv2.resize(roi, (W, H)) # cv2.resize expects (W,H) | |
| pred = seg_model.predict(np.expand_dims(resized / 255.0, 0), verbose=0) | |
| prob = _to_prob(pred) # (H,W) in [0,1] | |
| binmask = _adaptive_threshold(prob, hard=0.5) | |
| # gentle cleanup + largest component | |
| binmask = cv2.morphologyEx(binmask, cv2.MORPH_OPEN, np.ones((3,3), np.uint8), iterations=1) | |
| binmask = cv2.morphologyEx(binmask, cv2.MORPH_CLOSE, np.ones((3,3), np.uint8), iterations=1) | |
| binmask = largest_component_mask(binmask, min_area_px=30) | |
| # back to ROI size {0,1} | |
| mask_roi_01 = cv2.resize(binmask, (roi.shape[1], roi.shape[0]), interpolation=cv2.INTER_NEAREST).astype(np.uint8) | |
| logging.info(f"seg prob stats: min={prob.min():.4f}, max={prob.max():.4f}, mean={prob.mean():.4f}; on={(mask_roi_01==1).sum()}") | |
| except Exception as e: | |
| logging.warning(f"Segmentation failed: {e}") | |
| mask_roi_01 = None | |
| else: | |
| logging.info("Skipping segmentation (no model).") | |
| # --- Measurement --- | |
| if mask_roi_01 is not None and mask_roi_01.any(): | |
| length_cm, breadth_cm, (box_pts, _) = measure_min_area_rect(mask_roi_01, px_per_cm) | |
| surface_area_cm2 = count_area_cm2(mask_roi_01, px_per_cm) | |
| anno_roi = draw_measurement_overlay(roi, mask_roi_01, box_pts, length_cm, breadth_cm) | |
| else: | |
| # fallback to detection-box cm | |
| h_px = max(0, y2 - y1); w_px = max(0, x2 - x1) | |
| length_cm = round(h_px / px_per_cm, 2) | |
| breadth_cm = round(w_px / px_per_cm, 2) | |
| surface_area_cm2 = round((h_px * w_px) / (px_per_cm ** 2), 2) | |
| anno_roi = roi.copy() | |
| # --- Save visualizations --- | |
| out_dir = self._ensure_analysis_dir() | |
| ts = datetime.now().strftime("%Y%m%d_%H%M%S") | |
| original_path = os.path.join(out_dir, f"original_{ts}.png") | |
| cv2.imwrite(original_path, image_cv) | |
| det_vis = image_cv.copy() | |
| cv2.rectangle(det_vis, (x1, y1), (x2, y2), (0, 255, 0), 2) | |
| detection_path = os.path.join(out_dir, f"detection_{ts}.png") | |
| cv2.imwrite(detection_path, det_vis) | |
| segmentation_path = None | |
| annotated_seg_path = None | |
| if mask_roi_01 is not None and mask_roi_01.any(): | |
| # safe masked blend (no mask kwarg to addWeighted) | |
| seg_full = image_cv.copy() | |
| roi_overlay = roi.copy() | |
| red = np.zeros_like(roi_overlay); red[:] = (0, 0, 255) | |
| blended = cv2.addWeighted(roi_overlay, 1.0, red, 0.3, 0) | |
| mask_u8 = (mask_roi_01.astype(np.uint8) * 255) | |
| mask3 = cv2.merge([mask_u8, mask_u8, mask_u8]) | |
| blended_masked = cv2.bitwise_and(blended, mask3) | |
| roi_bg = cv2.bitwise_and(roi_overlay, cv2.bitwise_not(mask3)) | |
| roi_overlay = cv2.add(roi_bg, blended_masked) | |
| seg_full[y1:y2, x1:x2] = roi_overlay | |
| segmentation_path = os.path.join(out_dir, f"segmentation_{ts}.png") | |
| cv2.imwrite(segmentation_path, seg_full) | |
| anno_full = image_cv.copy() | |
| anno_full[y1:y2, x1:x2] = anno_roi | |
| annotated_seg_path = os.path.join(out_dir, f"segmentation_annotated_{ts}.png") | |
| cv2.imwrite(annotated_seg_path, anno_full) | |
| # --- Optional classification --- | |
| wound_type = "Unknown" | |
| cls_pipe = self.models_cache.get("cls") | |
| if cls_pipe is not None: | |
| try: | |
| preds = cls_pipe(Image.fromarray(cv2.cvtColor(roi, cv2.COLOR_BGR2RGB))) | |
| if preds: | |
| wound_type = max(preds, key=lambda x: x.get("score", 0)).get("label", "Unknown") | |
| except Exception as e: | |
| logging.warning(f"Classification failed: {e}") | |
| return { | |
| "wound_type": wound_type, | |
| "length_cm": length_cm, | |
| "breadth_cm": breadth_cm, | |
| "surface_area_cm2": surface_area_cm2, | |
| "px_per_cm": round(px_per_cm, 2), | |
| "calibration_meta": exif_meta, | |
| "detection_confidence": float(results[0].boxes.conf[0].cpu().item()) | |
| if getattr(results[0].boxes, "conf", None) is not None else 0.0, | |
| "detection_image_path": detection_path, | |
| "segmentation_image_path": segmentation_path, | |
| "segmentation_annotated_path": annotated_seg_path, | |
| "original_image_path": original_path, | |
| } | |
| except Exception as e: | |
| logging.error(f"Visual analysis failed: {e}", exc_info=True) | |
| raise | |
| # ---------- Knowledge base and reporting stay unchanged ---------- | |
| def query_guidelines(self, query: str) -> str: | |
| try: | |
| vs = self.knowledge_base_cache.get("vector_store") | |
| if not vs: | |
| return "Knowledge base is not available." | |
| try: | |
| retriever = vs.as_retriever(search_kwargs={"k": 5}) | |
| docs = retriever.get_relevant_documents(query) | |
| except Exception: | |
| retriever = vs.as_retriever(search_kwargs={"k": 5}) | |
| docs = retriever.invoke(query) | |
| lines: List[str] = [] | |
| for d in docs: | |
| src = (d.metadata or {}).get("source", "N/A") | |
| txt = (d.page_content or "")[:300] | |
| lines.append(f"Source: {src}\nContent: {txt}...") | |
| return "\n\n".join(lines) if lines else "No relevant guideline snippets found." | |
| except Exception as e: | |
| logging.warning(f"Guidelines query failed: {e}") | |
| return f"Guidelines query failed: {str(e)}" | |
| def _generate_fallback_report(self, patient_info: str, visual_results: Dict, guideline_context: str) -> str: | |
| return f"""# 🩺 SmartHeal AI - Comprehensive Wound Analysis Report | |
| ## 📋 Patient Information | |
| {patient_info} | |
| ## 🔍 Visual Analysis Results | |
| - **Wound Type**: {visual_results.get('wound_type', 'Unknown')} | |
| - **Dimensions**: {visual_results.get('length_cm', 0)} cm × {visual_results.get('breadth_cm', 0)} cm | |
| - **Surface Area**: {visual_results.get('surface_area_cm2', 0)} cm² | |
| - **Detection Confidence**: {visual_results.get('detection_confidence', 0):.1%} | |
| - **Calibration**: {visual_results.get('px_per_cm','?')} px/cm ({(visual_results.get('calibration_meta') or {}).get('used','default')}) | |
| ## 📊 Analysis Images | |
| - **Original**: {visual_results.get('original_image_path', 'N/A')} | |
| - **Detection**: {visual_results.get('detection_image_path', 'N/A')} | |
| - **Segmentation**: {visual_results.get('segmentation_image_path', 'N/A')} | |
| - **Annotated**: {visual_results.get('segmentation_annotated_path', 'N/A')} | |
| ## 🎯 Clinical Summary | |
| Automated analysis provides quantitative measurements; verify via clinical examination. | |
| ## 💊 Recommendations | |
| - Cleanse wound gently; select dressing per exudate/infection risk | |
| - Debride necrotic tissue if indicated (clinical decision) | |
| - Document with serial photos and measurements | |
| ## 📅 Monitoring | |
| - Daily in week 1, then every 2–3 days (or as indicated) | |
| - Weekly progress review | |
| ## 📚 Guideline Context | |
| {(guideline_context or '')[:800]}{"..." if guideline_context and len(guideline_context) > 800 else ''} | |
| **Disclaimer:** Automated, for decision support only. Verify clinically. | |
| """ | |
| def generate_final_report( | |
| self, | |
| patient_info: str, | |
| visual_results: Dict, | |
| guideline_context: str, | |
| image_pil: Image.Image, | |
| max_new_tokens: Optional[int] = None, | |
| ) -> str: | |
| try: | |
| report = generate_medgemma_report( | |
| patient_info, visual_results, guideline_context, image_pil, max_new_tokens | |
| ) | |
| if report and report.strip() and not report.startswith(("⚠️", "❌")): | |
| return report | |
| logging.warning("MedGemma unavailable/invalid; using fallback.") | |
| return self._generate_fallback_report(patient_info, visual_results, guideline_context) | |
| except Exception as e: | |
| logging.error(f"Report generation failed: {e}") | |
| return self._generate_fallback_report(patient_info, visual_results, guideline_context) | |
| def save_and_commit_image(self, image_pil: Image.Image) -> str: | |
| try: | |
| os.makedirs(self.uploads_dir, exist_ok=True) | |
| ts = datetime.now().strftime("%Y%m%d_%H%M%S") | |
| filename = f"{ts}.png" | |
| path = os.path.join(self.uploads_dir, filename) | |
| image_pil.convert("RGB").save(path) | |
| logging.info(f"✅ Image saved locally: {path}") | |
| if HF_TOKEN and DATASET_ID: | |
| try: | |
| HfApi, HfFolder = _import_hf_hub() | |
| HfFolder.save_token(HF_TOKEN) | |
| api = HfApi() | |
| api.upload_file( | |
| path_or_fileobj=path, | |
| path_in_repo=f"images/{filename}", | |
| repo_id=DATASET_ID, | |
| repo_type="dataset", | |
| token=HF_TOKEN, | |
| commit_message=f"Upload wound image: {filename}", | |
| ) | |
| logging.info("✅ Image committed to HF dataset") | |
| except Exception as e: | |
| logging.warning(f"HF upload failed: {e}") | |
| return path | |
| except Exception as e: | |
| logging.error(f"Failed to save/commit image: {e}") | |
| return "" | |
| def full_analysis_pipeline(self, image_pil: Image.Image, questionnaire_data: Dict) -> Dict: | |
| try: | |
| saved_path = self.save_and_commit_image(image_pil) | |
| visual_results = self.perform_visual_analysis(image_pil) | |
| pi = questionnaire_data or {} | |
| patient_info = ( | |
| f"Age: {pi.get('age','N/A')}, " | |
| f"Diabetic: {pi.get('diabetic','N/A')}, " | |
| f"Allergies: {pi.get('allergies','N/A')}, " | |
| f"Date of Wound: {pi.get('date_of_injury','N/A')}, " | |
| f"Professional Care: {pi.get('professional_care','N/A')}, " | |
| f"Oozing/Bleeding: {pi.get('oozing_bleeding','N/A')}, " | |
| f"Infection: {pi.get('infection','N/A')}, " | |
| f"Moisture: {pi.get('moisture','N/A')}" | |
| ) | |
| query = ( | |
| f"best practices for managing a {visual_results.get('wound_type','Unknown')} " | |
| f"with moisture '{pi.get('moisture','unknown')}' and infection '{pi.get('infection','unknown')}' " | |
| f"in a diabetic status '{pi.get('diabetic','unknown')}'" | |
| ) | |
| guideline_context = self.query_guidelines(query) | |
| report = self.generate_final_report(patient_info, visual_results, guideline_context, image_pil) | |
| return { | |
| "success": True, | |
| "visual_analysis": visual_results, | |
| "report": report, | |
| "saved_image_path": saved_path, | |
| "guideline_context": (guideline_context or "")[:500] + ( | |
| "..." if guideline_context and len(guideline_context) > 500 else "" | |
| ), | |
| } | |
| except Exception as e: | |
| logging.error(f"Pipeline error: {e}") | |
| return { | |
| "success": False, | |
| "error": str(e), | |
| "visual_analysis": {}, | |
| "report": f"Analysis failed: {str(e)}", | |
| "saved_image_path": None, | |
| "guideline_context": "", | |
| } | |
| def analyze_wound(self, image, questionnaire_data: Dict) -> Dict: | |
| try: | |
| if isinstance(image, str): | |
| if not os.path.exists(image): | |
| raise ValueError(f"Image file not found: {image}") | |
| image_pil = Image.open(image) | |
| elif isinstance(image, Image.Image): | |
| image_pil = image | |
| elif isinstance(image, np.ndarray): | |
| image_pil = Image.fromarray(image) | |
| else: | |
| raise ValueError(f"Unsupported image type: {type(image)}") | |
| return self.full_analysis_pipeline(image_pil, questionnaire_data or {}) | |
| except Exception as e: | |
| logging.error(f"Wound analysis error: {e}") | |
| return { | |
| "success": False, | |
| "error": str(e), | |
| "visual_analysis": {}, | |
| "report": f"Analysis initialization failed: {str(e)}", | |
| "saved_image_path": None, | |
| "guideline_context": "", | |
| } |