SmartHeal commited on
Commit
a391b19
·
verified ·
1 Parent(s): 18a5b30

Update src/ai_processor.py

Browse files
Files changed (1) hide show
  1. src/ai_processor.py +69 -48
src/ai_processor.py CHANGED
@@ -3,14 +3,12 @@
3
  # Turn on deep logging: export LOGLEVEL=DEBUG SMARTHEAL_DEBUG=1
4
 
5
  import os
6
- import time
7
  import logging
8
  from datetime import datetime
9
  from typing import Optional, Dict, List, Tuple
10
 
11
- # ---- Environment defaults ----
12
  os.environ.setdefault("TOKENIZERS_PARALLELISM", "false")
13
- os.environ.setdefault("CUDA_VISIBLE_DEVICES", "")
14
  LOGLEVEL = os.getenv("LOGLEVEL", "INFO").upper()
15
  SMARTHEAL_DEBUG = os.getenv("SMARTHEAL_DEBUG", "0") == "1"
16
 
@@ -35,6 +33,7 @@ from spaces import GPU as _SPACES_GPU
35
  def smartheal_gpu_stub(ping: int = 0) -> str:
36
  return "ready"
37
 
 
38
  UPLOADS_DIR = "uploads"
39
  os.makedirs(UPLOADS_DIR, exist_ok=True)
40
 
@@ -54,15 +53,37 @@ SEG_THRESH = float(os.getenv("SEG_THRESH", "0.5"))
54
  models_cache: Dict[str, object] = {}
55
  knowledge_base_cache: Dict[str, object] = {}
56
 
57
- # ---------- Lazy imports ----------
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
58
  def _import_ultralytics():
59
- from ultralytics import YOLO
 
 
60
  return YOLO
61
 
62
  def _import_tf_loader():
63
  import tensorflow as tf
64
  try:
65
- tf.config.set_visible_devices([], "GPU") # keep TF on CPU
 
66
  except Exception:
67
  pass
68
  from tensorflow.keras.models import load_model
@@ -122,6 +143,27 @@ Keep to 220–300 words. Do NOT provide diagnosis. Avoid contraindicated advice.
122
  """
123
 
124
  # ---------- VLM (MedGemma replaced with Qwen2-VL) ----------
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
125
  def generate_medgemma_report( # kept name so callers don't change
126
  patient_info: str,
127
  visual_results: Dict,
@@ -131,6 +173,7 @@ def generate_medgemma_report( # kept name so callers don't change
131
  ) -> str:
132
  """
133
  MedGemma replacement using Qwen/Qwen2-VL-2B-Instruct via image-text-to-text.
 
134
  """
135
  if os.getenv("SMARTHEAL_ENABLE_VLM", "1") != "1":
136
  return "⚠️ VLM disabled"
@@ -138,20 +181,6 @@ def generate_medgemma_report( # kept name so callers don't change
138
  model_id = os.getenv("SMARTHEAL_VLM_MODEL", "Qwen/Qwen2-VL-2B-Instruct")
139
  max_new_tokens = max_new_tokens or int(os.getenv("SMARTHEAL_VLM_MAX_TOKENS", "600"))
140
 
141
- try:
142
- from transformers import pipeline
143
- pipe = pipeline(
144
- task="image-text-to-text",
145
- model=model_id,
146
- device_map=None, # keep CPU by default for Spaces stability
147
- token=HF_TOKEN,
148
- trust_remote_code=True,
149
- model_kwargs={"low_cpu_mem_usage": True},
150
- )
151
- except Exception as e:
152
- logging.error(f"❌ Could not load VLM ({model_id}): {e}")
153
- return "⚠️ VLM error"
154
-
155
  uprompt = SMARTHEAL_USER_PREFIX.format(
156
  patient_info=patient_info,
157
  wound_type=visual_results.get("wound_type", "Unknown"),
@@ -163,34 +192,28 @@ def generate_medgemma_report( # kept name so callers don't change
163
  guideline_context=(guideline_context or "")[:900],
164
  )
165
 
 
 
 
 
 
 
 
 
166
  try:
167
- messages = [
168
- {"role": "system", "content": [{"type": "text", "text": SMARTHEAL_SYSTEM_PROMPT}]},
169
- {"role": "user", "content": [
170
- {"type": "image", "image": image_pil},
171
- {"type": "text", "text": uprompt},
172
- ]},
173
- ]
174
- out = pipe(text=messages, max_new_tokens=max_new_tokens, do_sample=False, temperature=0.2)
175
- if out and len(out) > 0:
176
- try:
177
- text = out[0]["generated_text"][-1].get("content", "")
178
- except Exception:
179
- text = out[0].get("generated_text", "")
180
- text = (text or "").strip()
181
- return text if text else "⚠️ Empty response"
182
- return "⚠️ No output generated"
183
  except Exception as e:
184
- logging.error(f"VLM generation error: {e}")
185
  return "⚠️ VLM error"
186
 
187
- UPLOADS_DIR = "uploads"
188
- os.makedirs(UPLOADS_DIR, exist_ok=True)
189
-
190
  # ---------- Initialize CPU models ----------
191
  def load_yolo_model():
192
  YOLO = _import_ultralytics()
193
- return YOLO(YOLO_MODEL_PATH)
 
 
 
194
 
195
  def load_segmentation_model():
196
  load_model = _import_tf_loader()
@@ -216,7 +239,7 @@ def initialize_cpu_models() -> None:
216
  if "det" not in models_cache:
217
  try:
218
  models_cache["det"] = load_yolo_model()
219
- logging.info("✅ YOLO loaded (CPU)")
220
  except Exception as e:
221
  logging.error(f"YOLO load failed: {e}")
222
 
@@ -653,6 +676,7 @@ class AIProcessor:
653
  det_model = self.models_cache.get("det")
654
  if det_model is None:
655
  raise RuntimeError("YOLO model not loaded")
 
656
  results = det_model.predict(image_cv, verbose=False, device="cpu")
657
  if (not results) or (not getattr(results[0], "boxes", None)) or (len(results[0].boxes) == 0):
658
  try:
@@ -797,12 +821,9 @@ class AIProcessor:
797
  vs = self.knowledge_base_cache.get("vector_store")
798
  if not vs:
799
  return "Knowledge base is not available."
800
- try:
801
- retriever = vs.as_retriever(search_kwargs={"k": 5})
802
- docs = retriever.get_relevant_documents(query)
803
- except Exception:
804
- retriever = vs.as_retriever(search_kwargs={"k": 5})
805
- docs = retriever.invoke(query)
806
  lines: List[str] = []
807
  for d in docs:
808
  src = (d.metadata or {}).get("source", "N/A")
 
3
  # Turn on deep logging: export LOGLEVEL=DEBUG SMARTHEAL_DEBUG=1
4
 
5
  import os
 
6
  import logging
7
  from datetime import datetime
8
  from typing import Optional, Dict, List, Tuple
9
 
10
+ # ---- Environment defaults (do NOT globally hint CUDA here) ----
11
  os.environ.setdefault("TOKENIZERS_PARALLELISM", "false")
 
12
  LOGLEVEL = os.getenv("LOGLEVEL", "INFO").upper()
13
  SMARTHEAL_DEBUG = os.getenv("SMARTHEAL_DEBUG", "0") == "1"
14
 
 
33
  def smartheal_gpu_stub(ping: int = 0) -> str:
34
  return "ready"
35
 
36
+ # ---- Paths / constants ----
37
  UPLOADS_DIR = "uploads"
38
  os.makedirs(UPLOADS_DIR, exist_ok=True)
39
 
 
53
  models_cache: Dict[str, object] = {}
54
  knowledge_base_cache: Dict[str, object] = {}
55
 
56
+ # ---------- Utilities to prevent CUDA in main process ----------
57
+ from contextlib import contextmanager
58
+
59
+ @contextmanager
60
+ def _no_cuda_env():
61
+ """
62
+ Mask GPUs so any library imported/constructed in the main process
63
+ cannot see CUDA (required for Spaces Stateless GPU).
64
+ """
65
+ prev = os.environ.get("CUDA_VISIBLE_DEVICES")
66
+ os.environ["CUDA_VISIBLE_DEVICES"] = "-1"
67
+ try:
68
+ yield
69
+ finally:
70
+ if prev is None:
71
+ os.environ.pop("CUDA_VISIBLE_DEVICES", None)
72
+ else:
73
+ os.environ["CUDA_VISIBLE_DEVICES"] = prev
74
+
75
+ # ---------- Lazy imports (wrapped where needed) ----------
76
  def _import_ultralytics():
77
+ # Prevent Ultralytics from probing CUDA on import
78
+ with _no_cuda_env():
79
+ from ultralytics import YOLO
80
  return YOLO
81
 
82
  def _import_tf_loader():
83
  import tensorflow as tf
84
  try:
85
+ # Keep TF on CPU only
86
+ tf.config.set_visible_devices([], "GPU")
87
  except Exception:
88
  pass
89
  from tensorflow.keras.models import load_model
 
143
  """
144
 
145
  # ---------- VLM (MedGemma replaced with Qwen2-VL) ----------
146
+ @_SPACES_GPU(enable_queue=True)
147
+ def _vlm_infer_gpu(messages, model_id: str, max_new_tokens: int, token: Optional[str]):
148
+ """
149
+ Runs entirely inside a Spaces GPU worker. It's the ONLY place we allow CUDA init.
150
+ """
151
+ from transformers import pipeline
152
+ pipe = pipeline(
153
+ task="image-text-to-text",
154
+ model=model_id,
155
+ device_map="auto", # CUDA init happens here, safely in GPU worker
156
+ token=token,
157
+ trust_remote_code=True,
158
+ model_kwargs={"low_cpu_mem_usage": True},
159
+ )
160
+ out = pipe(text=messages, max_new_tokens=max_new_tokens, do_sample=False, temperature=0.2)
161
+ try:
162
+ txt = out[0]["generated_text"][-1].get("content", "")
163
+ except Exception:
164
+ txt = out[0].get("generated_text", "")
165
+ return (txt or "").strip() or "⚠️ Empty response"
166
+
167
  def generate_medgemma_report( # kept name so callers don't change
168
  patient_info: str,
169
  visual_results: Dict,
 
173
  ) -> str:
174
  """
175
  MedGemma replacement using Qwen/Qwen2-VL-2B-Instruct via image-text-to-text.
176
+ Loads & runs ONLY inside a GPU worker to satisfy Stateless GPU constraints.
177
  """
178
  if os.getenv("SMARTHEAL_ENABLE_VLM", "1") != "1":
179
  return "⚠️ VLM disabled"
 
181
  model_id = os.getenv("SMARTHEAL_VLM_MODEL", "Qwen/Qwen2-VL-2B-Instruct")
182
  max_new_tokens = max_new_tokens or int(os.getenv("SMARTHEAL_VLM_MAX_TOKENS", "600"))
183
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
184
  uprompt = SMARTHEAL_USER_PREFIX.format(
185
  patient_info=patient_info,
186
  wound_type=visual_results.get("wound_type", "Unknown"),
 
192
  guideline_context=(guideline_context or "")[:900],
193
  )
194
 
195
+ messages = [
196
+ {"role": "system", "content": [{"type": "text", "text": SMARTHEAL_SYSTEM_PROMPT}]},
197
+ {"role": "user", "content": [
198
+ {"type": "image", "image": image_pil},
199
+ {"type": "text", "text": uprompt},
200
+ ]},
201
+ ]
202
+
203
  try:
204
+ # IMPORTANT: do not import transformers or touch CUDA here. Only call the GPU worker.
205
+ return _vlm_infer_gpu(messages, model_id, max_new_tokens, HF_TOKEN)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
206
  except Exception as e:
207
+ logging.error(f"VLM call failed: {e}")
208
  return "⚠️ VLM error"
209
 
 
 
 
210
  # ---------- Initialize CPU models ----------
211
  def load_yolo_model():
212
  YOLO = _import_ultralytics()
213
+ # Construct model with CUDA masked to avoid auto-selecting cuda:0
214
+ with _no_cuda_env():
215
+ model = YOLO(YOLO_MODEL_PATH)
216
+ return model
217
 
218
  def load_segmentation_model():
219
  load_model = _import_tf_loader()
 
239
  if "det" not in models_cache:
240
  try:
241
  models_cache["det"] = load_yolo_model()
242
+ logging.info("✅ YOLO loaded (CPU; CUDA masked in main)")
243
  except Exception as e:
244
  logging.error(f"YOLO load failed: {e}")
245
 
 
676
  det_model = self.models_cache.get("det")
677
  if det_model is None:
678
  raise RuntimeError("YOLO model not loaded")
679
+ # Force CPU inference and avoid CUDA touch
680
  results = det_model.predict(image_cv, verbose=False, device="cpu")
681
  if (not results) or (not getattr(results[0], "boxes", None)) or (len(results[0].boxes) == 0):
682
  try:
 
821
  vs = self.knowledge_base_cache.get("vector_store")
822
  if not vs:
823
  return "Knowledge base is not available."
824
+ retriever = vs.as_retriever(search_kwargs={"k": 5})
825
+ # Modern API (avoid get_relevant_documents deprecation)
826
+ docs = retriever.invoke(query)
 
 
 
827
  lines: List[str] = []
828
  for d in docs:
829
  src = (d.metadata or {}).get("source", "N/A")