""" Model manager for Anthropic models. This module provides utilities for managing multiple Anthropic models, environment-based switching, and model performance tracking. """ from typing import Dict, Any, Optional, List from dataclasses import dataclass from datetime import datetime, timedelta from config.settings import get_settings, AnthropicModel from config.logging import get_logger from core.anthropic_client import AnthropicClient, AnthropicResponse logger = get_logger(__name__) @dataclass class ModelPerformance: """Performance metrics for a specific model.""" model_name: str total_requests: int = 0 total_tokens: int = 0 total_errors: int = 0 avg_response_time: float = 0.0 last_used: Optional[datetime] = None success_rate: float = 100.0 class ModelManager: """ Manager for multiple Anthropic models with performance tracking. Features: - Environment-based model selection - Performance tracking per model - Automatic model switching based on performance - Model health monitoring """ def __init__(self): """Initialize model manager.""" self.settings = get_settings() self.current_model = self.settings.anthropic_model self.performance_stats: Dict[str, ModelPerformance] = {} self.clients: Dict[str, AnthropicClient] = {} # Initialize performance tracking for all models for model in AnthropicModel: self.performance_stats[model.value] = ModelPerformance(model_name=model.value) logger.info(f"Model manager initialized with default model: {self.current_model.value}") async def get_client(self, model: Optional[AnthropicModel] = None) -> AnthropicClient: """ Get or create client for specified model. Args: model: Model to get client for (uses current if not specified) Returns: AnthropicClient instance for the model """ target_model = model or self.current_model model_key = target_model.value if model_key not in self.clients: self.clients[model_key] = AnthropicClient(model=target_model) logger.info(f"Created new client for model: {model_key}") return self.clients[model_key] async def generate_completion( self, messages: List[Dict[str, str]], max_tokens: int = 4000, temperature: float = 0.0, system_prompt: Optional[str] = None, model: Optional[AnthropicModel] = None, fallback_on_error: bool = True ) -> AnthropicResponse: """ Generate completion with automatic fallback and performance tracking. Args: messages: List of message dictionaries max_tokens: Maximum tokens to generate temperature: Sampling temperature system_prompt: Optional system prompt model: Model to use (uses current if not specified) fallback_on_error: Whether to try fallback model on error Returns: AnthropicResponse with completion and metadata """ target_model = model or self.current_model start_time = datetime.now() try: client = await self.get_client(target_model) response = await client.generate_completion( messages=messages, max_tokens=max_tokens, temperature=temperature, system_prompt=system_prompt, model=target_model ) # Update performance stats self._update_performance_stats(target_model.value, response, start_time, success=True) return response except Exception as e: # Update error stats self._update_performance_stats(target_model.value, None, start_time, success=False) # Try fallback model if enabled and not already using fallback if fallback_on_error and target_model != self._get_fallback_model(): logger.warning(f"Model {target_model.value} failed, trying fallback: {str(e)}") try: fallback_model = self._get_fallback_model() client = await self.get_client(fallback_model) response = await client.generate_completion( messages=messages, max_tokens=max_tokens, temperature=temperature, system_prompt=system_prompt, model=fallback_model ) # Update performance stats for fallback self._update_performance_stats(fallback_model.value, response, start_time, success=True) logger.info(f"Fallback model {fallback_model.value} succeeded") return response except Exception as fallback_error: logger.error(f"Fallback model also failed: {str(fallback_error)}") self._update_performance_stats(self._get_fallback_model().value, None, start_time, success=False) # Re-raise original exception if no fallback or fallback failed raise def switch_model(self, model: AnthropicModel) -> None: """ Switch the current default model. Args: model: New model to use as default """ old_model = self.current_model.value self.current_model = model logger.info(f"Switched default model from {old_model} to {model.value}") def get_model_from_env(self) -> AnthropicModel: """ Get model from environment variable. Returns: AnthropicModel based on environment configuration """ return self.settings.anthropic_model def get_performance_stats(self) -> Dict[str, Dict[str, Any]]: """ Get performance statistics for all models. Returns: Dictionary with performance stats for each model """ stats = {} for model_name, perf in self.performance_stats.items(): stats[model_name] = { "total_requests": perf.total_requests, "total_tokens": perf.total_tokens, "total_errors": perf.total_errors, "avg_response_time": perf.avg_response_time, "success_rate": perf.success_rate, "last_used": perf.last_used.isoformat() if perf.last_used else None } return stats def get_best_performing_model(self) -> AnthropicModel: """ Get the best performing model based on success rate and response time. Returns: AnthropicModel with best performance """ best_model = self.current_model best_score = 0.0 for model_name, perf in self.performance_stats.items(): if perf.total_requests == 0: continue # Score based on success rate and inverse response time score = perf.success_rate * (1.0 / max(perf.avg_response_time, 0.1)) if score > best_score: best_score = score try: best_model = AnthropicModel(model_name) except ValueError: continue return best_model def _get_fallback_model(self) -> AnthropicModel: """Get fallback model (currently Claude 3.5 Haiku for speed).""" if self.current_model == AnthropicModel.CLAUDE_3_5_SONNET: return AnthropicModel.CLAUDE_3_5_HAIKU else: return AnthropicModel.CLAUDE_3_5_SONNET def _update_performance_stats( self, model_name: str, response: Optional[AnthropicResponse], start_time: datetime, success: bool ) -> None: """Update performance statistics for a model.""" if model_name not in self.performance_stats: self.performance_stats[model_name] = ModelPerformance(model_name=model_name) perf = self.performance_stats[model_name] perf.total_requests += 1 perf.last_used = datetime.now() if success and response: # Update token count perf.total_tokens += response.usage.get('total_tokens', 0) # Update average response time response_time = (datetime.now() - start_time).total_seconds() if perf.avg_response_time == 0: perf.avg_response_time = response_time else: # Exponential moving average perf.avg_response_time = 0.9 * perf.avg_response_time + 0.1 * response_time else: perf.total_errors += 1 # Update success rate perf.success_rate = ((perf.total_requests - perf.total_errors) / perf.total_requests) * 100 async def health_check(self) -> Dict[str, Any]: """ Perform health check on all available models. Returns: Dictionary with health status for each model """ health_status = {} for model in AnthropicModel: try: client = await self.get_client(model) is_healthy = await client.validate_connection() health_status[model.value] = { "healthy": is_healthy, "last_check": datetime.now().isoformat(), "performance": self.performance_stats[model.value].__dict__ } except Exception as e: health_status[model.value] = { "healthy": False, "error": str(e), "last_check": datetime.now().isoformat(), "performance": self.performance_stats[model.value].__dict__ } return health_status async def close_all_clients(self) -> None: """Close all model clients.""" for client in self.clients.values(): await client.close() self.clients.clear() logger.info("All model clients closed") # Global model manager instance _model_manager: Optional[ModelManager] = None def get_model_manager() -> ModelManager: """Get or create global model manager instance.""" global _model_manager if _model_manager is None: _model_manager = ModelManager() return _model_manager async def close_model_manager() -> None: """Close global model manager.""" global _model_manager if _model_manager is not None: await _model_manager.close_all_clients() _model_manager = None