claude-code-proxy / core /model_capabilities.py
Yash030's picture
$(cat <<EOF
98fdd46
"""Model capability registry for intelligent routing."""
from __future__ import annotations
from dataclasses import dataclass
from typing import TYPE_CHECKING
if TYPE_CHECKING:
from collections.abc import Sequence
@dataclass(frozen=True, slots=True)
class ModelCapabilities:
"""Capabilities of a specific model for routing decisions."""
provider_id: str
model_id: str
model_ref: str # provider/model format
vision: bool = False # Can process images
supports_base64_images: bool = False # Accepts data: URLs with base64
supports_remote_images: bool = False # Accepts http/https URLs
supports_pdfs: bool = False # Can process PDF documents
max_images: int = 0 # Max images per request (0 = unlimited)
max_image_size_mb: float = 10.0 # Max size per image in MB
coding: bool = False # Good at code generation/analysis
reasoning: bool = False # Strong reasoning/thinking
general_text: bool = True # General text generation
multimodal_input: bool = False # Can handle multiple input types
multimodal_output: bool = False # Can produce multiple output types
max_tokens: int = 4096
speed: str = "medium" # "fast", "medium", "slow"
priority: int = 100 # Higher = preferred for its capabilities
# Registry of all available models and their capabilities
# This can be extended with actual model discovery later
MODEL_CAPABILITIES: dict[str, ModelCapabilities] = {
# Zen/minimax models
"zen/minimax-m2.5-free": ModelCapabilities(
provider_id="zen",
model_id="minimax-m2.5-free",
model_ref="zen/minimax-m2.5-free",
coding=True,
reasoning=True,
general_text=True,
max_tokens=32000,
speed="fast",
priority=80,
),
# NVIDIA NIM models
"nvidia_nim/stepfun-ai/step-3.5-flash": ModelCapabilities(
provider_id="nvidia_nim",
model_id="step-3.5-flash",
model_ref="nvidia_nim/stepfun-ai/step-3.5-flash",
coding=True,
reasoning=True,
general_text=True,
max_tokens=32000,
speed="fast",
priority=70,
),
"nvidia_nim/qwen/qwen3-coder-480b-a35b-instruct": ModelCapabilities(
provider_id="nvidia_nim",
model_id="qwen3-coder-480b-a35b-instruct",
model_ref="nvidia_nim/qwen/qwen3-coder-480b-a35b-instruct",
coding=True,
reasoning=True,
general_text=True,
max_tokens=32000,
speed="slow",
priority=90,
),
"nvidia_nim/mistralai/mistral-large-3-675b-instruct-2512": ModelCapabilities(
provider_id="nvidia_nim",
model_id="mistral-large-3-675b-instruct-2512",
model_ref="nvidia_nim/mistralai/mistral-large-3-675b-instruct-2512",
vision=True,
supports_base64_images=True,
supports_remote_images=False,
max_images=16,
max_image_size_mb=10.0,
multimodal_input=True,
coding=True,
reasoning=True,
general_text=True,
max_tokens=32000,
speed="slow",
priority=90,
),
"nvidia_nim/abacusai/dracarys-llama-3.1-70b-instruct": ModelCapabilities(
provider_id="nvidia_nim",
model_id="dracarys-llama-3.1-70b-instruct",
model_ref="nvidia_nim/abacusai/dracarys-llama-3.1-70b-instruct",
coding=True,
reasoning=True,
general_text=True,
max_tokens=32000,
speed="medium",
priority=75,
),
"nvidia_nim/z-ai/glm4.7": ModelCapabilities(
provider_id="nvidia_nim",
model_id="glm4.7",
model_ref="nvidia_nim/z-ai/glm4.7",
coding=True,
reasoning=True,
general_text=True,
max_tokens=32000,
speed="medium",
priority=70,
),
"nvidia_nim/bytedance/seed-oss-36b-instruct": ModelCapabilities(
provider_id="nvidia_nim",
model_id="seed-oss-36b-instruct",
model_ref="nvidia_nim/bytedance/seed-oss-36b-instruct",
coding=True,
reasoning=True,
general_text=True,
max_tokens=32000,
speed="medium",
priority=65,
),
"nvidia_nim/mistralai/mistral-nemotron": ModelCapabilities(
provider_id="nvidia_nim",
model_id="mistral-nemotron",
model_ref="nvidia_nim/mistralai/mistral-nemotron",
coding=True,
reasoning=True,
general_text=True,
max_tokens=32000,
speed="medium",
priority=60,
),
# Cerebras models (key only has access to llama3.1-8b currently)
# Note: qwen-3-235b-a22b-instruct-2507 exists but is rate-limited
# Note: zai-glm-4.7 and gpt-oss-120b are not accessible with current key
"cerebras/llama3.1-8b": ModelCapabilities(
provider_id="cerebras",
model_id="llama3.1-8b",
model_ref="cerebras/llama3.1-8b",
coding=True,
reasoning=False,
general_text=True,
max_tokens=32000,
speed="fast",
priority=60,
),
# Silicon Flow models
"silicon/Qwen/Qwen3.6-35B-A3B": ModelCapabilities(
provider_id="silicon",
model_id="Qwen/Qwen3.6-35B-A3B",
model_ref="silicon/Qwen/Qwen3.6-35B-A3B",
vision=True,
supports_base64_images=True,
max_images=1,
multimodal_input=True,
coding=True,
reasoning=True,
general_text=True,
max_tokens=262144,
speed="medium",
priority=85,
),
"silicon/Qwen/Qwen3.6-27B": ModelCapabilities(
provider_id="silicon",
model_id="Qwen/Qwen3.6-27B",
model_ref="silicon/Qwen/Qwen3.6-27B",
vision=True,
supports_base64_images=True,
max_images=1,
multimodal_input=True,
coding=True,
reasoning=True,
general_text=True,
max_tokens=262144,
speed="medium",
priority=82,
),
"silicon/Qwen/Qwen3.5-35B-A3B": ModelCapabilities(
provider_id="silicon",
model_id="Qwen/Qwen3.5-35B-A3B",
model_ref="silicon/Qwen/Qwen3.5-35B-A3B",
vision=True,
supports_base64_images=True,
max_images=1,
multimodal_input=True,
coding=True,
reasoning=True,
general_text=True,
max_tokens=262144,
speed="medium",
priority=80,
),
"silicon/Qwen/Qwen3.5-27B": ModelCapabilities(
provider_id="silicon",
model_id="Qwen/Qwen3.5-27B",
model_ref="silicon/Qwen/Qwen3.5-27B",
vision=True,
supports_base64_images=True,
max_images=1,
multimodal_input=True,
coding=True,
reasoning=True,
general_text=True,
max_tokens=262144,
speed="medium",
priority=78,
),
"silicon/google/gemma-4-26B-A4B-it": ModelCapabilities(
provider_id="silicon",
model_id="google/gemma-4-26B-A4B-it",
model_ref="silicon/google/gemma-4-26B-A4B-it",
coding=True,
reasoning=True,
general_text=True,
max_tokens=262144,
speed="fast",
priority=75,
),
"silicon/google/gemma-4-31B-it": ModelCapabilities(
provider_id="silicon",
model_id="google/gemma-4-31B-it",
model_ref="silicon/google/gemma-4-31B-it",
coding=True,
reasoning=True,
general_text=True,
max_tokens=262144,
speed="fast",
priority=76,
),
# Groq models
"groq/llama-3.3-70b-versatile": ModelCapabilities(
provider_id="groq",
model_id="llama-3.3-70b-versatile",
model_ref="groq/llama-3.3-70b-versatile",
coding=True,
reasoning=True,
general_text=True,
max_tokens=32768,
speed="fast",
priority=85,
),
"groq/llama-3.1-8b-instant": ModelCapabilities(
provider_id="groq",
model_id="llama-3.1-8b-instant",
model_ref="groq/llama-3.1-8b-instant",
coding=True,
general_text=True,
max_tokens=131072,
speed="fast",
priority=90,
),
"groq/qwen3-32b": ModelCapabilities(
provider_id="groq",
model_id="qwen3-32b",
model_ref="groq/qwen3-32b",
coding=True,
reasoning=True,
general_text=True,
max_tokens=40960,
speed="medium",
priority=88,
),
}
def get_model_capabilities(model_ref: str) -> ModelCapabilities | None:
"""Get capabilities for a specific model reference."""
return MODEL_CAPABILITIES.get(model_ref)
def find_models_with_capability(capability: str) -> list[ModelCapabilities]:
"""Find all models that have a specific capability."""
results = []
for caps in MODEL_CAPABILITIES.values():
if getattr(caps, capability, False):
results.append(caps)
# Sort by priority (higher = better)
results.sort(key=lambda x: x.priority, reverse=True)
return results
def find_best_model_for_task(
required_capabilities: set[str],
available_models: Sequence[str] | None = None,
) -> ModelCapabilities | None:
"""Find the best model matching required capabilities.
Args:
required_capabilities: Set of capability names needed (e.g., {"coding", "vision"})
available_models: Optional list of model refs to filter by
Returns:
Best matching ModelCapabilities or None
"""
candidates = []
models_to_check = (
[MODEL_CAPABILITIES[m] for m in available_models if m in MODEL_CAPABILITIES]
if available_models
else list(MODEL_CAPABILITIES.values())
)
for caps in models_to_check:
# Check if model has all required capabilities
if all(getattr(caps, cap, False) for cap in required_capabilities):
candidates.append(caps)
if not candidates:
return None
# Sort by priority and return best
candidates.sort(key=lambda x: x.priority, reverse=True)
return candidates[0]
def get_capability_match_score(
model_caps: ModelCapabilities,
required: set[str],
) -> tuple[int, int]:
"""Calculate match score for routing.
Returns (matched_count, priority) for sorting.
"""
matched = sum(1 for cap in required if getattr(model_caps, cap, False))
return (matched, model_caps.priority)