Spaces:
Sleeping
Sleeping
Commit ·
71cf583
1
Parent(s): fc5088a
fix: handle provider/model format in router (strip prefix before calling provider)
Browse files
backend/app/models/__pycache__/router.cpython-314.pyc
CHANGED
|
Binary files a/backend/app/models/__pycache__/router.cpython-314.pyc and b/backend/app/models/__pycache__/router.cpython-314.pyc differ
|
|
|
backend/app/models/router.py
CHANGED
|
@@ -311,19 +311,42 @@ class SmartModelRouter:
|
|
| 311 |
return models
|
| 312 |
|
| 313 |
def get_provider_for_model(self, model: str) -> BaseProvider | None:
|
| 314 |
-
"""Get the provider for a specific model.
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 315 |
for provider in self.providers.values():
|
| 316 |
-
|
| 317 |
-
|
| 318 |
-
|
| 319 |
-
|
| 320 |
-
|
| 321 |
-
|
| 322 |
-
|
| 323 |
-
|
| 324 |
-
|
| 325 |
-
|
| 326 |
-
|
| 327 |
|
| 328 |
return None
|
| 329 |
|
|
@@ -522,8 +545,10 @@ class SmartModelRouter:
|
|
| 522 |
|
| 523 |
for i, (model_id, provider) in enumerate(models_to_try):
|
| 524 |
try:
|
| 525 |
-
|
| 526 |
-
|
|
|
|
|
|
|
| 527 |
|
| 528 |
# Track cost
|
| 529 |
self.cost_tracker.track(response)
|
|
|
|
| 311 |
return models
|
| 312 |
|
| 313 |
def get_provider_for_model(self, model: str) -> BaseProvider | None:
|
| 314 |
+
"""Get the provider for a specific model.
|
| 315 |
+
|
| 316 |
+
Supports both formats:
|
| 317 |
+
- "gemini-1.5-flash" (bare model name)
|
| 318 |
+
- "google/gemini-1.5-flash" (provider/model format)
|
| 319 |
+
"""
|
| 320 |
+
# Strip provider prefix if present (e.g., "google/gemini-1.5-flash" -> "gemini-1.5-flash")
|
| 321 |
+
model_name = model
|
| 322 |
+
if "/" in model:
|
| 323 |
+
provider_prefix, model_name = model.split("/", 1)
|
| 324 |
+
# Try to match provider directly first
|
| 325 |
+
if provider_prefix in self.providers:
|
| 326 |
+
provider = self.providers[provider_prefix]
|
| 327 |
+
try:
|
| 328 |
+
if provider.get_model_info(model_name):
|
| 329 |
+
return provider
|
| 330 |
+
except Exception:
|
| 331 |
+
pass
|
| 332 |
+
# Check aliases
|
| 333 |
+
if hasattr(provider, "MODEL_ALIASES"):
|
| 334 |
+
if model_name in provider.MODEL_ALIASES: # type: ignore
|
| 335 |
+
return provider
|
| 336 |
+
|
| 337 |
+
# Fallback: try all providers with both original and stripped names
|
| 338 |
for provider in self.providers.values():
|
| 339 |
+
for name in [model, model_name]:
|
| 340 |
+
try:
|
| 341 |
+
if provider.get_model_info(name):
|
| 342 |
+
return provider
|
| 343 |
+
except Exception:
|
| 344 |
+
pass
|
| 345 |
+
|
| 346 |
+
# Check aliases
|
| 347 |
+
if hasattr(provider, "MODEL_ALIASES"):
|
| 348 |
+
if name in provider.MODEL_ALIASES: # type: ignore
|
| 349 |
+
return provider
|
| 350 |
|
| 351 |
return None
|
| 352 |
|
|
|
|
| 545 |
|
| 546 |
for i, (model_id, provider) in enumerate(models_to_try):
|
| 547 |
try:
|
| 548 |
+
# Strip provider prefix if present (e.g., "google/gemini-1.5-flash" -> "gemini-1.5-flash")
|
| 549 |
+
model_name = model_id.split("/", 1)[1] if "/" in model_id else model_id
|
| 550 |
+
logger.info(f"Attempting completion with {provider.PROVIDER_NAME}/{model_name}")
|
| 551 |
+
response = await provider.complete(messages, model_name, **kwargs)
|
| 552 |
|
| 553 |
# Track cost
|
| 554 |
self.cost_tracker.track(response)
|