Spaces:
Running
Running
| from dataclasses import dataclass, field | |
| from typing import Dict, Literal, Optional, List | |
| import logging | |
| from observability import logger as obs_logger | |
| from observability import components as obs_components | |
| from .base import LLMClient | |
| logger = logging.getLogger(__name__) | |
| class CapabilityMismatchError(Exception): | |
| """Raised when a selected model does not meet the agent's capability requirements.""" | |
| pass | |
| class ModelProfile: | |
| provider: str | |
| model_name: str | |
| supports_tools: bool | |
| supports_strict_json: bool | |
| latency_class: Literal["fast", "medium", "slow"] | |
| cost_class: Literal["cheap", "medium", "expensive"] | |
| stability_rating: int # 1–5 | |
| def __post_init__(self): | |
| if not (1 <= self.stability_rating <= 5): | |
| raise ValueError( | |
| f"stability_rating must be between 1 and 5, got {self.stability_rating}" | |
| ) | |
| MODEL_REGISTRY: Dict[str, ModelProfile] = { | |
| "hf_gpt_oss_20b": ModelProfile( | |
| provider="litellm", | |
| model_name="huggingface/openai/gpt-oss-20b", | |
| supports_tools=False, | |
| supports_strict_json=False, | |
| latency_class="medium", | |
| cost_class="cheap", | |
| stability_rating=3, | |
| ), | |
| "gemini_flash": ModelProfile( | |
| provider="gemini", | |
| model_name="gemini-3-flash-preview", | |
| supports_tools=True, | |
| supports_strict_json=True, | |
| latency_class="fast", | |
| cost_class="cheap", | |
| stability_rating=5, | |
| ), | |
| "openai_gpt5": ModelProfile( | |
| provider="litellm", | |
| model_name="openai/gpt-5-mini", | |
| supports_tools=True, | |
| supports_strict_json=True, | |
| latency_class="medium", | |
| cost_class="expensive", | |
| stability_rating=5, | |
| ), | |
| } | |
| def select_model_for_agent(agent_name: str) -> ModelProfile: | |
| """ | |
| Manually maps an agent to a model profile based on requirements. | |
| Validates that the selected model meets the agent's capability needs. | |
| """ | |
| from .agent_capabilities import AGENT_CAPABILITIES | |
| requirements = AGENT_CAPABILITIES.get(agent_name) | |
| if not requirements: | |
| logger.warning( | |
| f"No capability requirements defined for agent: {agent_name}. Using fallback." | |
| ) | |
| # Default fallback if unknown agent | |
| return MODEL_REGISTRY["gemini_flash"] | |
| # Manual mapping as requested | |
| mapping = { | |
| "InsightsAgent": "gemini_flash", | |
| "PlanAgent": "gemini_flash", | |
| "VisualizationAgent": "openai_gpt5", | |
| "Router": "gemini_flash", # Changed from hf_gpt_oss_20b which lacks strict JSON | |
| "ChatAgent": "hf_gpt_oss_20b", | |
| "BriefService": "gemini_flash", | |
| } | |
| model_key = mapping.get(agent_name, "gemini_flash") | |
| model_profile = MODEL_REGISTRY.get(model_key) | |
| if not model_profile: | |
| raise ValueError(f"Model key '{model_key}' not found in registry for agent '{agent_name}'") | |
| # Capability Validation | |
| mismatches = [] | |
| if requirements.tools_required and not model_profile.supports_tools: | |
| mismatches.append("tools_required=True but supports_tools=False") | |
| if requirements.strict_json_required and not model_profile.supports_strict_json: | |
| mismatches.append("strict_json_required=True but supports_strict_json=False") | |
| if mismatches: | |
| error_msg = f"Capability mismatch for agent '{agent_name}' with model '{model_key}': {', '.join(mismatches)}" | |
| obs_logger.log_event( | |
| level="error", | |
| message=error_msg, | |
| event="capability_mismatch", | |
| component=obs_components.LLM, | |
| agent_name=agent_name, | |
| model_key=model_key, | |
| mismatches=mismatches, | |
| ) | |
| raise CapabilityMismatchError(error_msg) | |
| # Success Log | |
| obs_logger.log_event( | |
| level="info", | |
| message=f"Model selected for agent '{agent_name}': {model_key}", | |
| event="model_selected", | |
| component=obs_components.LLM, | |
| agent_name=agent_name, | |
| selected_model=model_key, | |
| provider=model_profile.provider, | |
| model_name=model_profile.model_name, | |
| required_capabilities={ | |
| "tools_required": requirements.tools_required, | |
| "strict_json_required": requirements.strict_json_required, | |
| "latency_preference": requirements.latency_preference, | |
| }, | |
| ) | |
| return model_profile | |