File size: 4,380 Bytes
d64fd55
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
from dataclasses import dataclass, field
from typing import Dict, Literal, Optional, List
import logging
from observability import logger as obs_logger
from observability import components as obs_components
from .base import LLMClient

logger = logging.getLogger(__name__)


class CapabilityMismatchError(Exception):
    """Raised when a selected model does not meet the agent's capability requirements."""

    pass


@dataclass
class ModelProfile:
    provider: str
    model_name: str
    supports_tools: bool
    supports_strict_json: bool
    latency_class: Literal["fast", "medium", "slow"]
    cost_class: Literal["cheap", "medium", "expensive"]
    stability_rating: int  # 1–5

    def __post_init__(self):
        if not (1 <= self.stability_rating <= 5):
            raise ValueError(
                f"stability_rating must be between 1 and 5, got {self.stability_rating}"
            )


MODEL_REGISTRY: Dict[str, ModelProfile] = {
    "hf_gpt_oss_20b": ModelProfile(
        provider="litellm",
        model_name="huggingface/openai/gpt-oss-20b",
        supports_tools=False,
        supports_strict_json=False,
        latency_class="medium",
        cost_class="cheap",
        stability_rating=3,
    ),
    "gemini_flash": ModelProfile(
        provider="gemini",
        model_name="gemini-3-flash-preview",
        supports_tools=True,
        supports_strict_json=True,
        latency_class="fast",
        cost_class="cheap",
        stability_rating=5,
    ),
    "openai_gpt5": ModelProfile(
        provider="litellm",
        model_name="openai/gpt-5-mini",
        supports_tools=True,
        supports_strict_json=True,
        latency_class="medium",
        cost_class="expensive",
        stability_rating=5,
    ),
}


def select_model_for_agent(agent_name: str) -> ModelProfile:
    """
    Manually maps an agent to a model profile based on requirements.
    Validates that the selected model meets the agent's capability needs.
    """
    from .agent_capabilities import AGENT_CAPABILITIES

    requirements = AGENT_CAPABILITIES.get(agent_name)
    if not requirements:
        logger.warning(
            f"No capability requirements defined for agent: {agent_name}. Using fallback."
        )
        # Default fallback if unknown agent
        return MODEL_REGISTRY["gemini_flash"]

    # Manual mapping as requested
    mapping = {
        "InsightsAgent": "gemini_flash",
        "PlanAgent": "gemini_flash",
        "VisualizationAgent": "openai_gpt5",
        "Router": "gemini_flash",  # Changed from hf_gpt_oss_20b which lacks strict JSON
        "ChatAgent": "hf_gpt_oss_20b",
        "BriefService": "gemini_flash",
    }

    model_key = mapping.get(agent_name, "gemini_flash")
    model_profile = MODEL_REGISTRY.get(model_key)

    if not model_profile:
        raise ValueError(f"Model key '{model_key}' not found in registry for agent '{agent_name}'")

    # Capability Validation
    mismatches = []
    if requirements.tools_required and not model_profile.supports_tools:
        mismatches.append("tools_required=True but supports_tools=False")

    if requirements.strict_json_required and not model_profile.supports_strict_json:
        mismatches.append("strict_json_required=True but supports_strict_json=False")

    if mismatches:
        error_msg = f"Capability mismatch for agent '{agent_name}' with model '{model_key}': {', '.join(mismatches)}"
        obs_logger.log_event(
            level="error",
            message=error_msg,
            event="capability_mismatch",
            component=obs_components.LLM,
            agent_name=agent_name,
            model_key=model_key,
            mismatches=mismatches,
        )
        raise CapabilityMismatchError(error_msg)

    # Success Log
    obs_logger.log_event(
        level="info",
        message=f"Model selected for agent '{agent_name}': {model_key}",
        event="model_selected",
        component=obs_components.LLM,
        agent_name=agent_name,
        selected_model=model_key,
        provider=model_profile.provider,
        model_name=model_profile.model_name,
        required_capabilities={
            "tools_required": requirements.tools_required,
            "strict_json_required": requirements.strict_json_required,
            "latency_preference": requirements.latency_preference,
        },
    )

    return model_profile