SmartHeal commited on
Commit
bf0aa04
·
verified ·
1 Parent(s): ff7f2b3

Update src/ai_processor.py

Browse files
Files changed (1) hide show
  1. src/ai_processor.py +232 -183
src/ai_processor.py CHANGED
@@ -1,62 +1,171 @@
1
  import os
2
- import io
3
- import base64
4
  import logging
5
- import numpy as np
6
  import cv2
 
7
  from PIL import Image
8
  from datetime import datetime
 
 
 
 
9
  from langchain_community.document_loaders import PyPDFLoader
10
  from langchain.text_splitter import RecursiveCharacterTextSplitter
11
  from langchain_community.embeddings import HuggingFaceEmbeddings
12
  from langchain_community.vectorstores import FAISS
13
- from huggingface_hub import HfApi, HfFolder
14
- import spaces
15
- from .config import Config
16
-
17
- # Inline system prompt for MedGemma GPU pipeline
18
- default_system_prompt = (
19
- "You are a world-class medical AI assistant specializing in wound care "
20
- "with expertise in wound assessment and treatment. Provide concise, "
21
- "evidence-based medical assessments focusing on: (1) Precise wound "
22
- "classification based on tissue type and appearance, (2) Specific "
23
- "treatment recommendations with exact product names or interventions when "
24
- "appropriate, (3) Objective evaluation of healing progression or deterioration "
25
- "indicators, and (4) Clear follow-up timelines. Avoid general statements and "
26
- "prioritize actionable insights based on the visual analysis measurements and "
27
- "patient context."
28
- )
29
-
30
- # No torch or transformers-related imports at top-level!
31
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
32
  @spaces.GPU(enable_queue=True, duration=120)
33
  def generate_medgemma_report(
34
- patient_info: str,
35
- visual_results: dict,
36
- guideline_context: str,
37
- detection_image_path: str,
38
- segmentation_image_path: str,
39
- max_new_tokens: int = None
40
- ) -> str:
41
- # --- All GPU-related imports and model loading here! ---
 
42
  import torch
43
  from transformers import pipeline
44
  from PIL import Image
45
 
46
- # System prompt as before
47
- global default_system_prompt
 
 
 
 
 
 
 
 
 
48
 
49
  # Lazy-load MedGemma pipeline on GPU
50
  if not hasattr(generate_medgemma_report, "_pipe"):
51
  try:
52
- cfg = Config()
53
  generate_medgemma_report._pipe = pipeline(
54
- 'image-text-to-text',
55
- model='google/medgemma-4b-it',
56
- device='cuda', # Explicitly on GPU
57
- torch_dtype='auto',
58
- offload_folder='offload',
59
- token=cfg.HF_TOKEN
60
  )
61
  logging.info("✅ MedGemma pipeline loaded on GPU")
62
  except Exception as e:
@@ -67,202 +176,132 @@ def generate_medgemma_report(
67
 
68
  # Compose messages
69
  msgs = [
70
- {'role': 'system', 'content': [{'type': 'text', 'text': default_system_prompt}]},
71
- {'role': 'user', 'content': []},
72
  ]
73
 
74
  # Attach images if available
75
  for path in (detection_image_path, segmentation_image_path):
76
  if path and os.path.exists(path):
77
- msgs[1]['content'].append({'type': 'image', 'image': Image.open(path)})
78
 
79
- # Attach text
80
  prompt = f"## Patient\n{patient_info}\n## Wound Type: {visual_results.get('wound_type','Unknown')}"
81
- msgs[1]['content'].append({'type': 'text', 'text': prompt})
82
-
83
- out = pipe(
84
- text=msgs,
85
- max_new_tokens=max_new_tokens or Config().MAX_NEW_TOKENS,
86
- do_sample=False
87
- )
88
- return out[0]['generated_text'][-1].get('content', '')
89
 
 
 
 
 
 
 
90
 
 
91
  class AIProcessor:
92
  def __init__(self):
93
- self.models_cache = {}
94
- self.knowledge_base_cache = {}
95
- self.config = Config()
96
- self.px_per_cm = self.config.PIXELS_PER_CM
97
- self._initialize_models()
98
- self._load_knowledge_base()
99
-
100
- def _initialize_models(self):
101
- """Load all CPU-only models here."""
102
- # Set HuggingFace token
103
- if self.config.HF_TOKEN:
104
- HfFolder.save_token(self.config.HF_TOKEN)
105
- logging.info("✅ HuggingFace token set")
106
-
107
- # YOLO detection (CPU-only)
108
- try:
109
- from ultralytics import YOLO
110
- self.models_cache['det'] = YOLO(self.config.YOLO_MODEL_PATH)
111
- logging.info("✅ YOLO model loaded (CPU only)")
112
- except Exception as e:
113
- logging.error(f"YOLO load failed: {e}")
114
- raise
115
-
116
- # Segmentation model (CPU)
117
- try:
118
- from tensorflow.keras.models import load_model
119
- self.models_cache['seg'] = load_model(self.config.SEG_MODEL_PATH, compile=False)
120
- logging.info("✅ Segmentation model loaded (CPU)")
121
- except Exception as e:
122
- logging.warning(f"Segmentation model not available: {e}")
123
-
124
- # Classification pipeline (CPU)
125
- try:
126
- from transformers import pipeline
127
- self.models_cache['cls'] = pipeline(
128
- 'image-classification',
129
- model='Hemg/Wound-classification',
130
- token=self.config.HF_TOKEN,
131
- device='cpu'
132
- )
133
- logging.info("✅ Classification pipeline loaded (CPU)")
134
- except Exception as e:
135
- logging.warning(f"Classification pipeline not available: {e}")
136
-
137
- # Embedding model (CPU)
138
- try:
139
- self.models_cache['embedding_model'] = HuggingFaceEmbeddings(
140
- model_name='sentence-transformers/all-MiniLM-L6-v2',
141
- model_kwargs={'device': 'cpu'}
142
- )
143
- logging.info("✅ Embedding model loaded (CPU)")
144
- except Exception as e:
145
- logging.warning(f"Embedding model not available: {e}")
146
-
147
- def _load_knowledge_base(self):
148
- """Load PDF guidelines into a FAISS vector store."""
149
- docs = []
150
- for pdf in self.config.GUIDELINE_PDFS:
151
- if os.path.exists(pdf):
152
- loader = PyPDFLoader(pdf)
153
- docs.extend(loader.load())
154
- logging.info(f"Loaded PDF: {pdf}")
155
-
156
- if docs and 'embedding_model' in self.models_cache:
157
- splitter = RecursiveCharacterTextSplitter(chunk_size=1000, chunk_overlap=100)
158
- chunks = splitter.split_documents(docs)
159
- self.knowledge_base_cache['vectorstore'] = FAISS.from_documents(
160
- chunks, self.models_cache['embedding_model']
161
- )
162
- logging.info(f"✅ Knowledge base loaded ({len(chunks)} chunks)")
163
- else:
164
- self.knowledge_base_cache['vectorstore'] = None
165
- logging.warning("Knowledge base unavailable")
166
 
167
  def perform_visual_analysis(self, image_pil: Image.Image) -> dict:
168
  """Detect & segment on CPU; return metrics + file paths."""
169
- if 'det' not in self.models_cache:
 
 
170
  raise RuntimeError("YOLO model ('det') not loaded")
171
 
172
- img_cv = cv2.cvtColor(np.array(image_pil), cv2.COLOR_RGB2BGR)
173
- res = self.models_cache['det'].predict(img_cv, verbose=False)[0]
174
  if not res.boxes:
175
  raise ValueError("No wound detected")
176
 
177
- x1, y1, x2, y2 = res.boxes.xyxy[0].cpu().numpy().astype(int)
178
  region = img_cv[y1:y2, x1:x2]
179
 
180
  # Save detection overlay
181
  det_vis = img_cv.copy()
182
- cv2.rectangle(det_vis, (x1, y1), (x2, y2), (0,255,0), 2)
183
- os.makedirs(f"{self.config.UPLOADS_DIR}/analysis", exist_ok=True)
184
- ts = datetime.now().strftime('%Y%m%d_%H%M%S')
185
- det_path = f"{self.config.UPLOADS_DIR}/analysis/detection_{ts}.png"
186
  cv2.imwrite(det_path, det_vis)
187
 
188
  # Segmentation
189
  length = breadth = area = 0
190
  seg_path = None
191
- if 'seg' in self.models_cache:
192
- h, w = self.models_cache['seg'].input_shape[1:3]
 
193
  inp = cv2.resize(region, (w, h)) / 255.0
194
- mask = (self.models_cache['seg'].predict(inp[None])[0,:,:,0] > 0.5).astype(np.uint8)
195
- mask_rs = cv2.resize(mask, (region.shape[1], region.shape[0]), interpolation=cv2.INTER_NEAREST)
196
- ov = region.copy(); ov[mask_rs==1] = [0,0,255]
 
 
197
  seg_vis = cv2.addWeighted(region, 0.7, ov, 0.3, 0)
198
- seg_path = f"{self.config.UPLOADS_DIR}/analysis/segmentation_{ts}.png"
199
  cv2.imwrite(seg_path, seg_vis)
 
200
  cnts, _ = cv2.findContours(mask_rs, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE)
201
  if cnts:
202
  cnt = max(cnts, key=cv2.contourArea)
203
  _, _, w0, h0 = cv2.boundingRect(cnt)
204
  length = round(h0 / self.px_per_cm, 2)
205
  breadth = round(w0 / self.px_per_cm, 2)
206
- area = round(cv2.contourArea(cnt) / (self.px_per_cm**2), 2)
207
 
208
  # Classification
209
- wound_type = 'Unknown'
210
- if 'cls' in self.models_cache:
 
211
  try:
212
- preds = self.models_cache['cls'](Image.fromarray(cv2.cvtColor(region, cv2.COLOR_BGR2RGB)))
213
- wound_type = max(preds, key=lambda x: x['score'])['label']
214
  except Exception:
215
  pass
216
 
217
  return {
218
- 'wound_type': wound_type,
219
- 'length_cm': length,
220
- 'breadth_cm': breadth,
221
- 'surface_area_cm2': area,
222
- 'detection_confidence': float(res.boxes.conf[0].cpu().item()),
223
- 'detection_image_path': det_path,
224
- 'segmentation_image_path': seg_path
225
  }
226
 
227
  def query_guidelines(self, query: str) -> str:
228
- vs = self.knowledge_base_cache.get('vectorstore')
 
229
  if not vs:
230
  return "Clinical guidelines unavailable"
231
- docs = vs.as_retriever(search_kwargs={'k':10}).invoke(query)
232
- return '\n\n'.join(
233
- f"Source: {d.metadata.get('source','?')}, Page: {d.metadata.get('page','?')}\n{d.page_content}"
234
- for d in docs
235
  )
236
 
237
  def generate_final_report(
238
- self,
239
- patient_info: str,
240
- visual_results: dict,
241
- guideline_context: str,
242
- image_pil: Image.Image,
243
- max_new_tokens: int = None
244
  ) -> str:
245
- det = visual_results.get('detection_image_path', '')
246
- seg = visual_results.get('segmentation_image_path', '')
247
- # This GPU call is safe: it triggers all CUDA/model code *inside* the decorator context.
248
- report = generate_medgemma_report(
249
- patient_info, visual_results, guideline_context,
250
- det, seg, max_new_tokens
251
- )
252
  if report:
253
  return report
254
  return self._generate_fallback_report(patient_info, visual_results, guideline_context)
255
 
256
  def _generate_fallback_report(
257
- self,
258
- patient_info: str,
259
- visual_results: dict,
260
- guideline_context: str
261
  ) -> str:
 
262
  dp = visual_results.get('detection_image_path','N/A')
263
  sp = visual_results.get('segmentation_image_path','N/A')
264
  return (
265
- f"# Report\n{patient_info}\n"
266
  f"Type: {visual_results.get('wound_type','Unknown')}\n"
267
  f"Detection Image: {dp}\n"
268
  f"Segmentation Image: {sp}\n"
@@ -270,40 +309,46 @@ class AIProcessor:
270
  )
271
 
272
  def save_and_commit_image(self, image_pil: Image.Image) -> str:
273
- os.makedirs(self.config.UPLOADS_DIR, exist_ok=True)
 
274
  fn = f"{datetime.now():%Y%m%d_%H%M%S}.png"
275
- path = os.path.join(self.config.UPLOADS_DIR, fn)
276
- image_pil.convert('RGB').save(path)
277
- if self.config.HF_TOKEN and getattr(self.config, 'DATASET_ID', None):
 
278
  try:
279
  HfApi().upload_file(
280
  path_or_fileobj=path,
281
  path_in_repo=f"images/{fn}",
282
- repo_id=self.config.DATASET_ID,
283
- repo_type='dataset'
284
  )
 
285
  except Exception as e:
286
  logging.warning(f"HF upload failed: {e}")
287
  return path
288
 
289
  def full_analysis_pipeline(self, image_pil: Image.Image, questionnaire_data: dict) -> dict:
 
290
  try:
291
  saved = self.save_and_commit_image(image_pil)
292
- vis = self.perform_visual_analysis(image_pil)
293
- info = ", ".join(f"{k}:{v}" for k,v in questionnaire_data.items() if v)
294
- gc = self.query_guidelines(info)
295
- report= self.generate_final_report(info, vis, gc, image_pil)
296
  return {'success': True, 'visual_analysis': vis, 'report': report, 'saved_image_path': saved}
297
  except Exception as e:
298
  logging.error(f"Pipeline error: {e}")
299
  return {'success': False, 'error': str(e)}
300
 
301
  def analyze_wound(self, image, questionnaire_data: dict) -> dict:
 
302
  if isinstance(image, str):
303
  image = Image.open(image)
304
  return self.full_analysis_pipeline(image, questionnaire_data)
305
 
306
  def _assess_risk_legacy(self, questionnaire_data: dict) -> dict:
 
307
  risk_factors, risk_score = [], 0
308
  try:
309
  age = questionnaire_data.get('patient_age', 0)
@@ -311,12 +356,15 @@ class AIProcessor:
311
  risk_factors.append("Advanced age (>65)"); risk_score += 2
312
  elif age > 50:
313
  risk_factors.append("Older adult (50-65)"); risk_score += 1
 
314
  dur = questionnaire_data.get('wound_duration', '').lower()
315
  if any(t in dur for t in ['month','year']):
316
  risk_factors.append("Chronic wound (>4 weeks)"); risk_score += 3
 
317
  pain = questionnaire_data.get('pain_level', 0)
318
  if pain >= 7:
319
  risk_factors.append("High pain level"); risk_score += 2
 
320
  hist = questionnaire_data.get('medical_history','').lower()
321
  if 'diabetes' in hist:
322
  risk_factors.append("Diabetes mellitus"); risk_score += 3
@@ -324,8 +372,9 @@ class AIProcessor:
324
  risk_factors.append("Vascular issues"); risk_score += 2
325
  if 'immune' in hist:
326
  risk_factors.append("Immune compromise"); risk_score += 2
 
327
  level = ("High" if risk_score >= 7 else "Moderate" if risk_score >= 4 else "Low")
328
  return {'risk_score': risk_score, 'risk_level': level, 'risk_factors': risk_factors}
329
  except Exception as e:
330
  logging.error(f"Risk assessment error: {e}")
331
- return {'risk_score': 0, 'risk_level': 'Unknown', 'risk_factors': []}
 
1
  import os
 
 
2
  import logging
 
3
  import cv2
4
+ import numpy as np
5
  from PIL import Image
6
  from datetime import datetime
7
+ import gradio as gr
8
+ import spaces
9
+
10
+ from huggingface_hub import HfApi, HfFolder
11
  from langchain_community.document_loaders import PyPDFLoader
12
  from langchain.text_splitter import RecursiveCharacterTextSplitter
13
  from langchain_community.embeddings import HuggingFaceEmbeddings
14
  from langchain_community.vectorstores import FAISS
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
15
 
16
+ # =============== LOGGING SETUP ===============
17
+ logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s')
18
+
19
+ # =============== CONFIGURATION ===============
20
+ UPLOADS_DIR = "uploads"
21
+ if not os.path.exists(UPLOADS_DIR):
22
+ os.makedirs(UPLOADS_DIR)
23
+ logging.info(f"Created uploads directory: {UPLOADS_DIR}")
24
+
25
+ HF_TOKEN = os.getenv("HF_TOKEN")
26
+ YOLO_MODEL_PATH = "best.pt"
27
+ SEG_MODEL_PATH = "segmentation_model.h5"
28
+ GUIDELINE_PDFS = ["eHealth in Wound Care.pdf", "IWGDF Guideline.pdf", "evaluation.pdf"]
29
+ DATASET_ID = "SmartHeal/wound-image-uploads"
30
+ MAX_NEW_TOKENS = 2048
31
+ PIXELS_PER_CM = 38
32
+
33
+ # =============== GLOBAL CACHES ===============
34
+ models_cache = {}
35
+ knowledge_base_cache = {}
36
+
37
+ # =============== LAZY LOADING FUNCTIONS (CPU-SAFE) ===============
38
+ def load_yolo_model(yolo_model_path):
39
+ """Lazy import and load YOLO model to avoid CUDA initialization."""
40
+ from ultralytics import YOLO
41
+ return YOLO(yolo_model_path)
42
+
43
+ def load_segmentation_model(seg_model_path):
44
+ """Lazy import and load segmentation model."""
45
+ from tensorflow.keras.models import load_model
46
+ return load_model(seg_model_path, compile=False)
47
+
48
+ def load_classification_pipeline(hf_token):
49
+ """Lazy import and load classification pipeline (CPU only)."""
50
+ from transformers import pipeline
51
+ return pipeline(
52
+ "image-classification",
53
+ model="Hemg/Wound-classification",
54
+ token=hf_token,
55
+ device="cpu"
56
+ )
57
+
58
+ def load_embedding_model():
59
+ """Load embedding model for knowledge base."""
60
+ return HuggingFaceEmbeddings(
61
+ model_name="sentence-transformers/all-MiniLM-L6-v2",
62
+ model_kwargs={"device": "cpu"}
63
+ )
64
+
65
+ # =============== MODEL INITIALIZATION ===============
66
+ def initialize_cpu_models():
67
+ """Initialize all CPU-only models once."""
68
+ global models_cache
69
+
70
+ if HF_TOKEN:
71
+ HfFolder.save_token(HF_TOKEN)
72
+ logging.info("✅ HuggingFace token set")
73
+
74
+ if "det" not in models_cache:
75
+ try:
76
+ models_cache["det"] = load_yolo_model(YOLO_MODEL_PATH)
77
+ logging.info("✅ YOLO model loaded (CPU only)")
78
+ except Exception as e:
79
+ logging.error(f"YOLO load failed: {e}")
80
+
81
+ if "seg" not in models_cache:
82
+ try:
83
+ models_cache["seg"] = load_segmentation_model(SEG_MODEL_PATH)
84
+ logging.info("✅ Segmentation model loaded (CPU)")
85
+ except Exception as e:
86
+ logging.warning(f"Segmentation model not available: {e}")
87
+
88
+ if "cls" not in models_cache:
89
+ try:
90
+ models_cache["cls"] = load_classification_pipeline(HF_TOKEN)
91
+ logging.info("✅ Classification pipeline loaded (CPU)")
92
+ except Exception as e:
93
+ logging.warning(f"Classification pipeline not available: {e}")
94
+
95
+ if "embedding_model" not in models_cache:
96
+ try:
97
+ models_cache["embedding_model"] = load_embedding_model()
98
+ logging.info("✅ Embedding model loaded (CPU)")
99
+ except Exception as e:
100
+ logging.warning(f"Embedding model not available: {e}")
101
+
102
+ def setup_knowledge_base():
103
+ """Load PDF documents and create FAISS vector store."""
104
+ global knowledge_base_cache
105
+ if "vector_store" in knowledge_base_cache:
106
+ return
107
+
108
+ docs = []
109
+ for pdf_path in GUIDELINE_PDFS:
110
+ if os.path.exists(pdf_path):
111
+ try:
112
+ loader = PyPDFLoader(pdf_path)
113
+ docs.extend(loader.load())
114
+ logging.info(f"Loaded PDF: {pdf_path}")
115
+ except Exception as e:
116
+ logging.warning(f"Failed to load PDF {pdf_path}: {e}")
117
+
118
+ if docs and "embedding_model" in models_cache:
119
+ splitter = RecursiveCharacterTextSplitter(chunk_size=1000, chunk_overlap=100)
120
+ chunks = splitter.split_documents(docs)
121
+ knowledge_base_cache["vector_store"] = FAISS.from_documents(chunks, models_cache["embedding_model"])
122
+ logging.info(f"✅ Knowledge base ready with {len(chunks)} chunks")
123
+ else:
124
+ knowledge_base_cache["vector_store"] = None
125
+ logging.warning("Knowledge base unavailable")
126
+
127
+ # Initialize models on app startup
128
+ initialize_cpu_models()
129
+ setup_knowledge_base()
130
+
131
+ # =============== GPU-DECORATED MEDGEMMA FUNCTION ===============
132
  @spaces.GPU(enable_queue=True, duration=120)
133
  def generate_medgemma_report(
134
+ patient_info,
135
+ visual_results,
136
+ guideline_context,
137
+ detection_image_path,
138
+ segmentation_image_path,
139
+ max_new_tokens=None,
140
+ ):
141
+ """GPU-only function for MedGemma report generation."""
142
+ # Import GPU libraries ONLY here
143
  import torch
144
  from transformers import pipeline
145
  from PIL import Image
146
 
147
+ default_system_prompt = (
148
+ "You are a world-class medical AI assistant specializing in wound care "
149
+ "with expertise in wound assessment and treatment. Provide concise, "
150
+ "evidence-based medical assessments focusing on: (1) Precise wound "
151
+ "classification based on tissue type and appearance, (2) Specific "
152
+ "treatment recommendations with exact product names or interventions when "
153
+ "appropriate, (3) Objective evaluation of healing progression or deterioration "
154
+ "indicators, and (4) Clear follow-up timelines. Avoid general statements and "
155
+ "prioritize actionable insights based on the visual analysis measurements and "
156
+ "patient context."
157
+ )
158
 
159
  # Lazy-load MedGemma pipeline on GPU
160
  if not hasattr(generate_medgemma_report, "_pipe"):
161
  try:
 
162
  generate_medgemma_report._pipe = pipeline(
163
+ "image-text-to-text",
164
+ model="google/medgemma-4b-it",
165
+ device="cuda",
166
+ torch_dtype=torch.bfloat16,
167
+ offload_folder="offload",
168
+ token=HF_TOKEN,
169
  )
170
  logging.info("✅ MedGemma pipeline loaded on GPU")
171
  except Exception as e:
 
176
 
177
  # Compose messages
178
  msgs = [
179
+ {"role": "system", "content": [{"type": "text", "text": default_system_prompt}]},
180
+ {"role": "user", "content": []},
181
  ]
182
 
183
  # Attach images if available
184
  for path in (detection_image_path, segmentation_image_path):
185
  if path and os.path.exists(path):
186
+ msgs[1]["content"].append({"type": "image", "image": Image.open(path)})
187
 
188
+ # Attach text prompt
189
  prompt = f"## Patient\n{patient_info}\n## Wound Type: {visual_results.get('wound_type','Unknown')}"
190
+ msgs[1]["content"].append({"type": "text", "text": prompt})
 
 
 
 
 
 
 
191
 
192
+ try:
193
+ out = pipe(text=msgs, max_new_tokens=max_new_tokens or MAX_NEW_TOKENS, do_sample=False)
194
+ return out[0]["generated_text"][-1].get("content", "")
195
+ except Exception as e:
196
+ logging.error(f"Failed to generate MedGemma report: {e}")
197
+ return f"❌ An error occurred: {e}"
198
 
199
+ # =============== AI PROCESSOR CLASS ===============
200
  class AIProcessor:
201
  def __init__(self):
202
+ self.models_cache = models_cache
203
+ self.knowledge_base_cache = knowledge_base_cache
204
+ self.px_per_cm = PIXELS_PER_CM
205
+ self.uploads_dir = UPLOADS_DIR
206
+ self.dataset_id = DATASET_ID
207
+ self.hf_token = HF_TOKEN
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
208
 
209
  def perform_visual_analysis(self, image_pil: Image.Image) -> dict:
210
  """Detect & segment on CPU; return metrics + file paths."""
211
+ img_cv = cv2.cvtColor(np.array(image_pil), cv2.COLOR_RGB2BGR)
212
+ yolo = self.models_cache.get("det")
213
+ if yolo is None:
214
  raise RuntimeError("YOLO model ('det') not loaded")
215
 
216
+ res = yolo.predict(img_cv, verbose=False, device="cpu")[0]
 
217
  if not res.boxes:
218
  raise ValueError("No wound detected")
219
 
220
+ x1, y1, x2, y2 = res.boxes.xyxy.cpu().numpy().astype(int)
221
  region = img_cv[y1:y2, x1:x2]
222
 
223
  # Save detection overlay
224
  det_vis = img_cv.copy()
225
+ cv2.rectangle(det_vis, (x1, y1), (x2, y2), (0, 255, 0), 2)
226
+ os.makedirs(f"{self.uploads_dir}/analysis", exist_ok=True)
227
+ ts = datetime.now().strftime("%Y%m%d_%H%M%S")
228
+ det_path = f"{self.uploads_dir}/analysis/detection_{ts}.png"
229
  cv2.imwrite(det_path, det_vis)
230
 
231
  # Segmentation
232
  length = breadth = area = 0
233
  seg_path = None
234
+ seg_model = self.models_cache.get("seg")
235
+ if seg_model:
236
+ h, w = seg_model.input_shape[1:3]
237
  inp = cv2.resize(region, (w, h)) / 255.0
238
+ mask = (seg_model.predict(inp[None])[0, :, :, 0] > 0.5).astype(np.uint8)
239
+ mask_rs = cv2.resize(mask, (region.shape[1], region.shape), interpolation=cv2.INTER_NEAREST)
240
+
241
+ ov = region.copy()
242
+ ov[mask_rs == 1] = [0, 0, 255]
243
  seg_vis = cv2.addWeighted(region, 0.7, ov, 0.3, 0)
244
+ seg_path = f"{self.uploads_dir}/analysis/segmentation_{ts}.png"
245
  cv2.imwrite(seg_path, seg_vis)
246
+
247
  cnts, _ = cv2.findContours(mask_rs, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE)
248
  if cnts:
249
  cnt = max(cnts, key=cv2.contourArea)
250
  _, _, w0, h0 = cv2.boundingRect(cnt)
251
  length = round(h0 / self.px_per_cm, 2)
252
  breadth = round(w0 / self.px_per_cm, 2)
253
+ area = round(cv2.contourArea(cnt) / (self.px_per_cm ** 2), 2)
254
 
255
  # Classification
256
+ wound_type = "Unknown"
257
+ cls_pipe = self.models_cache.get("cls")
258
+ if cls_pipe:
259
  try:
260
+ preds = cls_pipe(Image.fromarray(cv2.cvtColor(region, cv2.COLOR_BGR2RGB)))
261
+ wound_type = max(preds, key=lambda x: x["score"])["label"]
262
  except Exception:
263
  pass
264
 
265
  return {
266
+ "wound_type": wound_type,
267
+ "length_cm": length,
268
+ "breadth_cm": breadth,
269
+ "surface_area_cm2": area,
270
+ "detection_confidence": float(res.boxes.conf[0].cpu().item()),
271
+ "detection_image_path": det_path,
272
+ "segmentation_image_path": seg_path,
273
  }
274
 
275
  def query_guidelines(self, query: str) -> str:
276
+ """Query the knowledge base for relevant information."""
277
+ vs = self.knowledge_base_cache.get("vector_store")
278
  if not vs:
279
  return "Clinical guidelines unavailable"
280
+ docs = vs.as_retriever(search_kwargs={"k": 10}).invoke(query)
281
+ return "\n\n".join(
282
+ f"Source: {d.metadata.get('source','?')}, Page: {d.metadata.get('page','?')}\n{d.page_content}" for d in docs
 
283
  )
284
 
285
  def generate_final_report(
286
+ self, patient_info: str, visual_results: dict, guideline_context: str, image_pil: Image.Image, max_new_tokens: int = None
 
 
 
 
 
287
  ) -> str:
288
+ """Generate final report using MedGemma GPU pipeline."""
289
+ det = visual_results.get("detection_image_path", "")
290
+ seg = visual_results.get("segmentation_image_path", "")
291
+
292
+ report = generate_medgemma_report(patient_info, visual_results, guideline_context, det, seg, max_new_tokens)
 
 
293
  if report:
294
  return report
295
  return self._generate_fallback_report(patient_info, visual_results, guideline_context)
296
 
297
  def _generate_fallback_report(
298
+ self, patient_info: str, visual_results: dict, guideline_context: str
 
 
 
299
  ) -> str:
300
+ """Generate fallback report if MedGemma fails."""
301
  dp = visual_results.get('detection_image_path','N/A')
302
  sp = visual_results.get('segmentation_image_path','N/A')
303
  return (
304
+ f"# Fallback Report\n{patient_info}\n"
305
  f"Type: {visual_results.get('wound_type','Unknown')}\n"
306
  f"Detection Image: {dp}\n"
307
  f"Segmentation Image: {sp}\n"
 
309
  )
310
 
311
  def save_and_commit_image(self, image_pil: Image.Image) -> str:
312
+ """Save image locally and optionally commit to HF dataset."""
313
+ os.makedirs(self.uploads_dir, exist_ok=True)
314
  fn = f"{datetime.now():%Y%m%d_%H%M%S}.png"
315
+ path = os.path.join(self.uploads_dir, fn)
316
+ image_pil.convert("RGB").save(path)
317
+
318
+ if self.hf_token and self.dataset_id:
319
  try:
320
  HfApi().upload_file(
321
  path_or_fileobj=path,
322
  path_in_repo=f"images/{fn}",
323
+ repo_id=self.dataset_id,
324
+ repo_type="dataset",
325
  )
326
+ logging.info("✅ Image committed to HF dataset")
327
  except Exception as e:
328
  logging.warning(f"HF upload failed: {e}")
329
  return path
330
 
331
  def full_analysis_pipeline(self, image_pil: Image.Image, questionnaire_data: dict) -> dict:
332
+ """Run full analysis pipeline."""
333
  try:
334
  saved = self.save_and_commit_image(image_pil)
335
+ vis = self.perform_visual_analysis(image_pil)
336
+ info = ", ".join(f"{k}:{v}" for k,v in questionnaire_data.items() if v)
337
+ gc = self.query_guidelines(info)
338
+ report = self.generate_final_report(info, vis, gc, image_pil)
339
  return {'success': True, 'visual_analysis': vis, 'report': report, 'saved_image_path': saved}
340
  except Exception as e:
341
  logging.error(f"Pipeline error: {e}")
342
  return {'success': False, 'error': str(e)}
343
 
344
  def analyze_wound(self, image, questionnaire_data: dict) -> dict:
345
+ """Main analysis entry point."""
346
  if isinstance(image, str):
347
  image = Image.open(image)
348
  return self.full_analysis_pipeline(image, questionnaire_data)
349
 
350
  def _assess_risk_legacy(self, questionnaire_data: dict) -> dict:
351
+ """Legacy risk assessment function."""
352
  risk_factors, risk_score = [], 0
353
  try:
354
  age = questionnaire_data.get('patient_age', 0)
 
356
  risk_factors.append("Advanced age (>65)"); risk_score += 2
357
  elif age > 50:
358
  risk_factors.append("Older adult (50-65)"); risk_score += 1
359
+
360
  dur = questionnaire_data.get('wound_duration', '').lower()
361
  if any(t in dur for t in ['month','year']):
362
  risk_factors.append("Chronic wound (>4 weeks)"); risk_score += 3
363
+
364
  pain = questionnaire_data.get('pain_level', 0)
365
  if pain >= 7:
366
  risk_factors.append("High pain level"); risk_score += 2
367
+
368
  hist = questionnaire_data.get('medical_history','').lower()
369
  if 'diabetes' in hist:
370
  risk_factors.append("Diabetes mellitus"); risk_score += 3
 
372
  risk_factors.append("Vascular issues"); risk_score += 2
373
  if 'immune' in hist:
374
  risk_factors.append("Immune compromise"); risk_score += 2
375
+
376
  level = ("High" if risk_score >= 7 else "Moderate" if risk_score >= 4 else "Low")
377
  return {'risk_score': risk_score, 'risk_level': level, 'risk_factors': risk_factors}
378
  except Exception as e:
379
  logging.error(f"Risk assessment error: {e}")
380
+ return {'risk_score': 0, 'risk_level': 'Unknown', 'risk_factors': []}