"""Model capability registry for intelligent routing.""" from __future__ import annotations from dataclasses import dataclass from typing import TYPE_CHECKING if TYPE_CHECKING: from collections.abc import Sequence @dataclass(frozen=True, slots=True) class ModelCapabilities: """Capabilities of a specific model for routing decisions.""" provider_id: str model_id: str model_ref: str # provider/model format vision: bool = False # Can process images supports_base64_images: bool = False # Accepts data: URLs with base64 supports_remote_images: bool = False # Accepts http/https URLs supports_pdfs: bool = False # Can process PDF documents max_images: int = 0 # Max images per request (0 = unlimited) max_image_size_mb: float = 10.0 # Max size per image in MB coding: bool = False # Good at code generation/analysis reasoning: bool = False # Strong reasoning/thinking general_text: bool = True # General text generation multimodal_input: bool = False # Can handle multiple input types multimodal_output: bool = False # Can produce multiple output types max_tokens: int = 4096 speed: str = "medium" # "fast", "medium", "slow" priority: int = 100 # Higher = preferred for its capabilities # Registry of all available models and their capabilities # This can be extended with actual model discovery later MODEL_CAPABILITIES: dict[str, ModelCapabilities] = { # Zen/minimax models "zen/minimax-m2.5-free": ModelCapabilities( provider_id="zen", model_id="minimax-m2.5-free", model_ref="zen/minimax-m2.5-free", coding=True, reasoning=True, general_text=True, max_tokens=32000, speed="fast", priority=80, ), # NVIDIA NIM models "nvidia_nim/stepfun-ai/step-3.5-flash": ModelCapabilities( provider_id="nvidia_nim", model_id="step-3.5-flash", model_ref="nvidia_nim/stepfun-ai/step-3.5-flash", coding=True, reasoning=True, general_text=True, max_tokens=32000, speed="fast", priority=70, ), "nvidia_nim/qwen/qwen3-coder-480b-a35b-instruct": ModelCapabilities( provider_id="nvidia_nim", model_id="qwen3-coder-480b-a35b-instruct", model_ref="nvidia_nim/qwen/qwen3-coder-480b-a35b-instruct", coding=True, reasoning=True, general_text=True, max_tokens=32000, speed="slow", priority=90, ), "nvidia_nim/mistralai/mistral-large-3-675b-instruct-2512": ModelCapabilities( provider_id="nvidia_nim", model_id="mistral-large-3-675b-instruct-2512", model_ref="nvidia_nim/mistralai/mistral-large-3-675b-instruct-2512", vision=True, supports_base64_images=True, supports_remote_images=False, max_images=16, max_image_size_mb=10.0, multimodal_input=True, coding=True, reasoning=True, general_text=True, max_tokens=32000, speed="slow", priority=90, ), "nvidia_nim/abacusai/dracarys-llama-3.1-70b-instruct": ModelCapabilities( provider_id="nvidia_nim", model_id="dracarys-llama-3.1-70b-instruct", model_ref="nvidia_nim/abacusai/dracarys-llama-3.1-70b-instruct", coding=True, reasoning=True, general_text=True, max_tokens=32000, speed="medium", priority=75, ), "nvidia_nim/z-ai/glm4.7": ModelCapabilities( provider_id="nvidia_nim", model_id="glm4.7", model_ref="nvidia_nim/z-ai/glm4.7", coding=True, reasoning=True, general_text=True, max_tokens=32000, speed="medium", priority=70, ), "nvidia_nim/bytedance/seed-oss-36b-instruct": ModelCapabilities( provider_id="nvidia_nim", model_id="seed-oss-36b-instruct", model_ref="nvidia_nim/bytedance/seed-oss-36b-instruct", coding=True, reasoning=True, general_text=True, max_tokens=32000, speed="medium", priority=65, ), "nvidia_nim/mistralai/mistral-nemotron": ModelCapabilities( provider_id="nvidia_nim", model_id="mistral-nemotron", model_ref="nvidia_nim/mistralai/mistral-nemotron", coding=True, reasoning=True, general_text=True, max_tokens=32000, speed="medium", priority=60, ), # Cerebras models (key only has access to llama3.1-8b currently) # Note: qwen-3-235b-a22b-instruct-2507 exists but is rate-limited # Note: zai-glm-4.7 and gpt-oss-120b are not accessible with current key "cerebras/llama3.1-8b": ModelCapabilities( provider_id="cerebras", model_id="llama3.1-8b", model_ref="cerebras/llama3.1-8b", coding=True, reasoning=False, general_text=True, max_tokens=32000, speed="fast", priority=60, ), # Silicon Flow models "silicon/Qwen/Qwen3.6-35B-A3B": ModelCapabilities( provider_id="silicon", model_id="Qwen/Qwen3.6-35B-A3B", model_ref="silicon/Qwen/Qwen3.6-35B-A3B", vision=True, supports_base64_images=True, max_images=1, multimodal_input=True, coding=True, reasoning=True, general_text=True, max_tokens=262144, speed="medium", priority=85, ), "silicon/Qwen/Qwen3.6-27B": ModelCapabilities( provider_id="silicon", model_id="Qwen/Qwen3.6-27B", model_ref="silicon/Qwen/Qwen3.6-27B", vision=True, supports_base64_images=True, max_images=1, multimodal_input=True, coding=True, reasoning=True, general_text=True, max_tokens=262144, speed="medium", priority=82, ), "silicon/Qwen/Qwen3.5-35B-A3B": ModelCapabilities( provider_id="silicon", model_id="Qwen/Qwen3.5-35B-A3B", model_ref="silicon/Qwen/Qwen3.5-35B-A3B", vision=True, supports_base64_images=True, max_images=1, multimodal_input=True, coding=True, reasoning=True, general_text=True, max_tokens=262144, speed="medium", priority=80, ), "silicon/Qwen/Qwen3.5-27B": ModelCapabilities( provider_id="silicon", model_id="Qwen/Qwen3.5-27B", model_ref="silicon/Qwen/Qwen3.5-27B", vision=True, supports_base64_images=True, max_images=1, multimodal_input=True, coding=True, reasoning=True, general_text=True, max_tokens=262144, speed="medium", priority=78, ), "silicon/google/gemma-4-26B-A4B-it": ModelCapabilities( provider_id="silicon", model_id="google/gemma-4-26B-A4B-it", model_ref="silicon/google/gemma-4-26B-A4B-it", coding=True, reasoning=True, general_text=True, max_tokens=262144, speed="fast", priority=75, ), "silicon/google/gemma-4-31B-it": ModelCapabilities( provider_id="silicon", model_id="google/gemma-4-31B-it", model_ref="silicon/google/gemma-4-31B-it", coding=True, reasoning=True, general_text=True, max_tokens=262144, speed="fast", priority=76, ), # Groq models "groq/llama-3.3-70b-versatile": ModelCapabilities( provider_id="groq", model_id="llama-3.3-70b-versatile", model_ref="groq/llama-3.3-70b-versatile", coding=True, reasoning=True, general_text=True, max_tokens=32768, speed="fast", priority=85, ), "groq/llama-3.1-8b-instant": ModelCapabilities( provider_id="groq", model_id="llama-3.1-8b-instant", model_ref="groq/llama-3.1-8b-instant", coding=True, general_text=True, max_tokens=131072, speed="fast", priority=90, ), "groq/qwen3-32b": ModelCapabilities( provider_id="groq", model_id="qwen3-32b", model_ref="groq/qwen3-32b", coding=True, reasoning=True, general_text=True, max_tokens=40960, speed="medium", priority=88, ), } def get_model_capabilities(model_ref: str) -> ModelCapabilities | None: """Get capabilities for a specific model reference.""" return MODEL_CAPABILITIES.get(model_ref) def find_models_with_capability(capability: str) -> list[ModelCapabilities]: """Find all models that have a specific capability.""" results = [] for caps in MODEL_CAPABILITIES.values(): if getattr(caps, capability, False): results.append(caps) # Sort by priority (higher = better) results.sort(key=lambda x: x.priority, reverse=True) return results def find_best_model_for_task( required_capabilities: set[str], available_models: Sequence[str] | None = None, ) -> ModelCapabilities | None: """Find the best model matching required capabilities. Args: required_capabilities: Set of capability names needed (e.g., {"coding", "vision"}) available_models: Optional list of model refs to filter by Returns: Best matching ModelCapabilities or None """ candidates = [] models_to_check = ( [MODEL_CAPABILITIES[m] for m in available_models if m in MODEL_CAPABILITIES] if available_models else list(MODEL_CAPABILITIES.values()) ) for caps in models_to_check: # Check if model has all required capabilities if all(getattr(caps, cap, False) for cap in required_capabilities): candidates.append(caps) if not candidates: return None # Sort by priority and return best candidates.sort(key=lambda x: x.priority, reverse=True) return candidates[0] def get_capability_match_score( model_caps: ModelCapabilities, required: set[str], ) -> tuple[int, int]: """Calculate match score for routing. Returns (matched_count, priority) for sorting. """ matched = sum(1 for cap in required if getattr(model_caps, cap, False)) return (matched, model_caps.priority)