claude-code-proxy / core /model_capabilities.py
Yash030's picture
Add smart task-aware routing (Phase 1)
4974012
raw
history blame
5.91 kB
"""Model capability registry for intelligent routing."""
from __future__ import annotations
from dataclasses import dataclass
from typing import TYPE_CHECKING
if TYPE_CHECKING:
from collections.abc import Sequence
@dataclass(frozen=True, slots=True)
class ModelCapabilities:
"""Capabilities of a specific model for routing decisions."""
provider_id: str
model_id: str
model_ref: str # provider/model format
vision: bool = False # Can process images
coding: bool = False # Good at code generation/analysis
reasoning: bool = False # Strong reasoning/thinking
general_text: bool = True # General text generation
multimodal_input: bool = False # Can handle multiple input types
multimodal_output: bool = False # Can produce multiple output types
max_tokens: int = 4096
speed: str = "medium" # "fast", "medium", "slow"
priority: int = 100 # Higher = preferred for its capabilities
# Registry of all available models and their capabilities
# This can be extended with actual model discovery later
MODEL_CAPABILITIES: dict[str, ModelCapabilities] = {
# Zen/minimax models
"zen/minimax-m2.5-free": ModelCapabilities(
provider_id="zen",
model_id="minimax-m2.5-free",
model_ref="zen/minimax-m2.5-free",
coding=True,
reasoning=True,
general_text=True,
max_tokens=32000,
speed="fast",
priority=80,
),
# NVIDIA NIM models
"nvidia_nim/stepfun-ai/step-3.5-flash": ModelCapabilities(
provider_id="nvidia_nim",
model_id="step-3.5-flash",
model_ref="nvidia_nim/stepfun-ai/step-3.5-flash",
coding=True,
reasoning=True,
general_text=True,
max_tokens=32000,
speed="fast",
priority=70,
),
"nvidia_nim/qwen/qwen3-coder-480b-a35b-instruct": ModelCapabilities(
provider_id="nvidia_nim",
model_id="qwen3-coder-480b-a35b-instruct",
model_ref="nvidia_nim/qwen/qwen3-coder-480b-a35b-instruct",
coding=True,
reasoning=True,
general_text=True,
max_tokens=32000,
speed="slow",
priority=90,
),
"nvidia_nim/mistralai/mistral-large-3-675b-instruct-2512": ModelCapabilities(
provider_id="nvidia_nim",
model_id="mistral-large-3-675b-instruct-2512",
model_ref="nvidia_nim/mistralai/mistral-large-3-675b-instruct-2512",
coding=True,
reasoning=True,
general_text=True,
max_tokens=32000,
speed="slow",
priority=85,
),
"nvidia_nim/abacusai/dracarys-llama-3.1-70b-instruct": ModelCapabilities(
provider_id="nvidia_nim",
model_id="dracarys-llama-3.1-70b-instruct",
model_ref="nvidia_nim/abacusai/dracarys-llama-3.1-70b-instruct",
coding=True,
reasoning=True,
general_text=True,
max_tokens=32000,
speed="medium",
priority=75,
),
"nvidia_nim/z-ai/glm4.7": ModelCapabilities(
provider_id="nvidia_nim",
model_id="glm4.7",
model_ref="nvidia_nim/z-ai/glm4.7",
coding=True,
reasoning=True,
general_text=True,
max_tokens=32000,
speed="medium",
priority=70,
),
"nvidia_nim/bytedance/seed-oss-36b-instruct": ModelCapabilities(
provider_id="nvidia_nim",
model_id="seed-oss-36b-instruct",
model_ref="nvidia_nim/bytedance/seed-oss-36b-instruct",
coding=True,
reasoning=True,
general_text=True,
max_tokens=32000,
speed="medium",
priority=65,
),
"nvidia_nim/mistralai/mistral-nemotron": ModelCapabilities(
provider_id="nvidia_nim",
model_id="mistral-nemotron",
model_ref="nvidia_nim/mistralai/mistral-nemotron",
coding=True,
reasoning=True,
general_text=True,
max_tokens=32000,
speed="medium",
priority=60,
),
}
def get_model_capabilities(model_ref: str) -> ModelCapabilities | None:
"""Get capabilities for a specific model reference."""
return MODEL_CAPABILITIES.get(model_ref)
def find_models_with_capability(capability: str) -> list[ModelCapabilities]:
"""Find all models that have a specific capability."""
results = []
for caps in MODEL_CAPABILITIES.values():
if getattr(caps, capability, False):
results.append(caps)
# Sort by priority (higher = better)
results.sort(key=lambda x: x.priority, reverse=True)
return results
def find_best_model_for_task(
required_capabilities: set[str],
available_models: Sequence[str] | None = None,
) -> ModelCapabilities | None:
"""Find the best model matching required capabilities.
Args:
required_capabilities: Set of capability names needed (e.g., {"coding", "vision"})
available_models: Optional list of model refs to filter by
Returns:
Best matching ModelCapabilities or None
"""
candidates = []
models_to_check = (
[MODEL_CAPABILITIES[m] for m in available_models if m in MODEL_CAPABILITIES]
if available_models
else list(MODEL_CAPABILITIES.values())
)
for caps in models_to_check:
# Check if model has all required capabilities
if all(getattr(caps, cap, False) for cap in required_capabilities):
candidates.append(caps)
if not candidates:
return None
# Sort by priority and return best
candidates.sort(key=lambda x: x.priority, reverse=True)
return candidates[0]
def get_capability_match_score(
model_caps: ModelCapabilities,
required: set[str],
) -> tuple[int, int]:
"""Calculate match score for routing.
Returns (matched_count, priority) for sorting.
"""
matched = sum(1 for cap in required if getattr(model_caps, cap, False))
return (matched, model_caps.priority)