Spaces:
Sleeping
Sleeping
File size: 11,116 Bytes
7d5083d |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 270 271 272 273 274 275 276 277 278 279 280 281 282 283 284 285 286 287 288 289 290 291 292 293 294 295 296 297 298 299 300 301 302 303 304 305 306 307 308 309 |
"""
Model manager for Anthropic models.
This module provides utilities for managing multiple Anthropic models,
environment-based switching, and model performance tracking.
"""
from typing import Dict, Any, Optional, List
from dataclasses import dataclass
from datetime import datetime, timedelta
from config.settings import get_settings, AnthropicModel
from config.logging import get_logger
from core.anthropic_client import AnthropicClient, AnthropicResponse
logger = get_logger(__name__)
@dataclass
class ModelPerformance:
"""Performance metrics for a specific model."""
model_name: str
total_requests: int = 0
total_tokens: int = 0
total_errors: int = 0
avg_response_time: float = 0.0
last_used: Optional[datetime] = None
success_rate: float = 100.0
class ModelManager:
"""
Manager for multiple Anthropic models with performance tracking.
Features:
- Environment-based model selection
- Performance tracking per model
- Automatic model switching based on performance
- Model health monitoring
"""
def __init__(self):
"""Initialize model manager."""
self.settings = get_settings()
self.current_model = self.settings.anthropic_model
self.performance_stats: Dict[str, ModelPerformance] = {}
self.clients: Dict[str, AnthropicClient] = {}
# Initialize performance tracking for all models
for model in AnthropicModel:
self.performance_stats[model.value] = ModelPerformance(model_name=model.value)
logger.info(f"Model manager initialized with default model: {self.current_model.value}")
async def get_client(self, model: Optional[AnthropicModel] = None) -> AnthropicClient:
"""
Get or create client for specified model.
Args:
model: Model to get client for (uses current if not specified)
Returns:
AnthropicClient instance for the model
"""
target_model = model or self.current_model
model_key = target_model.value
if model_key not in self.clients:
self.clients[model_key] = AnthropicClient(model=target_model)
logger.info(f"Created new client for model: {model_key}")
return self.clients[model_key]
async def generate_completion(
self,
messages: List[Dict[str, str]],
max_tokens: int = 4000,
temperature: float = 0.0,
system_prompt: Optional[str] = None,
model: Optional[AnthropicModel] = None,
fallback_on_error: bool = True
) -> AnthropicResponse:
"""
Generate completion with automatic fallback and performance tracking.
Args:
messages: List of message dictionaries
max_tokens: Maximum tokens to generate
temperature: Sampling temperature
system_prompt: Optional system prompt
model: Model to use (uses current if not specified)
fallback_on_error: Whether to try fallback model on error
Returns:
AnthropicResponse with completion and metadata
"""
target_model = model or self.current_model
start_time = datetime.now()
try:
client = await self.get_client(target_model)
response = await client.generate_completion(
messages=messages,
max_tokens=max_tokens,
temperature=temperature,
system_prompt=system_prompt,
model=target_model
)
# Update performance stats
self._update_performance_stats(target_model.value, response, start_time, success=True)
return response
except Exception as e:
# Update error stats
self._update_performance_stats(target_model.value, None, start_time, success=False)
# Try fallback model if enabled and not already using fallback
if fallback_on_error and target_model != self._get_fallback_model():
logger.warning(f"Model {target_model.value} failed, trying fallback: {str(e)}")
try:
fallback_model = self._get_fallback_model()
client = await self.get_client(fallback_model)
response = await client.generate_completion(
messages=messages,
max_tokens=max_tokens,
temperature=temperature,
system_prompt=system_prompt,
model=fallback_model
)
# Update performance stats for fallback
self._update_performance_stats(fallback_model.value, response, start_time, success=True)
logger.info(f"Fallback model {fallback_model.value} succeeded")
return response
except Exception as fallback_error:
logger.error(f"Fallback model also failed: {str(fallback_error)}")
self._update_performance_stats(self._get_fallback_model().value, None, start_time, success=False)
# Re-raise original exception if no fallback or fallback failed
raise
def switch_model(self, model: AnthropicModel) -> None:
"""
Switch the current default model.
Args:
model: New model to use as default
"""
old_model = self.current_model.value
self.current_model = model
logger.info(f"Switched default model from {old_model} to {model.value}")
def get_model_from_env(self) -> AnthropicModel:
"""
Get model from environment variable.
Returns:
AnthropicModel based on environment configuration
"""
return self.settings.anthropic_model
def get_performance_stats(self) -> Dict[str, Dict[str, Any]]:
"""
Get performance statistics for all models.
Returns:
Dictionary with performance stats for each model
"""
stats = {}
for model_name, perf in self.performance_stats.items():
stats[model_name] = {
"total_requests": perf.total_requests,
"total_tokens": perf.total_tokens,
"total_errors": perf.total_errors,
"avg_response_time": perf.avg_response_time,
"success_rate": perf.success_rate,
"last_used": perf.last_used.isoformat() if perf.last_used else None
}
return stats
def get_best_performing_model(self) -> AnthropicModel:
"""
Get the best performing model based on success rate and response time.
Returns:
AnthropicModel with best performance
"""
best_model = self.current_model
best_score = 0.0
for model_name, perf in self.performance_stats.items():
if perf.total_requests == 0:
continue
# Score based on success rate and inverse response time
score = perf.success_rate * (1.0 / max(perf.avg_response_time, 0.1))
if score > best_score:
best_score = score
try:
best_model = AnthropicModel(model_name)
except ValueError:
continue
return best_model
def _get_fallback_model(self) -> AnthropicModel:
"""Get fallback model (currently Claude 3.5 Haiku for speed)."""
if self.current_model == AnthropicModel.CLAUDE_3_5_SONNET:
return AnthropicModel.CLAUDE_3_5_HAIKU
else:
return AnthropicModel.CLAUDE_3_5_SONNET
def _update_performance_stats(
self,
model_name: str,
response: Optional[AnthropicResponse],
start_time: datetime,
success: bool
) -> None:
"""Update performance statistics for a model."""
if model_name not in self.performance_stats:
self.performance_stats[model_name] = ModelPerformance(model_name=model_name)
perf = self.performance_stats[model_name]
perf.total_requests += 1
perf.last_used = datetime.now()
if success and response:
# Update token count
perf.total_tokens += response.usage.get('total_tokens', 0)
# Update average response time
response_time = (datetime.now() - start_time).total_seconds()
if perf.avg_response_time == 0:
perf.avg_response_time = response_time
else:
# Exponential moving average
perf.avg_response_time = 0.9 * perf.avg_response_time + 0.1 * response_time
else:
perf.total_errors += 1
# Update success rate
perf.success_rate = ((perf.total_requests - perf.total_errors) / perf.total_requests) * 100
async def health_check(self) -> Dict[str, Any]:
"""
Perform health check on all available models.
Returns:
Dictionary with health status for each model
"""
health_status = {}
for model in AnthropicModel:
try:
client = await self.get_client(model)
is_healthy = await client.validate_connection()
health_status[model.value] = {
"healthy": is_healthy,
"last_check": datetime.now().isoformat(),
"performance": self.performance_stats[model.value].__dict__
}
except Exception as e:
health_status[model.value] = {
"healthy": False,
"error": str(e),
"last_check": datetime.now().isoformat(),
"performance": self.performance_stats[model.value].__dict__
}
return health_status
async def close_all_clients(self) -> None:
"""Close all model clients."""
for client in self.clients.values():
await client.close()
self.clients.clear()
logger.info("All model clients closed")
# Global model manager instance
_model_manager: Optional[ModelManager] = None
def get_model_manager() -> ModelManager:
"""Get or create global model manager instance."""
global _model_manager
if _model_manager is None:
_model_manager = ModelManager()
return _model_manager
async def close_model_manager() -> None:
"""Close global model manager."""
global _model_manager
if _model_manager is not None:
await _model_manager.close_all_clients()
_model_manager = None |