""" SmartHeal AI Processor - Zero GPU Compatible Version Designed specifically for Hugging Face Spaces with Zero GPU architecture """ import os import logging import cv2 import numpy as np from PIL import Image import json from datetime import datetime from typing import Optional, Dict, List, Tuple, Any from contextlib import contextmanager # Configure logging logging.basicConfig( level=logging.INFO, format="%(asctime)s - %(levelname)s - %(message)s" ) # Environment setup for Zero GPU compatibility os.environ.setdefault("TOKENIZERS_PARALLELISM", "false") os.environ.setdefault("CUDA_VISIBLE_DEVICES", "-1") # Hide GPU from main process # Import spaces decorator try: import spaces _SPACES_GPU = spaces.GPU except ImportError: logging.warning("spaces package not available - running in CPU mode") # Create dummy decorator for local testing def _SPACES_GPU_dummy(*args, **kwargs): def decorator(func): return func return decorator _SPACES_GPU = _SPACES_GPU_dummy @contextmanager def _no_cuda_env(): """Context manager to prevent CUDA initialization in main process""" prev_cuda = os.environ.get("CUDA_VISIBLE_DEVICES") os.environ["CUDA_VISIBLE_DEVICES"] = "-1" try: yield finally: if prev_cuda is None: os.environ.pop("CUDA_VISIBLE_DEVICES", None) else: os.environ["CUDA_VISIBLE_DEVICES"] = prev_cuda # ---- Paths / constants ---- UPLOADS_DIR = "uploads" os.makedirs(UPLOADS_DIR, exist_ok=True) HF_TOKEN = os.getenv("HF_TOKEN", None) YOLO_MODEL_PATH = "src/best.pt" SEG_MODEL_PATH = "src/segmentation_model.h5" # optional GUIDELINE_PDFS = ["src/eHealth in Wound Care.pdf", "src/IWGDF Guideline.pdf", "src/evaluation.pdf"] DATASET_ID = "SmartHeal/wound-image-uploads" DEFAULT_PX_PER_CM = 38.0 PX_PER_CM_MIN, PX_PER_CM_MAX = 5.0, 1200.0 models_cache: Dict[str, object] = {} knowledge_base_cache: Dict[str, object] = {} # ---------- Lazy imports (wrapped where needed) ---------- def _import_ultralytics(): # Prevent Ultralytics from probing CUDA on import with _no_cuda_env(): from ultralytics import YOLO return YOLO def _import_tf_loader(): # Ensure TensorFlow does not try to use GPU in main process with _no_cuda_env(): import tensorflow as tf tf.config.set_visible_devices([], "GPU") from tensorflow.keras.models import load_model return load_model def _import_hf_cls(): with _no_cuda_env(): from transformers import pipeline return pipeline def _import_embeddings(): with _no_cuda_env(): from langchain_community.embeddings import HuggingFaceEmbeddings return HuggingFaceEmbeddings def _import_langchain_pdf(): with _no_cuda_env(): from langchain_community.document_loaders import PyPDFLoader return PyPDFLoader def _import_langchain_faiss(): with _no_cuda_env(): from langchain_community.vectorstores import FAISS return FAISS def _import_hf_hub(): with _no_cuda_env(): from huggingface_hub import HfApi, HfFolder return HfApi, HfFolder # ---------- SmartHeal prompts (system + user prefix) ---------- SMARTHEAL_SYSTEM_PROMPT = """ You are SmartHeal Clinical Assistant, a wound-care decision-support system. You analyze wound photographs and brief patient context to produce careful, specific, guideline-informed recommendations WITHOUT diagnosing. You always: - Use the measurements calculated by the vision pipeline as ground truth. - Prefer concise, actionable steps tailored to exudate level, infection risk, and pain. - Flag uncertainties and red flags that need escalation to a clinician. - Avoid contraindicated advice; do not infer unseen comorbidities. - Keep under 300 words and use the requested headings exactly. - Tone: professional, clear, and conservative; no definitive medical claims. - Safety: remind the user to seek clinician review for changes or red flags. """ SMARTHEAL_USER_PREFIX = """ Patient: {patient_info} Visual findings: type={wound_type}, size={length_cm}x{breadth_cm} cm, area={area_cm2} cm^2, detection_conf={det_conf:.2f}, calibration={px_per_cm} px/cm. Guideline context (snippets you can draw principles from; do not quote at length): {guideline_context} Write a structured answer with these headings exactly: 1. Clinical Summary (max 4 bullet points) 2. Likely Stage/Type (if uncertain, say 'uncertain') 3. Treatment Plan (specific dressing choices and frequency based on exudate/infection risk) 4. Red Flags (what to escalate and when) 5. Follow-up Cadence (days) 6. Notes (assumptions/uncertainties) Keep to 220–300 words. Do NOT provide diagnosis. Avoid contraindicated advice. """ # ---------- VLM (MedGemma replaced with Qwen2-VL) ---------- @_SPACES_GPU(enable_queue=True) def _vlm_infer_gpu(messages, model_id: str, max_new_tokens: int, token: Optional[str]): """ Runs entirely inside a Spaces GPU worker. It's the ONLY place we allow CUDA init. """ from transformers import pipeline import torch # Ensure torch is imported here pipe = pipeline( task="image-text-to-text", model=model_id, torch_dtype=torch.bfloat16, # Use torch_dtype from the working example device_map="auto", # CUDA init happens here, safely in GPU worker token=token, trust_remote_code=True, model_kwargs={"low_cpu_mem_usage": True}, ) out = pipe(text=messages, max_new_tokens=max_new_tokens, do_sample=False, temperature=0.2) try: txt = out[0]["generated_text"][-1].get("content", "") except Exception: txt = out[0].get("generated_text", "") return (txt or "").strip() or "⚠️ Empty response" def generate_medgemma_report( # kept name so callers don't change patient_info: str, visual_results: Dict, guideline_context: str, image_pil: Image.Image, max_new_tokens: Optional[int] = None, ) -> str: """ MedGemma replacement using Qwen/Qwen2-VL-2B-Instruct via image-text-to-text. Loads & runs ONLY inside a GPU worker to satisfy Stateless GPU constraints. """ if os.getenv("SMARTHEAL_ENABLE_VLM", "1") != "1": return "⚠️ VLM disabled" model_id = os.getenv("SMARTHEAL_VLM_MODEL", "Qwen/Qwen2-VL-2B-Instruct") max_new_tokens = max_new_tokens or int(os.getenv("SMARTHEAL_VLM_MAX_TOKENS", "600")) uprompt = SMARTHEAL_USER_PREFIX.format( patient_info=patient_info, wound_type=visual_results.get("wound_type", "Unknown"), length_cm=visual_results.get("length_cm", 0), breadth_cm=visual_results.get("breadth_cm", 0), area_cm2=visual_results.get("surface_area_cm2", 0), det_conf=float(visual_results.get("detection_confidence", 0.0)), px_per_cm=visual_results.get("px_per_cm", "?"), guideline_context=(guideline_context or "")[:900], ) messages = [ {"role": "system", "content": [{"type": "text", "text": SMARTHEAL_SYSTEM_PROMPT}]}, {"role": "user", "content": [ {"type": "image", "image": image_pil}, {"type": "text", "text": uprompt}, ]}, ] try: # IMPORTANT: do not import transformers or touch CUDA here. Only call the GPU worker. return _vlm_infer_gpu(messages, model_id, max_new_tokens, HF_TOKEN) except Exception as e: logging.error(f"VLM call failed: {e}") return "⚠️ VLM error" # ---------- Initialize CPU models ---------- def load_yolo_model(): YOLO = _import_ultralytics() # Construct model with CUDA masked to avoid auto-selecting cuda:0 with _no_cuda_env(): model = YOLO(YOLO_MODEL_PATH) return model def load_segmentation_model(): load_model = _import_tf_loader() return load_model(SEG_MODEL_PATH, compile=False, custom_objects={'InputLayer': _import_tf_loader().layers.InputLayer}) def load_classification_pipeline(): pipe = _import_hf_cls() return pipe("image-classification", model="Hemg/Wound-classification", token=HF_TOKEN, device="cpu") def load_embedding_model(): Emb = _import_embeddings() return Emb(model_name="sentence-transformers/all-MiniLM-L6-v2", model_kwargs={"device": "cpu"}) def initialize_cpu_models() -> None: if HF_TOKEN: try: HfApi, HfFolder = _import_hf_hub() HfFolder.save_token(HF_TOKEN) logging.info("✅ HF token set") except Exception as e: logging.warning(f"HF token save failed: {e}") if "det" not in models_cache: try: models_cache["det"] = load_yolo_model() logging.info("✅ YOLO loaded (CPU; CUDA masked in main)") except Exception as e: logging.error(f"YOLO load failed: {e}") if "seg" not in models_cache: try: if os.path.exists(SEG_MODEL_PATH): models_cache["seg"] = load_segmentation_model() m = models_cache["seg"] ishape = getattr(m, "input_shape", None) oshape = getattr(m, "output_shape", None) logging.info(f"✅ Segmentation model loaded (CPU) | input_shape={ishape} output_shape={oshape}") else: models_cache["seg"] = None logging.warning("Segmentation model file missing; skipping.") except Exception as e: models_cache["seg"] = None logging.warning(f"Segmentation unavailable: {e}") if "cls" not in models_cache: try: models_cache["cls"] = load_classification_pipeline() logging.info("✅ Classifier loaded (CPU)") except Exception as e: models_cache["cls"] = None logging.warning(f"Classifier unavailable: {e}") if "embedding_model" not in models_cache: try: models_cache["embedding_model"] = load_embedding_model() logging.info("✅ Embeddings loaded (CPU)") except Exception as e: models_cache["embedding_model"] = None logging.warning(f"Embeddings unavailable: {e}") def setup_knowledge_base() -> None: if "vector_store" in knowledge_base_cache: return docs: List = [] try: PyPDFLoader = _import_langchain_pdf() for pdf in GUIDELINE_PDFS: if os.path.exists(pdf): try: docs.extend(PyPDFLoader(pdf).load()) logging.info(f"Loaded PDF: {pdf}") except Exception as e: logging.warning(f"PDF load failed ({pdf}): {e}") except Exception as e: logging.warning(f"LangChain PDF loader unavailable: {e}") if docs and models_cache.get("embedding_model"): try: from langchain.text_splitter import RecursiveCharacterTextSplitter FAISS = _import_langchain_faiss() chunks = RecursiveCharacterTextSplitter(chunk_size=1000, chunk_overlap=100).split_documents(docs) knowledge_base_cache["vector_store"] = FAISS.from_documents(chunks, models_cache["embedding_model"]) logging.info(f"✅ Knowledge base ready ({len(chunks)} chunks)") except Exception as e: knowledge_base_cache["vector_store"] = None logging.warning(f"KB build failed: {e}") else: knowledge_base_cache["vector_store"] = None logging.warning("KB disabled (no docs or embeddings).") initialize_cpu_models() setup_knowledge_base() # ---------- Calibration helpers ---------- def _exif_to_dict(pil_img: Image.Image) -> Dict[str, object]: out = {} try: exif = pil_img.getexif() if not exif: return out for k, v in exif.items(): tag = TAGS.get(k, k) out[tag] = v except Exception: pass return out def _to_float(val) -> Optional[float]: try: if val is None: return None if isinstance(val, tuple) and len(val) == 2: num, den = float(val[0]), float(val[1]) if float(val[1]) != 0 else 1.0 return num / den return float(val) except Exception: return None def _estimate_sensor_width_mm(f_mm: Optional[float], f35: Optional[float]) -> Optional[float]: if f_mm and f35 and f35 > 0: return 36.0 * f_mm / f35 return None def estimate_px_per_cm_from_exif(pil_img: Image.Image, default_px_per_cm: float = DEFAULT_PX_PER_CM) -> Tuple[float, Dict]: meta = {"used": "default", "f_mm": None, "f35": None, "sensor_w_mm": None, "distance_m": None} try: exif = _exif_to_dict(pil_img) f_mm = _to_float(exif.get("FocalLength")) f35 = _to_float(exif.get("FocalLengthIn35mmFilm") or exif.get("FocalLengthIn35mm")) subj_dist_m = _to_float(exif.get("SubjectDistance")) sensor_w_mm = _estimate_sensor_width_mm(f_mm, f35) meta.update({"f_mm": f_mm, "f35": f35, "sensor_w_mm": sensor_w_mm, "distance_m": subj_dist_m}) if f_mm and sensor_w_mm and subj_dist_m and subj_dist_m > 0: w_px = pil_img.width field_w_mm = sensor_w_mm * (subj_dist_m * 1000.0) / f_mm field_w_cm = field_w_mm / 10.0 px_per_cm = w_px / max(field_w_cm, 1e-6) px_per_cm = float(np.clip(px_per_cm, PX_PER_CM_MIN, PX_PER_CM_MAX)) meta["used"] = "exif" return px_per_cm, meta return float(default_px_per_cm), meta except Exception: return float(default_px_per_cm), meta # ---------- Segmentation helpers ---------- def _imagenet_norm(arr: np.ndarray) -> np.ndarray: mean = np.array([123.675, 116.28, 103.53], dtype=np.float32) std = np.array([58.395, 57.12, 57.375], dtype=np.float32) return (arr.astype(np.float32) - mean) / std def _preprocess_for_seg(bgr_roi: np.ndarray, target_hw: Tuple[int, int]) -> np.ndarray: H, W = target_hw resized = cv2.resize(bgr_roi, (W, H), interpolation=cv2.INTER_LINEAR) x = resized.astype(np.float32) / 255.0 x = np.expand_dims(x, axis=0) # (1,H,W,3) return x def _to_prob(pred: np.ndarray) -> np.ndarray: p = np.squeeze(pred) pmin, pmax = float(p.min()), float(p.max()) if pmax > 1.0 or pmin < 0.0: p = 1.0 / (1.0 + np.exp(-p)) return p.astype(np.float32) # ---- Adaptive threshold + GrabCut grow ---- def _adaptive_prob_threshold(p: np.ndarray) -> float: """ Choose a threshold that avoids tiny blobs while not swallowing skin. Try Otsu and the 90th percentile, clamp to [0.25, 0.65], pick by area heuristic. """ p01 = np.clip(p.astype(np.float32), 0, 1) p255 = (p01 * 255).astype(np.uint8) ret_otsu, _ = cv2.threshold(p255, 0, 255, cv2.THRESH_BINARY + cv2.THRESH_OTSU) thr_otsu = float(np.clip(ret_otsu / 255.0, 0.25, 0.65)) thr_pctl = float(np.clip(np.percentile(p01, 90), 0.25, 0.65)) def area_frac(thr: float) -> float: return float((p01 >= thr).sum()) / float(p01.size) af_otsu = area_frac(thr_otsu) af_pctl = area_frac(thr_pctl) def score(af: float) -> float: target_low, target_high = 0.03, 0.10 if af < target_low: return abs(af - target_low) * 3.0 if af > target_high: return abs(af - target_high) * 1.5 return 0.0 return thr_otsu if score(af_otsu) <= score(af_pctl) else thr_pctl def _grabcut_refine(bgr: np.ndarray, seed01: np.ndarray, iters: int = 3) -> np.ndarray: """ Grow from a confident core into low-contrast margins. """ h, w = bgr.shape[:2] gc = np.full((h, w), cv2.GC_PR_BGD, np.uint8) k = cv2.getStructuringElement(cv2.MORPH_ELLIPSE, (5, 5)) seed_dil = cv2.dilate(seed01, k, iterations=1) gc[seed01.astype(bool)] = cv2.GC_PR_FGD gc[seed_dil.astype(bool)] = cv2.GC_FGD gc[0, :], gc[-1, :], gc[:, 0], gc[:, 1] = cv2.GC_BGD, cv2.GC_BGD, cv2.GC_BGD, cv2.GC_BGD bgdModel = np.zeros((1, 65), np.float64) fgdModel = np.zeros((1, 65), np.float64) cv2.grabCut(bgr, gc, None, bgdModel, fgdModel, iters, cv2.GC_INIT_WITH_MASK) return np.where((gc == cv2.GC_FGD) | (gc == cv2.GC_PR_FGD), 1, 0).astype(np.uint8) # ---------- Main AIProcessor Class ---------- class AIProcessor: def __init__(self): self.config = type("Config", (object,), { "HF_TOKEN": HF_TOKEN, "YOLO_MODEL_PATH": YOLO_MODEL_PATH, "SEG_MODEL_PATH": SEG_MODEL_PATH, "DATASET_ID": DATASET_ID, "UPLOADS_DIR": UPLOADS_DIR, "GUIDELINE_PDFS": GUIDELINE_PDFS })() self.models_cache = models_cache self.knowledge_base_cache = knowledge_base_cache self.px_per_cm = DEFAULT_PX_PER_CM # Use default from constants # Ensure CPU models and KB are initialized initialize_cpu_models() setup_knowledge_base() def perform_visual_analysis(self, image_pil: Image.Image) -> Dict[str, Any]: image_cv = cv2.cvtColor(np.array(image_pil), cv2.COLOR_RGB2BGR) if "det" not in self.models_cache or not self.models_cache["det"]: raise ValueError("YOLO model not initialized.") results = self.models_cache["det"].predict(image_cv, verbose=False, device="cpu") if not results or not results[0].boxes: raise ValueError("No wound detected.") box = results[0].boxes[0].xyxy[0].cpu().numpy().astype(int) region_cv = image_cv[box[1]:box[3], box[0]:box[2]] detection_confidence = float(results[0].boxes[0].conf[0].cpu().numpy()) length = breadth = area = 0 if "seg" in self.models_cache and self.models_cache["seg"]: try: seg_model = self.models_cache["seg"] input_size = seg_model.input_shape[1:3] preprocessed_roi = _preprocess_for_seg(region_cv, input_size) mask_pred = seg_model.predict(preprocessed_roi, verbose=0)[0] prob_mask = _to_prob(mask_pred) # Adaptive thresholding and GrabCut refinement initial_mask = (prob_mask >= _adaptive_prob_threshold(prob_mask)).astype(np.uint8) refined_mask = _grabcut_refine(region_cv, initial_mask) contours, _ = cv2.findContours(refined_mask, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE) if contours: cnt = max(contours, key=cv2.contourArea) x, y, w, h = cv2.boundingRect(cnt) length = round(h / self.px_per_cm, 2) breadth = round(w / self.px_per_cm, 2) area = round(cv2.contourArea(cnt) / (self.px_per_cm ** 2), 2) except Exception as e: logging.warning(f"Segmentation process failed: {e}") wound_type = "Unknown" if "cls" in self.models_cache and self.models_cache["cls"]: try: wound_region_pil = Image.fromarray(cv2.cvtColor(region_cv, cv2.COLOR_BGR2RGB)) classification_results = self.models_cache["cls"](wound_region_pil) wound_type = max(classification_results, key=lambda x: x["score"])["label"] except Exception as e: logging.warning(f"Classification process failed: {e}") return { "wound_type": wound_type, "length_cm": length, "breadth_cm": breadth, "surface_area_cm2": area, "detection_confidence": detection_confidence, "px_per_cm": self.px_per_cm } def query_guidelines(self, query: str) -> str: vector_store = self.knowledge_base_cache.get("vector_store") if not vector_store: return "Knowledge base unavailable." retriever = vector_store.as_retriever(search_kwargs={"k": 10}) docs = retriever.invoke(query) return "\n\n".join([ f"Source: {doc.metadata.get('source', 'N/A')}, Page: {doc.metadata.get('page', 'N/A')}\nContent: {doc.page_content}" for doc in docs ]) def generate_final_report(self, patient_info, visual_results, guideline_context, image_pil, max_new_tokens=2048): return generate_medgemma_report(patient_info, visual_results, guideline_context, image_pil, max_new_tokens) def save_and_commit_image(self, image_pil): filename = f"{datetime.now().strftime('%Y-%m-%d_%H-%M-%S')}.png" local_path = os.path.join(self.config.UPLOADS_DIR, filename) image_pil.convert("RGB").save(local_path) logging.info(f"Image saved locally: {local_path}") if self.config.HF_TOKEN and self.config.DATASET_ID: try: api = _import_hf_hub()[0]() # HfApi api.upload_file( path_or_fileobj=local_path, path_in_repo=f"images/{filename}", repo_id=self.config.DATASET_ID, repo_type="dataset", commit_message=f"Upload wound image: {filename}", token=self.config.HF_TOKEN ) logging.info("✅ Image uploaded to HF dataset.") except Exception as e: logging.warning(f"Upload failed: {e}") @_SPACES_GPU(enable_queue=True, duration=120) def full_analysis_pipeline(self, image, questionnaire_data): try: self.save_and_commit_image(image) visual = self.perform_visual_analysis(image) patient_info = ", ".join([f"{k}: {v}" for k, v in questionnaire_data.items()]) query = f"best practices for managing a {visual['wound_type']} with moisture level '{questionnaire_data.get('moisture')}' and signs of infection '{questionnaire_data.get('infection')}' in a patient who is diabetic '{questionnaire_data.get('diabetic')}'" guideline_context = self.query_guidelines(query) return self.generate_final_report(patient_info, visual, guideline_context, image) except Exception as e: logging.error(f"Pipeline error: {e}", exc_info=True) return f"❌ Error: {e}"