from dataclasses import dataclass, field from typing import Dict, Literal, Optional, List import logging from observability import logger as obs_logger from observability import components as obs_components from .base import LLMClient logger = logging.getLogger(__name__) class CapabilityMismatchError(Exception): """Raised when a selected model does not meet the agent's capability requirements.""" pass @dataclass class ModelProfile: provider: str model_name: str supports_tools: bool supports_strict_json: bool latency_class: Literal["fast", "medium", "slow"] cost_class: Literal["cheap", "medium", "expensive"] stability_rating: int # 1–5 def __post_init__(self): if not (1 <= self.stability_rating <= 5): raise ValueError( f"stability_rating must be between 1 and 5, got {self.stability_rating}" ) MODEL_REGISTRY: Dict[str, ModelProfile] = { "hf_gpt_oss_20b": ModelProfile( provider="litellm", model_name="huggingface/openai/gpt-oss-20b", supports_tools=False, supports_strict_json=False, latency_class="medium", cost_class="cheap", stability_rating=3, ), "gemini_flash": ModelProfile( provider="gemini", model_name="gemini-3-flash-preview", supports_tools=True, supports_strict_json=True, latency_class="fast", cost_class="cheap", stability_rating=5, ), "openai_gpt5": ModelProfile( provider="litellm", model_name="openai/gpt-5-mini", supports_tools=True, supports_strict_json=True, latency_class="medium", cost_class="expensive", stability_rating=5, ), } def select_model_for_agent(agent_name: str) -> ModelProfile: """ Manually maps an agent to a model profile based on requirements. Validates that the selected model meets the agent's capability needs. """ from .agent_capabilities import AGENT_CAPABILITIES requirements = AGENT_CAPABILITIES.get(agent_name) if not requirements: logger.warning( f"No capability requirements defined for agent: {agent_name}. Using fallback." ) # Default fallback if unknown agent return MODEL_REGISTRY["gemini_flash"] # Manual mapping as requested mapping = { "InsightsAgent": "gemini_flash", "PlanAgent": "gemini_flash", "VisualizationAgent": "openai_gpt5", "Router": "gemini_flash", # Changed from hf_gpt_oss_20b which lacks strict JSON "ChatAgent": "hf_gpt_oss_20b", "BriefService": "gemini_flash", } model_key = mapping.get(agent_name, "gemini_flash") model_profile = MODEL_REGISTRY.get(model_key) if not model_profile: raise ValueError(f"Model key '{model_key}' not found in registry for agent '{agent_name}'") # Capability Validation mismatches = [] if requirements.tools_required and not model_profile.supports_tools: mismatches.append("tools_required=True but supports_tools=False") if requirements.strict_json_required and not model_profile.supports_strict_json: mismatches.append("strict_json_required=True but supports_strict_json=False") if mismatches: error_msg = f"Capability mismatch for agent '{agent_name}' with model '{model_key}': {', '.join(mismatches)}" obs_logger.log_event( level="error", message=error_msg, event="capability_mismatch", component=obs_components.LLM, agent_name=agent_name, model_key=model_key, mismatches=mismatches, ) raise CapabilityMismatchError(error_msg) # Success Log obs_logger.log_event( level="info", message=f"Model selected for agent '{agent_name}': {model_key}", event="model_selected", component=obs_components.LLM, agent_name=agent_name, selected_model=model_key, provider=model_profile.provider, model_name=model_profile.model_name, required_capabilities={ "tools_required": requirements.tools_required, "strict_json_required": requirements.strict_json_required, "latency_preference": requirements.latency_preference, }, ) return model_profile