Spaces:
Running
Running
Update src/ai_processor.py
Browse files- src/ai_processor.py +180 -419
src/ai_processor.py
CHANGED
|
@@ -1,12 +1,12 @@
|
|
| 1 |
import os
|
|
|
|
|
|
|
| 2 |
import logging
|
| 3 |
import cv2
|
| 4 |
import numpy as np
|
| 5 |
from PIL import Image
|
| 6 |
import torch
|
| 7 |
-
import json
|
| 8 |
from datetime import datetime
|
| 9 |
-
import tensorflow as tf
|
| 10 |
from transformers import pipeline
|
| 11 |
from ultralytics import YOLO
|
| 12 |
from tensorflow.keras.models import load_model
|
|
@@ -16,9 +16,21 @@ from langchain_community.embeddings import HuggingFaceEmbeddings
|
|
| 16 |
from langchain_community.vectorstores import FAISS
|
| 17 |
from huggingface_hub import HfApi, HfFolder
|
| 18 |
import spaces
|
| 19 |
-
|
| 20 |
from .config import Config
|
| 21 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 22 |
class AIProcessor:
|
| 23 |
def __init__(self):
|
| 24 |
self.models_cache = {}
|
|
@@ -28,458 +40,207 @@ class AIProcessor:
|
|
| 28 |
self._initialize_models()
|
| 29 |
|
| 30 |
def _initialize_models(self):
|
| 31 |
-
"""Initialize
|
| 32 |
-
|
| 33 |
-
|
| 34 |
-
|
| 35 |
-
|
| 36 |
-
logging.info("HuggingFace token set successfully")
|
| 37 |
|
| 38 |
-
|
| 39 |
-
|
| 40 |
-
|
| 41 |
-
|
| 42 |
-
|
| 43 |
-
|
| 44 |
-
|
| 45 |
-
|
| 46 |
-
|
| 47 |
-
|
| 48 |
-
|
| 49 |
-
|
| 50 |
-
|
| 51 |
-
|
| 52 |
-
# Initialize YOLO model for wound detection
|
| 53 |
-
try:
|
| 54 |
-
self.models_cache["det"] = YOLO(self.config.YOLO_MODEL_PATH)
|
| 55 |
-
logging.info("✅ YOLO detection model loaded successfully")
|
| 56 |
-
except Exception as e:
|
| 57 |
-
logging.warning(f"YOLO model not available: {e}")
|
| 58 |
-
|
| 59 |
-
# Initialize segmentation model
|
| 60 |
-
try:
|
| 61 |
-
self.models_cache["seg"] = load_model(self.config.SEG_MODEL_PATH, compile=False)
|
| 62 |
-
logging.info("✅ Segmentation model loaded successfully")
|
| 63 |
-
except Exception as e:
|
| 64 |
-
logging.warning(f"Segmentation model not available: {e}")
|
| 65 |
-
|
| 66 |
-
# Initialize wound classification model
|
| 67 |
-
try:
|
| 68 |
-
self.models_cache["cls"] = pipeline(
|
| 69 |
-
"image-classification",
|
| 70 |
-
model="Hemg/Wound-classification",
|
| 71 |
-
token=self.config.HF_TOKEN,
|
| 72 |
-
device="cpu"
|
| 73 |
-
)
|
| 74 |
-
logging.info("✅ Wound classification model loaded successfully")
|
| 75 |
-
except Exception as e:
|
| 76 |
-
logging.warning(f"Wound classification model not available: {e}")
|
| 77 |
|
| 78 |
-
|
| 79 |
-
|
| 80 |
-
|
| 81 |
-
|
| 82 |
-
|
| 83 |
-
|
| 84 |
-
logging.info("✅ Embedding model loaded successfully")
|
| 85 |
-
except Exception as e:
|
| 86 |
-
logging.warning(f"Embedding model not available: {e}")
|
| 87 |
|
| 88 |
-
|
| 89 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
| 90 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 91 |
except Exception as e:
|
| 92 |
-
logging.
|
| 93 |
|
| 94 |
-
|
| 95 |
-
"""Load knowledge base from PDF guidelines"""
|
| 96 |
try:
|
| 97 |
-
|
| 98 |
-
|
| 99 |
-
|
| 100 |
-
|
| 101 |
-
|
| 102 |
-
|
| 103 |
-
|
| 104 |
|
| 105 |
-
|
| 106 |
-
|
| 107 |
-
text_splitter = RecursiveCharacterTextSplitter(
|
| 108 |
-
chunk_size=1000,
|
| 109 |
-
chunk_overlap=100
|
| 110 |
-
)
|
| 111 |
-
chunks = text_splitter.split_documents(documents)
|
| 112 |
-
|
| 113 |
-
# Create vector store
|
| 114 |
-
vectorstore = FAISS.from_documents(chunks, self.models_cache['embedding_model'])
|
| 115 |
-
self.knowledge_base_cache['vectorstore'] = vectorstore
|
| 116 |
-
logging.info(f"✅ Knowledge base loaded with {len(chunks)} chunks")
|
| 117 |
-
else:
|
| 118 |
-
self.knowledge_base_cache['vectorstore'] = None
|
| 119 |
-
logging.warning("Knowledge base not available - no PDFs found or embedding model unavailable")
|
| 120 |
|
| 121 |
-
|
| 122 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 123 |
self.knowledge_base_cache['vectorstore'] = None
|
|
|
|
| 124 |
|
| 125 |
-
@spaces.GPU(enable_queue=True, duration=120)
|
| 126 |
def perform_visual_analysis(self, image_pil):
|
| 127 |
-
"""
|
| 128 |
try:
|
| 129 |
-
|
| 130 |
-
|
| 131 |
-
|
| 132 |
-
|
| 133 |
-
|
| 134 |
-
|
| 135 |
-
|
| 136 |
-
|
| 137 |
-
|
| 138 |
-
|
| 139 |
-
|
| 140 |
-
|
| 141 |
-
|
| 142 |
-
|
| 143 |
-
|
| 144 |
-
|
| 145 |
-
|
| 146 |
-
# Save detection image
|
| 147 |
-
detection_image_cv = image_cv.copy()
|
| 148 |
-
cv2.rectangle(detection_image_cv, (x1, y1), (x2, y2), (0, 255, 0), 2)
|
| 149 |
-
os.makedirs(os.path.join(self.config.UPLOADS_DIR, "analysis"), exist_ok=True)
|
| 150 |
-
timestamp = datetime.now().strftime('%Y%m%d_%H%M%S')
|
| 151 |
-
detection_image_path = os.path.join(self.config.UPLOADS_DIR, "analysis", f"detection_{timestamp}.png")
|
| 152 |
-
cv2.imwrite(detection_image_path, detection_image_cv)
|
| 153 |
-
detection_image_pil = Image.fromarray(cv2.cvtColor(detection_image_cv, cv2.COLOR_BGR2RGB))
|
| 154 |
-
|
| 155 |
-
# Initialize outputs
|
| 156 |
length = breadth = area = 0
|
| 157 |
-
|
| 158 |
-
|
| 159 |
-
|
| 160 |
-
# Segmentation (optional)
|
| 161 |
if 'seg' in self.models_cache:
|
| 162 |
-
|
| 163 |
-
|
| 164 |
-
|
| 165 |
-
|
| 166 |
-
|
| 167 |
-
|
| 168 |
-
|
| 169 |
-
|
| 170 |
-
|
| 171 |
-
|
| 172 |
-
|
| 173 |
-
|
| 174 |
-
|
| 175 |
-
|
| 176 |
-
|
| 177 |
-
|
| 178 |
-
|
| 179 |
-
|
| 180 |
-
|
| 181 |
-
cv2.imwrite(segmentation_image_path, segmented_visual)
|
| 182 |
-
segmentation_image_pil = Image.fromarray(cv2.cvtColor(segmented_visual, cv2.COLOR_BGR2RGB))
|
| 183 |
-
|
| 184 |
-
# Wound measurements from resized mask
|
| 185 |
-
contours, _ = cv2.findContours(mask_resized, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE)
|
| 186 |
-
if contours:
|
| 187 |
-
cnt = max(contours, key=cv2.contourArea)
|
| 188 |
-
x, y, w, h = cv2.boundingRect(cnt)
|
| 189 |
-
length = round(h / self.px_per_cm, 2)
|
| 190 |
-
breadth = round(w / self.px_per_cm, 2)
|
| 191 |
-
area = round(cv2.contourArea(cnt) / (self.px_per_cm ** 2), 2)
|
| 192 |
-
|
| 193 |
-
# Classification (optional)
|
| 194 |
-
wound_type = "Unknown"
|
| 195 |
if 'cls' in self.models_cache:
|
| 196 |
try:
|
| 197 |
-
|
| 198 |
-
|
| 199 |
-
|
| 200 |
-
|
| 201 |
-
logging.warning(f"Wound classification error: {e}")
|
| 202 |
-
|
| 203 |
return {
|
| 204 |
'wound_type': wound_type,
|
| 205 |
'length_cm': length,
|
| 206 |
'breadth_cm': breadth,
|
| 207 |
'surface_area_cm2': area,
|
| 208 |
-
'detection_confidence': float(
|
| 209 |
-
'
|
| 210 |
-
'
|
| 211 |
-
'detection_image_pil': detection_image_pil,
|
| 212 |
-
'segmentation_image_path': segmentation_image_path,
|
| 213 |
-
'segmentation_image_pil': segmentation_image_pil
|
| 214 |
}
|
| 215 |
-
|
| 216 |
except Exception as e:
|
| 217 |
logging.error(f"Visual analysis error: {e}")
|
| 218 |
-
raise
|
| 219 |
-
|
| 220 |
|
| 221 |
def query_guidelines(self, query: str):
|
| 222 |
-
"""
|
| 223 |
-
|
| 224 |
-
|
| 225 |
-
|
| 226 |
-
|
| 227 |
-
|
| 228 |
-
|
| 229 |
-
|
| 230 |
-
docs = retriever.invoke(query)
|
| 231 |
-
|
| 232 |
-
if not docs:
|
| 233 |
-
return "No relevant guidelines found for the query"
|
| 234 |
-
|
| 235 |
-
# Format the results
|
| 236 |
-
formatted_results = []
|
| 237 |
-
for doc in docs:
|
| 238 |
-
source = doc.metadata.get('source', 'Unknown')
|
| 239 |
-
page = doc.metadata.get('page', 'N/A')
|
| 240 |
-
content = doc.page_content.strip()
|
| 241 |
-
formatted_results.append(f"Source: {source}, Page: {page}\nContent: {content}")
|
| 242 |
-
|
| 243 |
-
return "\n\n".join(formatted_results)
|
| 244 |
-
|
| 245 |
-
except Exception as e:
|
| 246 |
-
logging.error(f"Guidelines query error: {e}")
|
| 247 |
-
return f"Error querying guidelines: {str(e)}"
|
| 248 |
-
|
| 249 |
-
@spaces.GPU(enable_queue=True, duration=120)
|
| 250 |
def generate_final_report(self, patient_info, visual_results, guideline_context, image_pil, max_new_tokens=None):
|
| 251 |
-
"""
|
| 252 |
-
|
| 253 |
-
if 'medgemma_pipe' not in self.models_cache:
|
| 254 |
-
return self._generate_fallback_report(patient_info, visual_results, guideline_context)
|
| 255 |
-
|
| 256 |
-
max_tokens = max_new_tokens or self.config.MAX_NEW_TOKENS
|
| 257 |
-
|
| 258 |
-
# Get detection and segmentation images if available
|
| 259 |
-
detection_image = visual_results.get('detection_image_pil', None)
|
| 260 |
-
segmentation_image = visual_results.get('segmentation_image_pil', None)
|
| 261 |
-
|
| 262 |
-
# Create image paths for report
|
| 263 |
-
detection_path = visual_results.get('detection_image_path', '')
|
| 264 |
-
segmentation_path = visual_results.get('segmentation_image_path', '')
|
| 265 |
-
|
| 266 |
-
# Create detailed prompt for medical analysis with image paths
|
| 267 |
-
prompt = f"""
|
| 268 |
-
# Wound Care Report
|
| 269 |
-
|
| 270 |
-
## Patient Information
|
| 271 |
-
{patient_info}
|
| 272 |
-
|
| 273 |
-
## Visual Analysis Summary
|
| 274 |
-
- Wound Type: {visual_results.get('wound_type', 'Unknown')}
|
| 275 |
-
- Length: {visual_results.get('length_cm', 0)} cm
|
| 276 |
-
- Breadth: {visual_results.get('breadth_cm', 0)} cm
|
| 277 |
-
- Surface Area: {visual_results.get('surface_area_cm2', 0)} cm²
|
| 278 |
-
- Detection Confidence: {visual_results.get('detection_confidence', 0):.2f}
|
| 279 |
-
|
| 280 |
-
## Clinical Reference
|
| 281 |
-
{guideline_context}
|
| 282 |
-
|
| 283 |
-
You are SmartHeal-AI Agent, a world-class wound care AI specialist trained in clinical wound assessment and guideline-based treatment planning.
|
| 284 |
-
Your task is to process the following structured inputs (patient data, wound measurements, clinical guidelines, and image) and perform **clinical reasoning and decision-making** to generate a complete wound care report.
|
| 285 |
-
---
|
| 286 |
-
🔍 **YOUR PROCESS — FOLLOW STRICTLY:**
|
| 287 |
-
### Step 1: Clinical Reasoning (Chain-of-Thought)
|
| 288 |
-
Use the provided information to think step-by-step about:
|
| 289 |
-
- Patient’s risk factors (e.g. diabetes, age, healing limitations)
|
| 290 |
-
- Wound characteristics (size, tissue appearance, moisture, infection signs)
|
| 291 |
-
- Visual clues from the image (location, granulation, maceration, inflammation, surrounding skin)
|
| 292 |
-
|
| 293 |
-
---
|
| 294 |
-
-Step 2: Structured Clinical Report
|
| 295 |
-
Generate the following report sections using markdown and medical terminology:
|
| 296 |
-
**1. Clinical Summary**
|
| 297 |
-
- Describe wound appearance and tissue types (e.g., slough, necrotic, granulating, epithelializing)
|
| 298 |
-
- Include size, wound bed condition, peri-wound skin, and signs of infection or biofilm
|
| 299 |
-
- Mention inferred location (e.g., heel, forefoot) if image allows
|
| 300 |
-
- Summarize patient's systemic risk profile
|
| 301 |
-
**2. Medicinal & Dressing Recommendations**
|
| 302 |
-
Based on your analysis:
|
| 303 |
-
- Recommend specific **wound care dressings** (e.g., hydrocolloid, alginate, foam, antimicrobial silver, etc.) suitable to wound moisture level and infection risk
|
| 304 |
-
- Propose **topical or systemic agents** ONLY if relevant — include name classes (e.g., antiseptic: povidone iodine, antibiotic ointments, enzymatic debriders)
|
| 305 |
-
- Mention **techniques** (e.g., sharp debridement, NPWT, moisture balance, pressure offloading, dressing frequency)
|
| 306 |
-
- Avoid repeating guidelines — **apply them**
|
| 307 |
-
**3. Key Risk Factors**
|
| 308 |
-
Explain how the patient’s condition (e.g., diabetic, poor circulation, advanced age, poor hygiene) may affect wound healing
|
| 309 |
-
**4. Prognosis & Monitoring Advice**
|
| 310 |
-
- Mention how often wound should be reassessed
|
| 311 |
-
- Indicate signs to monitor for deterioration or improvement
|
| 312 |
-
- Include when escalation to specialist is necessary
|
| 313 |
-
|
| 314 |
-
**Note:** Every dressing change is a chance for wound reassessment. Always perform a thorough wound evaluation at each dressing change.
|
| 315 |
-
"""
|
| 316 |
-
|
| 317 |
-
# Prepare messages for MedGemma with all available images
|
| 318 |
-
content_list = [{"type": "text", "text": prompt}]
|
| 319 |
-
|
| 320 |
-
# Add original image
|
| 321 |
-
if image_pil:
|
| 322 |
-
content_list.insert(0, {"type": "image", "image": image_pil})
|
| 323 |
-
|
| 324 |
-
# Add detection image if available
|
| 325 |
-
if detection_image:
|
| 326 |
-
content_list.insert(1, {"type": "image", "image": detection_image})
|
| 327 |
-
|
| 328 |
-
# Add segmentation image if available
|
| 329 |
-
if segmentation_image:
|
| 330 |
-
content_list.insert(2, {"type": "image", "image": segmentation_image})
|
| 331 |
-
|
| 332 |
-
messages = [
|
| 333 |
-
{
|
| 334 |
-
"role": "system",
|
| 335 |
-
"content": [{"type": "text", "text": "You are a world-class medical AI assistant specializing in wound care with expertise in wound assessment and treatment. Provide concise, evidence-based medical assessments focusing on: (1) Precise wound classification based on tissue type and appearance, (2) Specific treatment recommendations with exact product names or interventions when appropriate, (3) Objective evaluation of healing progression or deterioration indicators, and (4) Clear follow-up timelines. Avoid general statements and prioritize actionable insights based on the visual analysis measurements and patient context."}],
|
| 336 |
-
},
|
| 337 |
-
{
|
| 338 |
-
"role": "user",
|
| 339 |
-
"content": content_list
|
| 340 |
-
}
|
| 341 |
-
]
|
| 342 |
-
|
| 343 |
-
# Generate report using MedGemma
|
| 344 |
-
output = self.models_cache['medgemma_pipe'](
|
| 345 |
-
text=messages,
|
| 346 |
-
max_new_tokens=1024,
|
| 347 |
-
do_sample=False,
|
| 348 |
-
)
|
| 349 |
-
|
| 350 |
-
generated_content = output[0]['generated_text'][-1].get('content', '').strip()
|
| 351 |
-
|
| 352 |
-
# Include image paths in the final report for display in UI
|
| 353 |
-
if generated_content:
|
| 354 |
-
# Add image paths to the report for frontend display
|
| 355 |
-
image_paths_section = f"""
|
| 356 |
-
## Analysis Images
|
| 357 |
-
- Original Image: {image_pil}
|
| 358 |
-
- Detection Image: {detection_path}
|
| 359 |
-
- Segmentation Image: {segmentation_path}
|
| 360 |
-
"""
|
| 361 |
-
generated_content = image_paths_section + generated_content
|
| 362 |
-
|
| 363 |
-
return generated_content if generated_content else self._generate_fallback_report(patient_info, visual_results, guideline_context)
|
| 364 |
-
|
| 365 |
-
except Exception as e:
|
| 366 |
-
logging.error(f"MedGemma report generation error: {e}")
|
| 367 |
return self._generate_fallback_report(patient_info, visual_results, guideline_context)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 368 |
|
| 369 |
def _generate_fallback_report(self, patient_info, visual_results, guideline_context):
|
| 370 |
-
"""
|
| 371 |
-
|
| 372 |
-
|
| 373 |
-
|
| 374 |
-
|
| 375 |
-
report = f"""
|
| 376 |
-
# Wound Analysis Report
|
| 377 |
-
## Patient Information
|
| 378 |
-
{patient_info}
|
| 379 |
-
|
| 380 |
-
## Visual Analysis Results
|
| 381 |
-
- **Wound Type**: {visual_results.get('wound_type', 'Unknown')}
|
| 382 |
-
- **Dimensions**: {visual_results.get('length_cm', 0)} cm × {visual_results.get('breadth_cm', 0)} cm
|
| 383 |
-
- **Surface Area**: {visual_results.get('surface_area_cm2', 0)} cm²
|
| 384 |
-
- **Detection Confidence**: {visual_results.get('detection_confidence', 0):.2f}
|
| 385 |
-
|
| 386 |
-
## Analysis Images
|
| 387 |
-
- **Detection Image**: {detection_path}
|
| 388 |
-
- **Segmentation Image**: {segmentation_path}
|
| 389 |
-
|
| 390 |
-
## Assessment
|
| 391 |
-
Based on the visual analysis, this appears to be a {visual_results.get('wound_type', 'wound')} with measurable dimensions.
|
| 392 |
-
|
| 393 |
-
## Recommendations
|
| 394 |
-
- Continue monitoring wound healing progress
|
| 395 |
-
- Maintain proper wound hygiene
|
| 396 |
-
- Follow appropriate dressing protocols
|
| 397 |
-
- Seek medical attention if signs of infection develop
|
| 398 |
-
|
| 399 |
-
## Clinical Guidelines
|
| 400 |
-
{guideline_context[:500]}...
|
| 401 |
-
|
| 402 |
-
*Note: This is an automated analysis. Please consult with a healthcare professional for definitive diagnosis and treatment.*
|
| 403 |
-
"""
|
| 404 |
-
return report
|
| 405 |
|
| 406 |
def save_and_commit_image(self, image_pil):
|
| 407 |
-
"""Save
|
| 408 |
-
|
| 409 |
-
|
| 410 |
-
|
| 411 |
-
|
| 412 |
-
|
| 413 |
-
|
| 414 |
-
|
| 415 |
-
|
| 416 |
-
|
| 417 |
-
|
| 418 |
-
|
| 419 |
-
|
| 420 |
-
# Upload to HuggingFace dataset if configured
|
| 421 |
-
if self.config.HF_TOKEN and self.config.DATASET_ID:
|
| 422 |
-
try:
|
| 423 |
-
api = HfApi()
|
| 424 |
-
api.upload_file(
|
| 425 |
-
path_or_fileobj=local_path,
|
| 426 |
-
path_in_repo=f"images/{filename}",
|
| 427 |
-
repo_id=self.config.DATASET_ID,
|
| 428 |
-
repo_type="dataset",
|
| 429 |
-
commit_message=f"Upload wound image: {filename}"
|
| 430 |
-
)
|
| 431 |
-
logging.info("✅ Image uploaded to HuggingFace dataset")
|
| 432 |
-
except Exception as e:
|
| 433 |
-
logging.warning(f"HuggingFace upload failed: {e}")
|
| 434 |
-
|
| 435 |
-
return local_path
|
| 436 |
-
|
| 437 |
-
except Exception as e:
|
| 438 |
-
logging.error(f"Image saving error: {e}")
|
| 439 |
-
return None
|
| 440 |
|
| 441 |
-
|
| 442 |
-
|
| 443 |
-
"""Complete analysis pipeline with real-time models"""
|
| 444 |
try:
|
| 445 |
-
|
| 446 |
-
|
| 447 |
-
|
| 448 |
-
|
| 449 |
-
|
| 450 |
-
|
| 451 |
-
# Format patient information
|
| 452 |
-
patient_info = ", ".join([f"{k}: {v}" for k, v in questionnaire_data.items() if v])
|
| 453 |
-
|
| 454 |
-
# Create query for guidelines
|
| 455 |
-
wound_type = visual_results.get('wound_type', 'wound')
|
| 456 |
-
moisture = questionnaire_data.get('moisture', 'unknown')
|
| 457 |
-
infection = questionnaire_data.get('infection', 'unknown')
|
| 458 |
-
diabetic = questionnaire_data.get('diabetic', 'unknown')
|
| 459 |
-
|
| 460 |
-
query = f"best practices for managing a {wound_type} with moisture level '{moisture}' and signs of infection '{infection}' in a patient who is diabetic '{diabetic}'"
|
| 461 |
-
|
| 462 |
-
# Query guidelines
|
| 463 |
-
guideline_context = self.query_guidelines(query)
|
| 464 |
-
|
| 465 |
-
# Generate final report
|
| 466 |
-
final_report = self.generate_final_report(patient_info, visual_results, guideline_context, image)
|
| 467 |
-
|
| 468 |
-
return {
|
| 469 |
-
'success': True,
|
| 470 |
-
'visual_analysis': visual_results,
|
| 471 |
-
'report': final_report,
|
| 472 |
-
'saved_image_path': saved_path,
|
| 473 |
-
'timestamp': datetime.now().isoformat()
|
| 474 |
-
}
|
| 475 |
-
|
| 476 |
except Exception as e:
|
| 477 |
-
logging.error(f"
|
| 478 |
-
return {
|
| 479 |
-
|
| 480 |
-
'error': str(e),
|
| 481 |
-
'timestamp': datetime.now().isoformat()
|
| 482 |
-
}
|
| 483 |
|
| 484 |
# Legacy methods for backward compatibility
|
| 485 |
def analyze_wound(self, image, questionnaire_data):
|
|
|
|
| 1 |
import os
|
| 2 |
+
import io
|
| 3 |
+
import base64
|
| 4 |
import logging
|
| 5 |
import cv2
|
| 6 |
import numpy as np
|
| 7 |
from PIL import Image
|
| 8 |
import torch
|
|
|
|
| 9 |
from datetime import datetime
|
|
|
|
| 10 |
from transformers import pipeline
|
| 11 |
from ultralytics import YOLO
|
| 12 |
from tensorflow.keras.models import load_model
|
|
|
|
| 16 |
from langchain_community.vectorstores import FAISS
|
| 17 |
from huggingface_hub import HfApi, HfFolder
|
| 18 |
import spaces
|
|
|
|
| 19 |
from .config import Config
|
| 20 |
|
| 21 |
+
# Inline system prompt for MedGemma GPU pipeline
|
| 22 |
+
default_system_prompt = (
|
| 23 |
+
"You are a world-class medical AI assistant specializing in wound care "
|
| 24 |
+
"with expertise in wound assessment and treatment. Provide concise, "
|
| 25 |
+
"evidence-based medical assessments focusing on: (1) Precise wound "
|
| 26 |
+
"classification based on tissue type and appearance, (2) Specific "
|
| 27 |
+
"treatment recommendations with exact product names or interventions when "
|
| 28 |
+
"appropriate, (3) Objective evaluation of healing progression or deterioration "
|
| 29 |
+
"indicators, and (4) Clear follow-up timelines. Avoid general statements and "
|
| 30 |
+
"prioritize actionable insights based on the visual analysis measurements and "
|
| 31 |
+
"patient context."
|
| 32 |
+
)
|
| 33 |
+
|
| 34 |
class AIProcessor:
|
| 35 |
def __init__(self):
|
| 36 |
self.models_cache = {}
|
|
|
|
| 40 |
self._initialize_models()
|
| 41 |
|
| 42 |
def _initialize_models(self):
|
| 43 |
+
"""Initialize AI models; only MedGemma uses GPU."""
|
| 44 |
+
# Set HuggingFace token
|
| 45 |
+
if self.config.HF_TOKEN:
|
| 46 |
+
HfFolder.save_token(self.config.HF_TOKEN)
|
| 47 |
+
logging.info("HuggingFace token set successfully")
|
|
|
|
| 48 |
|
| 49 |
+
# MedGemma pipeline on GPU
|
| 50 |
+
try:
|
| 51 |
+
self.models_cache['medgemma_pipe'] = pipeline(
|
| 52 |
+
'image-text-to-text',
|
| 53 |
+
model='google/medgemma-4b-it',
|
| 54 |
+
device='cuda',
|
| 55 |
+
torch_dtype=torch.bfloat16,
|
| 56 |
+
offload_folder='offload',
|
| 57 |
+
token=self.config.HF_TOKEN
|
| 58 |
+
)
|
| 59 |
+
logging.info("✅ MedGemma pipeline loaded on GPU")
|
| 60 |
+
except Exception as e:
|
| 61 |
+
logging.warning(f"MedGemma pipeline not available: {e}")
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 62 |
|
| 63 |
+
# YOLO detection on CPU
|
| 64 |
+
try:
|
| 65 |
+
self.models_cache['det'] = YOLO(self.config.YOLO_MODEL_PATH)
|
| 66 |
+
logging.info("✅ YOLO detection model loaded on CPU")
|
| 67 |
+
except Exception as e:
|
| 68 |
+
logging.warning(f"YOLO model not available: {e}")
|
|
|
|
|
|
|
|
|
|
| 69 |
|
| 70 |
+
# Segmentation model on CPU
|
| 71 |
+
try:
|
| 72 |
+
self.models_cache['seg'] = load_model(self.config.SEG_MODEL_PATH, compile=False)
|
| 73 |
+
logging.info("✅ Segmentation model loaded on CPU")
|
| 74 |
+
except Exception as e:
|
| 75 |
+
logging.warning(f"Segmentation model not available: {e}")
|
| 76 |
|
| 77 |
+
# Classification on CPU
|
| 78 |
+
try:
|
| 79 |
+
self.models_cache['cls'] = pipeline(
|
| 80 |
+
'image-classification',
|
| 81 |
+
model='Hemg/Wound-classification',
|
| 82 |
+
token=self.config.HF_TOKEN,
|
| 83 |
+
device='cpu'
|
| 84 |
+
)
|
| 85 |
+
logging.info("✅ Wound classification model loaded on CPU")
|
| 86 |
except Exception as e:
|
| 87 |
+
logging.warning(f"Wound classification model not available: {e}")
|
| 88 |
|
| 89 |
+
# Embedding for knowledge base
|
|
|
|
| 90 |
try:
|
| 91 |
+
self.models_cache['embedding_model'] = HuggingFaceEmbeddings(
|
| 92 |
+
model_name='sentence-transformers/all-MiniLM-L6-v2',
|
| 93 |
+
model_kwargs={'device': 'cpu'}
|
| 94 |
+
)
|
| 95 |
+
logging.info("✅ Embedding model loaded on CPU")
|
| 96 |
+
except Exception as e:
|
| 97 |
+
logging.warning(f"Embedding model not available: {e}")
|
| 98 |
|
| 99 |
+
# Load knowledge base
|
| 100 |
+
self._load_knowledge_base()
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 101 |
|
| 102 |
+
def _load_knowledge_base(self):
|
| 103 |
+
"""Load PDF guidelines into a FAISS vector store."""
|
| 104 |
+
docs = []
|
| 105 |
+
for pdf in self.config.GUIDELINE_PDFS:
|
| 106 |
+
if os.path.exists(pdf):
|
| 107 |
+
loader = PyPDFLoader(pdf)
|
| 108 |
+
docs.extend(loader.load())
|
| 109 |
+
logging.info(f"Loaded PDF: {pdf}")
|
| 110 |
+
|
| 111 |
+
if docs and 'embedding_model' in self.models_cache:
|
| 112 |
+
splitter = RecursiveCharacterTextSplitter(chunk_size=1000, chunk_overlap=100)
|
| 113 |
+
chunks = splitter.split_documents(docs)
|
| 114 |
+
vs = FAISS.from_documents(chunks, self.models_cache['embedding_model'])
|
| 115 |
+
self.knowledge_base_cache['vectorstore'] = vs
|
| 116 |
+
logging.info(f"✅ Knowledge base loaded ({len(chunks)} chunks)")
|
| 117 |
+
else:
|
| 118 |
self.knowledge_base_cache['vectorstore'] = None
|
| 119 |
+
logging.warning("Knowledge base unavailable")
|
| 120 |
|
|
|
|
| 121 |
def perform_visual_analysis(self, image_pil):
|
| 122 |
+
"""Detect & segment on CPU; return only paths + metrics."""
|
| 123 |
try:
|
| 124 |
+
img_cv = cv2.cvtColor(np.array(image_pil), cv2.COLOR_RGB2BGR)
|
| 125 |
+
# YOLO detect
|
| 126 |
+
res = self.models_cache['det'].predict(img_cv, verbose=False)[0]
|
| 127 |
+
if not res.boxes:
|
| 128 |
+
raise ValueError("No wound detected")
|
| 129 |
+
# Bounding box
|
| 130 |
+
x1, y1, x2, y2 = res.boxes.xyxy[0].cpu().numpy().astype(int)
|
| 131 |
+
region = img_cv[y1:y2, x1:x2]
|
| 132 |
+
# Save detection overlay
|
| 133 |
+
det_vis = img_cv.copy()
|
| 134 |
+
cv2.rectangle(det_vis, (x1, y1), (x2, y2), (0,255,0), 2)
|
| 135 |
+
os.makedirs(f"{self.config.UPLOADS_DIR}/analysis", exist_ok=True)
|
| 136 |
+
ts = datetime.now().strftime('%Y%m%d_%H%M%S')
|
| 137 |
+
det_path = f"{self.config.UPLOADS_DIR}/analysis/detection_{ts}.png"
|
| 138 |
+
cv2.imwrite(det_path, det_vis)
|
| 139 |
+
# Initialize metrics & seg
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 140 |
length = breadth = area = 0
|
| 141 |
+
seg_path = None
|
| 142 |
+
# Segmentation
|
|
|
|
|
|
|
| 143 |
if 'seg' in self.models_cache:
|
| 144 |
+
h, w = self.models_cache['seg'].input_shape[1:3]
|
| 145 |
+
inp = cv2.resize(region, (w,h)) / 255.0
|
| 146 |
+
mask = (self.models_cache['seg'].predict(np.expand_dims(inp,0))[0,:,:,0] > 0.5).astype(np.uint8)
|
| 147 |
+
mask_rs = cv2.resize(mask, (region.shape[1], region.shape[0]), interpolation=cv2.INTER_NEAREST)
|
| 148 |
+
ov = region.copy()
|
| 149 |
+
ov[mask_rs==1] = [0,0,255]
|
| 150 |
+
seg_vis = cv2.addWeighted(region,0.7,ov,0.3,0)
|
| 151 |
+
seg_path = f"{self.config.UPLOADS_DIR}/analysis/segmentation_{ts}.png"
|
| 152 |
+
cv2.imwrite(seg_path, seg_vis)
|
| 153 |
+
# measure
|
| 154 |
+
cnts, _ = cv2.findContours(mask_rs, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE)
|
| 155 |
+
if cnts:
|
| 156 |
+
cnt = max(cnts, key=cv2.contourArea)
|
| 157 |
+
_,_,w0,h0 = cv2.boundingRect(cnt)
|
| 158 |
+
length = round(h0/self.px_per_cm,2)
|
| 159 |
+
breadth= round(w0/self.px_per_cm,2)
|
| 160 |
+
area = round(cv2.contourArea(cnt)/(self.px_per_cm**2),2)
|
| 161 |
+
# Classification
|
| 162 |
+
wound_type = 'Unknown'
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 163 |
if 'cls' in self.models_cache:
|
| 164 |
try:
|
| 165 |
+
label = self.models_cache['cls'](Image.fromarray(cv2.cvtColor(region, cv2.COLOR_BGR2RGB)))
|
| 166 |
+
wound_type = max(label, key=lambda x: x['score'])['label']
|
| 167 |
+
except Exception:
|
| 168 |
+
pass
|
|
|
|
|
|
|
| 169 |
return {
|
| 170 |
'wound_type': wound_type,
|
| 171 |
'length_cm': length,
|
| 172 |
'breadth_cm': breadth,
|
| 173 |
'surface_area_cm2': area,
|
| 174 |
+
'detection_confidence': float(res.boxes.conf[0].cpu().item()),
|
| 175 |
+
'detection_image_path': det_path,
|
| 176 |
+
'segmentation_image_path': seg_path
|
|
|
|
|
|
|
|
|
|
| 177 |
}
|
|
|
|
| 178 |
except Exception as e:
|
| 179 |
logging.error(f"Visual analysis error: {e}")
|
| 180 |
+
raise
|
|
|
|
| 181 |
|
| 182 |
def query_guidelines(self, query: str):
|
| 183 |
+
"""Retrieve clinical guidelines from vectorstore."""
|
| 184 |
+
vs = self.knowledge_base_cache.get('vectorstore')
|
| 185 |
+
if not vs:
|
| 186 |
+
return "Clinical guidelines unavailable"
|
| 187 |
+
docs = vs.as_retriever(search_kwargs={'k':10}).invoke(query)
|
| 188 |
+
return '\n\n'.join(f"Source: {d.metadata.get('source','?')}, Page: {d.metadata.get('page','?')}\n{d.page_content}" for d in docs)
|
| 189 |
+
|
| 190 |
+
@spaces.GPU(enable_queue=True, duration=120)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 191 |
def generate_final_report(self, patient_info, visual_results, guideline_context, image_pil, max_new_tokens=None):
|
| 192 |
+
"""Run MedGemma on GPU; return markdown report."""
|
| 193 |
+
if 'medgemma_pipe' not in self.models_cache:
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 194 |
return self._generate_fallback_report(patient_info, visual_results, guideline_context)
|
| 195 |
+
# build messages
|
| 196 |
+
msgs = [{ 'role':'system', 'content':[{'type':'text','text': default_system_prompt}] },
|
| 197 |
+
{ 'role':'user', 'content':[]}]
|
| 198 |
+
# images
|
| 199 |
+
if image_pil: msgs[1]['content'].append({'type':'image','image':image_pil})
|
| 200 |
+
for key in ('detection_image_path','segmentation_image_path'):
|
| 201 |
+
p = visual_results.get(key)
|
| 202 |
+
if p and os.path.exists(p):
|
| 203 |
+
msgs[1]['content'].append({'type':'image', 'image': Image.open(p)})
|
| 204 |
+
# text prompt stub (expand as needed)
|
| 205 |
+
prompt = f"## Patient\n{patient_info}\n## Visual Type: {visual_results['wound_type']}"
|
| 206 |
+
msgs[1]['content'].append({'type':'text','text':prompt})
|
| 207 |
+
out = self.models_cache['medgemma_pipe'](text=msgs, max_new_tokens=max_new_tokens or self.config.MAX_NEW_TOKENS)
|
| 208 |
+
report = out[0]['generated_text'][-1].get('content','')
|
| 209 |
+
return report or self._generate_fallback_report(patient_info, visual_results, guideline_context)
|
| 210 |
|
| 211 |
def _generate_fallback_report(self, patient_info, visual_results, guideline_context):
|
| 212 |
+
"""Produce text-only fallback."""
|
| 213 |
+
dp = visual_results.get('detection_image_path','N/A')
|
| 214 |
+
sp = visual_results.get('segmentation_image_path','N/A')
|
| 215 |
+
return f"# Report\n{patient_info}\nType: {visual_results['wound_type']}\nDetection Image: {dp}\nSegmentation Image: {sp}\nGuidelines: {guideline_context[:200]}..."
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 216 |
|
| 217 |
def save_and_commit_image(self, image_pil):
|
| 218 |
+
"""Save locally and optionally to HuggingFace."""
|
| 219 |
+
os.makedirs(self.config.UPLOADS_DIR, exist_ok=True)
|
| 220 |
+
fn = f"{datetime.now():%Y%m%d_%H%M%S}.png"
|
| 221 |
+
path = os.path.join(self.config.UPLOADS_DIR, fn)
|
| 222 |
+
image_pil.convert('RGB').save(path)
|
| 223 |
+
if self.config.HF_TOKEN and self.config.DATASET_ID:
|
| 224 |
+
try:
|
| 225 |
+
api = HfApi()
|
| 226 |
+
api.upload_file(path_or_fileobj=path, path_in_repo=f"images/{fn}", repo_id=self.config.DATASET_ID, repo_type='dataset')
|
| 227 |
+
except Exception as e:
|
| 228 |
+
logging.warning(f"HF upload failed: {e}")
|
| 229 |
+
return path
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 230 |
|
| 231 |
+
def full_analysis_pipeline(self, image_pil, questionnaire_data):
|
| 232 |
+
"""Orchestrate CPU steps + GPU report."""
|
|
|
|
| 233 |
try:
|
| 234 |
+
saved = self.save_and_commit_image(image_pil)
|
| 235 |
+
vis = self.perform_visual_analysis(image_pil)
|
| 236 |
+
info = ", ".join(f"{k}:{v}" for k,v in questionnaire_data.items() if v)
|
| 237 |
+
gc = self.query_guidelines(info)
|
| 238 |
+
report = self.generate_final_report(info, vis, gc, image_pil)
|
| 239 |
+
return {'success':True, 'visual_analysis':vis, 'report':report, 'saved_image_path':saved}
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 240 |
except Exception as e:
|
| 241 |
+
logging.error(f"Pipeline error: {e}")
|
| 242 |
+
return {'success':False, 'error':str(e)}
|
| 243 |
+
|
|
|
|
|
|
|
|
|
|
| 244 |
|
| 245 |
# Legacy methods for backward compatibility
|
| 246 |
def analyze_wound(self, image, questionnaire_data):
|