SmartHeal commited on
Commit
9f4b663
·
verified ·
1 Parent(s): c950f29

Update src/ai_processor.py

Browse files
Files changed (1) hide show
  1. src/ai_processor.py +45 -33
src/ai_processor.py CHANGED
@@ -1,11 +1,9 @@
1
- import os
2
  import io
3
  import base64
4
  import logging
5
  import cv2
6
  import numpy as np
7
  from PIL import Image
8
- import torch
9
  from datetime import datetime
10
  from transformers import pipeline
11
  from ultralytics import YOLO
@@ -40,26 +38,12 @@ class AIProcessor:
40
  self._initialize_models()
41
 
42
  def _initialize_models(self):
43
- """Initialize AI models; only MedGemma uses GPU."""
44
- # Set HuggingFace token
45
  if self.config.HF_TOKEN:
46
  HfFolder.save_token(self.config.HF_TOKEN)
47
  logging.info("HuggingFace token set successfully")
48
 
49
- # MedGemma pipeline on GPU
50
- try:
51
- self.models_cache['medgemma_pipe'] = pipeline(
52
- 'image-text-to-text',
53
- model='google/medgemma-4b-it',
54
- device='cuda',
55
- torch_dtype=torch.bfloat16,
56
- offload_folder='offload',
57
- token=self.config.HF_TOKEN
58
- )
59
- logging.info("✅ MedGemma pipeline loaded on GPU")
60
- except Exception as e:
61
- logging.warning(f"MedGemma pipeline not available: {e}")
62
-
63
  # YOLO detection on CPU
64
  try:
65
  self.models_cache['det'] = YOLO(self.config.YOLO_MODEL_PATH)
@@ -86,7 +70,7 @@ class AIProcessor:
86
  except Exception as e:
87
  logging.warning(f"Wound classification model not available: {e}")
88
 
89
- # Embedding for knowledge base
90
  try:
91
  self.models_cache['embedding_model'] = HuggingFaceEmbeddings(
92
  model_name='sentence-transformers/all-MiniLM-L6-v2',
@@ -190,21 +174,43 @@ class AIProcessor:
190
  @spaces.GPU(enable_queue=True, duration=120)
191
  def generate_final_report(self, patient_info, visual_results, guideline_context, image_pil, max_new_tokens=None):
192
  """Run MedGemma on GPU; return markdown report."""
 
193
  if 'medgemma_pipe' not in self.models_cache:
194
- return self._generate_fallback_report(patient_info, visual_results, guideline_context)
 
 
 
 
 
 
 
 
 
 
 
 
 
195
  # build messages
196
- msgs = [{ 'role':'system', 'content':[{'type':'text','text': default_system_prompt}] },
197
- { 'role':'user', 'content':[]}]
 
 
198
  # images
199
- if image_pil: msgs[1]['content'].append({'type':'image','image':image_pil})
 
200
  for key in ('detection_image_path','segmentation_image_path'):
201
  p = visual_results.get(key)
202
  if p and os.path.exists(p):
203
- msgs[1]['content'].append({'type':'image', 'image': Image.open(p)})
204
- # text prompt stub (expand as needed)
205
- prompt = f"## Patient\n{patient_info}\n## Visual Type: {visual_results['wound_type']}"
206
  msgs[1]['content'].append({'type':'text','text':prompt})
207
- out = self.models_cache['medgemma_pipe'](text=msgs, max_new_tokens=max_new_tokens or self.config.MAX_NEW_TOKENS)
 
 
 
 
 
208
  report = out[0]['generated_text'][-1].get('content','')
209
  return report or self._generate_fallback_report(patient_info, visual_results, guideline_context)
210
 
@@ -220,10 +226,15 @@ class AIProcessor:
220
  fn = f"{datetime.now():%Y%m%d_%H%M%S}.png"
221
  path = os.path.join(self.config.UPLOADS_DIR, fn)
222
  image_pil.convert('RGB').save(path)
223
- if self.config.HF_TOKEN and self.config.DATASET_ID:
224
  try:
225
  api = HfApi()
226
- api.upload_file(path_or_fileobj=path, path_in_repo=f"images/{fn}", repo_id=self.config.DATASET_ID, repo_type='dataset')
 
 
 
 
 
227
  except Exception as e:
228
  logging.warning(f"HF upload failed: {e}")
229
  return path
@@ -236,17 +247,18 @@ class AIProcessor:
236
  info = ", ".join(f"{k}:{v}" for k,v in questionnaire_data.items() if v)
237
  gc = self.query_guidelines(info)
238
  report = self.generate_final_report(info, vis, gc, image_pil)
239
- return {'success':True, 'visual_analysis':vis, 'report':report, 'saved_image_path':saved}
240
  except Exception as e:
241
  logging.error(f"Pipeline error: {e}")
242
- return {'success':False, 'error':str(e)}
243
-
244
 
245
  def analyze_wound(self, image, questionnaire_data):
246
  """Legacy wrapper."""
247
- if isinstance(image, str): image = Image.open(image)
 
248
  return self.full_analysis_pipeline(image, questionnaire_data)
249
 
 
250
  def _assess_risk_legacy(self, questionnaire_data):
251
  """Legacy risk assessment for backward compatibility"""
252
  risk_factors = []
 
 
1
  import io
2
  import base64
3
  import logging
4
  import cv2
5
  import numpy as np
6
  from PIL import Image
 
7
  from datetime import datetime
8
  from transformers import pipeline
9
  from ultralytics import YOLO
 
38
  self._initialize_models()
39
 
40
  def _initialize_models(self):
41
+ """Initialize CPU-only AI models; MedGemma is loaded on demand within GPU context."""
42
+ # Set HuggingFace token early
43
  if self.config.HF_TOKEN:
44
  HfFolder.save_token(self.config.HF_TOKEN)
45
  logging.info("HuggingFace token set successfully")
46
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
47
  # YOLO detection on CPU
48
  try:
49
  self.models_cache['det'] = YOLO(self.config.YOLO_MODEL_PATH)
 
70
  except Exception as e:
71
  logging.warning(f"Wound classification model not available: {e}")
72
 
73
+ # Embedding for knowledge base on CPU
74
  try:
75
  self.models_cache['embedding_model'] = HuggingFaceEmbeddings(
76
  model_name='sentence-transformers/all-MiniLM-L6-v2',
 
174
  @spaces.GPU(enable_queue=True, duration=120)
175
  def generate_final_report(self, patient_info, visual_results, guideline_context, image_pil, max_new_tokens=None):
176
  """Run MedGemma on GPU; return markdown report."""
177
+ # lazy-load MedGemma pipeline here to avoid CUDA init in main process
178
  if 'medgemma_pipe' not in self.models_cache:
179
+ try:
180
+ self.models_cache['medgemma_pipe'] = pipeline(
181
+ 'image-text-to-text',
182
+ model='google/medgemma-4b-it',
183
+ device='cuda',
184
+ torch_dtype=torch.bfloat16,
185
+ offload_folder='offload',
186
+ token=self.config.HF_TOKEN
187
+ )
188
+ logging.info("✅ MedGemma pipeline loaded on GPU")
189
+ except Exception as e:
190
+ logging.warning(f"MedGemma pipeline not available: {e}")
191
+ return self._generate_fallback_report(patient_info, visual_results, guideline_context)
192
+
193
  # build messages
194
+ msgs = [
195
+ {'role':'system','content':[{'type':'text','text':default_system_prompt}]},
196
+ {'role':'user','content':[]}
197
+ ]
198
  # images
199
+ if image_pil:
200
+ msgs[1]['content'].append({'type':'image','image':image_pil})
201
  for key in ('detection_image_path','segmentation_image_path'):
202
  p = visual_results.get(key)
203
  if p and os.path.exists(p):
204
+ msgs[1]['content'].append({'type':'image','image':Image.open(p)})
205
+ # text prompt
206
+ prompt = f"## Patient\n{patient_info}\n## Wound Type: {visual_results['wound_type']}"
207
  msgs[1]['content'].append({'type':'text','text':prompt})
208
+
209
+ out = self.models_cache['medgemma_pipe'](
210
+ text=msgs,
211
+ max_new_tokens=max_new_tokens or self.config.MAX_NEW_TOKENS,
212
+ do_sample=False
213
+ )
214
  report = out[0]['generated_text'][-1].get('content','')
215
  return report or self._generate_fallback_report(patient_info, visual_results, guideline_context)
216
 
 
226
  fn = f"{datetime.now():%Y%m%d_%H%M%S}.png"
227
  path = os.path.join(self.config.UPLOADS_DIR, fn)
228
  image_pil.convert('RGB').save(path)
229
+ if self.config.HF_TOKEN and hasattr(self.config,'DATASET_ID') and self.config.DATASET_ID:
230
  try:
231
  api = HfApi()
232
+ api.upload_file(
233
+ path_or_fileobj=path,
234
+ path_in_repo=f"images/{fn}",
235
+ repo_id=self.config.DATASET_ID,
236
+ repo_type='dataset'
237
+ )
238
  except Exception as e:
239
  logging.warning(f"HF upload failed: {e}")
240
  return path
 
247
  info = ", ".join(f"{k}:{v}" for k,v in questionnaire_data.items() if v)
248
  gc = self.query_guidelines(info)
249
  report = self.generate_final_report(info, vis, gc, image_pil)
250
+ return {'success':True,'visual_analysis':vis,'report':report,'saved_image_path':saved}
251
  except Exception as e:
252
  logging.error(f"Pipeline error: {e}")
253
+ return {'success':False,'error':str(e)}
 
254
 
255
  def analyze_wound(self, image, questionnaire_data):
256
  """Legacy wrapper."""
257
+ if isinstance(image,str):
258
+ image = Image.open(image)
259
  return self.full_analysis_pipeline(image, questionnaire_data)
260
 
261
+
262
  def _assess_risk_legacy(self, questionnaire_data):
263
  """Legacy risk assessment for backward compatibility"""
264
  risk_factors = []