ghitaben commited on
Commit
de5f46b
·
1 Parent(s): 9004314

Fix — new call chain in loader.py

Browse files
Files changed (1) hide show
  1. src/loader.py +82 -25
src/loader.py CHANGED
@@ -169,7 +169,74 @@ def get_text_model(
169
  return _get_local_causal_lm(model_name)
170
 
171
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
172
  @_gpu
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
173
  def run_inference(
174
  prompt: str,
175
  model_name: TextModelName = "medgemma_4b",
@@ -177,20 +244,18 @@ def run_inference(
177
  temperature: float = 0.2,
178
  **kwargs: Any,
179
  ) -> str:
180
- """Run inference with the specified model. This is the primary entry point for agents."""
181
  logger.info(f"Running inference with {model_name}, max_tokens={max_new_tokens}, temp={temperature}")
182
  try:
183
- model = get_text_model(model_name=model_name)
184
- logger.info(f"Model {model_name} loaded successfully")
185
- result = model(prompt, max_new_tokens=max_new_tokens, temperature=temperature, **kwargs)
186
- logger.info(f"Inference complete, response length: {len(result)} chars")
187
- return result
188
- except Exception as e:
189
  logger.error(f"Inference failed for {model_name}: {e}", exc_info=True)
190
  raise
191
 
192
 
193
- @_gpu
194
  def run_inference_with_image(
195
  prompt: str,
196
  image: Any, # PIL.Image.Image
@@ -202,25 +267,17 @@ def run_inference_with_image(
202
  """
203
  Run vision-language inference passing a PIL image alongside the text prompt.
204
 
205
- Falls back to text-only inference if the resolved model is not multimodal
206
- (e.g. when medgemma_4b is remapped to a text-only model in the env config).
207
  """
208
  logger.info(f"Running vision inference with {model_name}, max_tokens={max_new_tokens}")
209
  try:
210
- model_path = _get_model_path(model_name)
211
- if not _is_multimodal(model_path):
212
- logger.warning(
213
- f"{model_name} ({model_path}) is not a multimodal model; "
214
- "falling back to text-only inference."
215
- )
216
- return run_inference(prompt, model_name, max_new_tokens, temperature, **kwargs)
217
-
218
- model_fn = _get_local_multimodal(model_name)
219
- result = model_fn(
220
- prompt, max_new_tokens=max_new_tokens, temperature=temperature, image=image, **kwargs
221
- )
222
- logger.info(f"Vision inference complete, response length: {len(result)} chars")
223
- return result
224
- except Exception as e:
225
  logger.error(f"Vision inference failed for {model_name}: {e}", exc_info=True)
226
  raise
 
 
 
169
  return _get_local_causal_lm(model_name)
170
 
171
 
172
+ def _is_zerogpu_error(e: Exception) -> bool:
173
+ """Return True for errors that indicate ZeroGPU failed to allocate / init a GPU."""
174
+ msg = str(e)
175
+ return "No CUDA GPUs are available" in msg or "CUDA" in msg
176
+
177
+
178
+ def _inference_core(
179
+ prompt: str,
180
+ model_name: TextModelName = "medgemma_4b",
181
+ max_new_tokens: int = 512,
182
+ temperature: float = 0.2,
183
+ **kwargs: Any,
184
+ ) -> str:
185
+ """Core text inference — no GPU decorator, runs on whatever device is available."""
186
+ model = get_text_model(model_name=model_name)
187
+ logger.info(f"Model {model_name} ready")
188
+ result = model(prompt, max_new_tokens=max_new_tokens, temperature=temperature, **kwargs)
189
+ logger.info(f"Inference complete, response length: {len(result)} chars")
190
+ return result
191
+
192
+
193
+ def _inference_with_image_core(
194
+ prompt: str,
195
+ image: Any,
196
+ model_name: TextModelName = "medgemma_4b",
197
+ max_new_tokens: int = 1024,
198
+ temperature: float = 0.1,
199
+ **kwargs: Any,
200
+ ) -> str:
201
+ """Core vision inference — no GPU decorator, runs on whatever device is available."""
202
+ model_path = _get_model_path(model_name)
203
+ if not _is_multimodal(model_path):
204
+ logger.warning(
205
+ f"{model_name} ({model_path}) is not a multimodal model; "
206
+ "falling back to text-only inference."
207
+ )
208
+ return _inference_core(prompt, model_name, max_new_tokens, temperature, **kwargs)
209
+ model_fn = _get_local_multimodal(model_name)
210
+ result = model_fn(
211
+ prompt, max_new_tokens=max_new_tokens, temperature=temperature, image=image, **kwargs
212
+ )
213
+ logger.info(f"Vision inference complete, response length: {len(result)} chars")
214
+ return result
215
+
216
+
217
  @_gpu
218
+ def _run_inference_gpu(
219
+ prompt: str,
220
+ model_name: TextModelName = "medgemma_4b",
221
+ max_new_tokens: int = 512,
222
+ temperature: float = 0.2,
223
+ **kwargs: Any,
224
+ ) -> str:
225
+ return _inference_core(prompt, model_name, max_new_tokens, temperature, **kwargs)
226
+
227
+
228
+ @_gpu
229
+ def _run_inference_with_image_gpu(
230
+ prompt: str,
231
+ image: Any,
232
+ model_name: TextModelName = "medgemma_4b",
233
+ max_new_tokens: int = 1024,
234
+ temperature: float = 0.1,
235
+ **kwargs: Any,
236
+ ) -> str:
237
+ return _inference_with_image_core(prompt, image, model_name, max_new_tokens, temperature, **kwargs)
238
+
239
+
240
  def run_inference(
241
  prompt: str,
242
  model_name: TextModelName = "medgemma_4b",
 
244
  temperature: float = 0.2,
245
  **kwargs: Any,
246
  ) -> str:
247
+ """Run inference with the specified model. Tries ZeroGPU first, falls back to CPU."""
248
  logger.info(f"Running inference with {model_name}, max_tokens={max_new_tokens}, temp={temperature}")
249
  try:
250
+ return _run_inference_gpu(prompt, model_name, max_new_tokens, temperature, **kwargs)
251
+ except RuntimeError as e:
252
+ if _is_zerogpu_error(e):
253
+ logger.warning("ZeroGPU unavailable (%s) — retrying on CPU", e)
254
+ return _inference_core(prompt, model_name, max_new_tokens, temperature, **kwargs)
 
255
  logger.error(f"Inference failed for {model_name}: {e}", exc_info=True)
256
  raise
257
 
258
 
 
259
  def run_inference_with_image(
260
  prompt: str,
261
  image: Any, # PIL.Image.Image
 
267
  """
268
  Run vision-language inference passing a PIL image alongside the text prompt.
269
 
270
+ Falls back to text-only inference if the resolved model is not multimodal.
271
+ Tries ZeroGPU first, falls back to CPU on ZeroGPU init failure.
272
  """
273
  logger.info(f"Running vision inference with {model_name}, max_tokens={max_new_tokens}")
274
  try:
275
+ return _run_inference_with_image_gpu(prompt, image, model_name, max_new_tokens, temperature, **kwargs)
276
+ except RuntimeError as e:
277
+ if _is_zerogpu_error(e):
278
+ logger.warning("ZeroGPU unavailable (%s) retrying vision inference on CPU", e)
279
+ return _inference_with_image_core(prompt, image, model_name, max_new_tokens, temperature, **kwargs)
 
 
 
 
 
 
 
 
 
 
280
  logger.error(f"Vision inference failed for {model_name}: {e}", exc_info=True)
281
  raise
282
+
283
+