Spaces:
Running
Running
| import os | |
| # Ensure all CPU-only models never touch CUDA | |
| os.environ['CUDA_VISIBLE_DEVICES'] = '' | |
| import io | |
| import base64 | |
| import logging | |
| import cv2 | |
| import numpy as np | |
| from PIL import Image | |
| from datetime import datetime | |
| from transformers import pipeline | |
| from ultralytics import YOLO | |
| from tensorflow.keras.models import load_model | |
| from langchain_community.document_loaders import PyPDFLoader | |
| from langchain.text_splitter import RecursiveCharacterTextSplitter | |
| from langchain_community.embeddings import HuggingFaceEmbeddings | |
| from langchain_community.vectorstores import FAISS | |
| from huggingface_hub import HfApi, HfFolder | |
| import spaces | |
| from .config import Config | |
| # System prompt for MedGemma | |
| default_system_prompt = ( | |
| "You are a world-class medical AI assistant specializing in wound care " | |
| "with expertise in wound assessment and treatment. Provide concise, " | |
| "evidence-based medical assessments focusing on: (1) Precise wound " | |
| "classification based on tissue type and appearance, (2) Specific " | |
| "treatment recommendations with exact product names or interventions when " | |
| "appropriate, (3) Objective evaluation of healing progression or deterioration " | |
| "indicators, and (4) Clear follow-up timelines. Avoid general statements and " | |
| "prioritize actionable insights based on the visual analysis measurements and " | |
| "patient context." | |
| ) | |
| def generate_medgemma_report( | |
| patient_info: str, | |
| visual_results: dict, | |
| guideline_context: str, | |
| detection_image_path: str, | |
| segmentation_image_path: str, | |
| max_new_tokens: int = None | |
| ) -> str: | |
| """ | |
| Runs on GPU. Lazy-loads the MedGemma pipeline and returns the markdown report. | |
| Accepts only primitive types and file-paths, so pickling works. | |
| """ | |
| # Lazy-load pipeline | |
| if not hasattr(generate_medgemma_report, "_pipe"): | |
| try: | |
| cfg = Config() | |
| generate_medgemma_report._pipe = pipeline( | |
| 'image-text-to-text', | |
| model='google/medgemma-4b-it', | |
| device='auto', | |
| torch_dtype='auto', | |
| offload_folder='offload', | |
| token=cfg.HF_TOKEN | |
| ) | |
| logging.info("✅ MedGemma pipeline loaded on GPU") | |
| except Exception as e: | |
| logging.warning(f"MedGemma pipeline load failed: {e}") | |
| return None | |
| pipe = generate_medgemma_report._pipe | |
| # Assemble messages | |
| msgs = [ | |
| {'role':'system','content':[{'type':'text','text':default_system_prompt}]}, | |
| {'role':'user','content':[]} | |
| ] | |
| # Attach images | |
| for path in (detection_image_path, segmentation_image_path): | |
| if path and os.path.exists(path): | |
| msgs[1]['content'].append({'type':'image','image': Image.open(path)}) | |
| # Attach text | |
| prompt = f"## Patient\n{patient_info}\n## Wound Type: {visual_results.get('wound_type','Unknown')}" | |
| msgs[1]['content'].append({'type':'text','text': prompt}) | |
| out = pipe( | |
| text=msgs, | |
| max_new_tokens=max_new_tokens or Config().MAX_NEW_TOKENS, | |
| do_sample=False | |
| ) | |
| return out[0]['generated_text'][-1].get('content','') | |
| class AIProcessor: | |
| def __init__(self): | |
| self.models_cache = {} | |
| self.knowledge_base_cache = {} | |
| self.config = Config() | |
| self.px_per_cm = self.config.PIXELS_PER_CM | |
| self._initialize_models() | |
| self._load_knowledge_base() | |
| def _initialize_models(self): | |
| """Load all CPU-only models here.""" | |
| # Set HuggingFace token | |
| if self.config.HF_TOKEN: | |
| HfFolder.save_token(self.config.HF_TOKEN) | |
| logging.info("✅ HuggingFace token set") | |
| # YOLO detection (CPU) | |
| try: | |
| self.models_cache['det'] = YOLO(self.config.YOLO_MODEL_PATH) | |
| logging.info("✅ YOLO model loaded (CPU only)") | |
| except Exception as e: | |
| logging.error(f"YOLO load failed: {e}") | |
| raise | |
| # Segmentation (CPU) | |
| try: | |
| self.models_cache['seg'] = load_model(self.config.SEG_MODEL_PATH, compile=False) | |
| logging.info("✅ Segmentation model loaded (CPU)") | |
| except Exception as e: | |
| logging.warning(f"Segmentation model not available: {e}") | |
| # Classification (CPU) | |
| try: | |
| self.models_cache['cls'] = pipeline( | |
| 'image-classification', | |
| model='Hemg/Wound-classification', | |
| token=self.config.HF_TOKEN, | |
| device='cpu' | |
| ) | |
| logging.info("✅ Classification pipeline loaded (CPU)") | |
| except Exception as e: | |
| logging.warning(f"Classification pipeline not available: {e}") | |
| # Embedding model (CPU) | |
| try: | |
| self.models_cache['embedding_model'] = HuggingFaceEmbeddings( | |
| model_name='sentence-transformers/all-MiniLM-L6-v2', | |
| model_kwargs={'device':'cpu'} | |
| ) | |
| logging.info("✅ Embedding model loaded (CPU)") | |
| except Exception as e: | |
| logging.warning(f"Embedding model not available: {e}") | |
| def _load_knowledge_base(self): | |
| """Load PDF guidelines into a FAISS vector store.""" | |
| docs = [] | |
| for pdf in self.config.GUIDELINE_PDFS: | |
| if os.path.exists(pdf): | |
| loader = PyPDFLoader(pdf) | |
| docs.extend(loader.load()) | |
| logging.info(f"Loaded PDF: {pdf}") | |
| if docs and 'embedding_model' in self.models_cache: | |
| splitter = RecursiveCharacterTextSplitter(chunk_size=1000, chunk_overlap=100) | |
| chunks = splitter.split_documents(docs) | |
| vs = FAISS.from_documents(chunks, self.models_cache['embedding_model']) | |
| self.knowledge_base_cache['vectorstore'] = vs | |
| logging.info(f"✅ Knowledge base loaded ({len(chunks)} chunks)") | |
| else: | |
| self.knowledge_base_cache['vectorstore'] = None | |
| logging.warning("Knowledge base unavailable") | |
| def perform_visual_analysis(self, image_pil: Image.Image) -> dict: | |
| """Detect & segment on CPU; return metrics + file paths.""" | |
| if 'det' not in self.models_cache: | |
| raise RuntimeError("YOLO model ('det') not loaded") | |
| img_cv = cv2.cvtColor(np.array(image_pil), cv2.COLOR_RGB2BGR) | |
| res = self.models_cache['det'].predict(img_cv, verbose=False)[0] | |
| if not res.boxes: | |
| raise ValueError("No wound detected") | |
| x1, y1, x2, y2 = res.boxes.xyxy[0].cpu().numpy().astype(int) | |
| region = img_cv[y1:y2, x1:x2] | |
| # Save detection overlay | |
| det_vis = img_cv.copy() | |
| cv2.rectangle(det_vis, (x1, y1), (x2, y2), (0,255,0), 2) | |
| os.makedirs(f"{self.config.UPLOADS_DIR}/analysis", exist_ok=True) | |
| ts = datetime.now().strftime('%Y%m%d_%H%M%S') | |
| det_path = f"{self.config.UPLOADS_DIR}/analysis/detection_{ts}.png" | |
| cv2.imwrite(det_path, det_vis) | |
| # Segmentation metrics | |
| length = breadth = area = 0 | |
| seg_path = None | |
| if 'seg' in self.models_cache: | |
| h, w = self.models_cache['seg'].input_shape[1:3] | |
| inp = cv2.resize(region, (w,h)) / 255.0 | |
| mask = (self.models_cache['seg'].predict(inp[None])[0,:,:,0] > 0.5).astype(np.uint8) | |
| mask_rs = cv2.resize(mask, (region.shape[1], region.shape[0]), interpolation=cv2.INTER_NEAREST) | |
| ov = region.copy(); ov[mask_rs==1] = [0,0,255] | |
| seg_vis = cv2.addWeighted(region,0.7,ov,0.3,0) | |
| seg_path = f"{self.config.UPLOADS_DIR}/analysis/segmentation_{ts}.png" | |
| cv2.imwrite(seg_path, seg_vis) | |
| cnts, _ = cv2.findContours(mask_rs, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE) | |
| if cnts: | |
| cnt = max(cnts, key=cv2.contourArea) | |
| _,_,w0,h0 = cv2.boundingRect(cnt) | |
| length = round(h0/self.px_per_cm,2) | |
| breadth= round(w0/self.px_per_cm,2) | |
| area = round(cv2.contourArea(cnt)/(self.px_per_cm**2),2) | |
| # Classification | |
| wound_type = 'Unknown' | |
| if 'cls' in self.models_cache: | |
| try: | |
| preds = self.models_cache['cls']( | |
| Image.fromarray(cv2.cvtColor(region, cv2.COLOR_BGR2RGB)) | |
| ) | |
| wound_type = max(preds, key=lambda x: x['score'])['label'] | |
| except Exception: | |
| pass | |
| return { | |
| 'wound_type': wound_type, | |
| 'length_cm': length, | |
| 'breadth_cm': breadth, | |
| 'surface_area_cm2': area, | |
| 'detection_confidence': float(res.boxes.conf[0].cpu().item()), | |
| 'detection_image_path': det_path, | |
| 'segmentation_image_path': seg_path | |
| } | |
| def query_guidelines(self, query: str) -> str: | |
| vs = self.knowledge_base_cache.get('vectorstore') | |
| if not vs: | |
| return "Clinical guidelines unavailable" | |
| docs = vs.as_retriever(search_kwargs={'k':10}).invoke(query) | |
| return '\n\n'.join( | |
| f"Source: {d.metadata.get('source','?')}, Page: {d.metadata.get('page','?')}\n{d.page_content}" | |
| for d in docs | |
| ) | |
| def generate_final_report( | |
| self, | |
| patient_info: str, | |
| visual_results: dict, | |
| guideline_context: str, | |
| image_pil: Image.Image, | |
| max_new_tokens: int = None | |
| ) -> str: | |
| """ | |
| Signature unchanged. Gathers arguments, calls GPU function, and falls back if needed. | |
| """ | |
| det = visual_results.get('detection_image_path', '') | |
| seg = visual_results.get('segmentation_image_path', '') | |
| report = generate_medgemma_report( | |
| patient_info, | |
| visual_results, | |
| guideline_context, | |
| det, | |
| seg, | |
| max_new_tokens | |
| ) | |
| if report: | |
| return report | |
| return self._generate_fallback_report(patient_info, visual_results, guideline_context) | |
| def _generate_fallback_report( | |
| self, | |
| patient_info: str, | |
| visual_results: dict, | |
| guideline_context: str | |
| ) -> str: | |
| dp = visual_results.get('detection_image_path','N/A') | |
| sp = visual_results.get('segmentation_image_path','N/A') | |
| return ( | |
| f"# Report\n{patient_info}\n" | |
| f"Type: {visual_results.get('wound_type','Unknown')}\n" | |
| f"Detection Image: {dp}\n" | |
| f"Segmentation Image: {sp}\n" | |
| f"Guidelines: {guideline_context[:200]}..." | |
| ) | |
| def save_and_commit_image(self, image_pil: Image.Image) -> str: | |
| os.makedirs(self.config.UPLOADS_DIR, exist_ok=True) | |
| fn = f"{datetime.now():%Y%m%d_%H%M%S}.png" | |
| path = os.path.join(self.config.UPLOADS_DIR, fn) | |
| image_pil.convert('RGB').save(path) | |
| if self.config.HF_TOKEN and getattr(self.config, 'DATASET_ID', None): | |
| try: | |
| HfApi().upload_file( | |
| path_or_fileobj=path, | |
| path_in_repo=f"images/{fn}", | |
| repo_id=self.config.DATASET_ID, | |
| repo_type='dataset' | |
| ) | |
| except Exception as e: | |
| logging.warning(f"HF upload failed: {e}") | |
| return path | |
| def full_analysis_pipeline( | |
| self, | |
| image_pil: Image.Image, | |
| questionnaire_data: dict | |
| ) -> dict: | |
| try: | |
| saved = self.save_and_commit_image(image_pil) | |
| vis = self.perform_visual_analysis(image_pil) | |
| info = ", ".join(f"{k}:{v}" for k,v in questionnaire_data.items() if v) | |
| gc = self.query_guidelines(info) | |
| report = self.generate_final_report(info, vis, gc, image_pil) | |
| return { | |
| 'success': True, | |
| 'visual_analysis': vis, | |
| 'report': report, | |
| 'saved_image_path': saved | |
| } | |
| except Exception as e: | |
| logging.error(f"Pipeline error: {e}") | |
| return {'success': False, 'error': str(e)} | |
| def analyze_wound(self, image, questionnaire_data: dict) -> dict: | |
| if isinstance(image, str): | |
| image = Image.open(image) | |
| return self.full_analysis_pipeline(image, questionnaire_data) | |
| def _assess_risk_legacy(self, questionnaire_data): | |
| """Legacy risk assessment for backward compatibility""" | |
| risk_factors = [] | |
| risk_score = 0 | |
| try: | |
| # Age factor | |
| age = questionnaire_data.get('patient_age', 0) | |
| if age > 65: | |
| risk_factors.append("Advanced age (>65)") | |
| risk_score += 2 | |
| elif age > 50: | |
| risk_factors.append("Older adult (50-65)") | |
| risk_score += 1 | |
| # Duration factor | |
| duration = questionnaire_data.get('wound_duration', '').lower() | |
| if any(term in duration for term in ['month', 'months', 'year']): | |
| risk_factors.append("Chronic wound (>4 weeks)") | |
| risk_score += 3 | |
| # Pain level | |
| pain_level = questionnaire_data.get('pain_level', 0) | |
| if pain_level >= 7: | |
| risk_factors.append("High pain level") | |
| risk_score += 2 | |
| # Medical history risk factors | |
| medical_history = questionnaire_data.get('medical_history', '').lower() | |
| if 'diabetes' in medical_history: | |
| risk_factors.append("Diabetes mellitus") | |
| risk_score += 3 | |
| if 'circulation' in medical_history or 'vascular' in medical_history: | |
| risk_factors.append("Vascular/circulation issues") | |
| risk_score += 2 | |
| if 'immune' in medical_history: | |
| risk_factors.append("Immune system compromise") | |
| risk_score += 2 | |
| # Determine risk level | |
| if risk_score >= 7: | |
| risk_level = "High" | |
| elif risk_score >= 4: | |
| risk_level = "Moderate" | |
| else: | |
| risk_level = "Low" | |
| return { | |
| 'risk_score': risk_score, | |
| 'risk_level': risk_level, | |
| 'risk_factors': risk_factors | |
| } | |
| except Exception as e: | |
| logging.error(f"Risk assessment error: {e}") | |
| return {'risk_score': 0, 'risk_level': 'Unknown', 'risk_factors': []} |