claude-code-proxy / api /model_router.py
Yash030's picture
Implement image support in proxy with vision-aware routing
574e4e7
"""Model routing for Claude-compatible requests."""
from __future__ import annotations
from dataclasses import dataclass
from loguru import logger
from config.provider_ids import SUPPORTED_PROVIDER_IDS
from config.settings import Settings
from core.model_capabilities import find_best_model_for_task
from core.session_tracker import SessionTracker
from core.task_detector import TaskDetector
from providers.rate_limit import GlobalRateLimiter
from .gateway_model_ids import decode_gateway_model_id
from .models.anthropic import MessagesRequest, TokenCountRequest
# Default NIM models to include in auto routing (in order of preference)
DEFAULT_NIM_AUTO_MODELS = [
"nvidia_nim/qwen/qwen3-coder-480b-a35b-instruct",
"nvidia_nim/z-ai/glm4.7",
"nvidia_nim/stepfun-ai/step-3.5-flash",
"nvidia_nim/mistralai/mistral-large-3-675b-instruct-2512",
"nvidia_nim/abacusai/dracarys-llama-3.1-70b-instruct",
"nvidia_nim/bytedance/seed-oss-36b-instruct",
"nvidia_nim/mistralai/mistral-nemotron",
]
@dataclass(frozen=True, slots=True)
class ResolvedModel:
original_model: str
provider_id: str
provider_model: str
provider_model_ref: str
thinking_enabled: bool
@dataclass(frozen=True, slots=True)
class RoutedMessagesRequest:
request: MessagesRequest
resolved: ResolvedModel
@dataclass(frozen=True, slots=True)
class RoutedTokenCountRequest:
request: TokenCountRequest
resolved: ResolvedModel
class ModelRouter:
"""Resolve incoming Claude model names to configured provider/model pairs."""
def __init__(self, settings: Settings):
self._settings = settings
def _is_auto(self, model_name: str) -> bool:
"""Return whether the model name refers to the virtual 'auto' model."""
name_lower = model_name.lower()
return name_lower == "auto" or name_lower == "anthropic/auto"
def _normalize_candidate_ref(self, raw_ref: str) -> str | None:
"""Normalize auto candidate refs to ``provider/model`` when possible."""
candidate = (raw_ref or "").strip()
if not candidate:
return None
provider_id, separator, remainder = candidate.partition("/")
if separator and provider_id in SUPPORTED_PROVIDER_IDS and remainder:
return f"{provider_id}/{remainder}"
# Treat bare model ids and vendor/model ids as NVIDIA NIM models.
return f"nvidia_nim/{candidate}"
def resolve(self, claude_model_name: str) -> ResolvedModel:
# Special virtual model 'auto' maps to the configured default MODEL and
# enables provider-side fallbacks. Resolve it to the configured model
# while preserving the original requested name.
if self._is_auto(claude_model_name):
# If the user configured an explicit AUTO_MODEL_ORDER, try each
# provider/model pair in order and pick the first provider that is
# plausibly configured. Fall back to the single configured MODEL.
order_csv = (self._settings.auto_model_order or "").strip()
if order_csv:
for cand in [c.strip() for c in order_csv.split(",") if c.strip()]:
if "/" not in cand:
# assume vendor-prefixed entries; skip malformed
continue
provider_id = Settings.parse_provider_type(cand)
provider_model = Settings.parse_model_name(cand)
if self._settings.provider_is_configured(provider_id):
thinking_enabled = self._settings.resolve_thinking(
claude_model_name
)
return ResolvedModel(
original_model=claude_model_name,
provider_id=provider_id,
provider_model=provider_model,
provider_model_ref=cand,
thinking_enabled=thinking_enabled,
)
# No explicit order matched or none configured — fall back to default MODEL
provider_model_ref = self._settings.model
provider_id = Settings.parse_provider_type(provider_model_ref)
provider_model = Settings.parse_model_name(provider_model_ref)
thinking_enabled = self._settings.resolve_thinking(claude_model_name)
return ResolvedModel(
original_model=claude_model_name,
provider_id=provider_id,
provider_model=provider_model,
provider_model_ref=provider_model_ref,
thinking_enabled=thinking_enabled,
)
(
direct_provider_id,
direct_provider_model,
force_thinking_enabled,
) = self._direct_provider_model(claude_model_name)
if direct_provider_id is not None and direct_provider_model is not None:
thinking_enabled = (
force_thinking_enabled
if force_thinking_enabled is not None
else self._settings.resolve_thinking(direct_provider_model)
)
logger.debug(
"MODEL DIRECT: '{}' -> provider='{}' model='{}' thinking={}",
claude_model_name,
direct_provider_id,
direct_provider_model,
thinking_enabled,
)
return ResolvedModel(
original_model=claude_model_name,
provider_id=direct_provider_id,
provider_model=direct_provider_model,
provider_model_ref=claude_model_name,
thinking_enabled=thinking_enabled,
)
provider_model_ref = self._settings.resolve_model(claude_model_name)
thinking_enabled = self._settings.resolve_thinking(claude_model_name)
provider_id = Settings.parse_provider_type(provider_model_ref)
provider_model = Settings.parse_model_name(provider_model_ref)
if provider_model != claude_model_name:
logger.debug(
"MODEL MAPPING: '{}' -> '{}'", claude_model_name, provider_model
)
return ResolvedModel(
original_model=claude_model_name,
provider_id=provider_id,
provider_model=provider_model,
provider_model_ref=provider_model_ref,
thinking_enabled=thinking_enabled,
)
def resolve_candidates(self, claude_model_name: str) -> list[ResolvedModel]:
"""Resolve a model name to a prioritized list of candidates.
Used by the 'auto' routing logic to implement provider-side failover.
Considers session load for fair resource sharing across multiple clients.
Priority order:
1. AUTO_MODEL_ORDER (if configured)
2. MODEL (primary)
3. NVIDIA NIM fallback models (if configured, or DEFAULT_NIM_AUTO_MODELS)
4. MODEL_OPUS, MODEL_SONNET, MODEL_HAIKU
"""
if not self._is_auto(claude_model_name):
return [self.resolve(claude_model_name)]
healthy_candidates: list[ResolvedModel] = []
blocked_candidates: list[ResolvedModel] = []
seen: set[str] = set()
session_tracker = SessionTracker.get_instance()
def add_candidate(ref: str | None, source: str) -> None:
normalized_ref = self._normalize_candidate_ref(ref or "")
if normalized_ref is None or normalized_ref in seen:
return
provider_id = Settings.parse_provider_type(normalized_ref)
provider_model = Settings.parse_model_name(normalized_ref)
if self._settings.provider_is_configured(provider_id):
seen.add(normalized_ref)
resolved = ResolvedModel(
original_model=claude_model_name,
provider_id=provider_id,
provider_model=provider_model,
provider_model_ref=normalized_ref,
thinking_enabled=self._settings.resolve_thinking(claude_model_name),
)
limiter = GlobalRateLimiter.get_scoped_instance(provider_id)
is_blocked = limiter.is_blocked()
# For Zen provider, never consider it blocked (no rate limits)
if provider_id == "zen":
is_blocked = False
# Check model health (recent failures)
is_healthy = limiter.is_healthy(normalized_ref)
if is_blocked or not is_healthy:
reason = "BLOCKED" if is_blocked else "UNHEALTHY"
logger.debug(
"Routing: candidate '{}' (from {}) is {} (health={})",
normalized_ref,
source,
reason,
is_healthy,
)
blocked_candidates.append(resolved)
else:
# Smart ordering: Zen (no rate limits) gets priority, then by load
logger.debug(
"Routing: added candidate '{}' (from {})",
normalized_ref,
source,
)
healthy_candidates.append(resolved)
else:
logger.debug(
"Routing: candidate '{}' (from {}) is NOT CONFIGURED",
normalized_ref,
source,
)
# 1. AUTO_MODEL_ORDER (user-configured priority)
order_csv = (self._settings.auto_model_order or "").strip()
if order_csv:
for cand in [c.strip() for c in order_csv.split(",") if c.strip()]:
add_candidate(cand, "AUTO_MODEL_PRIORITY")
# 2. Primary MODEL
add_candidate(self._settings.model, "MODEL")
# 3. NVIDIA Fallbacks - use configured or defaults
nim_csv = (self._settings.nvidia_nim_fallback_models or "").strip()
if nim_csv:
for cand in [c.strip() for c in nim_csv.split(",") if c.strip()]:
add_candidate(cand, "NVIDIA_NIM_FALLBACK_MODELS")
else:
# Use default NIM models when no explicit fallback configured
for cand in DEFAULT_NIM_AUTO_MODELS:
add_candidate(cand, "DEFAULT_NIM_AUTO_MODELS")
# 4. Model-specific overrides
add_candidate(self._settings.model_opus, "MODEL_OPUS")
add_candidate(self._settings.model_sonnet, "MODEL_SONNET")
add_candidate(self._settings.model_haiku, "MODEL_HAIKU")
# Smart ordering: Zen goes first (no rate limits), then sort by load
def provider_priority(c: ResolvedModel) -> tuple:
# Priority: zen > others, then by active request count
is_zen = 0 if c.provider_id == "zen" else 1
active = session_tracker._provider_active.get(c.provider_id, 0)
return (is_zen, active)
healthy_candidates.sort(key=provider_priority)
all_candidates = healthy_candidates + blocked_candidates
logger.info(
"Routing: resolved '{}' to {} candidates: {}",
claude_model_name,
len(all_candidates),
", ".join(c.provider_model_ref for c in all_candidates),
)
return all_candidates
def _direct_provider_model(
self, model_name: str
) -> tuple[str | None, str | None, bool | None]:
decoded = decode_gateway_model_id(model_name)
if decoded is not None:
if decoded.provider_id not in SUPPORTED_PROVIDER_IDS:
return None, None, None
return (
decoded.provider_id,
decoded.provider_model,
decoded.force_thinking_enabled,
)
provider_id, separator, provider_model = model_name.partition("/")
if not separator:
return None, None, None
if provider_id not in SUPPORTED_PROVIDER_IDS:
return None, None, None
if not provider_model:
return None, None, None
return provider_id, provider_model, None
def resolve_messages_request(
self, request: MessagesRequest
) -> RoutedMessagesRequest:
"""Return an internal routed request context."""
resolved = self.resolve(request.model)
routed = request.model_copy(deep=True)
routed.model = resolved.provider_model
return RoutedMessagesRequest(request=routed, resolved=resolved)
def resolve_token_count_request(
self, request: TokenCountRequest
) -> RoutedTokenCountRequest:
"""Return an internal token-count request context."""
resolved = self.resolve(request.model)
routed = request.model_copy(
update={"model": resolved.provider_model}, deep=True
)
return RoutedTokenCountRequest(request=routed, resolved=resolved)
def resolve_with_task_awareness(
self,
claude_model_name: str,
messages: list,
) -> ResolvedModel:
"""Resolve model with task-based capability matching.
For 'auto' model, detects task requirements and routes to best-capable model.
"""
if not self._is_auto(claude_model_name):
return self.resolve(claude_model_name)
# Detect what capabilities are needed
detector = TaskDetector()
requirements = detector.detect_requirements(messages)
logger.info(
"Task-aware routing: detected requirements={} confidence={:.2f}",
requirements.required_capabilities,
requirements.confidence,
)
# Get available candidates
candidates = self.resolve_candidates(claude_model_name)
if not candidates:
# Fallback to default
return self.resolve(claude_model_name)
# If confidence is low or only general text needed, use load-based selection
if requirements.confidence < 0.7 or (
not requirements.requires_vision
and not requirements.requires_coding
and not requirements.requires_reasoning
):
logger.debug(
"Task-aware routing: low confidence, using load-based selection"
)
return candidates[0]
# Find best model matching required capabilities
required_caps = set()
if requirements.requires_coding:
required_caps.add("coding")
if requirements.requires_reasoning:
required_caps.add("reasoning")
if requirements.requires_vision:
required_caps.add("vision")
if required_caps:
model_refs = [c.provider_model_ref for c in candidates]
best = find_best_model_for_task(required_caps, model_refs)
if best:
# Find the matching candidate
for cand in candidates:
if cand.provider_model_ref == best.model_ref:
logger.info(
"Task-aware routing: selected {} for capabilities={}",
best.model_ref,
required_caps,
)
return cand
# Default to first candidate (load-balanced)
return candidates[0]
def get_routing_hint(self, messages: list) -> str:
"""Get a hint about what kind of model would be best."""
detector = TaskDetector()
requirements = detector.detect_requirements(messages)
return detector.get_priority_hint(requirements)