Spaces:
Running
Running
File size: 4,380 Bytes
d64fd55 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 | from dataclasses import dataclass, field
from typing import Dict, Literal, Optional, List
import logging
from observability import logger as obs_logger
from observability import components as obs_components
from .base import LLMClient
logger = logging.getLogger(__name__)
class CapabilityMismatchError(Exception):
"""Raised when a selected model does not meet the agent's capability requirements."""
pass
@dataclass
class ModelProfile:
provider: str
model_name: str
supports_tools: bool
supports_strict_json: bool
latency_class: Literal["fast", "medium", "slow"]
cost_class: Literal["cheap", "medium", "expensive"]
stability_rating: int # 1–5
def __post_init__(self):
if not (1 <= self.stability_rating <= 5):
raise ValueError(
f"stability_rating must be between 1 and 5, got {self.stability_rating}"
)
MODEL_REGISTRY: Dict[str, ModelProfile] = {
"hf_gpt_oss_20b": ModelProfile(
provider="litellm",
model_name="huggingface/openai/gpt-oss-20b",
supports_tools=False,
supports_strict_json=False,
latency_class="medium",
cost_class="cheap",
stability_rating=3,
),
"gemini_flash": ModelProfile(
provider="gemini",
model_name="gemini-3-flash-preview",
supports_tools=True,
supports_strict_json=True,
latency_class="fast",
cost_class="cheap",
stability_rating=5,
),
"openai_gpt5": ModelProfile(
provider="litellm",
model_name="openai/gpt-5-mini",
supports_tools=True,
supports_strict_json=True,
latency_class="medium",
cost_class="expensive",
stability_rating=5,
),
}
def select_model_for_agent(agent_name: str) -> ModelProfile:
"""
Manually maps an agent to a model profile based on requirements.
Validates that the selected model meets the agent's capability needs.
"""
from .agent_capabilities import AGENT_CAPABILITIES
requirements = AGENT_CAPABILITIES.get(agent_name)
if not requirements:
logger.warning(
f"No capability requirements defined for agent: {agent_name}. Using fallback."
)
# Default fallback if unknown agent
return MODEL_REGISTRY["gemini_flash"]
# Manual mapping as requested
mapping = {
"InsightsAgent": "gemini_flash",
"PlanAgent": "gemini_flash",
"VisualizationAgent": "openai_gpt5",
"Router": "gemini_flash", # Changed from hf_gpt_oss_20b which lacks strict JSON
"ChatAgent": "hf_gpt_oss_20b",
"BriefService": "gemini_flash",
}
model_key = mapping.get(agent_name, "gemini_flash")
model_profile = MODEL_REGISTRY.get(model_key)
if not model_profile:
raise ValueError(f"Model key '{model_key}' not found in registry for agent '{agent_name}'")
# Capability Validation
mismatches = []
if requirements.tools_required and not model_profile.supports_tools:
mismatches.append("tools_required=True but supports_tools=False")
if requirements.strict_json_required and not model_profile.supports_strict_json:
mismatches.append("strict_json_required=True but supports_strict_json=False")
if mismatches:
error_msg = f"Capability mismatch for agent '{agent_name}' with model '{model_key}': {', '.join(mismatches)}"
obs_logger.log_event(
level="error",
message=error_msg,
event="capability_mismatch",
component=obs_components.LLM,
agent_name=agent_name,
model_key=model_key,
mismatches=mismatches,
)
raise CapabilityMismatchError(error_msg)
# Success Log
obs_logger.log_event(
level="info",
message=f"Model selected for agent '{agent_name}': {model_key}",
event="model_selected",
component=obs_components.LLM,
agent_name=agent_name,
selected_model=model_key,
provider=model_profile.provider,
model_name=model_profile.model_name,
required_capabilities={
"tools_required": requirements.tools_required,
"strict_json_required": requirements.strict_json_required,
"latency_preference": requirements.latency_preference,
},
)
return model_profile
|