NeerajCodz commited on
Commit
71cf583
·
1 Parent(s): fc5088a

fix: handle provider/model format in router (strip prefix before calling provider)

Browse files
backend/app/models/__pycache__/router.cpython-314.pyc CHANGED
Binary files a/backend/app/models/__pycache__/router.cpython-314.pyc and b/backend/app/models/__pycache__/router.cpython-314.pyc differ
 
backend/app/models/router.py CHANGED
@@ -311,19 +311,42 @@ class SmartModelRouter:
311
  return models
312
 
313
  def get_provider_for_model(self, model: str) -> BaseProvider | None:
314
- """Get the provider for a specific model."""
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
315
  for provider in self.providers.values():
316
- try:
317
- if provider.get_model_info(model):
318
- return provider
319
- except Exception:
320
- # Model not found in this provider, continue to next
321
- pass
322
-
323
- # Check aliases for Anthropic and Google
324
- if hasattr(provider, "MODEL_ALIASES"):
325
- if model in provider.MODEL_ALIASES: # type: ignore
326
- return provider
327
 
328
  return None
329
 
@@ -522,8 +545,10 @@ class SmartModelRouter:
522
 
523
  for i, (model_id, provider) in enumerate(models_to_try):
524
  try:
525
- logger.info(f"Attempting completion with {provider.PROVIDER_NAME}/{model_id}")
526
- response = await provider.complete(messages, model_id, **kwargs)
 
 
527
 
528
  # Track cost
529
  self.cost_tracker.track(response)
 
311
  return models
312
 
313
  def get_provider_for_model(self, model: str) -> BaseProvider | None:
314
+ """Get the provider for a specific model.
315
+
316
+ Supports both formats:
317
+ - "gemini-1.5-flash" (bare model name)
318
+ - "google/gemini-1.5-flash" (provider/model format)
319
+ """
320
+ # Strip provider prefix if present (e.g., "google/gemini-1.5-flash" -> "gemini-1.5-flash")
321
+ model_name = model
322
+ if "/" in model:
323
+ provider_prefix, model_name = model.split("/", 1)
324
+ # Try to match provider directly first
325
+ if provider_prefix in self.providers:
326
+ provider = self.providers[provider_prefix]
327
+ try:
328
+ if provider.get_model_info(model_name):
329
+ return provider
330
+ except Exception:
331
+ pass
332
+ # Check aliases
333
+ if hasattr(provider, "MODEL_ALIASES"):
334
+ if model_name in provider.MODEL_ALIASES: # type: ignore
335
+ return provider
336
+
337
+ # Fallback: try all providers with both original and stripped names
338
  for provider in self.providers.values():
339
+ for name in [model, model_name]:
340
+ try:
341
+ if provider.get_model_info(name):
342
+ return provider
343
+ except Exception:
344
+ pass
345
+
346
+ # Check aliases
347
+ if hasattr(provider, "MODEL_ALIASES"):
348
+ if name in provider.MODEL_ALIASES: # type: ignore
349
+ return provider
350
 
351
  return None
352
 
 
545
 
546
  for i, (model_id, provider) in enumerate(models_to_try):
547
  try:
548
+ # Strip provider prefix if present (e.g., "google/gemini-1.5-flash" -> "gemini-1.5-flash")
549
+ model_name = model_id.split("/", 1)[1] if "/" in model_id else model_id
550
+ logger.info(f"Attempting completion with {provider.PROVIDER_NAME}/{model_name}")
551
+ response = await provider.complete(messages, model_name, **kwargs)
552
 
553
  # Track cost
554
  self.cost_tracker.track(response)