SmartHeal commited on
Commit
75ebbe5
·
verified ·
1 Parent(s): 043da85

Update src/ai_processor.py

Browse files
Files changed (1) hide show
  1. src/ai_processor.py +423 -394
src/ai_processor.py CHANGED
@@ -1,99 +1,225 @@
 
 
 
1
  import os
 
2
  import logging
 
 
 
3
  import cv2
4
  import numpy as np
5
  from PIL import Image
6
- from datetime import datetime
7
- import gradio as gr
8
- import spaces
9
- import torch
10
- import time
11
-
12
- from huggingface_hub import HfApi, HfFolder
13
- from langchain_community.document_loaders import PyPDFLoader
14
- from langchain.text_splitter import RecursiveCharacterTextSplitter
15
- from langchain_community.embeddings import HuggingFaceEmbeddings
16
- from langchain_community.vectorstores import FAISS
17
 
18
  # =============== LOGGING SETUP ===============
19
- logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s')
20
 
21
  # =============== CONFIGURATION ===============
22
  UPLOADS_DIR = "uploads"
23
- if not os.path.exists(UPLOADS_DIR):
24
- os.makedirs(UPLOADS_DIR)
25
- logging.info(f"Created uploads directory: {UPLOADS_DIR}")
26
 
27
- HF_TOKEN = os.getenv("HF_TOKEN")
28
  YOLO_MODEL_PATH = "src/best.pt"
29
- SEG_MODEL_PATH = "src/segmentation_model.h5"
30
  GUIDELINE_PDFS = ["src/eHealth in Wound Care.pdf", "src/IWGDF Guideline.pdf", "src/evaluation.pdf"]
31
- DATASET_ID = "SmartHeal/wound-image-uploads"
32
- MAX_NEW_TOKENS = 1024 # Reduced for stability
33
- PIXELS_PER_CM = 38
34
 
35
  # =============== GLOBAL CACHES ===============
36
- models_cache = {}
37
- knowledge_base_cache = {}
38
 
39
- # =============== LAZY LOADING FUNCTIONS (CPU-SAFE) ===============
40
- def load_yolo_model(yolo_model_path):
41
- """Lazy import and load YOLO model to avoid CUDA initialization."""
42
  from ultralytics import YOLO
43
- return YOLO(yolo_model_path)
44
 
45
- def load_segmentation_model(seg_model_path):
46
- """Lazy import and load segmentation model."""
47
  import tensorflow as tf
48
- tf.config.set_visible_devices([], 'GPU') # Force CPU for TensorFlow
49
  from tensorflow.keras.models import load_model
50
- return load_model(seg_model_path, compile=False)
51
 
52
- def load_classification_pipeline(hf_token):
53
- """Lazy import and load classification pipeline (CPU only)."""
54
  from transformers import pipeline
55
- return pipeline(
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
56
  "image-classification",
57
  model="Hemg/Wound-classification",
58
- token=hf_token,
59
- device="cpu"
60
  )
61
 
62
  def load_embedding_model():
63
- """Load embedding model for knowledge base."""
64
- return HuggingFaceEmbeddings(
65
- model_name="sentence-transformers/all-MiniLM-L6-v2",
66
- model_kwargs={"device": "cpu"}
67
- )
68
 
69
- # =============== MODEL INITIALIZATION ===============
70
- def initialize_cpu_models():
71
- """Initialize all CPU-only models once."""
72
- global models_cache
73
-
74
  if HF_TOKEN:
75
- HfFolder.save_token(HF_TOKEN)
76
- logging.info("✅ HuggingFace token set")
77
-
 
 
 
 
78
  if "det" not in models_cache:
79
  try:
80
- models_cache["det"] = load_yolo_model(YOLO_MODEL_PATH)
81
- logging.info("✅ YOLO model loaded (CPU only)")
82
  except Exception as e:
83
  logging.error(f"YOLO load failed: {e}")
84
 
85
  if "seg" not in models_cache:
86
  try:
87
- models_cache["seg"] = load_segmentation_model(SEG_MODEL_PATH)
88
- logging.info(" Segmentation model loaded (CPU)")
 
 
 
 
89
  except Exception as e:
 
90
  logging.warning(f"Segmentation model not available: {e}")
91
 
92
  if "cls" not in models_cache:
93
  try:
94
- models_cache["cls"] = load_classification_pipeline(HF_TOKEN)
95
  logging.info("✅ Classification pipeline loaded (CPU)")
96
  except Exception as e:
 
97
  logging.warning(f"Classification pipeline not available: {e}")
98
 
99
  if "embedding_model" not in models_cache:
@@ -101,119 +227,48 @@ def initialize_cpu_models():
101
  models_cache["embedding_model"] = load_embedding_model()
102
  logging.info("✅ Embedding model loaded (CPU)")
103
  except Exception as e:
 
104
  logging.warning(f"Embedding model not available: {e}")
105
 
106
- def setup_knowledge_base():
107
- """Load PDF documents and create FAISS vector store."""
108
- global knowledge_base_cache
109
  if "vector_store" in knowledge_base_cache:
110
  return
111
 
112
  docs = []
113
- for pdf_path in GUIDELINE_PDFS:
114
- if os.path.exists(pdf_path):
115
- try:
116
- loader = PyPDFLoader(pdf_path)
117
- docs.extend(loader.load())
118
- logging.info(f"Loaded PDF: {pdf_path}")
119
- except Exception as e:
120
- logging.warning(f"Failed to load PDF {pdf_path}: {e}")
121
-
122
- if docs and "embedding_model" in models_cache:
123
- splitter = RecursiveCharacterTextSplitter(chunk_size=1000, chunk_overlap=100)
124
- chunks = splitter.split_documents(docs)
125
- knowledge_base_cache["vector_store"] = FAISS.from_documents(chunks, models_cache["embedding_model"])
126
- logging.info(f"✅ Knowledge base ready with {len(chunks)} chunks")
 
 
 
 
 
 
 
 
 
 
127
  else:
128
  knowledge_base_cache["vector_store"] = None
129
- logging.warning("Knowledge base unavailable")
130
 
131
- # Initialize models on app startup
132
  initialize_cpu_models()
133
  setup_knowledge_base()
134
 
135
- # =============== GPU-DECORATED MEDGEMMA FUNCTION WITH TIMEOUT HANDLING ===============
136
- @spaces.GPU(enable_queue=True, duration=90) # Reduced duration for stability
137
- def generate_medgemma_report_with_timeout(
138
- patient_info,
139
- visual_results,
140
- guideline_context,
141
- image_pil,
142
- max_new_tokens=None,
143
- ):
144
- """GPU-only function for MedGemma report generation with improved timeout handling."""
145
- import torch
146
- from transformers import pipeline
147
-
148
- try:
149
- # Clear GPU cache first
150
- if torch.cuda.is_available():
151
- torch.cuda.empty_cache()
152
-
153
- # Use a shorter, more focused prompt to reduce processing time
154
- prompt = f"""
155
- You are a medical AI assistant. Analyze this wound image and patient data to provide a clinical assessment.
156
-
157
- Patient: {patient_info}
158
- Wound: {visual_results.get('wound_type', 'Unknown')} - {visual_results.get('length_cm', 0)}×{visual_results.get('breadth_cm', 0)}cm
159
-
160
- Provide a structured report with:
161
- 1. Clinical Summary (wound appearance, size, location)
162
- 2. Treatment Recommendations (dressings, care protocols)
163
- 3. Risk Assessment (healing factors)
164
- 4. Monitoring Plan (follow-up schedule)
165
-
166
- Keep response concise but medically comprehensive.
167
- """
168
-
169
- # Initialize pipeline with optimized settings
170
- pipe = pipeline(
171
- "image-text-to-text",
172
- model="google/medgemma-4b-it",
173
- torch_dtype=torch.bfloat16,
174
- device_map="auto",
175
- token=HF_TOKEN,
176
- model_kwargs={"low_cpu_mem_usage": True, "use_cache": True}
177
- )
178
-
179
- messages = [
180
- {
181
- "role": "user",
182
- "content": [
183
- {"type": "image", "image": image_pil},
184
- {"type": "text", "text": prompt},
185
- ]
186
- }
187
- ]
188
-
189
- # Generate with conservative settings
190
- start_time = time.time()
191
- output = pipe(
192
- text=messages,
193
- max_new_tokens=max_new_tokens or 800, # Reduced for stability
194
- do_sample=False,
195
- temperature=0.7,
196
- pad_token_id=pipe.tokenizer.eos_token_id
197
- )
198
-
199
- processing_time = time.time() - start_time
200
- logging.info(f"✅ MedGemma processing completed in {processing_time:.2f} seconds")
201
-
202
- if output and len(output) > 0:
203
- result = output[0]["generated_text"][-1].get("content", "").strip()
204
- return result if result else "⚠️ Empty response generated"
205
- else:
206
- return "⚠️ No output generated"
207
-
208
- except Exception as e:
209
- logging.error(f"❌ MedGemma generation error: {e}")
210
- return f"❌ Report generation failed: {str(e)}"
211
- finally:
212
- # Clear GPU memory
213
- if torch.cuda.is_available():
214
- torch.cuda.empty_cache()
215
-
216
- # =============== AI PROCESSOR CLASS ===============
217
  class AIProcessor:
218
  def __init__(self):
219
  self.models_cache = models_cache
@@ -223,237 +278,204 @@ class AIProcessor:
223
  self.dataset_id = DATASET_ID
224
  self.hf_token = HF_TOKEN
225
 
226
- def perform_visual_analysis(self, image_pil: Image.Image) -> dict:
227
- """Performs the full visual analysis pipeline."""
 
 
 
 
 
 
228
  try:
229
- # Convert PIL to OpenCV format
230
- image_cv = cv2.cvtColor(np.array(image_pil), cv2.COLOR_RGB2BGR)
231
-
232
- # YOLO Detection
233
- results = self.models_cache["det"].predict(image_cv, verbose=False, device="cpu")
234
- if not results or not results[0].boxes:
 
 
 
235
  raise ValueError("No wound could be detected.")
236
-
237
  box = results[0].boxes[0].xyxy[0].cpu().numpy().astype(int)
238
- detected_region_cv = image_cv[box[1]:box[3], box[0]:box[2]]
239
-
240
- # Segmentation
241
- input_size = self.models_cache["seg"].input_shape[1:3]
242
- resized = cv2.resize(detected_region_cv, (input_size[1], input_size[0]))
243
- mask_pred = self.models_cache["seg"].predict(np.expand_dims(resized / 255.0, 0), verbose=0)[0]
244
- mask_np = (mask_pred[:, :, 0] > 0.5).astype(np.uint8)
245
-
246
- # Calculate measurements
247
- contours, _ = cv2.findContours(mask_np, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE)
248
- length, breadth, area = (0, 0, 0)
249
- if contours:
250
- cnt = max(contours, key=cv2.contourArea)
251
- x, y, w, h = cv2.boundingRect(cnt)
252
- length, breadth, area = round(h / self.px_per_cm, 2), round(w / self.px_per_cm, 2), round(cv2.contourArea(cnt) / (self.px_per_cm ** 2), 2)
253
-
254
- # Classification
255
- detected_image_pil = Image.fromarray(cv2.cvtColor(detected_region_cv, cv2.COLOR_BGR2RGB))
256
- wound_type = max(self.models_cache["cls"](detected_image_pil), key=lambda x: x["score"])["label"]
257
-
258
- # Save visualization images
259
- os.makedirs(f"{self.uploads_dir}/analysis", exist_ok=True)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
260
  ts = datetime.now().strftime("%Y%m%d_%H%M%S")
261
-
262
- # Detection visualization
263
  det_vis = image_cv.copy()
264
- cv2.rectangle(det_vis, (box[0], box[1]), (box[2], box[3]), (0, 255, 0), 2)
265
- det_path = f"{self.uploads_dir}/analysis/detection_{ts}.png"
266
  cv2.imwrite(det_path, det_vis)
267
-
268
- # Original image
269
- original_path = f"{self.uploads_dir}/analysis/original_{ts}.png"
270
  cv2.imwrite(original_path, image_cv)
271
 
272
- # Segmentation visualization
273
- seg_path = None
274
- if contours:
275
- mask_resized = cv2.resize(mask_np * 255, (detected_region_cv.shape[1], detected_region_cv.shape[0]), interpolation=cv2.INTER_NEAREST)
276
- overlay = detected_region_cv.copy()
277
- overlay[mask_resized > 127] = [0, 0, 255] # Red overlay for wound area
278
- seg_vis = cv2.addWeighted(detected_region_cv, 0.7, overlay, 0.3, 0)
279
- seg_path = f"{self.uploads_dir}/analysis/segmentation_{ts}.png"
280
- cv2.imwrite(seg_path, seg_vis)
281
-
282
- visual_results = {
283
- "wound_type": wound_type,
284
- "length_cm": length,
285
- "breadth_cm": breadth,
286
  "surface_area_cm2": area,
287
- "detection_confidence": float(results[0].boxes.conf[0].cpu().item()) if results[0].boxes.conf is not None else 0.0,
 
 
288
  "detection_image_path": det_path,
289
  "segmentation_image_path": seg_path,
290
- "original_image_path": original_path
291
  }
292
- return visual_results
293
-
294
  except Exception as e:
295
  logging.error(f"Visual analysis failed: {e}")
296
- raise e
297
 
298
  def query_guidelines(self, query: str) -> str:
299
- """Query the knowledge base for relevant information."""
300
  try:
301
- vector_store = self.knowledge_base_cache.get("vector_store")
302
- if not vector_store:
303
  return "Knowledge base is not available."
304
-
305
- retriever = vector_store.as_retriever(search_kwargs={"k": 5}) # Reduced for efficiency
306
- docs = retriever.invoke(query)
307
- return "\n\n".join([f"Source: {doc.metadata.get('source', 'N/A')}\nContent: {doc.page_content[:300]}..." for doc in docs])
308
-
 
 
 
 
 
 
 
 
 
309
  except Exception as e:
310
- logging.error(f"Guidelines query failed: {e}")
311
  return f"Guidelines query failed: {str(e)}"
312
 
313
- def generate_final_report(
314
- self, patient_info: str, visual_results: dict, guideline_context: str,
315
- image_pil: Image.Image, max_new_tokens: int = None
316
- ) -> str:
317
- """Generate final report using MedGemma with timeout handling."""
318
- try:
319
- # Try MedGemma with timeout handling
320
- report = generate_medgemma_report_with_timeout(
321
- patient_info, visual_results, guideline_context, image_pil, max_new_tokens
322
- )
323
-
324
- # Check if report is valid
325
- if report and report.strip() and not report.startswith("❌") and not report.startswith("⚠️"):
326
- return report
327
- else:
328
- logging.warning("MedGemma returned invalid response, using fallback")
329
- return self._generate_fallback_report(patient_info, visual_results, guideline_context)
330
-
331
- except Exception as e:
332
- logging.error(f"MedGemma report generation failed: {e}")
333
- return self._generate_fallback_report(patient_info, visual_results, guideline_context)
334
-
335
- def _generate_fallback_report(
336
- self, patient_info: str, visual_results: dict, guideline_context: str
337
- ) -> str:
338
- """Generate comprehensive fallback report if MedGemma fails."""
339
-
340
- report = f"""# 🩺 SmartHeal AI - Comprehensive Wound Analysis Report
341
 
342
  ## 📋 Patient Information
343
  {patient_info}
344
 
345
  ## 🔍 Visual Analysis Results
346
  - **Wound Type**: {visual_results.get('wound_type', 'Unknown')}
347
- - **Dimensions**: {visual_results.get('length_cm', 0)} cm × {visual_results.get('breadth_cm', 0)} cm
348
  - **Surface Area**: {visual_results.get('surface_area_cm2', 0)} cm²
349
  - **Detection Confidence**: {visual_results.get('detection_confidence', 0):.1%}
350
 
351
- ## 📊 Analysis Images Available
352
- - **Original Image**: {visual_results.get('original_image_path', 'Available')}
353
- - **Detection Visualization**: {visual_results.get('detection_image_path', 'Available')}
354
- - **Segmentation Overlay**: {visual_results.get('segmentation_image_path', 'Available')}
355
-
356
- ## 🎯 Clinical Assessment Summary
357
-
358
- ### Wound Classification
359
- Based on automated analysis, this wound has been classified as **{visual_results.get('wound_type', 'Unspecified')}** with the following characteristics:
360
- - Size: {visual_results.get('length_cm', 0)} × {visual_results.get('breadth_cm', 0)} cm
361
- - Total area: {visual_results.get('surface_area_cm2', 0)} cm²
362
- - Detection confidence: {visual_results.get('detection_confidence', 0):.1%}
363
-
364
- ### Clinical Observations
365
- The automated visual analysis provides quantitative measurements that should be verified through clinical examination. The wound type classification helps guide initial treatment considerations.
366
-
367
- ## 💊 Treatment Recommendations
368
-
369
- ### Wound Care Protocol
370
- 1. **Assessment**: Comprehensive clinical evaluation by qualified healthcare professional
371
- 2. **Cleaning**: Gentle wound cleansing with appropriate solution
372
- 3. **Debridement**: Remove necrotic tissue if present (professional assessment required)
373
- 4. **Dressing Selection**: Choose appropriate dressing based on wound characteristics:
374
- - Moisture level assessment
375
- - Infection risk evaluation
376
- - Patient comfort and mobility
377
-
378
- ### Monitoring Plan
379
- - **Initial Phase**: Daily assessment for first week
380
- - **Ongoing Care**: Reassessment every 2-3 days or as clinically indicated
381
- - **Documentation**: Regular photo documentation and measurement tracking
382
- - **Progress Evaluation**: Weekly review of healing progression
383
-
384
- ## ⚠️ Risk Factors & Considerations
385
-
386
- ### Patient-Specific Factors
387
- Review patient history for factors that may impact healing:
388
- - Age and general health status
389
- - Diabetes or metabolic conditions
390
- - Circulation and vascular health
391
- - Nutritional status
392
- - Mobility and pressure relief
393
-
394
- ### Warning Signs
395
- Monitor for signs requiring immediate attention:
396
- - Increased pain, redness, or swelling
397
- - Purulent drainage or odor
398
- - Fever or systemic signs of infection
399
- - Wound expansion or deterioration
400
- - Delayed healing beyond expected timeframe
401
-
402
- ## 📚 Clinical Guidelines Context
403
- {guideline_context[:800]}{'...' if len(guideline_context) > 800 else ''}
404
-
405
- ## 🏥 Next Steps
406
-
407
- ### Immediate Actions
408
- 1. **Professional Consultation**: Schedule appointment with wound care specialist
409
- 2. **Baseline Documentation**: Establish comprehensive baseline assessment
410
- 3. **Treatment Plan**: Develop individualized care protocol
411
- 4. **Patient Education**: Provide wound care instructions and warning signs
412
-
413
- ### Follow-up Schedule
414
- - **Week 1**: Daily monitoring and assessment
415
- - **Week 2-4**: Every 2-3 days or as indicated
416
- - **Monthly**: Comprehensive reassessment and plan review
417
- - **As Needed**: Immediate evaluation for any concerning changes
418
-
419
- ## ⚖️ Important Medical Disclaimer
420
-
421
- **This automated analysis is provided for informational and educational purposes only.**
422
-
423
- - This report does not constitute medical diagnosis or treatment advice
424
- - All measurements are computer-generated estimates requiring clinical verification
425
- - Professional medical evaluation is essential for proper diagnosis and treatment
426
- - This AI tool should supplement, not replace, clinical judgment
427
- - Always consult qualified healthcare professionals for medical decisions
428
-
429
- ### Clinical Correlation Required
430
- - Verify all measurements with standard clinical tools
431
- - Correlate findings with patient symptoms and history
432
- - Consider factors not captured in automated analysis
433
- - Follow institutional protocols and guidelines
434
-
435
- ---
436
- *Generated by SmartHeal AI - Advanced Wound Care Analysis System*
437
- *Report Date: {datetime.now().strftime('%Y-%m-%d %H:%M:%S')}*
438
- *Version: AI-Processor v1.2 with Enhanced Fallback Reporting*
439
  """
440
- return report
441
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
442
  def save_and_commit_image(self, image_pil: Image.Image) -> str:
443
- """Save image locally and optionally commit to HF dataset."""
444
  try:
445
  os.makedirs(self.uploads_dir, exist_ok=True)
446
- timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
447
- filename = f"{timestamp}.png"
448
  path = os.path.join(self.uploads_dir, filename)
449
-
450
- # Save image
451
  image_pil.convert("RGB").save(path)
452
  logging.info(f"✅ Image saved locally: {path}")
453
-
454
- # Upload to HuggingFace dataset if configured
455
  if self.hf_token and self.dataset_id:
456
  try:
 
 
457
  api = HfApi()
458
  api.upload_file(
459
  path_or_fileobj=path,
@@ -461,85 +483,92 @@ Monitor for signs requiring immediate attention:
461
  repo_id=self.dataset_id,
462
  repo_type="dataset",
463
  token=self.hf_token,
464
- commit_message=f"Upload wound image: {filename}"
465
  )
466
  logging.info("✅ Image committed to HF dataset")
467
  except Exception as e:
468
  logging.warning(f"HF upload failed: {e}")
469
-
470
  return path
471
-
472
  except Exception as e:
473
- logging.error(f"Failed to save image: {e}")
474
  return ""
475
 
476
- def full_analysis_pipeline(self, image_pil: Image.Image, questionnaire_data: dict) -> dict:
477
- """Run full analysis pipeline."""
 
478
  try:
479
- # Save image first
480
  saved_path = self.save_and_commit_image(image_pil)
481
- logging.info(f"Image saved: {saved_path}")
482
-
483
- # Perform visual analysis
484
  visual_results = self.perform_visual_analysis(image_pil)
485
- logging.info(f"Visual analysis completed: {visual_results}")
486
-
487
- # Process questionnaire data
488
- patient_info = f"Age: {questionnaire_data.get('age', 'N/A')}, Diabetic: {questionnaire_data.get('diabetic', 'N/A')}, Allergies: {questionnaire_data.get('allergies', 'N/A')}, Date of Wound Sustained: {questionnaire_data.get('date_of_injury', 'N/A')}, Professional Care: {questionnaire_data.get('professional_care', 'N/A')}, Oozing/Bleeding: {questionnaire_data.get('oozing_bleeding', 'N/A')}, Infection: {questionnaire_data.get('infection', 'N/A')}, Moisture: {questionnaire_data.get('moisture', 'N/A')}"
489
-
 
 
 
 
 
 
 
 
 
490
  # Query guidelines
491
- query = f"best practices for managing a {visual_results['wound_type']} with moisture level '{questionnaire_data.get('moisture', 'unknown')}' and signs of infection '{questionnaire_data.get('infection', 'unknown')}' in a patient who is diabetic '{questionnaire_data.get('diabetic', 'unknown')}'"
 
 
 
 
492
  guideline_context = self.query_guidelines(query)
493
- logging.info("Guidelines queried successfully")
494
-
495
  # Generate final report
496
- report = self.generate_final_report(patient_info, visual_results, guideline_context, image_pil)
497
- logging.info("Report generated successfully")
498
-
 
 
499
  return {
500
- 'success': True,
501
- 'visual_analysis': visual_results,
502
- 'report': report,
503
- 'saved_image_path': saved_path,
504
- 'guideline_context': guideline_context[:500] + "..." if len(guideline_context) > 500 else guideline_context
505
  }
506
-
507
  except Exception as e:
508
  logging.error(f"Pipeline error: {e}")
509
  return {
510
- 'success': False,
511
- 'error': str(e),
512
- 'visual_analysis': {},
513
- 'report': f"Analysis failed: {str(e)}",
514
- 'saved_image_path': None,
515
- 'guideline_context': ""
516
  }
517
 
518
- def analyze_wound(self, image, questionnaire_data: dict) -> dict:
519
- """Main analysis entry point - maintains original function name."""
520
  try:
521
- # Handle different image input formats
522
  if isinstance(image, str):
523
- if os.path.exists(image):
524
- image_pil = Image.open(image)
525
- else:
526
  raise ValueError(f"Image file not found: {image}")
 
527
  elif isinstance(image, Image.Image):
528
  image_pil = image
529
  elif isinstance(image, np.ndarray):
530
  image_pil = Image.fromarray(image)
531
  else:
532
  raise ValueError(f"Unsupported image type: {type(image)}")
533
-
534
- return self.full_analysis_pipeline(image_pil, questionnaire_data)
535
-
536
  except Exception as e:
537
  logging.error(f"Wound analysis error: {e}")
538
  return {
539
- 'success': False,
540
- 'error': str(e),
541
- 'visual_analysis': {},
542
- 'report': f"Analysis initialization failed: {str(e)}",
543
- 'saved_image_path': None,
544
- 'guideline_context': ""
545
- }
 
1
+ # smartheal_ai_processor.py
2
+ # Full, functional module with conditional Spaces GPU support and CPU fallbacks.
3
+
4
  import os
5
+ import time
6
  import logging
7
+ from datetime import datetime
8
+ from typing import Optional, Dict, List
9
+
10
  import cv2
11
  import numpy as np
12
  from PIL import Image
 
 
 
 
 
 
 
 
 
 
 
13
 
14
  # =============== LOGGING SETUP ===============
15
+ logging.basicConfig(level=logging.INFO, format="%(asctime)s - %(levelname)s - %(message)s")
16
 
17
  # =============== CONFIGURATION ===============
18
  UPLOADS_DIR = "uploads"
19
+ os.makedirs(UPLOADS_DIR, exist_ok=True)
 
 
20
 
21
+ HF_TOKEN = os.getenv("HF_TOKEN", None)
22
  YOLO_MODEL_PATH = "src/best.pt"
23
+ SEG_MODEL_PATH = "src/segmentation_model.h5" # optional
24
  GUIDELINE_PDFS = ["src/eHealth in Wound Care.pdf", "src/IWGDF Guideline.pdf", "src/evaluation.pdf"]
25
+ DATASET_ID = "SmartHeal/wound-image-uploads" # optional (set HF_TOKEN too)
26
+ PIXELS_PER_CM = 38 # heuristic
 
27
 
28
  # =============== GLOBAL CACHES ===============
29
+ models_cache: Dict[str, object] = {}
30
+ knowledge_base_cache: Dict[str, object] = {}
31
 
32
+ # ---------- Optional imports guarded ----------
33
+ def _import_ultralytics():
 
34
  from ultralytics import YOLO
35
+ return YOLO
36
 
37
+ def _import_tf_loader():
 
38
  import tensorflow as tf
39
+ tf.config.set_visible_devices([], "GPU") # force CPU
40
  from tensorflow.keras.models import load_model
41
+ return load_model
42
 
43
+ def _import_hf_cls():
 
44
  from transformers import pipeline
45
+ return pipeline
46
+
47
+ def _import_embeddings():
48
+ from langchain_community.embeddings import HuggingFaceEmbeddings
49
+ return HuggingFaceEmbeddings
50
+
51
+ def _import_langchain_pdf():
52
+ from langchain_community.document_loaders import PyPDFLoader
53
+ return PyPDFLoader
54
+
55
+ def _import_langchain_faiss():
56
+ from langchain_community.vectorstores import FAISS
57
+ return FAISS
58
+
59
+ def _import_hf_hub():
60
+ from huggingface_hub import HfApi, HfFolder
61
+ return HfApi, HfFolder
62
+
63
+ # =============== SPACES GPU CONDITIONAL ===============
64
+ def _spaces_gpu_available() -> bool:
65
+ try:
66
+ import torch
67
+ return bool(torch.cuda.is_available())
68
+ except Exception:
69
+ return False
70
+
71
+ def _spaces_lib_available() -> bool:
72
+ try:
73
+ import spaces # noqa
74
+ return True
75
+ except Exception:
76
+ return False
77
+
78
+ HAVE_SPACES_GPU = _spaces_gpu_available() and _spaces_lib_available()
79
+
80
+ if HAVE_SPACES_GPU:
81
+ import spaces # define only if available & GPU present
82
+
83
+ @spaces.GPU(enable_queue=True, duration=90)
84
+ def generate_medgemma_report_with_timeout(
85
+ patient_info: str,
86
+ visual_results: Dict,
87
+ guideline_context: str,
88
+ image_pil: Image.Image,
89
+ max_new_tokens: Optional[int] = None,
90
+ ) -> str:
91
+ """Runs on Spaces GPU only; callers keep one signature on both paths."""
92
+ import torch
93
+ from transformers import pipeline
94
+ try:
95
+ torch.cuda.empty_cache()
96
+
97
+ prompt = f"""
98
+ You are a medical AI assistant. Analyze this wound image and patient data.
99
+
100
+ Patient: {patient_info}
101
+ Wound: {visual_results.get('wound_type', 'Unknown')} - {visual_results.get('length_cm', 0)}×{visual_results.get('breadth_cm', 0)} cm
102
+
103
+ Provide a structured report with:
104
+ 1. Clinical Summary
105
+ 2. Treatment Recommendations
106
+ 3. Risk Assessment
107
+ 4. Monitoring Plan
108
+ """.strip()
109
+
110
+ pipe = pipeline(
111
+ "image-text-to-text",
112
+ model="google/medgemma-4b-it",
113
+ torch_dtype=torch.bfloat16,
114
+ device_map="auto",
115
+ token=HF_TOKEN,
116
+ model_kwargs={"low_cpu_mem_usage": True, "use_cache": True},
117
+ )
118
+
119
+ messages = [
120
+ {
121
+ "role": "user",
122
+ "content": [
123
+ {"type": "image", "image": image_pil},
124
+ {"type": "text", "text": prompt},
125
+ ],
126
+ }
127
+ ]
128
+
129
+ t0 = time.time()
130
+ out = pipe(
131
+ text=messages,
132
+ max_new_tokens=max_new_tokens or 800,
133
+ do_sample=False,
134
+ temperature=0.7,
135
+ pad_token_id=pipe.tokenizer.eos_token_id,
136
+ )
137
+ logging.info(f"✅ MedGemma completed in {time.time() - t0:.2f}s")
138
+
139
+ if out and len(out) > 0:
140
+ # Defensive extraction
141
+ try:
142
+ return out[0]["generated_text"][-1].get("content", "").strip() or "⚠️ Empty response"
143
+ except Exception:
144
+ return (out[0].get("generated_text", "") or "").strip() or "⚠️ Empty response"
145
+ return "⚠️ No output generated"
146
+ except Exception as e:
147
+ logging.error(f"❌ MedGemma generation error: {e}")
148
+ return f"❌ Report generation failed: {str(e)}"
149
+ finally:
150
+ try:
151
+ torch.cuda.empty_cache()
152
+ except Exception:
153
+ pass
154
+ else:
155
+ def generate_medgemma_report_with_timeout(
156
+ patient_info: str,
157
+ visual_results: Dict,
158
+ guideline_context: str,
159
+ image_pil: Image.Image,
160
+ max_new_tokens: Optional[int] = None,
161
+ ) -> str:
162
+ """CPU-only path: return a warning so caller uses fallback."""
163
+ return "⚠️ GPU not available"
164
+
165
+ # =============== MODEL INITIALIZATION (CPU-SAFE) ===============
166
+ def load_yolo_model():
167
+ YOLO = _import_ultralytics()
168
+ return YOLO(YOLO_MODEL_PATH)
169
+
170
+ def load_segmentation_model():
171
+ load_model = _import_tf_loader()
172
+ return load_model(SEG_MODEL_PATH, compile=False)
173
+
174
+ def load_classification_pipeline():
175
+ pipe = _import_hf_cls()
176
+ return pipe(
177
  "image-classification",
178
  model="Hemg/Wound-classification",
179
+ token=HF_TOKEN,
180
+ device="cpu",
181
  )
182
 
183
  def load_embedding_model():
184
+ Emb = _import_embeddings()
185
+ return Emb(model_name="sentence-transformers/all-MiniLM-L6-v2", model_kwargs={"device": "cpu"})
 
 
 
186
 
187
+ def initialize_cpu_models() -> None:
188
+ """Initialize all CPU-only models once with robust fallbacks."""
189
+ # Hugging Face auth (optional)
 
 
190
  if HF_TOKEN:
191
+ try:
192
+ HfApi, HfFolder = _import_hf_hub()
193
+ HfFolder.save_token(HF_TOKEN)
194
+ logging.info("✅ HuggingFace token set")
195
+ except Exception as e:
196
+ logging.warning(f"HF token save failed: {e}")
197
+
198
  if "det" not in models_cache:
199
  try:
200
+ models_cache["det"] = load_yolo_model()
201
+ logging.info("✅ YOLO model loaded (CPU)")
202
  except Exception as e:
203
  logging.error(f"YOLO load failed: {e}")
204
 
205
  if "seg" not in models_cache:
206
  try:
207
+ if os.path.exists(SEG_MODEL_PATH):
208
+ models_cache["seg"] = load_segmentation_model()
209
+ logging.info("✅ Segmentation model loaded (CPU)")
210
+ else:
211
+ models_cache["seg"] = None
212
+ logging.warning("Segmentation model file not found; skipping seg.")
213
  except Exception as e:
214
+ models_cache["seg"] = None
215
  logging.warning(f"Segmentation model not available: {e}")
216
 
217
  if "cls" not in models_cache:
218
  try:
219
+ models_cache["cls"] = load_classification_pipeline()
220
  logging.info("✅ Classification pipeline loaded (CPU)")
221
  except Exception as e:
222
+ models_cache["cls"] = None
223
  logging.warning(f"Classification pipeline not available: {e}")
224
 
225
  if "embedding_model" not in models_cache:
 
227
  models_cache["embedding_model"] = load_embedding_model()
228
  logging.info("✅ Embedding model loaded (CPU)")
229
  except Exception as e:
230
+ models_cache["embedding_model"] = None
231
  logging.warning(f"Embedding model not available: {e}")
232
 
233
+ def setup_knowledge_base() -> None:
234
+ """Load PDFs and create FAISS vector store (optional)."""
 
235
  if "vector_store" in knowledge_base_cache:
236
  return
237
 
238
  docs = []
239
+ try:
240
+ PyPDFLoader = _import_langchain_pdf()
241
+ for pdf in GUIDELINE_PDFS:
242
+ if os.path.exists(pdf):
243
+ try:
244
+ loader = PyPDFLoader(pdf)
245
+ docs.extend(loader.load())
246
+ logging.info(f"Loaded PDF: {pdf}")
247
+ except Exception as e:
248
+ logging.warning(f"Failed to load PDF {pdf}: {e}")
249
+ except Exception as e:
250
+ logging.warning(f"LangChain PDF loader unavailable: {e}")
251
+
252
+ if docs and models_cache.get("embedding_model"):
253
+ try:
254
+ from langchain.text_splitter import RecursiveCharacterTextSplitter
255
+ FAISS = _import_langchain_faiss()
256
+ splitter = RecursiveCharacterTextSplitter(chunk_size=1000, chunk_overlap=100)
257
+ chunks = splitter.split_documents(docs)
258
+ knowledge_base_cache["vector_store"] = FAISS.from_documents(chunks, models_cache["embedding_model"])
259
+ logging.info(f"✅ Knowledge base ready with {len(chunks)} chunks")
260
+ except Exception as e:
261
+ knowledge_base_cache["vector_store"] = None
262
+ logging.warning(f"Knowledge base unavailable: {e}")
263
  else:
264
  knowledge_base_cache["vector_store"] = None
265
+ logging.warning("Knowledge base disabled (no docs or embeddings).")
266
 
267
+ # Initialize on import
268
  initialize_cpu_models()
269
  setup_knowledge_base()
270
 
271
+ # =============== AI PROCESSOR ===============
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
272
  class AIProcessor:
273
  def __init__(self):
274
  self.models_cache = models_cache
 
278
  self.dataset_id = DATASET_ID
279
  self.hf_token = HF_TOKEN
280
 
281
+ # ---------- Image utilities ----------
282
+ def _ensure_analysis_dir(self) -> str:
283
+ out_dir = os.path.join(self.uploads_dir, "analysis")
284
+ os.makedirs(out_dir, exist_ok=True)
285
+ return out_dir
286
+
287
+ def perform_visual_analysis(self, image_pil: Image.Image) -> Dict:
288
+ """YOLO detect → (optional) Keras seg → (optional) HF classifier → save visuals."""
289
  try:
290
+ image_cv = cv2.cvtColor(np.array(image_pil.convert("RGB")), cv2.COLOR_RGB2BGR)
291
+
292
+ det = self.models_cache.get("det")
293
+ if det is None:
294
+ raise RuntimeError("YOLO model not loaded")
295
+
296
+ # YOLO on CPU
297
+ results = det.predict(image_cv, verbose=False, device="cpu")
298
+ if not results or not getattr(results[0], "boxes", None) or len(results[0].boxes) == 0:
299
  raise ValueError("No wound could be detected.")
300
+
301
  box = results[0].boxes[0].xyxy[0].cpu().numpy().astype(int)
302
+ x1, y1, x2, y2 = [int(v) for v in box]
303
+ x1, y1 = max(0, x1), max(0, y1)
304
+ x2, y2 = min(image_cv.shape[1], x2), min(image_cv.shape[0], y2)
305
+ detected_region_cv = image_cv[y1:y2, x1:x2]
306
+
307
+ # Optional segmentation
308
+ seg_model = self.models_cache.get("seg")
309
+ length = breadth = area = 0.0
310
+ seg_path = None
311
+ if seg_model is not None and detected_region_cv.size > 0:
312
+ try:
313
+ input_size = seg_model.input_shape[1:3]
314
+ resized = cv2.resize(detected_region_cv, (input_size[1], input_size[0]))
315
+ mask_pred = seg_model.predict(np.expand_dims(resized / 255.0, 0), verbose=0)[0]
316
+ mask_np = (mask_pred[:, :, 0] > 0.5).astype(np.uint8)
317
+
318
+ contours, _ = cv2.findContours(mask_np, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE)
319
+ if contours:
320
+ cnt = max(contours, key=cv2.contourArea)
321
+ x, y, w, h = cv2.boundingRect(cnt)
322
+ length = round(h / self.px_per_cm, 2)
323
+ breadth = round(w / self.px_per_cm, 2)
324
+ area = round(cv2.contourArea(cnt) / (self.px_per_cm ** 2), 2)
325
+
326
+ # overlay visualization
327
+ mask_resized = cv2.resize(
328
+ mask_np * 255,
329
+ (detected_region_cv.shape[1], detected_region_cv.shape[0]),
330
+ interpolation=cv2.INTER_NEAREST,
331
+ )
332
+ overlay = detected_region_cv.copy()
333
+ overlay[mask_resized > 127] = [0, 0, 255]
334
+ seg_vis = cv2.addWeighted(detected_region_cv, 0.7, overlay, 0.3, 0)
335
+
336
+ ts = datetime.now().strftime("%Y%m%d_%H%M%S")
337
+ out_dir = self._ensure_analysis_dir()
338
+ seg_path = os.path.join(out_dir, f"segmentation_{ts}.png")
339
+ cv2.imwrite(seg_path, seg_vis)
340
+ except Exception as e:
341
+ logging.warning(f"Segmentation step skipped: {e}")
342
+
343
+ # Optional classification
344
+ wound_type = "Unknown"
345
+ cls_pipe = self.models_cache.get("cls")
346
+ if cls_pipe is not None:
347
+ try:
348
+ detected_image_pil = Image.fromarray(cv2.cvtColor(detected_region_cv, cv2.COLOR_BGR2RGB))
349
+ preds = cls_pipe(detected_image_pil)
350
+ if preds:
351
+ wound_type = max(preds, key=lambda x: x.get("score", 0)).get("label", "Unknown")
352
+ except Exception as e:
353
+ logging.warning(f"Classification step failed: {e}")
354
+
355
+ # Save detection & original
356
+ out_dir = self._ensure_analysis_dir()
357
  ts = datetime.now().strftime("%Y%m%d_%H%M%S")
 
 
358
  det_vis = image_cv.copy()
359
+ cv2.rectangle(det_vis, (x1, y1), (x2, y2), (0, 255, 0), 2)
360
+ det_path = os.path.join(out_dir, f"detection_{ts}.png")
361
  cv2.imwrite(det_path, det_vis)
362
+
363
+ original_path = os.path.join(out_dir, f"original_{ts}.png")
 
364
  cv2.imwrite(original_path, image_cv)
365
 
366
+ return {
367
+ "wound_type": wound_type,
368
+ "length_cm": length,
369
+ "breadth_cm": breadth,
 
 
 
 
 
 
 
 
 
 
370
  "surface_area_cm2": area,
371
+ "detection_confidence": float(results[0].boxes.conf[0].cpu().item())
372
+ if getattr(results[0].boxes, "conf", None) is not None
373
+ else 0.0,
374
  "detection_image_path": det_path,
375
  "segmentation_image_path": seg_path,
376
+ "original_image_path": original_path,
377
  }
 
 
378
  except Exception as e:
379
  logging.error(f"Visual analysis failed: {e}")
380
+ raise
381
 
382
  def query_guidelines(self, query: str) -> str:
383
+ """Query the knowledge base (optional)."""
384
  try:
385
+ vs = self.knowledge_base_cache.get("vector_store")
386
+ if not vs:
387
  return "Knowledge base is not available."
388
+ # support both old and new retriever APIs
389
+ try:
390
+ retriever = vs.as_retriever(search_kwargs={"k": 5})
391
+ docs = retriever.get_relevant_documents(query) # LC >= 0.2
392
+ except Exception:
393
+ retriever = vs.as_retriever(search_kwargs={"k": 5})
394
+ # older invoke API
395
+ docs = retriever.invoke(query)
396
+ lines: List[str] = []
397
+ for d in docs:
398
+ src = (d.metadata or {}).get("source", "N/A")
399
+ txt = (d.page_content or "")[:300]
400
+ lines.append(f"Source: {src}\nContent: {txt}...")
401
+ return "\n\n".join(lines) if lines else "No relevant guideline snippets found."
402
  except Exception as e:
403
+ logging.warning(f"Guidelines query failed: {e}")
404
  return f"Guidelines query failed: {str(e)}"
405
 
406
+ # ---------- Report builders ----------
407
+ def _generate_fallback_report(self, patient_info: str, visual_results: Dict, guideline_context: str) -> str:
408
+ """Plaintext/markdown fallback when MedGemma is unavailable."""
409
+ return f"""# 🩺 SmartHeal AI - Comprehensive Wound Analysis Report
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
410
 
411
  ## 📋 Patient Information
412
  {patient_info}
413
 
414
  ## 🔍 Visual Analysis Results
415
  - **Wound Type**: {visual_results.get('wound_type', 'Unknown')}
416
+ - **Dimensions**: {visual_results.get('length_cm', 0)} cm × {visual_results.get('breadth_cm', 0)} cm
417
  - **Surface Area**: {visual_results.get('surface_area_cm2', 0)} cm²
418
  - **Detection Confidence**: {visual_results.get('detection_confidence', 0):.1%}
419
 
420
+ ## 📊 Analysis Images
421
+ - **Original**: {visual_results.get('original_image_path', 'N/A')}
422
+ - **Detection**: {visual_results.get('detection_image_path', 'N/A')}
423
+ - **Segmentation**: {visual_results.get('segmentation_image_path', 'N/A')}
424
+
425
+ ## 🎯 Clinical Summary
426
+ Automated analysis provides quantitative measurements; verify via clinical examination.
427
+
428
+ ## 💊 Recommendations
429
+ - Cleanse wound gently; select dressing per exudate/infection risk
430
+ - Debride necrotic tissue if indicated (clinical decision)
431
+ - Document with serial photos and measurements
432
+
433
+ ## 📅 Monitoring
434
+ - Daily in week 1, then every 2-3 days (or as indicated)
435
+ - Weekly progress review
436
+
437
+ ## 📚 Guideline Context
438
+ {(guideline_context or '')[:800]}{'...' if guideline_context and len(guideline_context) > 800 else ''}
439
+
440
+ **Disclaimer:** Automated, for decision support only. Verify clinically.
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
441
  """
 
442
 
443
+ def generate_final_report(
444
+ self,
445
+ patient_info: str,
446
+ visual_results: Dict,
447
+ guideline_context: str,
448
+ image_pil: Image.Image,
449
+ max_new_tokens: Optional[int] = None,
450
+ ) -> str:
451
+ """Try MedGemma (GPU) → fallback report."""
452
+ try:
453
+ report = generate_medgemma_report_with_timeout(
454
+ patient_info, visual_results, guideline_context, image_pil, max_new_tokens
455
+ )
456
+ if report and report.strip() and not report.startswith(("⚠️", "❌")):
457
+ return report
458
+ logging.warning("MedGemma unavailable/invalid; using fallback.")
459
+ return self._generate_fallback_report(patient_info, visual_results, guideline_context)
460
+ except Exception as e:
461
+ logging.error(f"Report generation failed: {e}")
462
+ return self._generate_fallback_report(patient_info, visual_results, guideline_context)
463
+
464
+ # ---------- HF dataset commit ----------
465
  def save_and_commit_image(self, image_pil: Image.Image) -> str:
466
+ """Save image locally and optionally upload to HF dataset."""
467
  try:
468
  os.makedirs(self.uploads_dir, exist_ok=True)
469
+ ts = datetime.now().strftime("%Y%m%d_%H%M%S")
470
+ filename = f"{ts}.png"
471
  path = os.path.join(self.uploads_dir, filename)
 
 
472
  image_pil.convert("RGB").save(path)
473
  logging.info(f"✅ Image saved locally: {path}")
474
+
 
475
  if self.hf_token and self.dataset_id:
476
  try:
477
+ HfApi, HfFolder = _import_hf_hub()
478
+ HfFolder.save_token(self.hf_token)
479
  api = HfApi()
480
  api.upload_file(
481
  path_or_fileobj=path,
 
483
  repo_id=self.dataset_id,
484
  repo_type="dataset",
485
  token=self.hf_token,
486
+ commit_message=f"Upload wound image: {filename}",
487
  )
488
  logging.info("✅ Image committed to HF dataset")
489
  except Exception as e:
490
  logging.warning(f"HF upload failed: {e}")
491
+
492
  return path
 
493
  except Exception as e:
494
+ logging.error(f"Failed to save/commit image: {e}")
495
  return ""
496
 
497
+ # ---------- Orchestrator ----------
498
+ def full_analysis_pipeline(self, image_pil: Image.Image, questionnaire_data: Dict) -> Dict:
499
+ """End-to-end analysis with robust fallbacks."""
500
  try:
 
501
  saved_path = self.save_and_commit_image(image_pil)
502
+
 
 
503
  visual_results = self.perform_visual_analysis(image_pil)
504
+
505
+ # Patient info summary text
506
+ pi = questionnaire_data or {}
507
+ patient_info = (
508
+ f"Age: {pi.get('age', 'N/A')}, "
509
+ f"Diabetic: {pi.get('diabetic', 'N/A')}, "
510
+ f"Allergies: {pi.get('allergies', 'N/A')}, "
511
+ f"Date of Wound: {pi.get('date_of_injury', 'N/A')}, "
512
+ f"Professional Care: {pi.get('professional_care', 'N/A')}, "
513
+ f"Oozing/Bleeding: {pi.get('oozing_bleeding', 'N/A')}, "
514
+ f"Infection: {pi.get('infection', 'N/A')}, "
515
+ f"Moisture: {pi.get('moisture', 'N/A')}"
516
+ )
517
+
518
  # Query guidelines
519
+ query = (
520
+ f"best practices for managing a {visual_results.get('wound_type','Unknown')} "
521
+ f"with moisture '{pi.get('moisture','unknown')}' and infection '{pi.get('infection','unknown')}' "
522
+ f"in a diabetic status '{pi.get('diabetic','unknown')}'"
523
+ )
524
  guideline_context = self.query_guidelines(query)
525
+
 
526
  # Generate final report
527
+ report = self.generate_final_report(patient_info=patient_info,
528
+ visual_results=visual_results,
529
+ guideline_context=guideline_context,
530
+ image_pil=image_pil)
531
+
532
  return {
533
+ "success": True,
534
+ "visual_analysis": visual_results,
535
+ "report": report,
536
+ "saved_image_path": saved_path,
537
+ "guideline_context": (guideline_context or "")[:500] + ("..." if guideline_context and len(guideline_context) > 500 else ""),
538
  }
 
539
  except Exception as e:
540
  logging.error(f"Pipeline error: {e}")
541
  return {
542
+ "success": False,
543
+ "error": str(e),
544
+ "visual_analysis": {},
545
+ "report": f"Analysis failed: {str(e)}",
546
+ "saved_image_path": None,
547
+ "guideline_context": "",
548
  }
549
 
550
+ def analyze_wound(self, image, questionnaire_data: Dict) -> Dict:
551
+ """Public entrypoint used by your UI."""
552
  try:
 
553
  if isinstance(image, str):
554
+ if not os.path.exists(image):
 
 
555
  raise ValueError(f"Image file not found: {image}")
556
+ image_pil = Image.open(image)
557
  elif isinstance(image, Image.Image):
558
  image_pil = image
559
  elif isinstance(image, np.ndarray):
560
  image_pil = Image.fromarray(image)
561
  else:
562
  raise ValueError(f"Unsupported image type: {type(image)}")
563
+
564
+ return self.full_analysis_pipeline(image_pil, questionnaire_data or {})
 
565
  except Exception as e:
566
  logging.error(f"Wound analysis error: {e}")
567
  return {
568
+ "success": False,
569
+ "error": str(e),
570
+ "visual_analysis": {},
571
+ "report": f"Analysis initialization failed: {str(e)}",
572
+ "saved_image_path": None,
573
+ "guideline_context": "",
574
+ }