File size: 5,257 Bytes
a1bf219
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
"""Configuration validation utilities."""

from typing import Dict, List, Optional, Tuple

# Valid routing policies
VALID_ROUTING_POLICIES = [
    "auto",
    ":fastest",
    ":cheapest",
    "groq",
    "together",
    "replicate",
    "cerebras",
    "fireworks",
    "deepinfra",
    "meta-llama/Llama-3.3-70B-Instruct",
]

# Valid model tiers
VALID_MODEL_TIERS = ["fast", "capable", "vision"]

# Valid LLM providers
VALID_LLM_PROVIDERS = ["openai", "anthropic", "huggingface", "qwen"]


def validate_routing_policy(policy: str) -> Tuple[bool, Optional[str]]:
    """
    Validate a routing policy.

    Args:
        policy: Routing policy string

    Returns:
        Tuple of (is_valid, error_message)
    """
    if not policy:
        return False, "Routing policy cannot be empty"

    # Check if it's a known policy
    if policy in VALID_ROUTING_POLICIES:
        return True, None

    # Check if it's a model path (provider/model format)
    if "/" in policy:
        parts = policy.split("/")
        if len(parts) >= 2:
            return True, None  # Assume it's a valid model path

    # Unknown policy
    return (
        False,
        f"Invalid routing policy: {policy}. Must be one of {VALID_ROUTING_POLICIES} or a provider/model path",
    )


def validate_model_tier(tier: str) -> Tuple[bool, Optional[str]]:
    """
    Validate a model tier.

    Args:
        tier: Model tier string

    Returns:
        Tuple of (is_valid, error_message)
    """
    if tier not in VALID_MODEL_TIERS:
        return False, f"Invalid model tier: {tier}. Must be one of {VALID_MODEL_TIERS}"

    return True, None


def validate_agent_routing_config(
    config: Dict[str, Dict[str, str]],
) -> Tuple[bool, Optional[str]]:
    """
    Validate agent routing configuration.

    Args:
        config: Dictionary mapping agent names to routing config

    Returns:
        Tuple of (is_valid, error_message)
    """
    if not isinstance(config, dict):
        return False, "Configuration must be a dictionary"

    for agent_name, agent_config in config.items():
        if not isinstance(agent_config, dict):
            return False, f"Configuration for {agent_name} must be a dictionary"

        # Validate routing policy
        if "routing_policy" not in agent_config:
            return False, f"Missing routing_policy for {agent_name}"

        is_valid, error = validate_routing_policy(agent_config["routing_policy"])
        if not is_valid:
            return False, f"{agent_name}: {error}"

        # Validate model tier (optional)
        if "model_tier" in agent_config:
            is_valid, error = validate_model_tier(agent_config["model_tier"])
            if not is_valid:
                return False, f"{agent_name}: {error}"

    return True, None


def validate_llm_provider(provider: str) -> Tuple[bool, Optional[str]]:
    """
    Validate LLM provider.

    Args:
        provider: Provider name

    Returns:
        Tuple of (is_valid, error_message)
    """
    if provider not in VALID_LLM_PROVIDERS:
        return (
            False,
            f"Invalid LLM provider: {provider}. Must be one of {VALID_LLM_PROVIDERS}",
        )

    return True, None


def validate_configuration(config: Dict) -> Tuple[bool, Optional[str]]:
    """
    Validate complete configuration.

    Args:
        config: Full configuration dictionary

    Returns:
        Tuple of (is_valid, error_message)
    """
    # Validate LLM provider
    if "llm_provider" in config:
        is_valid, error = validate_llm_provider(config["llm_provider"])
        if not is_valid:
            return False, error

    # Validate routing policy (global)
    if "routing_policy" in config:
        is_valid, error = validate_routing_policy(config["routing_policy"])
        if not is_valid:
            return False, error

    # Validate per-agent routing config
    if "agent_routing_config" in config:
        is_valid, error = validate_agent_routing_config(config["agent_routing_config"])
        if not is_valid:
            return False, error

    # Validate indicator parameters
    if "indicator_parameters" in config:
        params = config["indicator_parameters"]
        if not isinstance(params, dict):
            return False, "indicator_parameters must be a dictionary"

        # Check RSI period
        if "rsi_period" in params:
            rsi = params["rsi_period"]
            if not isinstance(rsi, int) or rsi < 2 or rsi > 100:
                return False, "RSI period must be between 2 and 100"

        # Check MACD parameters
        if "macd_fast" in params:
            fast = params["macd_fast"]
            if not isinstance(fast, int) or fast < 2 or fast > 50:
                return False, "MACD fast period must be between 2 and 50"

        if "macd_slow" in params:
            slow = params["macd_slow"]
            if not isinstance(slow, int) or slow < 10 or slow > 100:
                return False, "MACD slow period must be between 10 and 100"

        if "macd_signal" in params:
            signal = params["macd_signal"]
            if not isinstance(signal, int) or signal < 2 or signal > 50:
                return False, "MACD signal period must be between 2 and 50"

    return True, None