Spaces:
Sleeping
Sleeping
| # smartheal_ai_processor.py | |
| # Full, functional module with an always-present @spaces.GPU function (if `spaces` is importable) | |
| # and robust CPU fallbacks to avoid crashes when GPU isn't actually available yet. | |
| # + Automatic calibration (px/cm) and measurement overlay on segmentation. | |
| import os | |
| import time | |
| import logging | |
| from datetime import datetime | |
| from typing import Optional, Dict, List, Tuple, Union | |
| import cv2 | |
| import numpy as np | |
| from PIL import Image, TiffImagePlugin | |
| # =============== LOGGING =============== | |
| logging.basicConfig(level=logging.INFO, format="%(asctime)s - %(levelname)s - %(message)s") | |
| # =============== CONFIG =============== | |
| UPLOADS_DIR = "uploads" | |
| os.makedirs(UPLOADS_DIR, exist_ok=True) | |
| HF_TOKEN = os.getenv("HF_TOKEN", None) | |
| YOLO_MODEL_PATH = "src/best.pt" | |
| SEG_MODEL_PATH = "src/segmentation_model.h5" # optional | |
| GUIDELINE_PDFS = ["src/eHealth in Wound Care.pdf", "src/IWGDF Guideline.pdf", "src/evaluation.pdf"] | |
| DATASET_ID = "SmartHeal/wound-image-uploads" # optional (requires HF_TOKEN) | |
| # Fallback px/cm if we cannot calibrate from EXIF | |
| DEFAULT_PIXELS_PER_CM = 38.0 | |
| # =============== CACHES =============== | |
| models_cache: Dict[str, object] = {} | |
| knowledge_base_cache: Dict[str, object] = {} | |
| # =============== Optional imports (lazy) =============== | |
| def _import_ultralytics(): | |
| from ultralytics import YOLO | |
| return YOLO | |
| def _import_tf_loader(): | |
| import tensorflow as tf | |
| tf.config.set_visible_devices([], "GPU") # force CPU for TF | |
| from tensorflow.keras.models import load_model | |
| return load_model | |
| def _import_hf_cls(): | |
| from transformers import pipeline | |
| return pipeline | |
| def _import_embeddings(): | |
| from langchain_community.embeddings import HuggingFaceEmbeddings | |
| return HuggingFaceEmbeddings | |
| def _import_langchain_pdf(): | |
| from langchain_community.document_loaders import PyPDFLoader | |
| return PyPDFLoader | |
| def _import_langchain_faiss(): | |
| from langchain_community.vectorstores import FAISS | |
| return FAISS | |
| def _import_hf_hub(): | |
| from huggingface_hub import HfApi, HfFolder | |
| return HfApi, HfFolder | |
| # =============== Spaces GPU function (always defined if `spaces` import works) =============== | |
| try: | |
| import spaces | |
| def generate_medgemma_report( | |
| patient_info: str, | |
| visual_results: Dict, | |
| guideline_context: str, | |
| image_pil: Image.Image, | |
| max_new_tokens: Optional[int] = None, | |
| ) -> str: | |
| """ | |
| This function MUST exist at import time so Spaces Zero detects it. | |
| It is guarded internally so if anything fails (no GPU yet, model load error), | |
| it returns a warning and your pipeline will use the fallback report. | |
| """ | |
| try: | |
| import torch | |
| from transformers import pipeline | |
| # Try to free cache; if no CUDA, this will raise and we return a warning. | |
| try: | |
| if hasattr(torch, "cuda") and torch.cuda.is_available(): | |
| torch.cuda.empty_cache() | |
| except Exception: | |
| pass | |
| prompt = f""" | |
| You are a medical AI assistant. Analyze this wound image and patient data. | |
| Patient: {patient_info} | |
| Wound: {visual_results.get('wound_type', 'Unknown')} - {visual_results.get('length_cm', 0)}×{visual_results.get('breadth_cm', 0)} cm | |
| Provide a structured report with: | |
| 1. Clinical Summary | |
| 2. Treatment Recommendations | |
| 3. Risk Assessment | |
| 4. Monitoring Plan | |
| """.strip() | |
| pipe = pipeline( | |
| "image-text-to-text", | |
| model="google/medgemma-4b-it", | |
| torch_dtype=getattr(torch, "bfloat16", None), | |
| device_map="auto", | |
| token=HF_TOKEN, | |
| model_kwargs={"low_cpu_mem_usage": True, "use_cache": True}, | |
| ) | |
| messages = [{"role": "user", "content": [ | |
| {"type": "image", "image": image_pil}, | |
| {"type": "text", "text": prompt}, | |
| ]}] | |
| t0 = time.time() | |
| out = pipe( | |
| text=messages, | |
| max_new_tokens=max_new_tokens or 800, | |
| do_sample=False, | |
| temperature=0.7, | |
| pad_token_id=pipe.tokenizer.eos_token_id, | |
| ) | |
| logging.info(f"✅ MedGemma finished in {time.time()-t0:.2f}s") | |
| if out and len(out) > 0: | |
| # Defensive extraction (different transformers versions) | |
| try: | |
| return out[0]["generated_text"][-1].get("content", "").strip() or "⚠️ Empty response" | |
| except Exception: | |
| return (out[0].get("generated_text", "") or "").strip() or "⚠️ Empty response" | |
| return "⚠️ No output generated" | |
| except Exception as e: | |
| logging.error(f"❌ MedGemma generation error: {e}") | |
| return "⚠️ GPU worker unavailable" | |
| except Exception: | |
| # If `spaces` cannot be imported locally, expose a CPU-safe stub with same signature. | |
| def generate_medgemma_report( | |
| patient_info: str, | |
| visual_results: Dict, | |
| guideline_context: str, | |
| image_pil: Image.Image, | |
| max_new_tokens: Optional[int] = None, | |
| ) -> str: | |
| return "⚠️ GPU not available" | |
| # =============== Model init (CPU-safe) =============== | |
| def load_yolo_model(): | |
| YOLO = _import_ultralytics() | |
| return YOLO(YOLO_MODEL_PATH) | |
| def load_segmentation_model(): | |
| load_model = _import_tf_loader() | |
| return load_model(SEG_MODEL_PATH, compile=False) | |
| def load_classification_pipeline(): | |
| pipe = _import_hf_cls() | |
| return pipe("image-classification", model="Hemg/Wound-classification", token=HF_TOKEN, device="cpu") | |
| def load_embedding_model(): | |
| Emb = _import_embeddings() | |
| return Emb(model_name="sentence-transformers/all-MiniLM-L6-v2", model_kwargs={"device": "cpu"}) | |
| def initialize_cpu_models() -> None: | |
| if HF_TOKEN: | |
| try: | |
| HfApi, HfFolder = _import_hf_hub() | |
| HfFolder.save_token(HF_TOKEN) | |
| logging.info("✅ HF token set") | |
| except Exception as e: | |
| logging.warning(f"HF token save failed: {e}") | |
| if "det" not in models_cache: | |
| try: | |
| models_cache["det"] = load_yolo_model() | |
| logging.info("✅ YOLO loaded (CPU)") | |
| except Exception as e: | |
| logging.error(f"YOLO load failed: {e}") | |
| if "seg" not in models_cache: | |
| try: | |
| if os.path.exists(SEG_MODEL_PATH): | |
| models_cache["seg"] = load_segmentation_model() | |
| logging.info("✅ Segmentation model loaded (CPU)") | |
| else: | |
| models_cache["seg"] = None | |
| logging.warning("Segmentation model file missing; skipping.") | |
| except Exception as e: | |
| models_cache["seg"] = None | |
| logging.warning(f"Segmentation unavailable: {e}") | |
| if "cls" not in models_cache: | |
| try: | |
| models_cache["cls"] = load_classification_pipeline() | |
| logging.info("✅ Classifier loaded (CPU)") | |
| except Exception as e: | |
| models_cache["cls"] = None | |
| logging.warning(f"Classifier unavailable: {e}") | |
| if "embedding_model" not in models_cache: | |
| try: | |
| models_cache["embedding_model"] = load_embedding_model() | |
| logging.info("✅ Embeddings loaded (CPU)") | |
| except Exception as e: | |
| models_cache["embedding_model"] = None | |
| logging.warning(f"Embeddings unavailable: {e}") | |
| def setup_knowledge_base() -> None: | |
| if "vector_store" in knowledge_base_cache: | |
| return | |
| docs: List = [] | |
| try: | |
| PyPDFLoader = _import_langchain_pdf() | |
| for pdf in GUIDELINE_PDFS: | |
| if os.path.exists(pdf): | |
| try: | |
| docs.extend(PyPDFLoader(pdf).load()) | |
| logging.info(f"Loaded PDF: {pdf}") | |
| except Exception as e: | |
| logging.warning(f"PDF load failed ({pdf}): {e}") | |
| except Exception as e: | |
| logging.warning(f"LangChain PDF loader unavailable: {e}") | |
| if docs and models_cache.get("embedding_model"): | |
| try: | |
| from langchain.text_splitter import RecursiveCharacterTextSplitter | |
| FAISS = _import_langchain_faiss() | |
| chunks = RecursiveCharacterTextSplitter(chunk_size=1000, chunk_overlap=100).split_documents(docs) | |
| knowledge_base_cache["vector_store"] = FAISS.from_documents(chunks, models_cache["embedding_model"]) | |
| logging.info(f"✅ Knowledge base ready ({len(chunks)} chunks)") | |
| except Exception as e: | |
| knowledge_base_cache["vector_store"] = None | |
| logging.warning(f"KB build failed: {e}") | |
| else: | |
| knowledge_base_cache["vector_store"] = None | |
| logging.warning("KB disabled (no docs or embeddings).") | |
| # Initialize on import so app is ready | |
| initialize_cpu_models() | |
| setup_knowledge_base() | |
| # =============== Utility: EXIF-based auto calibration =============== | |
| def _rational_to_float(val) -> Optional[float]: | |
| try: | |
| if isinstance(val, TiffImagePlugin.IFDRational): | |
| return float(val.numerator) / float(val.denominator or 1) | |
| if isinstance(val, tuple) and len(val) == 2 and all(isinstance(x, (int, float)) for x in val): | |
| # (num, den) | |
| den = val[1] if val[1] else 1.0 | |
| return float(val[0]) / float(den) | |
| return float(val) | |
| except Exception: | |
| return None | |
| def _auto_pixels_per_cm_from_exif(image_pil: Image.Image) -> Tuple[float, str]: | |
| """ | |
| Try several EXIF / info sources to estimate pixels-per-cm. | |
| Return (px_per_cm, source_str). | |
| NOTE: Many phones set DPI metadata arbitrarily; we clamp to a sensible range and | |
| fall back to DEFAULT_PIXELS_PER_CM if values look bogus. | |
| """ | |
| # 1) PIL .info["dpi"] | |
| try: | |
| dpi_info = image_pil.info.get("dpi") | |
| if isinstance(dpi_info, (tuple, list)) and len(dpi_info) >= 1: | |
| xdpi = float(dpi_info[0]) if dpi_info[0] else None | |
| if xdpi and 40 <= xdpi <= 1200: | |
| ppcm = xdpi / 2.54 | |
| if 5 <= ppcm <= 500: | |
| return ppcm, "dpi_info" | |
| except Exception: | |
| pass | |
| # 2) EXIF XResolution (282), YResolution (283), ResolutionUnit (296) [2 = inch, 3 = cm] | |
| try: | |
| exif = image_pil.getexif() | |
| if exif: | |
| xres = _rational_to_float(exif.get(282)) # XResolution | |
| unit = int(exif.get(296) or 2) # default to inches | |
| if xres: | |
| if unit == 3: # per cm | |
| if 5 <= xres <= 500: | |
| return xres, "EXIF_XRes_cm" | |
| else: # per inch | |
| ppcm = xres / 2.54 | |
| if 5 <= ppcm <= 500: | |
| return ppcm, "EXIF_XRes_in" | |
| except Exception: | |
| pass | |
| # 3) Heuristic fallback | |
| return DEFAULT_PIXELS_PER_CM, "default" | |
| # =============== Drawing helpers =============== | |
| def _draw_measurement_overlay( | |
| base_bgr: np.ndarray, | |
| rect_xywh: Tuple[int, int, int, int], | |
| length_cm: float, | |
| breadth_cm: float, | |
| ) -> np.ndarray: | |
| """ | |
| Draw arrows for vertical (length) and horizontal (breadth) on top of base image. | |
| rect_xywh is relative to base_bgr. | |
| """ | |
| x, y, w, h = rect_xywh | |
| img = base_bgr.copy() | |
| # Colors (BGR) and styling | |
| color = (255, 255, 255) # white | |
| shadow = (0, 0, 0) # black outline | |
| thickness = 2 | |
| font = cv2.FONT_HERSHEY_SIMPLEX | |
| # --- Horizontal arrow (breadth) --- | |
| y_mid = y + h // 2 | |
| x_left = x | |
| x_right = x + w | |
| # shadow line | |
| cv2.arrowedLine(img, (x_left, y_mid+1), (x_right, y_mid+1), shadow, thickness+2, cv2.LINE_AA, tipLength=0.02) | |
| # main line | |
| cv2.arrowedLine(img, (x_left, y_mid), (x_right, y_mid), color, thickness, cv2.LINE_AA, tipLength=0.02) | |
| # breadth label | |
| label_b = f"{breadth_cm:.2f} cm" | |
| (tw, th), _ = cv2.getTextSize(label_b, font, 0.7, 2) | |
| tx = x + (w - tw) // 2 | |
| ty = y_mid - 8 | |
| cv2.putText(img, label_b, (tx+1, ty+1), font, 0.7, shadow, 3, cv2.LINE_AA) | |
| cv2.putText(img, label_b, (tx, ty), font, 0.7, color, 2, cv2.LINE_AA) | |
| # --- Vertical arrow (length) --- | |
| x_mid = x + w // 2 | |
| y_top = y | |
| y_bottom = y + h | |
| # shadow line | |
| cv2.arrowedLine(img, (x_mid+1, y_top), (x_mid+1, y_bottom), shadow, thickness+2, cv2.LINE_AA, tipLength=0.02) | |
| # main line | |
| cv2.arrowedLine(img, (x_mid, y_top), (x_mid, y_bottom), color, thickness, cv2.LINE_AA, tipLength=0.02) | |
| # length label | |
| label_l = f"{length_cm:.2f} cm" | |
| (tw2, th2), _ = cv2.getTextSize(label_l, font, 0.7, 2) | |
| tx2 = x_mid - (tw2 // 2) | |
| ty2 = y + th2 + 8 | |
| cv2.putText(img, label_l, (tx2+1, ty2+1), font, 0.7, shadow, 3, cv2.LINE_AA) | |
| cv2.putText(img, label_l, (tx2, ty2), font, 0.7, color, 2, cv2.LINE_AA) | |
| return img | |
| # =============== AI PROCESSOR =============== | |
| class AIProcessor: | |
| def __init__(self): | |
| self.models_cache = models_cache | |
| self.knowledge_base_cache = knowledge_base_cache | |
| self.uploads_dir = UPLOADS_DIR | |
| self.dataset_id = DATASET_ID | |
| self.hf_token = HF_TOKEN | |
| def _ensure_analysis_dir(self) -> str: | |
| out_dir = os.path.join(self.uploads_dir, "analysis") | |
| os.makedirs(out_dir, exist_ok=True) | |
| return out_dir | |
| def perform_visual_analysis(self, image_pil: Image.Image) -> Dict: | |
| """YOLO detect → (optional) Keras seg → (optional) HF classify → save visuals with measurement overlay.""" | |
| try: | |
| image_rgb = image_pil.convert("RGB") | |
| image_cv = cv2.cvtColor(np.array(image_rgb), cv2.COLOR_RGB2BGR) | |
| det = self.models_cache.get("det") | |
| if det is None: | |
| raise RuntimeError("YOLO model not loaded") | |
| # ---------- Automatic calibration (px/cm) ---------- | |
| px_per_cm, calib_src = _auto_pixels_per_cm_from_exif(image_rgb) | |
| # keep within reasonable range | |
| if not (5.0 <= px_per_cm <= 500.0): | |
| px_per_cm, calib_src = DEFAULT_PIXELS_PER_CM, "default" | |
| logging.info(f"Calibration: {px_per_cm:.2f} px/cm (source={calib_src})") | |
| # YOLO on CPU | |
| results = det.predict(image_cv, verbose=False, device="cpu") | |
| if not results or not getattr(results[0], "boxes", None) or len(results[0].boxes) == 0: | |
| raise ValueError("No wound could be detected.") | |
| box = results[0].boxes[0].xyxy[0].cpu().numpy().astype(int) | |
| x1, y1, x2, y2 = [int(v) for v in box] | |
| x1, y1 = max(0, x1), max(0, y1) | |
| x2, y2 = min(image_cv.shape[1], x2), min(image_cv.shape[0], y2) | |
| detected_region_cv = image_cv[y1:y2, x1:x2] | |
| # Optional segmentation | |
| seg_model = self.models_cache.get("seg") | |
| length_cm = breadth_cm = surface_area_cm2 = 0.0 | |
| seg_path = None | |
| rect_xywh_global = None # for overlay on full image if seg missing | |
| if seg_model is not None and detected_region_cv.size > 0: | |
| try: | |
| input_size = seg_model.input_shape[1:3] | |
| resized = cv2.resize(detected_region_cv, (input_size[1], input_size[0])) | |
| mask_pred = seg_model.predict(np.expand_dims(resized / 255.0, 0), verbose=0)[0] | |
| mask_np = (mask_pred[:, :, 0] > 0.5).astype(np.uint8) | |
| contours, _ = cv2.findContours(mask_np, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE) | |
| if contours: | |
| cnt = max(contours, key=cv2.contourArea) | |
| x, y, w, h = cv2.boundingRect(cnt) | |
| # Measurements using calibration | |
| length_cm = round(h / px_per_cm, 2) | |
| breadth_cm = round(w / px_per_cm, 2) | |
| surface_area_cm2 = round(cv2.contourArea(cnt) / (px_per_cm ** 2), 2) | |
| # Create segmentation overlay in the cropped region | |
| mask_resized = cv2.resize( | |
| mask_np * 255, | |
| (detected_region_cv.shape[1], detected_region_cv.shape[0]), | |
| interpolation=cv2.INTER_NEAREST, | |
| ) | |
| overlay = detected_region_cv.copy() | |
| overlay[mask_resized > 127] = [0, 0, 255] # red overlay | |
| seg_vis = cv2.addWeighted(detected_region_cv, 0.7, overlay, 0.3, 0) | |
| # Draw measurement arrows on seg_vis | |
| # Map rect from mask space -> cropped image space | |
| scale_x = detected_region_cv.shape[1] / float(input_size[1]) | |
| scale_y = detected_region_cv.shape[0] / float(input_size[0]) | |
| rect_xywh_cropped = ( | |
| int(x * scale_x), | |
| int(y * scale_y), | |
| int(w * scale_x), | |
| int(h * scale_y), | |
| ) | |
| seg_vis_meas = _draw_measurement_overlay(seg_vis, rect_xywh_cropped, length_cm, breadth_cm) | |
| ts = datetime.now().strftime("%Y%m%d_%H%M%S") | |
| out_dir = self._ensure_analysis_dir() | |
| seg_path = os.path.join(out_dir, f"segmentation_{ts}.png") | |
| cv2.imwrite(seg_path, seg_vis_meas) | |
| # Also store rect in full-image coordinates (if ever needed) | |
| rect_xywh_global = ( | |
| x1 + rect_xywh_cropped[0], | |
| y1 + rect_xywh_cropped[1], | |
| rect_xywh_cropped[2], | |
| rect_xywh_cropped[3], | |
| ) | |
| except Exception as e: | |
| logging.warning(f"Segmentation skipped: {e}") | |
| # Optional classification | |
| wound_type = "Unknown" | |
| cls_pipe = self.models_cache.get("cls") | |
| if cls_pipe is not None: | |
| try: | |
| detected_image_pil = Image.fromarray(cv2.cvtColor(detected_region_cv, cv2.COLOR_BGR2RGB)) | |
| preds = cls_pipe(detected_image_pil) | |
| if preds: | |
| wound_type = max(preds, key=lambda x: x.get("score", 0)).get("label", "Unknown") | |
| except Exception as e: | |
| logging.warning(f"Classification failed: {e}") | |
| # Save detection & original | |
| out_dir = self._ensure_analysis_dir() | |
| ts = datetime.now().strftime("%Y%m%d_%H%M%S") | |
| det_vis = image_cv.copy() | |
| cv2.rectangle(det_vis, (x1, y1), (x2, y2), (0, 255, 0), 2) | |
| det_path = os.path.join(out_dir, f"detection_{ts}.png") | |
| cv2.imwrite(det_path, det_vis) | |
| original_path = os.path.join(out_dir, f"original_{ts}.png") | |
| cv2.imwrite(original_path, image_cv) | |
| return { | |
| "wound_type": wound_type, | |
| "length_cm": float(length_cm), | |
| "breadth_cm": float(breadth_cm), | |
| "surface_area_cm2": float(surface_area_cm2), | |
| "calibration_px_per_cm": float(px_per_cm), | |
| "calibration_source": calib_src, | |
| "detection_confidence": float(results[0].boxes.conf[0].cpu().item()) | |
| if getattr(results[0].boxes, "conf", None) is not None | |
| else 0.0, | |
| "detection_image_path": det_path, | |
| "segmentation_image_path": seg_path, # <-- now includes arrow overlay if seg succeeded | |
| "original_image_path": original_path, | |
| } | |
| except Exception as e: | |
| logging.error(f"Visual analysis failed: {e}") | |
| raise | |
| def query_guidelines(self, query: str) -> str: | |
| """Query the (optional) guideline knowledge base.""" | |
| try: | |
| vs = self.knowledge_base_cache.get("vector_store") | |
| if not vs: | |
| return "Knowledge base is not available." | |
| try: | |
| retriever = vs.as_retriever(search_kwargs={"k": 5}) | |
| docs = retriever.get_relevant_documents(query) # LC >= 0.2 | |
| except Exception: | |
| retriever = vs.as_retriever(search_kwargs={"k": 5}) | |
| docs = retriever.invoke(query) # older LC | |
| lines: List[str] = [] | |
| for d in docs: | |
| src = (d.metadata or {}).get("source", "N/A") | |
| txt = (d.page_content or "")[:300] | |
| lines.append(f"Source: {src}\nContent: {txt}...") | |
| return "\n\n".join(lines) if lines else "No relevant guideline snippets found." | |
| except Exception as e: | |
| logging.warning(f"Guidelines query failed: {e}") | |
| return f"Guidelines query failed: {str(e)}" | |
| def _generate_fallback_report(self, patient_info: str, visual_results: Dict, guideline_context: str) -> str: | |
| return f"""# 🩺 SmartHeal AI - Comprehensive Wound Analysis Report | |
| ## 📋 Patient Information | |
| {patient_info} | |
| ## 🔍 Visual Analysis Results | |
| - **Wound Type**: {visual_results.get('wound_type', 'Unknown')} | |
| - **Dimensions**: {visual_results.get('length_cm', 0)} cm × {visual_results.get('breadth_cm', 0)} cm | |
| - **Surface Area**: {visual_results.get('surface_area_cm2', 0)} cm² | |
| - **Detection Confidence**: {visual_results.get('detection_confidence', 0):.1%} | |
| - **Calibration**: {visual_results.get('calibration_px_per_cm', 0)} px/cm (source: {visual_results.get('calibration_source','n/a')}) | |
| ## 📊 Analysis Images | |
| - **Original**: {visual_results.get('original_image_path', 'N/A')} | |
| - **Detection**: {visual_results.get('detection_image_path', 'N/A')} | |
| - **Segmentation (with measurements)**: {visual_results.get('segmentation_image_path', 'N/A')} | |
| ## 🎯 Clinical Summary | |
| Automated analysis provides quantitative measurements; verify via clinical examination. | |
| ## 💊 Recommendations | |
| - Cleanse wound gently; select dressing per exudate/infection risk | |
| - Debride necrotic tissue if indicated (clinical decision) | |
| - Document with serial photos and measurements | |
| ## 📅 Monitoring | |
| - Daily in week 1, then every 2–3 days (or as indicated) | |
| - Weekly progress review | |
| ## 📚 Guideline Context | |
| {(guideline_context or '')[:800]}{"..." if guideline_context and len(guideline_context) > 800 else ''} | |
| **Disclaimer:** Automated, for decision support only. Verify clinically. | |
| """ | |
| def generate_final_report( | |
| self, | |
| patient_info: str, | |
| visual_results: Dict, | |
| guideline_context: str, | |
| image_pil: Image.Image, | |
| max_new_tokens: Optional[int] = None, | |
| ) -> str: | |
| """Use GPU path when available, fallback otherwise.""" | |
| try: | |
| report = generate_medgemma_report( | |
| patient_info, visual_results, guideline_context, image_pil, max_new_tokens | |
| ) | |
| if report and report.strip() and not report.startswith(("⚠️", "❌")): | |
| return report | |
| logging.warning("MedGemma unavailable/invalid; using fallback.") | |
| return self._generate_fallback_report(patient_info, visual_results, guideline_context) | |
| except Exception as e: | |
| logging.error(f"Report generation failed: {e}") | |
| return self._generate_fallback_report(patient_info, visual_results, guideline_context) | |
| def save_and_commit_image(self, image_pil: Image.Image) -> str: | |
| """Save locally and (optionally) upload to HF dataset.""" | |
| try: | |
| os.makedirs(self.uploads_dir, exist_ok=True) | |
| ts = datetime.now().strftime("%Y%m%d_%H%M%S") | |
| filename = f"{ts}.png" | |
| path = os.path.join(self.uploads_dir, filename) | |
| image_pil.convert("RGB").save(path) | |
| logging.info(f"✅ Image saved locally: {path}") | |
| if HF_TOKEN and DATASET_ID: | |
| try: | |
| HfApi, HfFolder = _import_hf_hub() | |
| HfFolder.save_token(HF_TOKEN) | |
| api = HfApi() | |
| api.upload_file( | |
| path_or_fileobj=path, | |
| path_in_repo=f"images/{filename}", | |
| repo_id=DATASET_ID, | |
| repo_type="dataset", | |
| token=HF_TOKEN, | |
| commit_message=f"Upload wound image: {filename}", | |
| ) | |
| logging.info("✅ Image committed to HF dataset") | |
| except Exception as e: | |
| logging.warning(f"HF upload failed: {e}") | |
| return path | |
| except Exception as e: | |
| logging.error(f"Failed to save/commit image: {e}") | |
| return "" | |
| def full_analysis_pipeline(self, image_pil: Image.Image, questionnaire_data: Dict) -> Dict: | |
| """End-to-end analysis.""" | |
| try: | |
| saved_path = self.save_and_commit_image(image_pil) | |
| visual_results = self.perform_visual_analysis(image_pil) | |
| pi = questionnaire_data or {} | |
| patient_info = ( | |
| f"Age: {pi.get('age','N/A')}, " | |
| f"Diabetic: {pi.get('diabetic','N/A')}, " | |
| f"Allergies: {pi.get('allergies','N/A')}, " | |
| f"Date of Wound: {pi.get('date_of_injury','N/A')}, " | |
| f"Professional Care: {pi.get('professional_care','N/A')}, " | |
| f"Oozing/Bleeding: {pi.get('oozing_bleeding','N/A')}, " | |
| f"Infection: {pi.get('infection','N/A')}, " | |
| f"Moisture: {pi.get('moisture','N/A')}" | |
| ) | |
| query = ( | |
| f"best practices for managing a {visual_results.get('wound_type','Unknown')} " | |
| f"with moisture '{pi.get('moisture','unknown')}' and infection '{pi.get('infection','unknown')}' " | |
| f"in a diabetic status '{pi.get('diabetic','unknown')}'" | |
| ) | |
| guideline_context = self.query_guidelines(query) | |
| report = self.generate_final_report(patient_info, visual_results, guideline_context, image_pil) | |
| return { | |
| "success": True, | |
| "visual_analysis": visual_results, | |
| "report": report, | |
| "saved_image_path": saved_path, | |
| "guideline_context": (guideline_context or "")[:500] + ( | |
| "..." if guideline_context and len(guideline_context) > 500 else "" | |
| ), | |
| } | |
| except Exception as e: | |
| logging.error(f"Pipeline error: {e}") | |
| return { | |
| "success": False, | |
| "error": str(e), | |
| "visual_analysis": {}, | |
| "report": f"Analysis failed: {str(e)}", | |
| "saved_image_path": None, | |
| "guideline_context": "", | |
| } | |
| def analyze_wound(self, image, questionnaire_data: Dict) -> Dict: | |
| """Public entrypoint used by UI.""" | |
| try: | |
| if isinstance(image, str): | |
| if not os.path.exists(image): | |
| raise ValueError(f"Image file not found: {image}") | |
| image_pil = Image.open(image) | |
| elif isinstance(image, Image.Image): | |
| image_pil = image | |
| elif isinstance(image, np.ndarray): | |
| image_pil = Image.fromarray(image) | |
| else: | |
| raise ValueError(f"Unsupported image type: {type(image)}") | |
| return self.full_analysis_pipeline(image_pil, questionnaire_data or {}) | |
| except Exception as e: | |
| logging.error(f"Wound analysis error: {e}") | |
| return { | |
| "success": False, | |
| "error": str(e), | |
| "visual_analysis": {}, | |
| "report": f"Analysis initialization failed: {str(e)}", | |
| "saved_image_path": None, | |
| "guideline_context": "", | |
| } | |