import os import io import base64 import logging import numpy as np import cv2 from PIL import Image from datetime import datetime from langchain_community.document_loaders import PyPDFLoader from langchain.text_splitter import RecursiveCharacterTextSplitter from langchain_community.embeddings import HuggingFaceEmbeddings from langchain_community.vectorstores import FAISS from huggingface_hub import HfApi, HfFolder import spaces from .config import Config # Inline system prompt for MedGemma GPU pipeline default_system_prompt = ( "You are a world-class medical AI assistant specializing in wound care " "with expertise in wound assessment and treatment. Provide concise, " "evidence-based medical assessments focusing on: (1) Precise wound " "classification based on tissue type and appearance, (2) Specific " "treatment recommendations with exact product names or interventions when " "appropriate, (3) Objective evaluation of healing progression or deterioration " "indicators, and (4) Clear follow-up timelines. Avoid general statements and " "prioritize actionable insights based on the visual analysis measurements and " "patient context." ) # No torch or transformers-related imports at top-level! @spaces.GPU(enable_queue=True, duration=120) def generate_medgemma_report( patient_info: str, visual_results: dict, guideline_context: str, detection_image_path: str, segmentation_image_path: str, max_new_tokens: int = None ) -> str: # --- All GPU-related imports and model loading here! --- import torch from transformers import pipeline from PIL import Image # System prompt as before global default_system_prompt # Lazy-load MedGemma pipeline on GPU if not hasattr(generate_medgemma_report, "_pipe"): try: cfg = Config() generate_medgemma_report._pipe = pipeline( 'image-text-to-text', model='google/medgemma-4b-it', device='cuda', # Explicitly on GPU torch_dtype='auto', offload_folder='offload', token=cfg.HF_TOKEN ) logging.info("✅ MedGemma pipeline loaded on GPU") except Exception as e: logging.warning(f"MedGemma pipeline load failed: {e}") return None pipe = generate_medgemma_report._pipe # Compose messages msgs = [ {'role': 'system', 'content': [{'type': 'text', 'text': default_system_prompt}]}, {'role': 'user', 'content': []}, ] # Attach images if available for path in (detection_image_path, segmentation_image_path): if path and os.path.exists(path): msgs[1]['content'].append({'type': 'image', 'image': Image.open(path)}) # Attach text prompt = f"## Patient\n{patient_info}\n## Wound Type: {visual_results.get('wound_type','Unknown')}" msgs[1]['content'].append({'type': 'text', 'text': prompt}) out = pipe( text=msgs, max_new_tokens=max_new_tokens or Config().MAX_NEW_TOKENS, do_sample=False ) return out[0]['generated_text'][-1].get('content', '') class AIProcessor: def __init__(self): self.models_cache = {} self.knowledge_base_cache = {} self.config = Config() self.px_per_cm = self.config.PIXELS_PER_CM self._initialize_models() self._load_knowledge_base() def _initialize_models(self): """Load all CPU-only models here.""" # Set HuggingFace token if self.config.HF_TOKEN: HfFolder.save_token(self.config.HF_TOKEN) logging.info("✅ HuggingFace token set") # YOLO detection (CPU-only) try: from ultralytics import YOLO self.models_cache['det'] = YOLO(self.config.YOLO_MODEL_PATH) logging.info("✅ YOLO model loaded (CPU only)") except Exception as e: logging.error(f"YOLO load failed: {e}") raise # Segmentation model (CPU) try: from tensorflow.keras.models import load_model self.models_cache['seg'] = load_model(self.config.SEG_MODEL_PATH, compile=False) logging.info("✅ Segmentation model loaded (CPU)") except Exception as e: logging.warning(f"Segmentation model not available: {e}") # Classification pipeline (CPU) try: from transformers import pipeline self.models_cache['cls'] = pipeline( 'image-classification', model='Hemg/Wound-classification', token=self.config.HF_TOKEN, device='cpu' ) logging.info("✅ Classification pipeline loaded (CPU)") except Exception as e: logging.warning(f"Classification pipeline not available: {e}") # Embedding model (CPU) try: self.models_cache['embedding_model'] = HuggingFaceEmbeddings( model_name='sentence-transformers/all-MiniLM-L6-v2', model_kwargs={'device': 'cpu'} ) logging.info("✅ Embedding model loaded (CPU)") except Exception as e: logging.warning(f"Embedding model not available: {e}") def _load_knowledge_base(self): """Load PDF guidelines into a FAISS vector store.""" docs = [] for pdf in self.config.GUIDELINE_PDFS: if os.path.exists(pdf): loader = PyPDFLoader(pdf) docs.extend(loader.load()) logging.info(f"Loaded PDF: {pdf}") if docs and 'embedding_model' in self.models_cache: splitter = RecursiveCharacterTextSplitter(chunk_size=1000, chunk_overlap=100) chunks = splitter.split_documents(docs) self.knowledge_base_cache['vectorstore'] = FAISS.from_documents( chunks, self.models_cache['embedding_model'] ) logging.info(f"✅ Knowledge base loaded ({len(chunks)} chunks)") else: self.knowledge_base_cache['vectorstore'] = None logging.warning("Knowledge base unavailable") def perform_visual_analysis(self, image_pil: Image.Image) -> dict: """Detect & segment on CPU; return metrics + file paths.""" if 'det' not in self.models_cache: raise RuntimeError("YOLO model ('det') not loaded") img_cv = cv2.cvtColor(np.array(image_pil), cv2.COLOR_RGB2BGR) res = self.models_cache['det'].predict(img_cv, verbose=False)[0] if not res.boxes: raise ValueError("No wound detected") x1, y1, x2, y2 = res.boxes.xyxy[0].cpu().numpy().astype(int) region = img_cv[y1:y2, x1:x2] # Save detection overlay det_vis = img_cv.copy() cv2.rectangle(det_vis, (x1, y1), (x2, y2), (0,255,0), 2) os.makedirs(f"{self.config.UPLOADS_DIR}/analysis", exist_ok=True) ts = datetime.now().strftime('%Y%m%d_%H%M%S') det_path = f"{self.config.UPLOADS_DIR}/analysis/detection_{ts}.png" cv2.imwrite(det_path, det_vis) # Segmentation length = breadth = area = 0 seg_path = None if 'seg' in self.models_cache: h, w = self.models_cache['seg'].input_shape[1:3] inp = cv2.resize(region, (w, h)) / 255.0 mask = (self.models_cache['seg'].predict(inp[None])[0,:,:,0] > 0.5).astype(np.uint8) mask_rs = cv2.resize(mask, (region.shape[1], region.shape[0]), interpolation=cv2.INTER_NEAREST) ov = region.copy(); ov[mask_rs==1] = [0,0,255] seg_vis = cv2.addWeighted(region, 0.7, ov, 0.3, 0) seg_path = f"{self.config.UPLOADS_DIR}/analysis/segmentation_{ts}.png" cv2.imwrite(seg_path, seg_vis) cnts, _ = cv2.findContours(mask_rs, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE) if cnts: cnt = max(cnts, key=cv2.contourArea) _, _, w0, h0 = cv2.boundingRect(cnt) length = round(h0 / self.px_per_cm, 2) breadth = round(w0 / self.px_per_cm, 2) area = round(cv2.contourArea(cnt) / (self.px_per_cm**2), 2) # Classification wound_type = 'Unknown' if 'cls' in self.models_cache: try: preds = self.models_cache['cls'](Image.fromarray(cv2.cvtColor(region, cv2.COLOR_BGR2RGB))) wound_type = max(preds, key=lambda x: x['score'])['label'] except Exception: pass return { 'wound_type': wound_type, 'length_cm': length, 'breadth_cm': breadth, 'surface_area_cm2': area, 'detection_confidence': float(res.boxes.conf[0].cpu().item()), 'detection_image_path': det_path, 'segmentation_image_path': seg_path } def query_guidelines(self, query: str) -> str: vs = self.knowledge_base_cache.get('vectorstore') if not vs: return "Clinical guidelines unavailable" docs = vs.as_retriever(search_kwargs={'k':10}).invoke(query) return '\n\n'.join( f"Source: {d.metadata.get('source','?')}, Page: {d.metadata.get('page','?')}\n{d.page_content}" for d in docs ) def generate_final_report( self, patient_info: str, visual_results: dict, guideline_context: str, image_pil: Image.Image, max_new_tokens: int = None ) -> str: det = visual_results.get('detection_image_path', '') seg = visual_results.get('segmentation_image_path', '') # This GPU call is safe: it triggers all CUDA/model code *inside* the decorator context. report = generate_medgemma_report( patient_info, visual_results, guideline_context, det, seg, max_new_tokens ) if report: return report return self._generate_fallback_report(patient_info, visual_results, guideline_context) def _generate_fallback_report( self, patient_info: str, visual_results: dict, guideline_context: str ) -> str: dp = visual_results.get('detection_image_path','N/A') sp = visual_results.get('segmentation_image_path','N/A') return ( f"# Report\n{patient_info}\n" f"Type: {visual_results.get('wound_type','Unknown')}\n" f"Detection Image: {dp}\n" f"Segmentation Image: {sp}\n" f"Guidelines: {guideline_context[:200]}..." ) def save_and_commit_image(self, image_pil: Image.Image) -> str: os.makedirs(self.config.UPLOADS_DIR, exist_ok=True) fn = f"{datetime.now():%Y%m%d_%H%M%S}.png" path = os.path.join(self.config.UPLOADS_DIR, fn) image_pil.convert('RGB').save(path) if self.config.HF_TOKEN and getattr(self.config, 'DATASET_ID', None): try: HfApi().upload_file( path_or_fileobj=path, path_in_repo=f"images/{fn}", repo_id=self.config.DATASET_ID, repo_type='dataset' ) except Exception as e: logging.warning(f"HF upload failed: {e}") return path def full_analysis_pipeline(self, image_pil: Image.Image, questionnaire_data: dict) -> dict: try: saved = self.save_and_commit_image(image_pil) vis = self.perform_visual_analysis(image_pil) info = ", ".join(f"{k}:{v}" for k,v in questionnaire_data.items() if v) gc = self.query_guidelines(info) report= self.generate_final_report(info, vis, gc, image_pil) return {'success': True, 'visual_analysis': vis, 'report': report, 'saved_image_path': saved} except Exception as e: logging.error(f"Pipeline error: {e}") return {'success': False, 'error': str(e)} def analyze_wound(self, image, questionnaire_data: dict) -> dict: if isinstance(image, str): image = Image.open(image) return self.full_analysis_pipeline(image, questionnaire_data) def _assess_risk_legacy(self, questionnaire_data: dict) -> dict: risk_factors, risk_score = [], 0 try: age = questionnaire_data.get('patient_age', 0) if age > 65: risk_factors.append("Advanced age (>65)"); risk_score += 2 elif age > 50: risk_factors.append("Older adult (50-65)"); risk_score += 1 dur = questionnaire_data.get('wound_duration', '').lower() if any(t in dur for t in ['month','year']): risk_factors.append("Chronic wound (>4 weeks)"); risk_score += 3 pain = questionnaire_data.get('pain_level', 0) if pain >= 7: risk_factors.append("High pain level"); risk_score += 2 hist = questionnaire_data.get('medical_history','').lower() if 'diabetes' in hist: risk_factors.append("Diabetes mellitus"); risk_score += 3 if 'vascular' in hist: risk_factors.append("Vascular issues"); risk_score += 2 if 'immune' in hist: risk_factors.append("Immune compromise"); risk_score += 2 level = ("High" if risk_score >= 7 else "Moderate" if risk_score >= 4 else "Low") return {'risk_score': risk_score, 'risk_level': level, 'risk_factors': risk_factors} except Exception as e: logging.error(f"Risk assessment error: {e}") return {'risk_score': 0, 'risk_level': 'Unknown', 'risk_factors': []}