Spaces:
Running
Running
| """Model routing for Claude-compatible requests.""" | |
| from __future__ import annotations | |
| from dataclasses import dataclass | |
| from loguru import logger | |
| from config.provider_ids import SUPPORTED_PROVIDER_IDS | |
| from config.settings import Settings | |
| from core.model_capabilities import find_best_model_for_task | |
| from core.session_tracker import SessionTracker | |
| from core.task_detector import TaskDetector | |
| from providers.rate_limit import GlobalRateLimiter | |
| from .gateway_model_ids import decode_gateway_model_id | |
| from .models.anthropic import MessagesRequest, TokenCountRequest | |
| # Default NIM models to include in auto routing (in order of preference) | |
| DEFAULT_NIM_AUTO_MODELS = [ | |
| "nvidia_nim/qwen/qwen3-coder-480b-a35b-instruct", | |
| "nvidia_nim/z-ai/glm4.7", | |
| "nvidia_nim/stepfun-ai/step-3.5-flash", | |
| "nvidia_nim/mistralai/mistral-large-3-675b-instruct-2512", | |
| "nvidia_nim/abacusai/dracarys-llama-3.1-70b-instruct", | |
| "nvidia_nim/bytedance/seed-oss-36b-instruct", | |
| "nvidia_nim/mistralai/mistral-nemotron", | |
| ] | |
| class ResolvedModel: | |
| original_model: str | |
| provider_id: str | |
| provider_model: str | |
| provider_model_ref: str | |
| thinking_enabled: bool | |
| class RoutedMessagesRequest: | |
| request: MessagesRequest | |
| resolved: ResolvedModel | |
| class RoutedTokenCountRequest: | |
| request: TokenCountRequest | |
| resolved: ResolvedModel | |
| class ModelRouter: | |
| """Resolve incoming Claude model names to configured provider/model pairs.""" | |
| def __init__(self, settings: Settings): | |
| self._settings = settings | |
| def _is_auto(self, model_name: str) -> bool: | |
| """Return whether the model name refers to the virtual 'auto' model.""" | |
| name_lower = model_name.lower() | |
| return name_lower == "auto" or name_lower == "anthropic/auto" | |
| def _normalize_candidate_ref(self, raw_ref: str) -> str | None: | |
| """Normalize auto candidate refs to ``provider/model`` when possible.""" | |
| candidate = (raw_ref or "").strip() | |
| if not candidate: | |
| return None | |
| provider_id, separator, remainder = candidate.partition("/") | |
| if separator and provider_id in SUPPORTED_PROVIDER_IDS and remainder: | |
| return f"{provider_id}/{remainder}" | |
| # Treat bare model ids and vendor/model ids as NVIDIA NIM models. | |
| return f"nvidia_nim/{candidate}" | |
| def resolve(self, claude_model_name: str) -> ResolvedModel: | |
| # Special virtual model 'auto' maps to the configured default MODEL and | |
| # enables provider-side fallbacks. Resolve it to the configured model | |
| # while preserving the original requested name. | |
| if self._is_auto(claude_model_name): | |
| # If the user configured an explicit AUTO_MODEL_ORDER, try each | |
| # provider/model pair in order and pick the first provider that is | |
| # plausibly configured. Fall back to the single configured MODEL. | |
| order_csv = (self._settings.auto_model_order or "").strip() | |
| if order_csv: | |
| for cand in [c.strip() for c in order_csv.split(",") if c.strip()]: | |
| if "/" not in cand: | |
| # assume vendor-prefixed entries; skip malformed | |
| continue | |
| provider_id = Settings.parse_provider_type(cand) | |
| provider_model = Settings.parse_model_name(cand) | |
| if self._settings.provider_is_configured(provider_id): | |
| thinking_enabled = self._settings.resolve_thinking( | |
| claude_model_name | |
| ) | |
| return ResolvedModel( | |
| original_model=claude_model_name, | |
| provider_id=provider_id, | |
| provider_model=provider_model, | |
| provider_model_ref=cand, | |
| thinking_enabled=thinking_enabled, | |
| ) | |
| # No explicit order matched or none configured — fall back to default MODEL | |
| provider_model_ref = self._settings.model | |
| provider_id = Settings.parse_provider_type(provider_model_ref) | |
| provider_model = Settings.parse_model_name(provider_model_ref) | |
| thinking_enabled = self._settings.resolve_thinking(claude_model_name) | |
| return ResolvedModel( | |
| original_model=claude_model_name, | |
| provider_id=provider_id, | |
| provider_model=provider_model, | |
| provider_model_ref=provider_model_ref, | |
| thinking_enabled=thinking_enabled, | |
| ) | |
| ( | |
| direct_provider_id, | |
| direct_provider_model, | |
| force_thinking_enabled, | |
| ) = self._direct_provider_model(claude_model_name) | |
| if direct_provider_id is not None and direct_provider_model is not None: | |
| thinking_enabled = ( | |
| force_thinking_enabled | |
| if force_thinking_enabled is not None | |
| else self._settings.resolve_thinking(direct_provider_model) | |
| ) | |
| logger.debug( | |
| "MODEL DIRECT: '{}' -> provider='{}' model='{}' thinking={}", | |
| claude_model_name, | |
| direct_provider_id, | |
| direct_provider_model, | |
| thinking_enabled, | |
| ) | |
| return ResolvedModel( | |
| original_model=claude_model_name, | |
| provider_id=direct_provider_id, | |
| provider_model=direct_provider_model, | |
| provider_model_ref=claude_model_name, | |
| thinking_enabled=thinking_enabled, | |
| ) | |
| provider_model_ref = self._settings.resolve_model(claude_model_name) | |
| thinking_enabled = self._settings.resolve_thinking(claude_model_name) | |
| provider_id = Settings.parse_provider_type(provider_model_ref) | |
| provider_model = Settings.parse_model_name(provider_model_ref) | |
| if provider_model != claude_model_name: | |
| logger.debug( | |
| "MODEL MAPPING: '{}' -> '{}'", claude_model_name, provider_model | |
| ) | |
| return ResolvedModel( | |
| original_model=claude_model_name, | |
| provider_id=provider_id, | |
| provider_model=provider_model, | |
| provider_model_ref=provider_model_ref, | |
| thinking_enabled=thinking_enabled, | |
| ) | |
| def resolve_candidates(self, claude_model_name: str) -> list[ResolvedModel]: | |
| """Resolve a model name to a prioritized list of candidates. | |
| Used by the 'auto' routing logic to implement provider-side failover. | |
| Considers session load for fair resource sharing across multiple clients. | |
| Priority order: | |
| 1. AUTO_MODEL_ORDER (if configured) | |
| 2. MODEL (primary) | |
| 3. NVIDIA NIM fallback models (if configured, or DEFAULT_NIM_AUTO_MODELS) | |
| 4. MODEL_OPUS, MODEL_SONNET, MODEL_HAIKU | |
| """ | |
| if not self._is_auto(claude_model_name): | |
| return [self.resolve(claude_model_name)] | |
| healthy_candidates: list[ResolvedModel] = [] | |
| blocked_candidates: list[ResolvedModel] = [] | |
| seen: set[str] = set() | |
| session_tracker = SessionTracker.get_instance() | |
| def add_candidate(ref: str | None, source: str) -> None: | |
| normalized_ref = self._normalize_candidate_ref(ref or "") | |
| if normalized_ref is None or normalized_ref in seen: | |
| return | |
| provider_id = Settings.parse_provider_type(normalized_ref) | |
| provider_model = Settings.parse_model_name(normalized_ref) | |
| if self._settings.provider_is_configured(provider_id): | |
| seen.add(normalized_ref) | |
| resolved = ResolvedModel( | |
| original_model=claude_model_name, | |
| provider_id=provider_id, | |
| provider_model=provider_model, | |
| provider_model_ref=normalized_ref, | |
| thinking_enabled=self._settings.resolve_thinking(claude_model_name), | |
| ) | |
| limiter = GlobalRateLimiter.get_scoped_instance(provider_id) | |
| is_blocked = limiter.is_blocked() | |
| # For Zen provider, never consider it blocked (no rate limits) | |
| if provider_id == "zen": | |
| is_blocked = False | |
| # Check model health (recent failures) | |
| is_healthy = limiter.is_healthy(normalized_ref) | |
| if is_blocked or not is_healthy: | |
| reason = "BLOCKED" if is_blocked else "UNHEALTHY" | |
| logger.debug( | |
| "Routing: candidate '{}' (from {}) is {} (health={})", | |
| normalized_ref, | |
| source, | |
| reason, | |
| is_healthy, | |
| ) | |
| blocked_candidates.append(resolved) | |
| else: | |
| # Smart ordering: Zen (no rate limits) gets priority, then by load | |
| logger.debug( | |
| "Routing: added candidate '{}' (from {})", | |
| normalized_ref, | |
| source, | |
| ) | |
| healthy_candidates.append(resolved) | |
| else: | |
| logger.debug( | |
| "Routing: candidate '{}' (from {}) is NOT CONFIGURED", | |
| normalized_ref, | |
| source, | |
| ) | |
| # 1. AUTO_MODEL_ORDER (user-configured priority) | |
| order_csv = (self._settings.auto_model_order or "").strip() | |
| if order_csv: | |
| for cand in [c.strip() for c in order_csv.split(",") if c.strip()]: | |
| add_candidate(cand, "AUTO_MODEL_PRIORITY") | |
| # 2. Primary MODEL | |
| add_candidate(self._settings.model, "MODEL") | |
| # 3. NVIDIA Fallbacks - use configured or defaults | |
| nim_csv = (self._settings.nvidia_nim_fallback_models or "").strip() | |
| if nim_csv: | |
| for cand in [c.strip() for c in nim_csv.split(",") if c.strip()]: | |
| add_candidate(cand, "NVIDIA_NIM_FALLBACK_MODELS") | |
| else: | |
| # Use default NIM models when no explicit fallback configured | |
| for cand in DEFAULT_NIM_AUTO_MODELS: | |
| add_candidate(cand, "DEFAULT_NIM_AUTO_MODELS") | |
| # 4. Model-specific overrides | |
| add_candidate(self._settings.model_opus, "MODEL_OPUS") | |
| add_candidate(self._settings.model_sonnet, "MODEL_SONNET") | |
| add_candidate(self._settings.model_haiku, "MODEL_HAIKU") | |
| # Smart ordering: Zen goes first (no rate limits), then sort by load | |
| def provider_priority(c: ResolvedModel) -> tuple: | |
| # Priority: zen > others, then by active request count | |
| is_zen = 0 if c.provider_id == "zen" else 1 | |
| active = session_tracker._provider_active.get(c.provider_id, 0) | |
| return (is_zen, active) | |
| healthy_candidates.sort(key=provider_priority) | |
| all_candidates = healthy_candidates + blocked_candidates | |
| logger.info( | |
| "Routing: resolved '{}' to {} candidates: {}", | |
| claude_model_name, | |
| len(all_candidates), | |
| ", ".join(c.provider_model_ref for c in all_candidates), | |
| ) | |
| return all_candidates | |
| def _direct_provider_model( | |
| self, model_name: str | |
| ) -> tuple[str | None, str | None, bool | None]: | |
| decoded = decode_gateway_model_id(model_name) | |
| if decoded is not None: | |
| if decoded.provider_id not in SUPPORTED_PROVIDER_IDS: | |
| return None, None, None | |
| return ( | |
| decoded.provider_id, | |
| decoded.provider_model, | |
| decoded.force_thinking_enabled, | |
| ) | |
| provider_id, separator, provider_model = model_name.partition("/") | |
| if not separator: | |
| return None, None, None | |
| if provider_id not in SUPPORTED_PROVIDER_IDS: | |
| return None, None, None | |
| if not provider_model: | |
| return None, None, None | |
| return provider_id, provider_model, None | |
| def resolve_messages_request( | |
| self, request: MessagesRequest | |
| ) -> RoutedMessagesRequest: | |
| """Return an internal routed request context.""" | |
| resolved = self.resolve(request.model) | |
| routed = request.model_copy(deep=True) | |
| routed.model = resolved.provider_model | |
| return RoutedMessagesRequest(request=routed, resolved=resolved) | |
| def resolve_token_count_request( | |
| self, request: TokenCountRequest | |
| ) -> RoutedTokenCountRequest: | |
| """Return an internal token-count request context.""" | |
| resolved = self.resolve(request.model) | |
| routed = request.model_copy( | |
| update={"model": resolved.provider_model}, deep=True | |
| ) | |
| return RoutedTokenCountRequest(request=routed, resolved=resolved) | |
| def resolve_with_task_awareness( | |
| self, | |
| claude_model_name: str, | |
| messages: list, | |
| ) -> ResolvedModel: | |
| """Resolve model with task-based capability matching. | |
| For 'auto' model, detects task requirements and routes to best-capable model. | |
| """ | |
| if not self._is_auto(claude_model_name): | |
| return self.resolve(claude_model_name) | |
| # Detect what capabilities are needed | |
| detector = TaskDetector() | |
| requirements = detector.detect_requirements(messages) | |
| logger.info( | |
| "Task-aware routing: detected requirements={} confidence={:.2f}", | |
| requirements.required_capabilities, | |
| requirements.confidence, | |
| ) | |
| # Get available candidates | |
| candidates = self.resolve_candidates(claude_model_name) | |
| if not candidates: | |
| # Fallback to default | |
| return self.resolve(claude_model_name) | |
| # If confidence is low or only general text needed, use load-based selection | |
| if requirements.confidence < 0.7 or ( | |
| not requirements.requires_vision | |
| and not requirements.requires_coding | |
| and not requirements.requires_reasoning | |
| ): | |
| logger.debug( | |
| "Task-aware routing: low confidence, using load-based selection" | |
| ) | |
| return candidates[0] | |
| # Find best model matching required capabilities | |
| required_caps = set() | |
| if requirements.requires_coding: | |
| required_caps.add("coding") | |
| if requirements.requires_reasoning: | |
| required_caps.add("reasoning") | |
| if requirements.requires_vision: | |
| required_caps.add("vision") | |
| if required_caps: | |
| model_refs = [c.provider_model_ref for c in candidates] | |
| best = find_best_model_for_task(required_caps, model_refs) | |
| if best: | |
| # Find the matching candidate | |
| for cand in candidates: | |
| if cand.provider_model_ref == best.model_ref: | |
| logger.info( | |
| "Task-aware routing: selected {} for capabilities={}", | |
| best.model_ref, | |
| required_caps, | |
| ) | |
| return cand | |
| # Default to first candidate (load-balanced) | |
| return candidates[0] | |
| def get_routing_hint(self, messages: list) -> str: | |
| """Get a hint about what kind of model would be best.""" | |
| detector = TaskDetector() | |
| requirements = detector.detect_requirements(messages) | |
| return detector.get_priority_hint(requirements) | |