File size: 11,710 Bytes
abb6d23
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
# SPDX-License-Identifier: GPL-3.0-or-later
#
# Toolify: Empower any LLM with function calling capabilities.
# Copyright (C) 2025 FunnyCups (https://github.com/funnycups)

import os
import yaml
from typing import List, Dict, Any, Set, Optional
from pydantic import BaseModel, Field, field_validator


class ServerConfig(BaseModel):
    """Server configuration"""
    port: int = Field(default=8000, ge=1, le=65535, description="Server port")
    host: str = Field(default="0.0.0.0", description="Server host address")
    timeout: int = Field(default=180, ge=1, description="Request timeout (seconds)")
    upstream_retry_attempts: int = Field(default=3, ge=0, le=10, description="Max retry attempts for upstream connection errors and timeouts (0 to disable)")
    upstream_retry_base_delay: float = Field(default=0.5, ge=0.1, le=10.0, description="Base delay in seconds for exponential backoff between retries")


class UpstreamService(BaseModel):
    """Upstream service configuration"""
    name: str = Field(description="Service name")
    base_url: str = Field(description="Service base URL")
    api_key: str = Field(description="API key")
    models: List[str] = Field(default_factory=list, description="List of supported models")
    description: str = Field(default="", description="Service description")
    is_default: bool = Field(default=False, description="Is default service")
    
    @field_validator('base_url')
    def validate_base_url(cls, v):
        if not v.startswith(('http://', 'https://')):
            raise ValueError('base_url must start with http:// or https://')
        return v.rstrip('/')
    
    @field_validator('api_key')
    def validate_api_key(cls, v):
        if not v or v.strip() == "":
            raise ValueError('api_key cannot be empty')
        return v
    
    @field_validator('models')
    def validate_models(cls, v):
        # Allow empty models list when model_passthrough is enabled
        # This will be validated at AppConfig level
        if v:
            for model in v:
                if not model or model.strip() == "":
                    raise ValueError('model name cannot be empty')
        return v if v else []


class ClientAuthConfig(BaseModel):
    """Client authentication configuration"""
    allowed_keys: List[str] = Field(description="List of allowed client API keys")
    
    @field_validator('allowed_keys')
    def validate_allowed_keys(cls, v):
        if not v or len(v) == 0:
            raise ValueError('allowed_keys cannot be empty')
        for key in v:
            if not key or key.strip() == "":
                raise ValueError('API key cannot be empty')
        return v


class FeaturesConfig(BaseModel):
    """Feature configuration"""
    enable_function_calling: bool = Field(default=True, description="Enable function calling")
    log_level: str = Field(default="INFO", description="Logging level: DEBUG, INFO, WARNING, ERROR, CRITICAL, or DISABLED")
    convert_developer_to_system: bool = Field(default=True, description="Convert developer role to system role")
    prompt_template: Optional[str] = Field(default=None, description="Custom prompt template for function calling")
    key_passthrough: bool = Field(default=False, description="If true, directly forward client-provided API key to upstream instead of using configured upstream key")
    model_passthrough: bool = Field(default=False, description="If true, forward all requests directly to the 'openai' upstream service, ignoring model-based routing")
    
    # Function calling error retry configuration
    enable_fc_error_retry: bool = Field(default=False, description="Enable automatic retry when function call parsing fails")
    fc_error_retry_max_attempts: int = Field(default=3, ge=1, le=10, description="Maximum number of retry attempts for function call error correction")
    fc_error_retry_prompt_template: Optional[str] = Field(default=None, description="Custom prompt template for function call error correction. Must contain {error_details} and {original_response} placeholders if provided.")

    @field_validator('log_level')
    def validate_log_level(cls, v):
        valid_levels = ["DEBUG", "INFO", "WARNING", "ERROR", "CRITICAL", "DISABLED"]
        if v.upper() not in valid_levels:
            raise ValueError(f"log_level must be one of {valid_levels}")
        return v.upper()

    @field_validator('prompt_template')
    def validate_prompt_template(cls, v):
        if v:
            if "{tools_list}" not in v or "{trigger_signal}" not in v:
                raise ValueError("prompt_template must contain {tools_list} and {trigger_signal} placeholders")
        return v

    @field_validator('fc_error_retry_prompt_template')
    def validate_fc_error_retry_prompt_template(cls, v):
        if v:
            if "{error_details}" not in v or "{original_response}" not in v:
                raise ValueError("fc_error_retry_prompt_template must contain {error_details} and {original_response} placeholders")
        return v


class AppConfig(BaseModel):
    """Application full configuration"""
    server: ServerConfig = Field(default_factory=ServerConfig)
    upstream_services: List[UpstreamService] = Field(description="List of upstream services")
    client_authentication: ClientAuthConfig = Field(description="Client authentication configuration")
    features: FeaturesConfig = Field(default_factory=FeaturesConfig)
    
    @field_validator('upstream_services')
    def validate_upstream_services(cls, v, info):
        if not v or len(v) == 0:
            raise ValueError('upstream_services cannot be empty')
        
        # Get features config to check model_passthrough mode
        features = info.data.get('features', FeaturesConfig())
        model_passthrough = features.model_passthrough if hasattr(features, 'model_passthrough') else False
        
        # In model_passthrough mode, check for 'openai' service existence
        if model_passthrough:
            openai_service = next((s for s in v if s.name == 'openai'), None)
            if not openai_service:
                raise ValueError("When model_passthrough is enabled, an upstream service named 'openai' must be configured")
        else:
            # In normal mode, validate that services have models
            for service in v:
                if not service.models or len(service.models) == 0:
                    raise ValueError(f"Service '{service.name}' must have at least one model when model_passthrough is disabled")
        
        default_services = [service for service in v if service.is_default]
        if len(default_services) == 0:
            raise ValueError('Must have at least one default upstream service (is_default: true)')
        if len(default_services) > 1:
            raise ValueError('Only one upstream service can be marked as default')
        
        all_models = set()
        all_aliases = set()
        
        for service in v:
            for model in service.models:
                if model in all_models:
                    raise ValueError(f'Duplicate model entry found: {model}')
                all_models.add(model)
                
                if ':' in model:
                    parts = model.split(':', 1)
                    if len(parts) == 2:
                        alias, real_model = parts
                        if not alias.strip() or not real_model.strip():
                            raise ValueError(f"Invalid alias format in '{model}'. Both parts must not be empty.")
                        all_aliases.add(alias)
                    else:
                        raise ValueError(f"Invalid model format with colon: {model}")

        regular_models = {m for m in all_models if ':' not in m}
        conflicts = all_aliases.intersection(regular_models)
        if conflicts:
            raise ValueError(f"Alias names {conflicts} conflict with model names.")
                
        return v


class ConfigLoader:
    """Configuration loader"""
    
    def __init__(self, config_path: str = "config.yaml"):
        self.config_path = config_path
        self._config: AppConfig = None
    
    def load_config(self) -> AppConfig:
        """Load configuration file"""
        if not os.path.exists(self.config_path):
            raise FileNotFoundError(
                f"Configuration file '{self.config_path}' not found. "
                f"Please copy 'config.example.yaml' to '{self.config_path}' and modify the configuration as needed."
            )
        
        try:
            with open(self.config_path, 'r', encoding='utf-8') as f:
                config_data = yaml.safe_load(f)
        except yaml.YAMLError as e:
            raise ValueError(f"Configuration file format error: {e}")
        except Exception as e:
            raise ValueError(f"Failed to read configuration file: {e}")
        
        if not config_data:
            raise ValueError("Configuration file is empty")
        
        try:
            self._config = AppConfig(**config_data)
            return self._config
        except Exception as e:
            raise ValueError(f"Configuration validation failed: {e}")
    
    @property
    def config(self) -> AppConfig:
        """Get configuration object"""
        if self._config is None:
            self.load_config()
        return self._config
    
    def get_model_to_service_mapping(self) -> tuple[Dict[str, Dict[str, Any]], Dict[str, List[str]]]:
        """Get model to service mapping and model aliases"""
        config = self.config
        model_mapping = {}
        alias_mapping = {}
        
        for service in config.upstream_services:
            service_info = {
                "name": service.name,
                "base_url": service.base_url,
                "api_key": service.api_key,
                "description": service.description,
                "is_default": service.is_default
            }
            
            for model_entry in service.models:
                model_mapping[model_entry] = service_info
                if ':' in model_entry:
                    parts = model_entry.split(':', 1)
                    if len(parts) == 2:
                        alias, _ = parts
                        if alias not in alias_mapping:
                            alias_mapping[alias] = []
                        alias_mapping[alias].append(model_entry)
        
        return model_mapping, alias_mapping
    
    def get_default_service(self) -> Dict[str, Any]:
        """Get default service configuration"""
        config = self.config
        for service in config.upstream_services:
            if service.is_default:
                return {
                    "name": service.name,
                    "base_url": service.base_url,
                    "api_key": service.api_key,
                    "description": service.description,
                    "is_default": service.is_default
                }
        raise ValueError("No default service configured")
    
    def get_allowed_client_keys(self) -> Set[str]:
        """Get set of allowed client keys"""
        return set(self.config.client_authentication.allowed_keys)
    
    def get_log_level(self) -> str:
        """Get configured log level"""
        return self.config.features.log_level
    
    def get_features_config(self) -> Dict[str, Any]:
        """Get feature configuration"""
        return {
            "function_calling": self.config.features.enable_function_calling,
            "log_level": self.config.features.log_level,
            "convert_developer_to_system": self.config.features.convert_developer_to_system
        }


config_loader = ConfigLoader()