runner-ai-intelligence / src /llm /model_registry.py
avfranco's picture
HF Space deploy snapshot (minimal allow-list)
d64fd55
from dataclasses import dataclass, field
from typing import Dict, Literal, Optional, List
import logging
from observability import logger as obs_logger
from observability import components as obs_components
from .base import LLMClient
logger = logging.getLogger(__name__)
class CapabilityMismatchError(Exception):
"""Raised when a selected model does not meet the agent's capability requirements."""
pass
@dataclass
class ModelProfile:
provider: str
model_name: str
supports_tools: bool
supports_strict_json: bool
latency_class: Literal["fast", "medium", "slow"]
cost_class: Literal["cheap", "medium", "expensive"]
stability_rating: int # 1–5
def __post_init__(self):
if not (1 <= self.stability_rating <= 5):
raise ValueError(
f"stability_rating must be between 1 and 5, got {self.stability_rating}"
)
MODEL_REGISTRY: Dict[str, ModelProfile] = {
"hf_gpt_oss_20b": ModelProfile(
provider="litellm",
model_name="huggingface/openai/gpt-oss-20b",
supports_tools=False,
supports_strict_json=False,
latency_class="medium",
cost_class="cheap",
stability_rating=3,
),
"gemini_flash": ModelProfile(
provider="gemini",
model_name="gemini-3-flash-preview",
supports_tools=True,
supports_strict_json=True,
latency_class="fast",
cost_class="cheap",
stability_rating=5,
),
"openai_gpt5": ModelProfile(
provider="litellm",
model_name="openai/gpt-5-mini",
supports_tools=True,
supports_strict_json=True,
latency_class="medium",
cost_class="expensive",
stability_rating=5,
),
}
def select_model_for_agent(agent_name: str) -> ModelProfile:
"""
Manually maps an agent to a model profile based on requirements.
Validates that the selected model meets the agent's capability needs.
"""
from .agent_capabilities import AGENT_CAPABILITIES
requirements = AGENT_CAPABILITIES.get(agent_name)
if not requirements:
logger.warning(
f"No capability requirements defined for agent: {agent_name}. Using fallback."
)
# Default fallback if unknown agent
return MODEL_REGISTRY["gemini_flash"]
# Manual mapping as requested
mapping = {
"InsightsAgent": "gemini_flash",
"PlanAgent": "gemini_flash",
"VisualizationAgent": "openai_gpt5",
"Router": "gemini_flash", # Changed from hf_gpt_oss_20b which lacks strict JSON
"ChatAgent": "hf_gpt_oss_20b",
"BriefService": "gemini_flash",
}
model_key = mapping.get(agent_name, "gemini_flash")
model_profile = MODEL_REGISTRY.get(model_key)
if not model_profile:
raise ValueError(f"Model key '{model_key}' not found in registry for agent '{agent_name}'")
# Capability Validation
mismatches = []
if requirements.tools_required and not model_profile.supports_tools:
mismatches.append("tools_required=True but supports_tools=False")
if requirements.strict_json_required and not model_profile.supports_strict_json:
mismatches.append("strict_json_required=True but supports_strict_json=False")
if mismatches:
error_msg = f"Capability mismatch for agent '{agent_name}' with model '{model_key}': {', '.join(mismatches)}"
obs_logger.log_event(
level="error",
message=error_msg,
event="capability_mismatch",
component=obs_components.LLM,
agent_name=agent_name,
model_key=model_key,
mismatches=mismatches,
)
raise CapabilityMismatchError(error_msg)
# Success Log
obs_logger.log_event(
level="info",
message=f"Model selected for agent '{agent_name}': {model_key}",
event="model_selected",
component=obs_components.LLM,
agent_name=agent_name,
selected_model=model_key,
provider=model_profile.provider,
model_name=model_profile.model_name,
required_capabilities={
"tools_required": requirements.tools_required,
"strict_json_required": requirements.strict_json_required,
"latency_preference": requirements.latency_preference,
},
)
return model_profile