Spaces:
Running
Running
| import os | |
| import io | |
| import base64 | |
| import logging | |
| import cv2 | |
| import numpy as np | |
| from PIL import Image | |
| import torch | |
| from datetime import datetime | |
| from transformers import pipeline | |
| from ultralytics import YOLO | |
| from tensorflow.keras.models import load_model | |
| from langchain_community.document_loaders import PyPDFLoader | |
| from langchain.text_splitter import RecursiveCharacterTextSplitter | |
| from langchain_community.embeddings import HuggingFaceEmbeddings | |
| from langchain_community.vectorstores import FAISS | |
| from huggingface_hub import HfApi, HfFolder | |
| import spaces | |
| from .config import Config | |
| # Inline system prompt for MedGemma GPU pipeline | |
| default_system_prompt = ( | |
| "You are a world-class medical AI assistant specializing in wound care " | |
| "with expertise in wound assessment and treatment. Provide concise, " | |
| "evidence-based medical assessments focusing on: (1) Precise wound " | |
| "classification based on tissue type and appearance, (2) Specific " | |
| "treatment recommendations with exact product names or interventions when " | |
| "appropriate, (3) Objective evaluation of healing progression or deterioration " | |
| "indicators, and (4) Clear follow-up timelines. Avoid general statements and " | |
| "prioritize actionable insights based on the visual analysis measurements and " | |
| "patient context." | |
| ) | |
| class AIProcessor: | |
| def __init__(self): | |
| self.models_cache = {} | |
| self.knowledge_base_cache = {} | |
| self.config = Config() | |
| self.px_per_cm = self.config.PIXELS_PER_CM | |
| self._initialize_models() | |
| def _initialize_models(self): | |
| """Initialize AI models; only MedGemma uses GPU.""" | |
| # Set HuggingFace token | |
| if self.config.HF_TOKEN: | |
| HfFolder.save_token(self.config.HF_TOKEN) | |
| logging.info("HuggingFace token set successfully") | |
| # MedGemma pipeline on GPU | |
| try: | |
| self.models_cache['medgemma_pipe'] = pipeline( | |
| 'image-text-to-text', | |
| model='google/medgemma-4b-it', | |
| device='cuda', | |
| torch_dtype=torch.bfloat16, | |
| offload_folder='offload', | |
| token=self.config.HF_TOKEN | |
| ) | |
| logging.info("β MedGemma pipeline loaded on GPU") | |
| except Exception as e: | |
| logging.warning(f"MedGemma pipeline not available: {e}") | |
| # YOLO detection on CPU | |
| try: | |
| self.models_cache['det'] = YOLO(self.config.YOLO_MODEL_PATH) | |
| logging.info("β YOLO detection model loaded on CPU") | |
| except Exception as e: | |
| logging.warning(f"YOLO model not available: {e}") | |
| # Segmentation model on CPU | |
| try: | |
| self.models_cache['seg'] = load_model(self.config.SEG_MODEL_PATH, compile=False) | |
| logging.info("β Segmentation model loaded on CPU") | |
| except Exception as e: | |
| logging.warning(f"Segmentation model not available: {e}") | |
| # Classification on CPU | |
| try: | |
| self.models_cache['cls'] = pipeline( | |
| 'image-classification', | |
| model='Hemg/Wound-classification', | |
| token=self.config.HF_TOKEN, | |
| device='cpu' | |
| ) | |
| logging.info("β Wound classification model loaded on CPU") | |
| except Exception as e: | |
| logging.warning(f"Wound classification model not available: {e}") | |
| # Embedding for knowledge base | |
| try: | |
| self.models_cache['embedding_model'] = HuggingFaceEmbeddings( | |
| model_name='sentence-transformers/all-MiniLM-L6-v2', | |
| model_kwargs={'device': 'cpu'} | |
| ) | |
| logging.info("β Embedding model loaded on CPU") | |
| except Exception as e: | |
| logging.warning(f"Embedding model not available: {e}") | |
| # Load knowledge base | |
| self._load_knowledge_base() | |
| def _load_knowledge_base(self): | |
| """Load PDF guidelines into a FAISS vector store.""" | |
| docs = [] | |
| for pdf in self.config.GUIDELINE_PDFS: | |
| if os.path.exists(pdf): | |
| loader = PyPDFLoader(pdf) | |
| docs.extend(loader.load()) | |
| logging.info(f"Loaded PDF: {pdf}") | |
| if docs and 'embedding_model' in self.models_cache: | |
| splitter = RecursiveCharacterTextSplitter(chunk_size=1000, chunk_overlap=100) | |
| chunks = splitter.split_documents(docs) | |
| vs = FAISS.from_documents(chunks, self.models_cache['embedding_model']) | |
| self.knowledge_base_cache['vectorstore'] = vs | |
| logging.info(f"β Knowledge base loaded ({len(chunks)} chunks)") | |
| else: | |
| self.knowledge_base_cache['vectorstore'] = None | |
| logging.warning("Knowledge base unavailable") | |
| def perform_visual_analysis(self, image_pil): | |
| """Detect & segment on CPU; return only paths + metrics.""" | |
| try: | |
| img_cv = cv2.cvtColor(np.array(image_pil), cv2.COLOR_RGB2BGR) | |
| # YOLO detect | |
| res = self.models_cache['det'].predict(img_cv, verbose=False)[0] | |
| if not res.boxes: | |
| raise ValueError("No wound detected") | |
| # Bounding box | |
| x1, y1, x2, y2 = res.boxes.xyxy[0].cpu().numpy().astype(int) | |
| region = img_cv[y1:y2, x1:x2] | |
| # Save detection overlay | |
| det_vis = img_cv.copy() | |
| cv2.rectangle(det_vis, (x1, y1), (x2, y2), (0,255,0), 2) | |
| os.makedirs(f"{self.config.UPLOADS_DIR}/analysis", exist_ok=True) | |
| ts = datetime.now().strftime('%Y%m%d_%H%M%S') | |
| det_path = f"{self.config.UPLOADS_DIR}/analysis/detection_{ts}.png" | |
| cv2.imwrite(det_path, det_vis) | |
| # Initialize metrics & seg | |
| length = breadth = area = 0 | |
| seg_path = None | |
| # Segmentation | |
| if 'seg' in self.models_cache: | |
| h, w = self.models_cache['seg'].input_shape[1:3] | |
| inp = cv2.resize(region, (w,h)) / 255.0 | |
| mask = (self.models_cache['seg'].predict(np.expand_dims(inp,0))[0,:,:,0] > 0.5).astype(np.uint8) | |
| mask_rs = cv2.resize(mask, (region.shape[1], region.shape[0]), interpolation=cv2.INTER_NEAREST) | |
| ov = region.copy() | |
| ov[mask_rs==1] = [0,0,255] | |
| seg_vis = cv2.addWeighted(region,0.7,ov,0.3,0) | |
| seg_path = f"{self.config.UPLOADS_DIR}/analysis/segmentation_{ts}.png" | |
| cv2.imwrite(seg_path, seg_vis) | |
| # measure | |
| cnts, _ = cv2.findContours(mask_rs, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE) | |
| if cnts: | |
| cnt = max(cnts, key=cv2.contourArea) | |
| _,_,w0,h0 = cv2.boundingRect(cnt) | |
| length = round(h0/self.px_per_cm,2) | |
| breadth= round(w0/self.px_per_cm,2) | |
| area = round(cv2.contourArea(cnt)/(self.px_per_cm**2),2) | |
| # Classification | |
| wound_type = 'Unknown' | |
| if 'cls' in self.models_cache: | |
| try: | |
| label = self.models_cache['cls'](Image.fromarray(cv2.cvtColor(region, cv2.COLOR_BGR2RGB))) | |
| wound_type = max(label, key=lambda x: x['score'])['label'] | |
| except Exception: | |
| pass | |
| return { | |
| 'wound_type': wound_type, | |
| 'length_cm': length, | |
| 'breadth_cm': breadth, | |
| 'surface_area_cm2': area, | |
| 'detection_confidence': float(res.boxes.conf[0].cpu().item()), | |
| 'detection_image_path': det_path, | |
| 'segmentation_image_path': seg_path | |
| } | |
| except Exception as e: | |
| logging.error(f"Visual analysis error: {e}") | |
| raise | |
| def query_guidelines(self, query: str): | |
| """Retrieve clinical guidelines from vectorstore.""" | |
| vs = self.knowledge_base_cache.get('vectorstore') | |
| if not vs: | |
| return "Clinical guidelines unavailable" | |
| docs = vs.as_retriever(search_kwargs={'k':10}).invoke(query) | |
| return '\n\n'.join(f"Source: {d.metadata.get('source','?')}, Page: {d.metadata.get('page','?')}\n{d.page_content}" for d in docs) | |
| def generate_final_report(self, patient_info, visual_results, guideline_context, image_pil, max_new_tokens=None): | |
| """Run MedGemma on GPU; return markdown report.""" | |
| if 'medgemma_pipe' not in self.models_cache: | |
| return self._generate_fallback_report(patient_info, visual_results, guideline_context) | |
| # build messages | |
| msgs = [{ 'role':'system', 'content':[{'type':'text','text': default_system_prompt}] }, | |
| { 'role':'user', 'content':[]}] | |
| # images | |
| if image_pil: msgs[1]['content'].append({'type':'image','image':image_pil}) | |
| for key in ('detection_image_path','segmentation_image_path'): | |
| p = visual_results.get(key) | |
| if p and os.path.exists(p): | |
| msgs[1]['content'].append({'type':'image', 'image': Image.open(p)}) | |
| # text prompt stub (expand as needed) | |
| prompt = f"## Patient\n{patient_info}\n## Visual Type: {visual_results['wound_type']}" | |
| msgs[1]['content'].append({'type':'text','text':prompt}) | |
| out = self.models_cache['medgemma_pipe'](text=msgs, max_new_tokens=max_new_tokens or self.config.MAX_NEW_TOKENS) | |
| report = out[0]['generated_text'][-1].get('content','') | |
| return report or self._generate_fallback_report(patient_info, visual_results, guideline_context) | |
| def _generate_fallback_report(self, patient_info, visual_results, guideline_context): | |
| """Produce text-only fallback.""" | |
| dp = visual_results.get('detection_image_path','N/A') | |
| sp = visual_results.get('segmentation_image_path','N/A') | |
| return f"# Report\n{patient_info}\nType: {visual_results['wound_type']}\nDetection Image: {dp}\nSegmentation Image: {sp}\nGuidelines: {guideline_context[:200]}..." | |
| def save_and_commit_image(self, image_pil): | |
| """Save locally and optionally to HuggingFace.""" | |
| os.makedirs(self.config.UPLOADS_DIR, exist_ok=True) | |
| fn = f"{datetime.now():%Y%m%d_%H%M%S}.png" | |
| path = os.path.join(self.config.UPLOADS_DIR, fn) | |
| image_pil.convert('RGB').save(path) | |
| if self.config.HF_TOKEN and self.config.DATASET_ID: | |
| try: | |
| api = HfApi() | |
| api.upload_file(path_or_fileobj=path, path_in_repo=f"images/{fn}", repo_id=self.config.DATASET_ID, repo_type='dataset') | |
| except Exception as e: | |
| logging.warning(f"HF upload failed: {e}") | |
| return path | |
| def full_analysis_pipeline(self, image_pil, questionnaire_data): | |
| """Orchestrate CPU steps + GPU report.""" | |
| try: | |
| saved = self.save_and_commit_image(image_pil) | |
| vis = self.perform_visual_analysis(image_pil) | |
| info = ", ".join(f"{k}:{v}" for k,v in questionnaire_data.items() if v) | |
| gc = self.query_guidelines(info) | |
| report = self.generate_final_report(info, vis, gc, image_pil) | |
| return {'success':True, 'visual_analysis':vis, 'report':report, 'saved_image_path':saved} | |
| except Exception as e: | |
| logging.error(f"Pipeline error: {e}") | |
| return {'success':False, 'error':str(e)} | |
| # Legacy methods for backward compatibility | |
| def analyze_wound(self, image, questionnaire_data): | |
| """Legacy method for backward compatibility""" | |
| try: | |
| # Convert string path to PIL Image if needed | |
| if isinstance(image, str): | |
| try: | |
| from PIL import Image | |
| image = Image.open(image) | |
| logging.info(f"Converted string path to PIL Image: {image}") | |
| except Exception as e: | |
| logging.error(f"Error converting string path to image: {e}") | |
| # Ensure we have a PIL Image object | |
| if not isinstance(image, Image.Image): | |
| try: | |
| from PIL import Image | |
| import io | |
| # If it's a file-like object | |
| if hasattr(image, 'read'): | |
| # Reset file pointer if possible | |
| if hasattr(image, 'seek'): | |
| image.seek(0) | |
| image = Image.open(image) | |
| logging.info("Converted file-like object to PIL Image") | |
| except Exception as e: | |
| logging.error(f"Error ensuring image is PIL Image: {e}") | |
| raise ValueError(f"Invalid image format: {type(image)}") | |
| result = self.full_analysis_pipeline(image, questionnaire_data) | |
| if result['success']: | |
| return { | |
| 'timestamp': result['timestamp'], | |
| 'summary': f"Analysis completed for {questionnaire_data.get('patient_name', 'patient')}", | |
| 'recommendations': result['report'], | |
| 'wound_detection': { | |
| 'status': 'success', | |
| 'detections': [result['visual_analysis']], | |
| 'total_wounds': 1 | |
| }, | |
| 'segmentation_result': { | |
| 'status': 'success', | |
| 'wound_area_percentage': result['visual_analysis'].get('surface_area_cm2', 0) | |
| }, | |
| 'risk_assessment': self._assess_risk_legacy(questionnaire_data), | |
| 'guideline_recommendations': [result['report'][:200] + "..."] | |
| } | |
| else: | |
| return { | |
| 'timestamp': result['timestamp'], | |
| 'summary': f"Analysis failed: {result['error']}", | |
| 'recommendations': "Please consult with a healthcare professional.", | |
| 'wound_detection': {'status': 'error', 'message': result['error']}, | |
| 'segmentation_result': {'status': 'error', 'message': result['error']}, | |
| 'risk_assessment': {'risk_score': 0, 'risk_level': 'Unknown', 'risk_factors': []}, | |
| 'guideline_recommendations': ["Analysis unavailable due to error"] | |
| } | |
| except Exception as e: | |
| logging.error(f"Legacy analyze_wound error: {e}") | |
| return { | |
| 'timestamp': datetime.now().isoformat(), | |
| 'summary': f"Analysis error: {str(e)}", | |
| 'recommendations': "Please consult with a healthcare professional.", | |
| 'wound_detection': {'status': 'error', 'message': str(e)}, | |
| 'segmentation_result': {'status': 'error', 'message': str(e)}, | |
| 'risk_assessment': {'risk_score': 0, 'risk_level': 'Unknown', 'risk_factors': []}, | |
| 'guideline_recommendations': ["Analysis unavailable due to error"] | |
| } | |
| def _assess_risk_legacy(self, questionnaire_data): | |
| """Legacy risk assessment for backward compatibility""" | |
| risk_factors = [] | |
| risk_score = 0 | |
| try: | |
| # Age factor | |
| age = questionnaire_data.get('patient_age', 0) | |
| if age > 65: | |
| risk_factors.append("Advanced age (>65)") | |
| risk_score += 2 | |
| elif age > 50: | |
| risk_factors.append("Older adult (50-65)") | |
| risk_score += 1 | |
| # Duration factor | |
| duration = questionnaire_data.get('wound_duration', '').lower() | |
| if any(term in duration for term in ['month', 'months', 'year']): | |
| risk_factors.append("Chronic wound (>4 weeks)") | |
| risk_score += 3 | |
| # Pain level | |
| pain_level = questionnaire_data.get('pain_level', 0) | |
| if pain_level >= 7: | |
| risk_factors.append("High pain level") | |
| risk_score += 2 | |
| # Medical history risk factors | |
| medical_history = questionnaire_data.get('medical_history', '').lower() | |
| if 'diabetes' in medical_history: | |
| risk_factors.append("Diabetes mellitus") | |
| risk_score += 3 | |
| if 'circulation' in medical_history or 'vascular' in medical_history: | |
| risk_factors.append("Vascular/circulation issues") | |
| risk_score += 2 | |
| if 'immune' in medical_history: | |
| risk_factors.append("Immune system compromise") | |
| risk_score += 2 | |
| # Determine risk level | |
| if risk_score >= 7: | |
| risk_level = "High" | |
| elif risk_score >= 4: | |
| risk_level = "Moderate" | |
| else: | |
| risk_level = "Low" | |
| return { | |
| 'risk_score': risk_score, | |
| 'risk_level': risk_level, | |
| 'risk_factors': risk_factors | |
| } | |
| except Exception as e: | |
| logging.error(f"Risk assessment error: {e}") | |
| return {'risk_score': 0, 'risk_level': 'Unknown', 'risk_factors': []} |