SmartHeal commited on
Commit
a923317
·
verified ·
1 Parent(s): c421c59

Update src/ai_processor.py

Browse files
Files changed (1) hide show
  1. src/ai_processor.py +120 -67
src/ai_processor.py CHANGED
@@ -1,5 +1,5 @@
1
- # Disable GPU for all CPU-only model loading to avoid triggering CUDA init in the main process
2
  import os
 
3
  os.environ['CUDA_VISIBLE_DEVICES'] = ''
4
 
5
  import io
@@ -20,6 +20,7 @@ from huggingface_hub import HfApi, HfFolder
20
  import spaces
21
  from .config import Config
22
 
 
23
  default_system_prompt = (
24
  "You are a world-class medical AI assistant specializing in wound care "
25
  "with expertise in wound assessment and treatment. Provide concise, "
@@ -32,6 +33,59 @@ default_system_prompt = (
32
  "patient context."
33
  )
34
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
35
  class AIProcessor:
36
  def __init__(self):
37
  self.models_cache = {}
@@ -39,30 +93,31 @@ class AIProcessor:
39
  self.config = Config()
40
  self.px_per_cm = self.config.PIXELS_PER_CM
41
  self._initialize_models()
 
42
 
43
  def _initialize_models(self):
44
- """Initialize all CPU-only models here; GPU models loaded later in GPU context."""
45
- # HuggingFace token
46
  if self.config.HF_TOKEN:
47
  HfFolder.save_token(self.config.HF_TOKEN)
48
  logging.info("✅ HuggingFace token set")
49
 
50
- # YOLO detection (CPU only)
51
  try:
52
  self.models_cache['det'] = YOLO(self.config.YOLO_MODEL_PATH)
53
  logging.info("✅ YOLO model loaded (CPU only)")
54
  except Exception as e:
55
- logging.error(f"Failed to load YOLO model: {e}")
56
  raise
57
 
58
- # Segmentation model (CPU)
59
  try:
60
  self.models_cache['seg'] = load_model(self.config.SEG_MODEL_PATH, compile=False)
61
  logging.info("✅ Segmentation model loaded (CPU)")
62
  except Exception as e:
63
  logging.warning(f"Segmentation model not available: {e}")
64
 
65
- # Classification model (CPU)
66
  try:
67
  self.models_cache['cls'] = pipeline(
68
  'image-classification',
@@ -78,40 +133,38 @@ class AIProcessor:
78
  try:
79
  self.models_cache['embedding_model'] = HuggingFaceEmbeddings(
80
  model_name='sentence-transformers/all-MiniLM-L6-v2',
81
- model_kwargs={'device': 'cpu'}
82
  )
83
  logging.info("✅ Embedding model loaded (CPU)")
84
  except Exception as e:
85
  logging.warning(f"Embedding model not available: {e}")
86
 
87
- # Load PDF guidelines into FAISS
88
- self._load_knowledge_base()
89
-
90
  def _load_knowledge_base(self):
 
91
  docs = []
92
  for pdf in self.config.GUIDELINE_PDFS:
93
  if os.path.exists(pdf):
94
  loader = PyPDFLoader(pdf)
95
  docs.extend(loader.load())
96
- logging.info(f"Loaded guideline PDF: {pdf}")
97
 
98
  if docs and 'embedding_model' in self.models_cache:
99
  splitter = RecursiveCharacterTextSplitter(chunk_size=1000, chunk_overlap=100)
100
  chunks = splitter.split_documents(docs)
101
  vs = FAISS.from_documents(chunks, self.models_cache['embedding_model'])
102
  self.knowledge_base_cache['vectorstore'] = vs
103
- logging.info(f"✅ Knowledge base loaded with {len(chunks)} chunks")
104
  else:
105
  self.knowledge_base_cache['vectorstore'] = None
106
  logging.warning("Knowledge base unavailable")
107
 
108
- def perform_visual_analysis(self, image_pil):
109
- """Detect & segment on CPU; return metrics and file paths."""
110
  if 'det' not in self.models_cache:
111
  raise RuntimeError("YOLO model ('det') not loaded")
112
 
113
  img_cv = cv2.cvtColor(np.array(image_pil), cv2.COLOR_RGB2BGR)
114
- res = self.models_cache['det'].predict(img_cv, device='cpu', verbose=False)[0]
115
  if not res.boxes:
116
  raise ValueError("No wound detected")
117
 
@@ -126,19 +179,18 @@ class AIProcessor:
126
  det_path = f"{self.config.UPLOADS_DIR}/analysis/detection_{ts}.png"
127
  cv2.imwrite(det_path, det_vis)
128
 
129
- # Segmentation (if available)
130
  length = breadth = area = 0
131
  seg_path = None
132
  if 'seg' in self.models_cache:
133
  h, w = self.models_cache['seg'].input_shape[1:3]
134
  inp = cv2.resize(region, (w,h)) / 255.0
135
- mask = (self.models_cache['seg'].predict(np.expand_dims(inp,0))[0,:,:,0] > 0.5).astype(np.uint8)
136
  mask_rs = cv2.resize(mask, (region.shape[1], region.shape[0]), interpolation=cv2.INTER_NEAREST)
137
  ov = region.copy(); ov[mask_rs==1] = [0,0,255]
138
  seg_vis = cv2.addWeighted(region,0.7,ov,0.3,0)
139
  seg_path = f"{self.config.UPLOADS_DIR}/analysis/segmentation_{ts}.png"
140
  cv2.imwrite(seg_path, seg_vis)
141
-
142
  cnts, _ = cv2.findContours(mask_rs, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE)
143
  if cnts:
144
  cnt = max(cnts, key=cv2.contourArea)
@@ -168,7 +220,7 @@ class AIProcessor:
168
  'segmentation_image_path': seg_path
169
  }
170
 
171
- def query_guidelines(self, query: str):
172
  vs = self.knowledge_base_cache.get('vectorstore')
173
  if not vs:
174
  return "Clinical guidelines unavailable"
@@ -178,63 +230,55 @@ class AIProcessor:
178
  for d in docs
179
  )
180
 
181
- @spaces.GPU(enable_queue=True, duration=120)
182
- def generate_final_report(self, patient_info, visual_results, guideline_context, image_pil, max_new_tokens=None):
183
- """Run MedGemma on GPU; return markdown report."""
184
- if 'medgemma_pipe' not in self.models_cache:
185
- try:
186
- self.models_cache['medgemma_pipe'] = pipeline(
187
- 'image-text-to-text',
188
- model='google/medgemma-4b-it',
189
- device='auto',
190
- torch_dtype='auto',
191
- offload_folder='offload',
192
- token=self.config.HF_TOKEN
193
- )
194
- logging.info("✅ MedGemma pipeline loaded on GPU")
195
- except Exception as e:
196
- logging.warning(f"MedGemma pipeline load failed: {e}")
197
- return self._generate_fallback_report(patient_info, visual_results, guideline_context)
198
-
199
- msgs = [
200
- {'role':'system','content':[{'type':'text','text':default_system_prompt}]},
201
- {'role':'user','content':[]}
202
- ]
203
- if image_pil:
204
- msgs[1]['content'].append({'type':'image','image':image_pil})
205
- for key in ('detection_image_path','segmentation_image_path'):
206
- p = visual_results.get(key)
207
- if p and os.path.exists(p):
208
- msgs[1]['content'].append({'type':'image','image':Image.open(p)})
209
- prompt = f"## Patient\n{patient_info}\n## Wound Type: {visual_results['wound_type']}"
210
- msgs[1]['content'].append({'type':'text','text':prompt})
211
-
212
- out = self.models_cache['medgemma_pipe'](
213
- text=msgs,
214
- max_new_tokens=max_new_tokens or self.config.MAX_NEW_TOKENS,
215
- do_sample=False
216
  )
217
- report = out[0]['generated_text'][-1].get('content','')
218
- return report or self._generate_fallback_report(patient_info, visual_results, guideline_context)
 
219
 
220
- def _generate_fallback_report(self, patient_info, visual_results, guideline_context):
 
 
 
 
 
221
  dp = visual_results.get('detection_image_path','N/A')
222
  sp = visual_results.get('segmentation_image_path','N/A')
223
  return (
224
- f"# Report\n{patient_info}\nType: {visual_results['wound_type']}\n"
225
- f"Detection Image: {dp}\nSegmentation Image: {sp}\n"
 
 
226
  f"Guidelines: {guideline_context[:200]}..."
227
  )
228
 
229
- def save_and_commit_image(self, image_pil):
230
  os.makedirs(self.config.UPLOADS_DIR, exist_ok=True)
231
  fn = f"{datetime.now():%Y%m%d_%H%M%S}.png"
232
  path = os.path.join(self.config.UPLOADS_DIR, fn)
233
  image_pil.convert('RGB').save(path)
234
  if self.config.HF_TOKEN and getattr(self.config, 'DATASET_ID', None):
235
  try:
236
- api = HfApi()
237
- api.upload_file(
238
  path_or_fileobj=path,
239
  path_in_repo=f"images/{fn}",
240
  repo_id=self.config.DATASET_ID,
@@ -244,19 +288,28 @@ class AIProcessor:
244
  logging.warning(f"HF upload failed: {e}")
245
  return path
246
 
247
- def full_analysis_pipeline(self, image_pil, questionnaire_data):
 
 
 
 
248
  try:
249
  saved = self.save_and_commit_image(image_pil)
250
  vis = self.perform_visual_analysis(image_pil)
251
- info = ", ".join(f"{k}:{v}" for k, v in questionnaire_data.items() if v)
252
  gc = self.query_guidelines(info)
253
  report = self.generate_final_report(info, vis, gc, image_pil)
254
- return {'success': True, 'visual_analysis': vis, 'report': report, 'saved_image_path': saved}
 
 
 
 
 
255
  except Exception as e:
256
  logging.error(f"Pipeline error: {e}")
257
  return {'success': False, 'error': str(e)}
258
 
259
- def analyze_wound(self, image, questionnaire_data):
260
  if isinstance(image, str):
261
  image = Image.open(image)
262
  return self.full_analysis_pipeline(image, questionnaire_data)
 
 
1
  import os
2
+ # Ensure all CPU-only models never touch CUDA
3
  os.environ['CUDA_VISIBLE_DEVICES'] = ''
4
 
5
  import io
 
20
  import spaces
21
  from .config import Config
22
 
23
+ # System prompt for MedGemma
24
  default_system_prompt = (
25
  "You are a world-class medical AI assistant specializing in wound care "
26
  "with expertise in wound assessment and treatment. Provide concise, "
 
33
  "patient context."
34
  )
35
 
36
+ @spaces.GPU(enable_queue=True, duration=120)
37
+ def generate_medgemma_report(
38
+ patient_info: str,
39
+ visual_results: dict,
40
+ guideline_context: str,
41
+ detection_image_path: str,
42
+ segmentation_image_path: str,
43
+ max_new_tokens: int = None
44
+ ) -> str:
45
+ """
46
+ Runs on GPU. Lazy-loads the MedGemma pipeline and returns the markdown report.
47
+ Accepts only primitive types and file-paths, so pickling works.
48
+ """
49
+ # Lazy-load pipeline
50
+ if not hasattr(generate_medgemma_report, "_pipe"):
51
+ try:
52
+ cfg = Config()
53
+ generate_medgemma_report._pipe = pipeline(
54
+ 'image-text-to-text',
55
+ model='google/medgemma-4b-it',
56
+ device='auto',
57
+ torch_dtype='auto',
58
+ offload_folder='offload',
59
+ token=cfg.HF_TOKEN
60
+ )
61
+ logging.info("✅ MedGemma pipeline loaded on GPU")
62
+ except Exception as e:
63
+ logging.warning(f"MedGemma pipeline load failed: {e}")
64
+ return None
65
+
66
+ pipe = generate_medgemma_report._pipe
67
+
68
+ # Assemble messages
69
+ msgs = [
70
+ {'role':'system','content':[{'type':'text','text':default_system_prompt}]},
71
+ {'role':'user','content':[]}
72
+ ]
73
+ # Attach images
74
+ for path in (detection_image_path, segmentation_image_path):
75
+ if path and os.path.exists(path):
76
+ msgs[1]['content'].append({'type':'image','image': Image.open(path)})
77
+ # Attach text
78
+ prompt = f"## Patient\n{patient_info}\n## Wound Type: {visual_results.get('wound_type','Unknown')}"
79
+ msgs[1]['content'].append({'type':'text','text': prompt})
80
+
81
+ out = pipe(
82
+ text=msgs,
83
+ max_new_tokens=max_new_tokens or Config().MAX_NEW_TOKENS,
84
+ do_sample=False
85
+ )
86
+ return out[0]['generated_text'][-1].get('content','')
87
+
88
+
89
  class AIProcessor:
90
  def __init__(self):
91
  self.models_cache = {}
 
93
  self.config = Config()
94
  self.px_per_cm = self.config.PIXELS_PER_CM
95
  self._initialize_models()
96
+ self._load_knowledge_base()
97
 
98
  def _initialize_models(self):
99
+ """Load all CPU-only models here."""
100
+ # Set HuggingFace token
101
  if self.config.HF_TOKEN:
102
  HfFolder.save_token(self.config.HF_TOKEN)
103
  logging.info("✅ HuggingFace token set")
104
 
105
+ # YOLO detection (CPU)
106
  try:
107
  self.models_cache['det'] = YOLO(self.config.YOLO_MODEL_PATH)
108
  logging.info("✅ YOLO model loaded (CPU only)")
109
  except Exception as e:
110
+ logging.error(f"YOLO load failed: {e}")
111
  raise
112
 
113
+ # Segmentation (CPU)
114
  try:
115
  self.models_cache['seg'] = load_model(self.config.SEG_MODEL_PATH, compile=False)
116
  logging.info("✅ Segmentation model loaded (CPU)")
117
  except Exception as e:
118
  logging.warning(f"Segmentation model not available: {e}")
119
 
120
+ # Classification (CPU)
121
  try:
122
  self.models_cache['cls'] = pipeline(
123
  'image-classification',
 
133
  try:
134
  self.models_cache['embedding_model'] = HuggingFaceEmbeddings(
135
  model_name='sentence-transformers/all-MiniLM-L6-v2',
136
+ model_kwargs={'device':'cpu'}
137
  )
138
  logging.info("✅ Embedding model loaded (CPU)")
139
  except Exception as e:
140
  logging.warning(f"Embedding model not available: {e}")
141
 
 
 
 
142
  def _load_knowledge_base(self):
143
+ """Load PDF guidelines into a FAISS vector store."""
144
  docs = []
145
  for pdf in self.config.GUIDELINE_PDFS:
146
  if os.path.exists(pdf):
147
  loader = PyPDFLoader(pdf)
148
  docs.extend(loader.load())
149
+ logging.info(f"Loaded PDF: {pdf}")
150
 
151
  if docs and 'embedding_model' in self.models_cache:
152
  splitter = RecursiveCharacterTextSplitter(chunk_size=1000, chunk_overlap=100)
153
  chunks = splitter.split_documents(docs)
154
  vs = FAISS.from_documents(chunks, self.models_cache['embedding_model'])
155
  self.knowledge_base_cache['vectorstore'] = vs
156
+ logging.info(f"✅ Knowledge base loaded ({len(chunks)} chunks)")
157
  else:
158
  self.knowledge_base_cache['vectorstore'] = None
159
  logging.warning("Knowledge base unavailable")
160
 
161
+ def perform_visual_analysis(self, image_pil: Image.Image) -> dict:
162
+ """Detect & segment on CPU; return metrics + file paths."""
163
  if 'det' not in self.models_cache:
164
  raise RuntimeError("YOLO model ('det') not loaded")
165
 
166
  img_cv = cv2.cvtColor(np.array(image_pil), cv2.COLOR_RGB2BGR)
167
+ res = self.models_cache['det'].predict(img_cv, verbose=False)[0]
168
  if not res.boxes:
169
  raise ValueError("No wound detected")
170
 
 
179
  det_path = f"{self.config.UPLOADS_DIR}/analysis/detection_{ts}.png"
180
  cv2.imwrite(det_path, det_vis)
181
 
182
+ # Segmentation metrics
183
  length = breadth = area = 0
184
  seg_path = None
185
  if 'seg' in self.models_cache:
186
  h, w = self.models_cache['seg'].input_shape[1:3]
187
  inp = cv2.resize(region, (w,h)) / 255.0
188
+ mask = (self.models_cache['seg'].predict(inp[None])[0,:,:,0] > 0.5).astype(np.uint8)
189
  mask_rs = cv2.resize(mask, (region.shape[1], region.shape[0]), interpolation=cv2.INTER_NEAREST)
190
  ov = region.copy(); ov[mask_rs==1] = [0,0,255]
191
  seg_vis = cv2.addWeighted(region,0.7,ov,0.3,0)
192
  seg_path = f"{self.config.UPLOADS_DIR}/analysis/segmentation_{ts}.png"
193
  cv2.imwrite(seg_path, seg_vis)
 
194
  cnts, _ = cv2.findContours(mask_rs, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE)
195
  if cnts:
196
  cnt = max(cnts, key=cv2.contourArea)
 
220
  'segmentation_image_path': seg_path
221
  }
222
 
223
+ def query_guidelines(self, query: str) -> str:
224
  vs = self.knowledge_base_cache.get('vectorstore')
225
  if not vs:
226
  return "Clinical guidelines unavailable"
 
230
  for d in docs
231
  )
232
 
233
+ def generate_final_report(
234
+ self,
235
+ patient_info: str,
236
+ visual_results: dict,
237
+ guideline_context: str,
238
+ image_pil: Image.Image,
239
+ max_new_tokens: int = None
240
+ ) -> str:
241
+ """
242
+ Signature unchanged. Gathers arguments, calls GPU function, and falls back if needed.
243
+ """
244
+ det = visual_results.get('detection_image_path', '')
245
+ seg = visual_results.get('segmentation_image_path', '')
246
+ report = generate_medgemma_report(
247
+ patient_info,
248
+ visual_results,
249
+ guideline_context,
250
+ det,
251
+ seg,
252
+ max_new_tokens
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
253
  )
254
+ if report:
255
+ return report
256
+ return self._generate_fallback_report(patient_info, visual_results, guideline_context)
257
 
258
+ def _generate_fallback_report(
259
+ self,
260
+ patient_info: str,
261
+ visual_results: dict,
262
+ guideline_context: str
263
+ ) -> str:
264
  dp = visual_results.get('detection_image_path','N/A')
265
  sp = visual_results.get('segmentation_image_path','N/A')
266
  return (
267
+ f"# Report\n{patient_info}\n"
268
+ f"Type: {visual_results.get('wound_type','Unknown')}\n"
269
+ f"Detection Image: {dp}\n"
270
+ f"Segmentation Image: {sp}\n"
271
  f"Guidelines: {guideline_context[:200]}..."
272
  )
273
 
274
+ def save_and_commit_image(self, image_pil: Image.Image) -> str:
275
  os.makedirs(self.config.UPLOADS_DIR, exist_ok=True)
276
  fn = f"{datetime.now():%Y%m%d_%H%M%S}.png"
277
  path = os.path.join(self.config.UPLOADS_DIR, fn)
278
  image_pil.convert('RGB').save(path)
279
  if self.config.HF_TOKEN and getattr(self.config, 'DATASET_ID', None):
280
  try:
281
+ HfApi().upload_file(
 
282
  path_or_fileobj=path,
283
  path_in_repo=f"images/{fn}",
284
  repo_id=self.config.DATASET_ID,
 
288
  logging.warning(f"HF upload failed: {e}")
289
  return path
290
 
291
+ def full_analysis_pipeline(
292
+ self,
293
+ image_pil: Image.Image,
294
+ questionnaire_data: dict
295
+ ) -> dict:
296
  try:
297
  saved = self.save_and_commit_image(image_pil)
298
  vis = self.perform_visual_analysis(image_pil)
299
+ info = ", ".join(f"{k}:{v}" for k,v in questionnaire_data.items() if v)
300
  gc = self.query_guidelines(info)
301
  report = self.generate_final_report(info, vis, gc, image_pil)
302
+ return {
303
+ 'success': True,
304
+ 'visual_analysis': vis,
305
+ 'report': report,
306
+ 'saved_image_path': saved
307
+ }
308
  except Exception as e:
309
  logging.error(f"Pipeline error: {e}")
310
  return {'success': False, 'error': str(e)}
311
 
312
+ def analyze_wound(self, image, questionnaire_data: dict) -> dict:
313
  if isinstance(image, str):
314
  image = Image.open(image)
315
  return self.full_analysis_pipeline(image, questionnaire_data)