SmartHeal commited on
Commit
e35f4e1
·
verified ·
1 Parent(s): 2b40a77

Update src/ai_processor.py

Browse files
Files changed (1) hide show
  1. src/ai_processor.py +180 -419
src/ai_processor.py CHANGED
@@ -1,12 +1,12 @@
1
  import os
 
 
2
  import logging
3
  import cv2
4
  import numpy as np
5
  from PIL import Image
6
  import torch
7
- import json
8
  from datetime import datetime
9
- import tensorflow as tf
10
  from transformers import pipeline
11
  from ultralytics import YOLO
12
  from tensorflow.keras.models import load_model
@@ -16,9 +16,21 @@ from langchain_community.embeddings import HuggingFaceEmbeddings
16
  from langchain_community.vectorstores import FAISS
17
  from huggingface_hub import HfApi, HfFolder
18
  import spaces
19
-
20
  from .config import Config
21
 
 
 
 
 
 
 
 
 
 
 
 
 
 
22
  class AIProcessor:
23
  def __init__(self):
24
  self.models_cache = {}
@@ -28,458 +40,207 @@ class AIProcessor:
28
  self._initialize_models()
29
 
30
  def _initialize_models(self):
31
- """Initialize all AI models including real-time models"""
32
- try:
33
- # Set HuggingFace token
34
- if self.config.HF_TOKEN:
35
- HfFolder.save_token(self.config.HF_TOKEN)
36
- logging.info("HuggingFace token set successfully")
37
 
38
- # Initialize MedGemma pipeline for medical text generation
39
- try:
40
- self.models_cache["medgemma_pipe"] = pipeline(
41
- "image-text-to-text",
42
- model="google/medgemma-4b-it",
43
- torch_dtype=torch.bfloat16,
44
- offload_folder="offload",
45
- device_map="cuda",
46
- token=self.config.HF_TOKEN
47
- )
48
- logging.info("✅ MedGemma pipeline loaded successfully")
49
- except Exception as e:
50
- logging.warning(f"MedGemma pipeline not available: {e}")
51
-
52
- # Initialize YOLO model for wound detection
53
- try:
54
- self.models_cache["det"] = YOLO(self.config.YOLO_MODEL_PATH)
55
- logging.info("✅ YOLO detection model loaded successfully")
56
- except Exception as e:
57
- logging.warning(f"YOLO model not available: {e}")
58
-
59
- # Initialize segmentation model
60
- try:
61
- self.models_cache["seg"] = load_model(self.config.SEG_MODEL_PATH, compile=False)
62
- logging.info("✅ Segmentation model loaded successfully")
63
- except Exception as e:
64
- logging.warning(f"Segmentation model not available: {e}")
65
-
66
- # Initialize wound classification model
67
- try:
68
- self.models_cache["cls"] = pipeline(
69
- "image-classification",
70
- model="Hemg/Wound-classification",
71
- token=self.config.HF_TOKEN,
72
- device="cpu"
73
- )
74
- logging.info("✅ Wound classification model loaded successfully")
75
- except Exception as e:
76
- logging.warning(f"Wound classification model not available: {e}")
77
 
78
- # Initialize embedding model for knowledge base
79
- try:
80
- self.models_cache["embedding_model"] = HuggingFaceEmbeddings(
81
- model_name="sentence-transformers/all-MiniLM-L6-v2",
82
- model_kwargs={'device': 'cpu'}
83
- )
84
- logging.info("✅ Embedding model loaded successfully")
85
- except Exception as e:
86
- logging.warning(f"Embedding model not available: {e}")
87
 
88
- logging.info("✅ All models loaded.")
89
- self._load_knowledge_base()
 
 
 
 
90
 
 
 
 
 
 
 
 
 
 
91
  except Exception as e:
92
- logging.error(f"Error initializing AI models: {e}")
93
 
94
- def _load_knowledge_base(self):
95
- """Load knowledge base from PDF guidelines"""
96
  try:
97
- documents = []
98
- for pdf_path in self.config.GUIDELINE_PDFS:
99
- if os.path.exists(pdf_path):
100
- loader = PyPDFLoader(pdf_path)
101
- docs = loader.load()
102
- documents.extend(docs)
103
- logging.info(f"Loaded PDF: {pdf_path}")
104
 
105
- if documents and 'embedding_model' in self.models_cache:
106
- # Split documents into chunks
107
- text_splitter = RecursiveCharacterTextSplitter(
108
- chunk_size=1000,
109
- chunk_overlap=100
110
- )
111
- chunks = text_splitter.split_documents(documents)
112
-
113
- # Create vector store
114
- vectorstore = FAISS.from_documents(chunks, self.models_cache['embedding_model'])
115
- self.knowledge_base_cache['vectorstore'] = vectorstore
116
- logging.info(f"✅ Knowledge base loaded with {len(chunks)} chunks")
117
- else:
118
- self.knowledge_base_cache['vectorstore'] = None
119
- logging.warning("Knowledge base not available - no PDFs found or embedding model unavailable")
120
 
121
- except Exception as e:
122
- logging.warning(f"Knowledge base loading error: {e}")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
123
  self.knowledge_base_cache['vectorstore'] = None
 
124
 
125
- @spaces.GPU(enable_queue=True, duration=120)
126
  def perform_visual_analysis(self, image_pil):
127
- """Perform comprehensive visual analysis of wound image."""
128
  try:
129
- # Convert PIL to OpenCV format
130
- image_cv = cv2.cvtColor(np.array(image_pil), cv2.COLOR_RGB2BGR)
131
-
132
- # YOLO detection
133
- if 'det' not in self.models_cache:
134
- raise ValueError("YOLO detection model not available.")
135
-
136
- results = self.models_cache['det'].predict(image_cv, verbose=False, device="cpu")
137
-
138
- if not results or not results[0].boxes:
139
- raise ValueError("No wound detected in the image.")
140
-
141
- # Extract bounding box
142
- box = results[0].boxes[0].xyxy[0].cpu().numpy().astype(int)
143
- x1, y1, x2, y2 = box
144
- region_cv = image_cv[y1:y2, x1:x2]
145
-
146
- # Save detection image
147
- detection_image_cv = image_cv.copy()
148
- cv2.rectangle(detection_image_cv, (x1, y1), (x2, y2), (0, 255, 0), 2)
149
- os.makedirs(os.path.join(self.config.UPLOADS_DIR, "analysis"), exist_ok=True)
150
- timestamp = datetime.now().strftime('%Y%m%d_%H%M%S')
151
- detection_image_path = os.path.join(self.config.UPLOADS_DIR, "analysis", f"detection_{timestamp}.png")
152
- cv2.imwrite(detection_image_path, detection_image_cv)
153
- detection_image_pil = Image.fromarray(cv2.cvtColor(detection_image_cv, cv2.COLOR_BGR2RGB))
154
-
155
- # Initialize outputs
156
  length = breadth = area = 0
157
- segmentation_image_pil = None
158
- segmentation_image_path = None
159
-
160
- # Segmentation (optional)
161
  if 'seg' in self.models_cache:
162
- input_size = self.models_cache['seg'].input_shape[1:3] # (height, width)
163
- resized_region = cv2.resize(region_cv, (input_size[1], input_size[0]))
164
-
165
- seg_input = np.expand_dims(resized_region / 255.0, 0)
166
- mask_pred = self.models_cache['seg'].predict(seg_input, verbose=0)[0]
167
- mask_np = (mask_pred[:, :, 0] > 0.5).astype(np.uint8)
168
-
169
- # Resize mask back to original region size
170
- mask_resized = cv2.resize(mask_np, (region_cv.shape[1], region_cv.shape[0]), interpolation=cv2.INTER_NEAREST)
171
-
172
- # Overlay mask on region for visualization
173
- overlay = region_cv.copy()
174
- overlay[mask_resized == 1] = [0, 0, 255] # Red overlay
175
-
176
- # Blend overlay for final output
177
- segmented_visual = cv2.addWeighted(region_cv, 0.7, overlay, 0.3, 0)
178
-
179
- # Save segmentation image
180
- segmentation_image_path = os.path.join(self.config.UPLOADS_DIR, "analysis", f"segmentation_{timestamp}.png")
181
- cv2.imwrite(segmentation_image_path, segmented_visual)
182
- segmentation_image_pil = Image.fromarray(cv2.cvtColor(segmented_visual, cv2.COLOR_BGR2RGB))
183
-
184
- # Wound measurements from resized mask
185
- contours, _ = cv2.findContours(mask_resized, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE)
186
- if contours:
187
- cnt = max(contours, key=cv2.contourArea)
188
- x, y, w, h = cv2.boundingRect(cnt)
189
- length = round(h / self.px_per_cm, 2)
190
- breadth = round(w / self.px_per_cm, 2)
191
- area = round(cv2.contourArea(cnt) / (self.px_per_cm ** 2), 2)
192
-
193
- # Classification (optional)
194
- wound_type = "Unknown"
195
  if 'cls' in self.models_cache:
196
  try:
197
- region_pil = Image.fromarray(cv2.cvtColor(region_cv, cv2.COLOR_BGR2RGB))
198
- cls_result = self.models_cache['cls'](region_pil)
199
- wound_type = max(cls_result, key=lambda x: x['score'])['label']
200
- except Exception as e:
201
- logging.warning(f"Wound classification error: {e}")
202
-
203
  return {
204
  'wound_type': wound_type,
205
  'length_cm': length,
206
  'breadth_cm': breadth,
207
  'surface_area_cm2': area,
208
- 'detection_confidence': float(results[0].boxes[0].conf.cpu().item()),
209
- 'bounding_box': box.tolist(),
210
- 'detection_image_path': detection_image_path,
211
- 'detection_image_pil': detection_image_pil,
212
- 'segmentation_image_path': segmentation_image_path,
213
- 'segmentation_image_pil': segmentation_image_pil
214
  }
215
-
216
  except Exception as e:
217
  logging.error(f"Visual analysis error: {e}")
218
- raise ValueError(f"Visual analysis failed: {str(e)}")
219
-
220
 
221
  def query_guidelines(self, query: str):
222
- """Query the knowledge base for relevant guidelines"""
223
- try:
224
- vector_store = self.knowledge_base_cache.get("vectorstore")
225
- if not vector_store:
226
- return "Knowledge base unavailable - clinical guidelines not loaded"
227
-
228
- # Retrieve relevant documents
229
- retriever = vector_store.as_retriever(search_kwargs={"k": 10})
230
- docs = retriever.invoke(query)
231
-
232
- if not docs:
233
- return "No relevant guidelines found for the query"
234
-
235
- # Format the results
236
- formatted_results = []
237
- for doc in docs:
238
- source = doc.metadata.get('source', 'Unknown')
239
- page = doc.metadata.get('page', 'N/A')
240
- content = doc.page_content.strip()
241
- formatted_results.append(f"Source: {source}, Page: {page}\nContent: {content}")
242
-
243
- return "\n\n".join(formatted_results)
244
-
245
- except Exception as e:
246
- logging.error(f"Guidelines query error: {e}")
247
- return f"Error querying guidelines: {str(e)}"
248
-
249
- @spaces.GPU(enable_queue=True, duration=120)
250
  def generate_final_report(self, patient_info, visual_results, guideline_context, image_pil, max_new_tokens=None):
251
- """Generate comprehensive medical report using MedGemma"""
252
- try:
253
- if 'medgemma_pipe' not in self.models_cache:
254
- return self._generate_fallback_report(patient_info, visual_results, guideline_context)
255
-
256
- max_tokens = max_new_tokens or self.config.MAX_NEW_TOKENS
257
-
258
- # Get detection and segmentation images if available
259
- detection_image = visual_results.get('detection_image_pil', None)
260
- segmentation_image = visual_results.get('segmentation_image_pil', None)
261
-
262
- # Create image paths for report
263
- detection_path = visual_results.get('detection_image_path', '')
264
- segmentation_path = visual_results.get('segmentation_image_path', '')
265
-
266
- # Create detailed prompt for medical analysis with image paths
267
- prompt = f"""
268
- # Wound Care Report
269
-
270
- ## Patient Information
271
- {patient_info}
272
-
273
- ## Visual Analysis Summary
274
- - Wound Type: {visual_results.get('wound_type', 'Unknown')}
275
- - Length: {visual_results.get('length_cm', 0)} cm
276
- - Breadth: {visual_results.get('breadth_cm', 0)} cm
277
- - Surface Area: {visual_results.get('surface_area_cm2', 0)} cm²
278
- - Detection Confidence: {visual_results.get('detection_confidence', 0):.2f}
279
-
280
- ## Clinical Reference
281
- {guideline_context}
282
-
283
- You are SmartHeal-AI Agent, a world-class wound care AI specialist trained in clinical wound assessment and guideline-based treatment planning.
284
- Your task is to process the following structured inputs (patient data, wound measurements, clinical guidelines, and image) and perform **clinical reasoning and decision-making** to generate a complete wound care report.
285
- ---
286
- 🔍 **YOUR PROCESS — FOLLOW STRICTLY:**
287
- ### Step 1: Clinical Reasoning (Chain-of-Thought)
288
- Use the provided information to think step-by-step about:
289
- - Patient’s risk factors (e.g. diabetes, age, healing limitations)
290
- - Wound characteristics (size, tissue appearance, moisture, infection signs)
291
- - Visual clues from the image (location, granulation, maceration, inflammation, surrounding skin)
292
-
293
- ---
294
- -Step 2: Structured Clinical Report
295
- Generate the following report sections using markdown and medical terminology:
296
- **1. Clinical Summary**
297
- - Describe wound appearance and tissue types (e.g., slough, necrotic, granulating, epithelializing)
298
- - Include size, wound bed condition, peri-wound skin, and signs of infection or biofilm
299
- - Mention inferred location (e.g., heel, forefoot) if image allows
300
- - Summarize patient's systemic risk profile
301
- **2. Medicinal & Dressing Recommendations**
302
- Based on your analysis:
303
- - Recommend specific **wound care dressings** (e.g., hydrocolloid, alginate, foam, antimicrobial silver, etc.) suitable to wound moisture level and infection risk
304
- - Propose **topical or systemic agents** ONLY if relevant — include name classes (e.g., antiseptic: povidone iodine, antibiotic ointments, enzymatic debriders)
305
- - Mention **techniques** (e.g., sharp debridement, NPWT, moisture balance, pressure offloading, dressing frequency)
306
- - Avoid repeating guidelines — **apply them**
307
- **3. Key Risk Factors**
308
- Explain how the patient’s condition (e.g., diabetic, poor circulation, advanced age, poor hygiene) may affect wound healing
309
- **4. Prognosis & Monitoring Advice**
310
- - Mention how often wound should be reassessed
311
- - Indicate signs to monitor for deterioration or improvement
312
- - Include when escalation to specialist is necessary
313
-
314
- **Note:** Every dressing change is a chance for wound reassessment. Always perform a thorough wound evaluation at each dressing change.
315
- """
316
-
317
- # Prepare messages for MedGemma with all available images
318
- content_list = [{"type": "text", "text": prompt}]
319
-
320
- # Add original image
321
- if image_pil:
322
- content_list.insert(0, {"type": "image", "image": image_pil})
323
-
324
- # Add detection image if available
325
- if detection_image:
326
- content_list.insert(1, {"type": "image", "image": detection_image})
327
-
328
- # Add segmentation image if available
329
- if segmentation_image:
330
- content_list.insert(2, {"type": "image", "image": segmentation_image})
331
-
332
- messages = [
333
- {
334
- "role": "system",
335
- "content": [{"type": "text", "text": "You are a world-class medical AI assistant specializing in wound care with expertise in wound assessment and treatment. Provide concise, evidence-based medical assessments focusing on: (1) Precise wound classification based on tissue type and appearance, (2) Specific treatment recommendations with exact product names or interventions when appropriate, (3) Objective evaluation of healing progression or deterioration indicators, and (4) Clear follow-up timelines. Avoid general statements and prioritize actionable insights based on the visual analysis measurements and patient context."}],
336
- },
337
- {
338
- "role": "user",
339
- "content": content_list
340
- }
341
- ]
342
-
343
- # Generate report using MedGemma
344
- output = self.models_cache['medgemma_pipe'](
345
- text=messages,
346
- max_new_tokens=1024,
347
- do_sample=False,
348
- )
349
-
350
- generated_content = output[0]['generated_text'][-1].get('content', '').strip()
351
-
352
- # Include image paths in the final report for display in UI
353
- if generated_content:
354
- # Add image paths to the report for frontend display
355
- image_paths_section = f"""
356
- ## Analysis Images
357
- - Original Image: {image_pil}
358
- - Detection Image: {detection_path}
359
- - Segmentation Image: {segmentation_path}
360
- """
361
- generated_content = image_paths_section + generated_content
362
-
363
- return generated_content if generated_content else self._generate_fallback_report(patient_info, visual_results, guideline_context)
364
-
365
- except Exception as e:
366
- logging.error(f"MedGemma report generation error: {e}")
367
  return self._generate_fallback_report(patient_info, visual_results, guideline_context)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
368
 
369
  def _generate_fallback_report(self, patient_info, visual_results, guideline_context):
370
- """Generate a fallback report when MedGemma is not available"""
371
- # Get image paths for report
372
- detection_path = visual_results.get('detection_image_path', 'Not available')
373
- segmentation_path = visual_results.get('segmentation_image_path', 'Not available')
374
-
375
- report = f"""
376
- # Wound Analysis Report
377
- ## Patient Information
378
- {patient_info}
379
-
380
- ## Visual Analysis Results
381
- - **Wound Type**: {visual_results.get('wound_type', 'Unknown')}
382
- - **Dimensions**: {visual_results.get('length_cm', 0)} cm × {visual_results.get('breadth_cm', 0)} cm
383
- - **Surface Area**: {visual_results.get('surface_area_cm2', 0)} cm²
384
- - **Detection Confidence**: {visual_results.get('detection_confidence', 0):.2f}
385
-
386
- ## Analysis Images
387
- - **Detection Image**: {detection_path}
388
- - **Segmentation Image**: {segmentation_path}
389
-
390
- ## Assessment
391
- Based on the visual analysis, this appears to be a {visual_results.get('wound_type', 'wound')} with measurable dimensions.
392
-
393
- ## Recommendations
394
- - Continue monitoring wound healing progress
395
- - Maintain proper wound hygiene
396
- - Follow appropriate dressing protocols
397
- - Seek medical attention if signs of infection develop
398
-
399
- ## Clinical Guidelines
400
- {guideline_context[:500]}...
401
-
402
- *Note: This is an automated analysis. Please consult with a healthcare professional for definitive diagnosis and treatment.*
403
- """
404
- return report
405
 
406
  def save_and_commit_image(self, image_pil):
407
- """Save image locally and optionally upload to HuggingFace dataset"""
408
- try:
409
- # Ensure uploads directory exists
410
- os.makedirs(self.config.UPLOADS_DIR, exist_ok=True)
411
-
412
- # Generate filename with timestamp
413
- filename = f"{datetime.now().strftime('%Y-%m-%d_%H-%M-%S')}.png"
414
- local_path = os.path.join(self.config.UPLOADS_DIR, filename)
415
-
416
- # Save image locally
417
- image_pil.convert("RGB").save(local_path)
418
- logging.info(f"Image saved locally: {local_path}")
419
-
420
- # Upload to HuggingFace dataset if configured
421
- if self.config.HF_TOKEN and self.config.DATASET_ID:
422
- try:
423
- api = HfApi()
424
- api.upload_file(
425
- path_or_fileobj=local_path,
426
- path_in_repo=f"images/{filename}",
427
- repo_id=self.config.DATASET_ID,
428
- repo_type="dataset",
429
- commit_message=f"Upload wound image: {filename}"
430
- )
431
- logging.info("✅ Image uploaded to HuggingFace dataset")
432
- except Exception as e:
433
- logging.warning(f"HuggingFace upload failed: {e}")
434
-
435
- return local_path
436
-
437
- except Exception as e:
438
- logging.error(f"Image saving error: {e}")
439
- return None
440
 
441
-
442
- def full_analysis_pipeline(self, image, questionnaire_data):
443
- """Complete analysis pipeline with real-time models"""
444
  try:
445
- # Save the image
446
- saved_path = self.save_and_commit_image(image)
447
-
448
- # Perform visual analysis
449
- visual_results = self.perform_visual_analysis(image)
450
-
451
- # Format patient information
452
- patient_info = ", ".join([f"{k}: {v}" for k, v in questionnaire_data.items() if v])
453
-
454
- # Create query for guidelines
455
- wound_type = visual_results.get('wound_type', 'wound')
456
- moisture = questionnaire_data.get('moisture', 'unknown')
457
- infection = questionnaire_data.get('infection', 'unknown')
458
- diabetic = questionnaire_data.get('diabetic', 'unknown')
459
-
460
- query = f"best practices for managing a {wound_type} with moisture level '{moisture}' and signs of infection '{infection}' in a patient who is diabetic '{diabetic}'"
461
-
462
- # Query guidelines
463
- guideline_context = self.query_guidelines(query)
464
-
465
- # Generate final report
466
- final_report = self.generate_final_report(patient_info, visual_results, guideline_context, image)
467
-
468
- return {
469
- 'success': True,
470
- 'visual_analysis': visual_results,
471
- 'report': final_report,
472
- 'saved_image_path': saved_path,
473
- 'timestamp': datetime.now().isoformat()
474
- }
475
-
476
  except Exception as e:
477
- logging.error(f"Full analysis pipeline error: {e}")
478
- return {
479
- 'success': False,
480
- 'error': str(e),
481
- 'timestamp': datetime.now().isoformat()
482
- }
483
 
484
  # Legacy methods for backward compatibility
485
  def analyze_wound(self, image, questionnaire_data):
 
1
  import os
2
+ import io
3
+ import base64
4
  import logging
5
  import cv2
6
  import numpy as np
7
  from PIL import Image
8
  import torch
 
9
  from datetime import datetime
 
10
  from transformers import pipeline
11
  from ultralytics import YOLO
12
  from tensorflow.keras.models import load_model
 
16
  from langchain_community.vectorstores import FAISS
17
  from huggingface_hub import HfApi, HfFolder
18
  import spaces
 
19
  from .config import Config
20
 
21
+ # Inline system prompt for MedGemma GPU pipeline
22
+ default_system_prompt = (
23
+ "You are a world-class medical AI assistant specializing in wound care "
24
+ "with expertise in wound assessment and treatment. Provide concise, "
25
+ "evidence-based medical assessments focusing on: (1) Precise wound "
26
+ "classification based on tissue type and appearance, (2) Specific "
27
+ "treatment recommendations with exact product names or interventions when "
28
+ "appropriate, (3) Objective evaluation of healing progression or deterioration "
29
+ "indicators, and (4) Clear follow-up timelines. Avoid general statements and "
30
+ "prioritize actionable insights based on the visual analysis measurements and "
31
+ "patient context."
32
+ )
33
+
34
  class AIProcessor:
35
  def __init__(self):
36
  self.models_cache = {}
 
40
  self._initialize_models()
41
 
42
  def _initialize_models(self):
43
+ """Initialize AI models; only MedGemma uses GPU."""
44
+ # Set HuggingFace token
45
+ if self.config.HF_TOKEN:
46
+ HfFolder.save_token(self.config.HF_TOKEN)
47
+ logging.info("HuggingFace token set successfully")
 
48
 
49
+ # MedGemma pipeline on GPU
50
+ try:
51
+ self.models_cache['medgemma_pipe'] = pipeline(
52
+ 'image-text-to-text',
53
+ model='google/medgemma-4b-it',
54
+ device='cuda',
55
+ torch_dtype=torch.bfloat16,
56
+ offload_folder='offload',
57
+ token=self.config.HF_TOKEN
58
+ )
59
+ logging.info("✅ MedGemma pipeline loaded on GPU")
60
+ except Exception as e:
61
+ logging.warning(f"MedGemma pipeline not available: {e}")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
62
 
63
+ # YOLO detection on CPU
64
+ try:
65
+ self.models_cache['det'] = YOLO(self.config.YOLO_MODEL_PATH)
66
+ logging.info("✅ YOLO detection model loaded on CPU")
67
+ except Exception as e:
68
+ logging.warning(f"YOLO model not available: {e}")
 
 
 
69
 
70
+ # Segmentation model on CPU
71
+ try:
72
+ self.models_cache['seg'] = load_model(self.config.SEG_MODEL_PATH, compile=False)
73
+ logging.info("✅ Segmentation model loaded on CPU")
74
+ except Exception as e:
75
+ logging.warning(f"Segmentation model not available: {e}")
76
 
77
+ # Classification on CPU
78
+ try:
79
+ self.models_cache['cls'] = pipeline(
80
+ 'image-classification',
81
+ model='Hemg/Wound-classification',
82
+ token=self.config.HF_TOKEN,
83
+ device='cpu'
84
+ )
85
+ logging.info("✅ Wound classification model loaded on CPU")
86
  except Exception as e:
87
+ logging.warning(f"Wound classification model not available: {e}")
88
 
89
+ # Embedding for knowledge base
 
90
  try:
91
+ self.models_cache['embedding_model'] = HuggingFaceEmbeddings(
92
+ model_name='sentence-transformers/all-MiniLM-L6-v2',
93
+ model_kwargs={'device': 'cpu'}
94
+ )
95
+ logging.info("✅ Embedding model loaded on CPU")
96
+ except Exception as e:
97
+ logging.warning(f"Embedding model not available: {e}")
98
 
99
+ # Load knowledge base
100
+ self._load_knowledge_base()
 
 
 
 
 
 
 
 
 
 
 
 
 
101
 
102
+ def _load_knowledge_base(self):
103
+ """Load PDF guidelines into a FAISS vector store."""
104
+ docs = []
105
+ for pdf in self.config.GUIDELINE_PDFS:
106
+ if os.path.exists(pdf):
107
+ loader = PyPDFLoader(pdf)
108
+ docs.extend(loader.load())
109
+ logging.info(f"Loaded PDF: {pdf}")
110
+
111
+ if docs and 'embedding_model' in self.models_cache:
112
+ splitter = RecursiveCharacterTextSplitter(chunk_size=1000, chunk_overlap=100)
113
+ chunks = splitter.split_documents(docs)
114
+ vs = FAISS.from_documents(chunks, self.models_cache['embedding_model'])
115
+ self.knowledge_base_cache['vectorstore'] = vs
116
+ logging.info(f"✅ Knowledge base loaded ({len(chunks)} chunks)")
117
+ else:
118
  self.knowledge_base_cache['vectorstore'] = None
119
+ logging.warning("Knowledge base unavailable")
120
 
 
121
  def perform_visual_analysis(self, image_pil):
122
+ """Detect & segment on CPU; return only paths + metrics."""
123
  try:
124
+ img_cv = cv2.cvtColor(np.array(image_pil), cv2.COLOR_RGB2BGR)
125
+ # YOLO detect
126
+ res = self.models_cache['det'].predict(img_cv, verbose=False)[0]
127
+ if not res.boxes:
128
+ raise ValueError("No wound detected")
129
+ # Bounding box
130
+ x1, y1, x2, y2 = res.boxes.xyxy[0].cpu().numpy().astype(int)
131
+ region = img_cv[y1:y2, x1:x2]
132
+ # Save detection overlay
133
+ det_vis = img_cv.copy()
134
+ cv2.rectangle(det_vis, (x1, y1), (x2, y2), (0,255,0), 2)
135
+ os.makedirs(f"{self.config.UPLOADS_DIR}/analysis", exist_ok=True)
136
+ ts = datetime.now().strftime('%Y%m%d_%H%M%S')
137
+ det_path = f"{self.config.UPLOADS_DIR}/analysis/detection_{ts}.png"
138
+ cv2.imwrite(det_path, det_vis)
139
+ # Initialize metrics & seg
 
 
 
 
 
 
 
 
 
 
 
140
  length = breadth = area = 0
141
+ seg_path = None
142
+ # Segmentation
 
 
143
  if 'seg' in self.models_cache:
144
+ h, w = self.models_cache['seg'].input_shape[1:3]
145
+ inp = cv2.resize(region, (w,h)) / 255.0
146
+ mask = (self.models_cache['seg'].predict(np.expand_dims(inp,0))[0,:,:,0] > 0.5).astype(np.uint8)
147
+ mask_rs = cv2.resize(mask, (region.shape[1], region.shape[0]), interpolation=cv2.INTER_NEAREST)
148
+ ov = region.copy()
149
+ ov[mask_rs==1] = [0,0,255]
150
+ seg_vis = cv2.addWeighted(region,0.7,ov,0.3,0)
151
+ seg_path = f"{self.config.UPLOADS_DIR}/analysis/segmentation_{ts}.png"
152
+ cv2.imwrite(seg_path, seg_vis)
153
+ # measure
154
+ cnts, _ = cv2.findContours(mask_rs, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE)
155
+ if cnts:
156
+ cnt = max(cnts, key=cv2.contourArea)
157
+ _,_,w0,h0 = cv2.boundingRect(cnt)
158
+ length = round(h0/self.px_per_cm,2)
159
+ breadth= round(w0/self.px_per_cm,2)
160
+ area = round(cv2.contourArea(cnt)/(self.px_per_cm**2),2)
161
+ # Classification
162
+ wound_type = 'Unknown'
 
 
 
 
 
 
 
 
 
 
 
 
 
 
163
  if 'cls' in self.models_cache:
164
  try:
165
+ label = self.models_cache['cls'](Image.fromarray(cv2.cvtColor(region, cv2.COLOR_BGR2RGB)))
166
+ wound_type = max(label, key=lambda x: x['score'])['label']
167
+ except Exception:
168
+ pass
 
 
169
  return {
170
  'wound_type': wound_type,
171
  'length_cm': length,
172
  'breadth_cm': breadth,
173
  'surface_area_cm2': area,
174
+ 'detection_confidence': float(res.boxes.conf[0].cpu().item()),
175
+ 'detection_image_path': det_path,
176
+ 'segmentation_image_path': seg_path
 
 
 
177
  }
 
178
  except Exception as e:
179
  logging.error(f"Visual analysis error: {e}")
180
+ raise
 
181
 
182
  def query_guidelines(self, query: str):
183
+ """Retrieve clinical guidelines from vectorstore."""
184
+ vs = self.knowledge_base_cache.get('vectorstore')
185
+ if not vs:
186
+ return "Clinical guidelines unavailable"
187
+ docs = vs.as_retriever(search_kwargs={'k':10}).invoke(query)
188
+ return '\n\n'.join(f"Source: {d.metadata.get('source','?')}, Page: {d.metadata.get('page','?')}\n{d.page_content}" for d in docs)
189
+
190
+ @spaces.GPU(enable_queue=True, duration=120)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
191
  def generate_final_report(self, patient_info, visual_results, guideline_context, image_pil, max_new_tokens=None):
192
+ """Run MedGemma on GPU; return markdown report."""
193
+ if 'medgemma_pipe' not in self.models_cache:
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
194
  return self._generate_fallback_report(patient_info, visual_results, guideline_context)
195
+ # build messages
196
+ msgs = [{ 'role':'system', 'content':[{'type':'text','text': default_system_prompt}] },
197
+ { 'role':'user', 'content':[]}]
198
+ # images
199
+ if image_pil: msgs[1]['content'].append({'type':'image','image':image_pil})
200
+ for key in ('detection_image_path','segmentation_image_path'):
201
+ p = visual_results.get(key)
202
+ if p and os.path.exists(p):
203
+ msgs[1]['content'].append({'type':'image', 'image': Image.open(p)})
204
+ # text prompt stub (expand as needed)
205
+ prompt = f"## Patient\n{patient_info}\n## Visual Type: {visual_results['wound_type']}"
206
+ msgs[1]['content'].append({'type':'text','text':prompt})
207
+ out = self.models_cache['medgemma_pipe'](text=msgs, max_new_tokens=max_new_tokens or self.config.MAX_NEW_TOKENS)
208
+ report = out[0]['generated_text'][-1].get('content','')
209
+ return report or self._generate_fallback_report(patient_info, visual_results, guideline_context)
210
 
211
  def _generate_fallback_report(self, patient_info, visual_results, guideline_context):
212
+ """Produce text-only fallback."""
213
+ dp = visual_results.get('detection_image_path','N/A')
214
+ sp = visual_results.get('segmentation_image_path','N/A')
215
+ return f"# Report\n{patient_info}\nType: {visual_results['wound_type']}\nDetection Image: {dp}\nSegmentation Image: {sp}\nGuidelines: {guideline_context[:200]}..."
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
216
 
217
  def save_and_commit_image(self, image_pil):
218
+ """Save locally and optionally to HuggingFace."""
219
+ os.makedirs(self.config.UPLOADS_DIR, exist_ok=True)
220
+ fn = f"{datetime.now():%Y%m%d_%H%M%S}.png"
221
+ path = os.path.join(self.config.UPLOADS_DIR, fn)
222
+ image_pil.convert('RGB').save(path)
223
+ if self.config.HF_TOKEN and self.config.DATASET_ID:
224
+ try:
225
+ api = HfApi()
226
+ api.upload_file(path_or_fileobj=path, path_in_repo=f"images/{fn}", repo_id=self.config.DATASET_ID, repo_type='dataset')
227
+ except Exception as e:
228
+ logging.warning(f"HF upload failed: {e}")
229
+ return path
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
230
 
231
+ def full_analysis_pipeline(self, image_pil, questionnaire_data):
232
+ """Orchestrate CPU steps + GPU report."""
 
233
  try:
234
+ saved = self.save_and_commit_image(image_pil)
235
+ vis = self.perform_visual_analysis(image_pil)
236
+ info = ", ".join(f"{k}:{v}" for k,v in questionnaire_data.items() if v)
237
+ gc = self.query_guidelines(info)
238
+ report = self.generate_final_report(info, vis, gc, image_pil)
239
+ return {'success':True, 'visual_analysis':vis, 'report':report, 'saved_image_path':saved}
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
240
  except Exception as e:
241
+ logging.error(f"Pipeline error: {e}")
242
+ return {'success':False, 'error':str(e)}
243
+
 
 
 
244
 
245
  # Legacy methods for backward compatibility
246
  def analyze_wound(self, image, questionnaire_data):