File size: 11,116 Bytes
7d5083d
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
"""
Model manager for Anthropic models.

This module provides utilities for managing multiple Anthropic models,
environment-based switching, and model performance tracking.
"""

from typing import Dict, Any, Optional, List
from dataclasses import dataclass
from datetime import datetime, timedelta

from config.settings import get_settings, AnthropicModel
from config.logging import get_logger
from core.anthropic_client import AnthropicClient, AnthropicResponse

logger = get_logger(__name__)


@dataclass
class ModelPerformance:
    """Performance metrics for a specific model."""
    model_name: str
    total_requests: int = 0
    total_tokens: int = 0
    total_errors: int = 0
    avg_response_time: float = 0.0
    last_used: Optional[datetime] = None
    success_rate: float = 100.0


class ModelManager:
    """
    Manager for multiple Anthropic models with performance tracking.
    
    Features:
    - Environment-based model selection
    - Performance tracking per model
    - Automatic model switching based on performance
    - Model health monitoring
    """
    
    def __init__(self):
        """Initialize model manager."""
        self.settings = get_settings()
        self.current_model = self.settings.anthropic_model
        self.performance_stats: Dict[str, ModelPerformance] = {}
        self.clients: Dict[str, AnthropicClient] = {}
        
        # Initialize performance tracking for all models
        for model in AnthropicModel:
            self.performance_stats[model.value] = ModelPerformance(model_name=model.value)
        
        logger.info(f"Model manager initialized with default model: {self.current_model.value}")
    
    async def get_client(self, model: Optional[AnthropicModel] = None) -> AnthropicClient:
        """
        Get or create client for specified model.
        
        Args:
            model: Model to get client for (uses current if not specified)
            
        Returns:
            AnthropicClient instance for the model
        """
        target_model = model or self.current_model
        model_key = target_model.value
        
        if model_key not in self.clients:
            self.clients[model_key] = AnthropicClient(model=target_model)
            logger.info(f"Created new client for model: {model_key}")
        
        return self.clients[model_key]
    
    async def generate_completion(
        self,
        messages: List[Dict[str, str]],
        max_tokens: int = 4000,
        temperature: float = 0.0,
        system_prompt: Optional[str] = None,
        model: Optional[AnthropicModel] = None,
        fallback_on_error: bool = True
    ) -> AnthropicResponse:
        """
        Generate completion with automatic fallback and performance tracking.
        
        Args:
            messages: List of message dictionaries
            max_tokens: Maximum tokens to generate
            temperature: Sampling temperature
            system_prompt: Optional system prompt
            model: Model to use (uses current if not specified)
            fallback_on_error: Whether to try fallback model on error
            
        Returns:
            AnthropicResponse with completion and metadata
        """
        target_model = model or self.current_model
        start_time = datetime.now()
        
        try:
            client = await self.get_client(target_model)
            response = await client.generate_completion(
                messages=messages,
                max_tokens=max_tokens,
                temperature=temperature,
                system_prompt=system_prompt,
                model=target_model
            )
            
            # Update performance stats
            self._update_performance_stats(target_model.value, response, start_time, success=True)
            
            return response
            
        except Exception as e:
            # Update error stats
            self._update_performance_stats(target_model.value, None, start_time, success=False)
            
            # Try fallback model if enabled and not already using fallback
            if fallback_on_error and target_model != self._get_fallback_model():
                logger.warning(f"Model {target_model.value} failed, trying fallback: {str(e)}")
                
                try:
                    fallback_model = self._get_fallback_model()
                    client = await self.get_client(fallback_model)
                    response = await client.generate_completion(
                        messages=messages,
                        max_tokens=max_tokens,
                        temperature=temperature,
                        system_prompt=system_prompt,
                        model=fallback_model
                    )
                    
                    # Update performance stats for fallback
                    self._update_performance_stats(fallback_model.value, response, start_time, success=True)
                    
                    logger.info(f"Fallback model {fallback_model.value} succeeded")
                    return response
                    
                except Exception as fallback_error:
                    logger.error(f"Fallback model also failed: {str(fallback_error)}")
                    self._update_performance_stats(self._get_fallback_model().value, None, start_time, success=False)
            
            # Re-raise original exception if no fallback or fallback failed
            raise
    
    def switch_model(self, model: AnthropicModel) -> None:
        """
        Switch the current default model.
        
        Args:
            model: New model to use as default
        """
        old_model = self.current_model.value
        self.current_model = model
        logger.info(f"Switched default model from {old_model} to {model.value}")
    
    def get_model_from_env(self) -> AnthropicModel:
        """
        Get model from environment variable.
        
        Returns:
            AnthropicModel based on environment configuration
        """
        return self.settings.anthropic_model
    
    def get_performance_stats(self) -> Dict[str, Dict[str, Any]]:
        """
        Get performance statistics for all models.
        
        Returns:
            Dictionary with performance stats for each model
        """
        stats = {}
        for model_name, perf in self.performance_stats.items():
            stats[model_name] = {
                "total_requests": perf.total_requests,
                "total_tokens": perf.total_tokens,
                "total_errors": perf.total_errors,
                "avg_response_time": perf.avg_response_time,
                "success_rate": perf.success_rate,
                "last_used": perf.last_used.isoformat() if perf.last_used else None
            }
        return stats
    
    def get_best_performing_model(self) -> AnthropicModel:
        """
        Get the best performing model based on success rate and response time.
        
        Returns:
            AnthropicModel with best performance
        """
        best_model = self.current_model
        best_score = 0.0
        
        for model_name, perf in self.performance_stats.items():
            if perf.total_requests == 0:
                continue
            
            # Score based on success rate and inverse response time
            score = perf.success_rate * (1.0 / max(perf.avg_response_time, 0.1))
            
            if score > best_score:
                best_score = score
                try:
                    best_model = AnthropicModel(model_name)
                except ValueError:
                    continue
        
        return best_model
    
    def _get_fallback_model(self) -> AnthropicModel:
        """Get fallback model (currently Claude 3.5 Haiku for speed)."""
        if self.current_model == AnthropicModel.CLAUDE_3_5_SONNET:
            return AnthropicModel.CLAUDE_3_5_HAIKU
        else:
            return AnthropicModel.CLAUDE_3_5_SONNET
    
    def _update_performance_stats(
        self,
        model_name: str,
        response: Optional[AnthropicResponse],
        start_time: datetime,
        success: bool
    ) -> None:
        """Update performance statistics for a model."""
        if model_name not in self.performance_stats:
            self.performance_stats[model_name] = ModelPerformance(model_name=model_name)
        
        perf = self.performance_stats[model_name]
        perf.total_requests += 1
        perf.last_used = datetime.now()
        
        if success and response:
            # Update token count
            perf.total_tokens += response.usage.get('total_tokens', 0)
            
            # Update average response time
            response_time = (datetime.now() - start_time).total_seconds()
            if perf.avg_response_time == 0:
                perf.avg_response_time = response_time
            else:
                # Exponential moving average
                perf.avg_response_time = 0.9 * perf.avg_response_time + 0.1 * response_time
        else:
            perf.total_errors += 1
        
        # Update success rate
        perf.success_rate = ((perf.total_requests - perf.total_errors) / perf.total_requests) * 100
    
    async def health_check(self) -> Dict[str, Any]:
        """
        Perform health check on all available models.
        
        Returns:
            Dictionary with health status for each model
        """
        health_status = {}
        
        for model in AnthropicModel:
            try:
                client = await self.get_client(model)
                is_healthy = await client.validate_connection()
                
                health_status[model.value] = {
                    "healthy": is_healthy,
                    "last_check": datetime.now().isoformat(),
                    "performance": self.performance_stats[model.value].__dict__
                }
                
            except Exception as e:
                health_status[model.value] = {
                    "healthy": False,
                    "error": str(e),
                    "last_check": datetime.now().isoformat(),
                    "performance": self.performance_stats[model.value].__dict__
                }
        
        return health_status
    
    async def close_all_clients(self) -> None:
        """Close all model clients."""
        for client in self.clients.values():
            await client.close()
        self.clients.clear()
        logger.info("All model clients closed")


# Global model manager instance
_model_manager: Optional[ModelManager] = None


def get_model_manager() -> ModelManager:
    """Get or create global model manager instance."""
    global _model_manager
    
    if _model_manager is None:
        _model_manager = ModelManager()
    
    return _model_manager


async def close_model_manager() -> None:
    """Close global model manager."""
    global _model_manager
    
    if _model_manager is not None:
        await _model_manager.close_all_clients()
        _model_manager = None