import os import logging import cv2 import numpy as np from PIL import Image import torch import json from datetime import datetime import tensorflow as tf from transformers import pipeline from ultralytics import YOLO from tensorflow.keras.models import load_model from langchain_community.document_loaders import PyPDFLoader from langchain.text_splitter import RecursiveCharacterTextSplitter from langchain_community.embeddings import HuggingFaceEmbeddings from langchain_community.vectorstores import FAISS from huggingface_hub import HfApi, HfFolder import spaces import time from typing import Dict, Any, Optional, Tuple from .config import Config class EnhancedAIProcessor: """Enhanced AI processor with dashboard integration and analytics tracking""" def __init__(self): self.models_cache = {} self.knowledge_base_cache = {} self.config = Config() self.px_per_cm = self.config.PIXELS_PER_CM self.model_version = "v1.2.0" # Version for tracking self._initialize_models() @spaces.GPU(enable_queue=True, duration=90) def _initialize_models(self): """Initialize all AI models including real-time models""" try: # Set HuggingFace token if self.config.HF_TOKEN: HfFolder.save_token(self.config.HF_TOKEN) logging.info("HuggingFace token set successfully") # Initialize MedGemma pipeline for medical text generation try: self.models_cache["medgemma_pipe"] = pipeline( "image-text-to-text", model="google/medgemma-4b-it", torch_dtype=torch.bfloat16, offload_folder="offload", device_map="auto", token=self.config.HF_TOKEN ) logging.info("✅ MedGemma pipeline loaded successfully") except Exception as e: logging.warning(f"MedGemma pipeline not available: {e}") # Initialize YOLO model for wound detection try: self.models_cache["det"] = YOLO(self.config.YOLO_MODEL_PATH) logging.info("✅ YOLO detection model loaded successfully") except Exception as e: logging.warning(f"YOLO model not available: {e}") # Initialize segmentation model try: self.models_cache["seg"] = load_model(self.config.SEG_MODEL_PATH, compile=False) logging.info("✅ Segmentation model loaded successfully") except Exception as e: logging.warning(f"Segmentation model not available: {e}") # Initialize wound classification model try: self.models_cache["cls"] = pipeline( "image-classification", model="Hemg/Wound-classification", token=self.config.HF_TOKEN, device="cpu" ) logging.info("✅ Wound classification model loaded successfully") except Exception as e: logging.warning(f"Wound classification model not available: {e}") # Initialize embedding model for knowledge base try: self.models_cache["embedding_model"] = HuggingFaceEmbeddings( model_name="sentence-transformers/all-MiniLM-L6-v2", model_kwargs={'device': 'cpu'} ) logging.info("✅ Embedding model loaded successfully") except Exception as e: logging.warning(f"Embedding model not available: {e}") logging.info("✅ All models loaded.") self._load_knowledge_base() except Exception as e: logging.error(f"Error initializing AI models: {e}") def _load_knowledge_base(self): """Load knowledge base from PDF guidelines""" try: documents = [] for pdf_path in self.config.GUIDELINE_PDFS: if os.path.exists(pdf_path): loader = PyPDFLoader(pdf_path) docs = loader.load() documents.extend(docs) logging.info(f"Loaded PDF: {pdf_path}") if documents and 'embedding_model' in self.models_cache: # Split documents into chunks text_splitter = RecursiveCharacterTextSplitter( chunk_size=1000, chunk_overlap=100 ) chunks = text_splitter.split_documents(documents) # Create vector store vectorstore = FAISS.from_documents(chunks, self.models_cache['embedding_model']) self.knowledge_base_cache['vectorstore'] = vectorstore logging.info(f"✅ Knowledge base loaded with {len(chunks)} chunks") else: self.knowledge_base_cache['vectorstore'] = None logging.warning("Knowledge base not available - no PDFs found or embedding model unavailable") except Exception as e: logging.warning(f"Knowledge base loading error: {e}") self.knowledge_base_cache['vectorstore'] = None def perform_comprehensive_analysis(self, image_pil: Image.Image, patient_info: Dict[str, Any]) -> Dict[str, Any]: """ Perform comprehensive analysis with enhanced tracking for dashboard integration """ start_time = time.time() try: # Perform visual analysis visual_results = self.perform_visual_analysis(image_pil) # Query guidelines for context guideline_query = f"wound care {visual_results.get('wound_type', 'general')} treatment recommendations" guideline_context = self.query_guidelines(guideline_query) # Generate comprehensive report report = self.generate_final_report(patient_info, visual_results, guideline_context, image_pil) # Calculate processing time processing_time = round(time.time() - start_time, 2) # Calculate risk score based on multiple factors risk_score = self._calculate_risk_score(visual_results, patient_info) # Prepare comprehensive analysis data analysis_data = { 'visual_results': visual_results, 'patient_info': patient_info, 'guideline_context': guideline_context, 'report': report, 'processing_time': processing_time, 'risk_score': risk_score, 'model_version': self.model_version, 'analysis_timestamp': datetime.now().isoformat(), 'analysis_metadata': { 'models_used': list(self.models_cache.keys()), 'image_dimensions': image_pil.size, 'guideline_sources': len(guideline_context.split('\n\n')) if guideline_context else 0 } } logging.info(f"✅ Comprehensive analysis completed in {processing_time}s with risk score {risk_score}") return analysis_data except Exception as e: processing_time = round(time.time() - start_time, 2) logging.error(f"❌ Analysis failed after {processing_time}s: {e}") # Return error analysis data return { 'error': str(e), 'processing_time': processing_time, 'risk_score': 0, 'model_version': self.model_version, 'analysis_timestamp': datetime.now().isoformat() } def perform_visual_analysis(self, image_pil: Image.Image) -> Dict[str, Any]: """Perform comprehensive visual analysis of wound image with enhanced tracking""" try: # Convert PIL to OpenCV format image_cv = cv2.cvtColor(np.array(image_pil), cv2.COLOR_RGB2BGR) # YOLO detection if 'det' not in self.models_cache: raise ValueError("YOLO detection model not available.") results = self.models_cache['det'].predict(image_cv, verbose=False, device="cpu") if not results or not results[0].boxes: raise ValueError("No wound detected in the image.") # Extract bounding box box = results[0].boxes[0].xyxy[0].cpu().numpy().astype(int) x1, y1, x2, y2 = box region_cv = image_cv[y1:y2, x1:x2] # Save detection image with timestamp detection_image_cv = image_cv.copy() cv2.rectangle(detection_image_cv, (x1, y1), (x2, y2), (0, 255, 0), 2) os.makedirs(os.path.join(self.config.UPLOADS_DIR, "analysis"), exist_ok=True) timestamp = datetime.now().strftime('%Y%m%d_%H%M%S') detection_image_path = os.path.join(self.config.UPLOADS_DIR, "analysis", f"detection_{timestamp}.png") cv2.imwrite(detection_image_path, detection_image_cv) detection_image_pil = Image.fromarray(cv2.cvtColor(detection_image_cv, cv2.COLOR_BGR2RGB)) # Initialize outputs length = breadth = area = 0 segmentation_image_pil = None segmentation_image_path = None segmentation_confidence = 0.0 # Segmentation (optional) if 'seg' in self.models_cache: input_size = self.models_cache['seg'].input_shape[1:3] # (height, width) resized_region = cv2.resize(region_cv, (input_size[1], input_size[0])) seg_input = np.expand_dims(resized_region / 255.0, 0) mask_pred = self.models_cache['seg'].predict(seg_input, verbose=0)[0] mask_np = (mask_pred[:, :, 0] > 0.5).astype(np.uint8) # Calculate segmentation confidence segmentation_confidence = float(np.mean(mask_pred[:, :, 0])) # Resize mask back to original region size mask_resized = cv2.resize(mask_np, (region_cv.shape[1], region_cv.shape[0]), interpolation=cv2.INTER_NEAREST) # Overlay mask on region for visualization overlay = region_cv.copy() overlay[mask_resized == 1] = [0, 0, 255] # Red overlay # Blend overlay for final output segmented_visual = cv2.addWeighted(region_cv, 0.7, overlay, 0.3, 0) # Save segmentation image segmentation_image_path = os.path.join(self.config.UPLOADS_DIR, "analysis", f"segmentation_{timestamp}.png") cv2.imwrite(segmentation_image_path, segmented_visual) segmentation_image_pil = Image.fromarray(cv2.cvtColor(segmented_visual, cv2.COLOR_BGR2RGB)) # Wound measurements from resized mask contours, _ = cv2.findContours(mask_resized, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE) if contours: cnt = max(contours, key=cv2.contourArea) x, y, w, h = cv2.boundingRect(cnt) length = round(h / self.px_per_cm, 2) breadth = round(w / self.px_per_cm, 2) area = round(cv2.contourArea(cnt) / (self.px_per_cm ** 2), 2) # Classification with confidence tracking wound_type = "Unknown" classification_confidence = 0.0 classification_scores = [] if 'cls' in self.models_cache: try: region_pil = Image.fromarray(cv2.cvtColor(region_cv, cv2.COLOR_BGR2RGB)) cls_result = self.models_cache['cls'](region_pil) if cls_result: best_result = max(cls_result, key=lambda x: x['score']) wound_type = best_result['label'] classification_confidence = float(best_result['score']) classification_scores = [{'label': r['label'], 'score': float(r['score'])} for r in cls_result] except Exception as e: logging.warning(f"Wound classification error: {e}") return { 'wound_type': wound_type, 'length_cm': length, 'breadth_cm': breadth, 'surface_area_cm2': area, 'detection_confidence': float(results[0].boxes[0].conf.cpu().item()), 'segmentation_confidence': segmentation_confidence, 'classification_confidence': classification_confidence, 'classification_scores': classification_scores, 'bounding_box': box.tolist(), 'detection_image_path': detection_image_path, 'detection_image_pil': detection_image_pil, 'segmentation_image_path': segmentation_image_path, 'segmentation_image_pil': segmentation_image_pil, 'analysis_quality': { 'detection_quality': 'high' if float(results[0].boxes[0].conf.cpu().item()) > 0.8 else 'medium', 'segmentation_quality': 'high' if segmentation_confidence > 0.7 else 'medium', 'classification_quality': 'high' if classification_confidence > 0.8 else 'medium' } } except Exception as e: logging.error(f"Visual analysis error: {e}") raise ValueError(f"Visual analysis failed: {str(e)}") def _calculate_risk_score(self, visual_results: Dict[str, Any], patient_info: Dict[str, Any]) -> int: """ Calculate comprehensive risk score (0-100) based on visual analysis and patient data """ try: risk_score = 0 # Wound size risk (0-25 points) area = visual_results.get('surface_area_cm2', 0) if area > 10: risk_score += 25 elif area > 5: risk_score += 15 elif area > 2: risk_score += 10 else: risk_score += 5 # Wound type risk (0-20 points) wound_type = visual_results.get('wound_type', '').lower() high_risk_types = ['ulcer', 'necrotic', 'infected', 'diabetic'] medium_risk_types = ['pressure', 'venous', 'arterial'] if any(risk_type in wound_type for risk_type in high_risk_types): risk_score += 20 elif any(risk_type in wound_type for risk_type in medium_risk_types): risk_score += 15 else: risk_score += 10 # Patient factors (0-30 points) age = patient_info.get('patient_age', 0) if age > 70: risk_score += 15 elif age > 50: risk_score += 10 else: risk_score += 5 # Diabetic status diabetic_status = patient_info.get('diabetic_status', '').lower() if 'yes' in diabetic_status or 'diabetic' in diabetic_status: risk_score += 15 # Pain level (0-10 points) pain_level = patient_info.get('pain_level', 0) if pain_level > 7: risk_score += 10 elif pain_level > 4: risk_score += 7 else: risk_score += 3 # Infection signs (0-15 points) infection_signs = patient_info.get('infection_signs', '').lower() if 'yes' in infection_signs or 'present' in infection_signs: risk_score += 15 elif 'possible' in infection_signs or 'mild' in infection_signs: risk_score += 10 else: risk_score += 5 # Ensure score is within 0-100 range risk_score = min(max(risk_score, 0), 100) logging.info(f"Calculated risk score: {risk_score}") return risk_score except Exception as e: logging.error(f"Error calculating risk score: {e}") return 50 # Default medium risk def query_guidelines(self, query: str) -> str: """Query the knowledge base for relevant guidelines with enhanced tracking""" try: vector_store = self.knowledge_base_cache.get("vectorstore") if not vector_store: return "Knowledge base unavailable - clinical guidelines not loaded" # Retrieve relevant documents retriever = vector_store.as_retriever(search_kwargs={"k": 10}) docs = retriever.invoke(query) if not docs: return "No relevant guidelines found for the query" # Format the results with enhanced metadata formatted_results = [] for i, doc in enumerate(docs): source = doc.metadata.get('source', 'Unknown') page = doc.metadata.get('page', 'N/A') content = doc.page_content.strip() # Add relevance indicator relevance = f"Result {i+1}/10" formatted_results.append(f"[{relevance}] Source: {source}, Page: {page}\nContent: {content}") guideline_text = "\n\n".join(formatted_results) logging.info(f"Retrieved {len(docs)} guideline documents for query: {query[:50]}...") return guideline_text except Exception as e: logging.error(f"Guidelines query error: {e}") return f"Error querying guidelines: {str(e)}" @spaces.GPU(enable_queue=True, duration=90) def generate_final_report(self, patient_info: Dict[str, Any], visual_results: Dict[str, Any], guideline_context: str, image_pil: Image.Image, max_new_tokens: int = None) -> str: """Generate comprehensive medical report using MedGemma with enhanced tracking""" try: if 'medgemma_pipe' not in self.models_cache: return self._generate_fallback_report(patient_info, visual_results, guideline_context) max_tokens = max_new_tokens or self.config.MAX_NEW_TOKENS # Get detection and segmentation images if available detection_image = visual_results.get('detection_image_pil', None) segmentation_image = visual_results.get('segmentation_image_pil', None) # Create enhanced prompt with quality indicators analysis_quality = visual_results.get('analysis_quality', {}) prompt = f""" # SmartHeal AI Wound Care Report ## Patient Information {self._format_patient_info(patient_info)} ## Visual Analysis Summary - Wound Type: {visual_results.get('wound_type', 'Unknown')} (Confidence: {visual_results.get('classification_confidence', 0):.2f}) - Dimensions: {visual_results.get('length_cm', 0)} × {visual_results.get('breadth_cm', 0)} cm - Surface Area: {visual_results.get('surface_area_cm2', 0)} cm² - Detection Quality: {analysis_quality.get('detection_quality', 'medium')} - Segmentation Quality: {analysis_quality.get('segmentation_quality', 'medium')} ## Clinical Reference Guidelines {guideline_context[:2000]}... ## Analysis Request You are SmartHeal-AI Agent, a specialized wound care AI with expertise in clinical assessment and evidence-based treatment planning. Based on the comprehensive data provided (patient information, precise wound measurements, clinical guidelines, and visual analysis), generate a structured clinical report with the following sections: ### 1. Clinical Assessment - Detailed wound characterization based on visual analysis - Tissue type assessment (granulation, slough, necrotic, epithelializing) - Peri-wound skin condition evaluation - Infection risk assessment ### 2. Treatment Recommendations - Specific wound care dressing recommendations based on wound characteristics - Topical treatments if indicated - Debridement recommendations if needed - Pressure offloading strategies if applicable ### 3. Risk Stratification - Patient-specific risk factors analysis - Healing prognosis assessment - Complications to monitor ### 4. Follow-up Plan - Recommended assessment frequency - Key monitoring parameters - Escalation criteria for specialist referral Generate a concise, evidence-based report suitable for clinical documentation. """ # Prepare messages for MedGemma with all available images content_list = [{"type": "text", "text": prompt}] # Add images in order of importance if image_pil: content_list.insert(0, {"type": "image", "image": image_pil}) if detection_image: content_list.insert(1, {"type": "image", "image": detection_image}) if segmentation_image: content_list.insert(2, {"type": "image", "image": segmentation_image}) messages = [ { "role": "system", "content": [{"type": "text", "text": "You are a specialized medical AI assistant for wound care with expertise in clinical assessment, treatment planning, and evidence-based recommendations. Provide structured, actionable clinical reports."}], }, { "role": "user", "content": content_list } ] # Generate report using MedGemma output = self.models_cache['medgemma_pipe']( text=messages, max_new_tokens=max_tokens, do_sample=False, ) generated_content = output[0]['generated_text'] # Extract the assistant's response if isinstance(generated_content, list): for message in generated_content: if message.get('role') == 'assistant': report_content = message.get('content', '') if isinstance(report_content, list): report_text = ''.join([item.get('text', '') for item in report_content if item.get('type') == 'text']) else: report_text = str(report_content) break else: report_text = str(generated_content) else: report_text = str(generated_content) # Add metadata to report report_with_metadata = f""" {report_text} --- **Report Metadata:** - Generated by: SmartHeal AI v{self.model_version} - Analysis Quality: Detection ({analysis_quality.get('detection_quality', 'medium')}), Segmentation ({analysis_quality.get('segmentation_quality', 'medium')}) - Generated at: {datetime.now().strftime('%Y-%m-%d %H:%M:%S')} """ logging.info("✅ MedGemma report generated successfully") return report_with_metadata except Exception as e: logging.error(f"MedGemma report generation error: {e}") return self._generate_fallback_report(patient_info, visual_results, guideline_context) def _format_patient_info(self, patient_info: Dict[str, Any]) -> str: """Format patient information for report""" formatted = f""" - Name: {patient_info.get('patient_name', 'N/A')} - Age: {patient_info.get('patient_age', 'N/A')} years - Gender: {patient_info.get('patient_gender', 'N/A')} - Wound Location: {patient_info.get('wound_location', 'N/A')} - Wound Duration: {patient_info.get('wound_duration', 'N/A')} - Pain Level: {patient_info.get('pain_level', 'N/A')}/10 - Diabetic Status: {patient_info.get('diabetic_status', 'N/A')} - Infection Signs: {patient_info.get('infection_signs', 'N/A')} - Previous Treatment: {patient_info.get('previous_treatment', 'N/A')} - Medical History: {patient_info.get('medical_history', 'N/A')} - Current Medications: {patient_info.get('medications', 'N/A')} - Known Allergies: {patient_info.get('allergies', 'N/A')} """ return formatted.strip() def _generate_fallback_report(self, patient_info: Dict[str, Any], visual_results: Dict[str, Any], guideline_context: str) -> str: """Generate fallback report when MedGemma is not available""" wound_type = visual_results.get('wound_type', 'Unknown') length = visual_results.get('length_cm', 0) breadth = visual_results.get('breadth_cm', 0) area = visual_results.get('surface_area_cm2', 0) # Basic risk assessment risk_factors = [] if patient_info.get('patient_age', 0) > 65: risk_factors.append("Advanced age") if 'yes' in str(patient_info.get('diabetic_status', '')).lower(): risk_factors.append("Diabetes mellitus") if patient_info.get('pain_level', 0) > 6: risk_factors.append("High pain level") if area > 5: risk_factors.append("Large wound size") report = f""" # SmartHeal AI Wound Assessment Report ## Clinical Summary **Patient:** {patient_info.get('patient_name', 'N/A')}, {patient_info.get('patient_age', 'N/A')} years old {patient_info.get('patient_gender', '')} **Wound Characteristics:** - Type: {wound_type} - Location: {patient_info.get('wound_location', 'N/A')} - Dimensions: {length} × {breadth} cm (Area: {area} cm²) - Duration: {patient_info.get('wound_duration', 'N/A')} - Pain Level: {patient_info.get('pain_level', 'N/A')}/10 ## Risk Assessment **Identified Risk Factors:** {chr(10).join(f'- {factor}' for factor in risk_factors) if risk_factors else '- No significant risk factors identified'} ## Treatment Recommendations **Wound Care:** - Regular wound assessment and documentation - Appropriate dressing selection based on wound characteristics - Maintain moist wound environment - Monitor for signs of infection **Patient Management:** - Pain management as indicated - Nutritional assessment and optimization - Patient education on wound care ## Follow-up Plan - Reassess wound in 1-2 weeks - Monitor for signs of healing or deterioration - Consider specialist referral if no improvement in 4 weeks --- **Report Generated by:** SmartHeal AI Fallback System v{self.model_version} **Generated at:** {datetime.now().strftime('%Y-%m-%d %H:%M:%S')} **Note:** This is a basic assessment. For comprehensive analysis, ensure all AI models are properly loaded. """ logging.info("✅ Fallback report generated") return report