""" Model configurations for the multi-agent algebra chatbot. Includes rate limits, model parameters, and factory functions. """ import os import time import asyncio from typing import Optional, Dict, Any, Callable, TypeVar from functools import wraps from dataclasses import dataclass, field from langchain_groq import ChatGroq @dataclass class ModelConfig: """Configuration for a specific model.""" id: str temperature: float = 0.6 max_tokens: int = 4096 context_length: int = 128000 # Default context window top_p: float = 1.0 streaming: bool = True # Rate limits rpm: int = 30 # Requests per minute rpd: int = 1000 # Requests per day tpm: int = 10000 # Tokens per minute tpd: int = 300000 # Tokens per day # Model configurations based on rate limit table MODEL_CONFIGS: Dict[str, ModelConfig] = { "kimi-k2": ModelConfig( id="moonshotai/kimi-k2-instruct-0905", temperature=0.0, max_tokens=16384, context_length=262144, # 256K tokens top_p=1.0, rpm=60, rpd=1000, tpm=10000, tpd=300000 ), "llama-4-maverick": ModelConfig( id="meta-llama/llama-4-maverick-17b-128e-instruct", temperature=0.0, max_tokens=8192, context_length=128000, rpm=30, rpd=1000, tpm=6000, tpd=500000 ), "llama-4-scout": ModelConfig( id="meta-llama/llama-4-scout-17b-16e-instruct", temperature=0.0, max_tokens=8192, context_length=128000, rpm=30, rpd=1000, tpm=30000, tpd=500000 ), "qwen3-32b": ModelConfig( id="qwen/qwen3-32b", temperature=0.0, max_tokens=8192, context_length=32768, # 32K tokens rpm=60, rpd=1000, tpm=6000, tpd=500000 ), "gpt-oss-120b": ModelConfig( id="openai/gpt-oss-120b", temperature=0.0, max_tokens=8192, context_length=128000, rpm=30, rpd=1000, tpm=8000, tpd=200000 ), "wolfram": ModelConfig( id="wolfram-alpha-api", temperature=0.0, max_tokens=0, context_length=0, rpm=30, rpd=2000, tpm=100000, tpd=1000000 ), } @dataclass class ModelRateLimitTracker: """Track rate limits for a specific model.""" model_name: str config: ModelConfig minute_requests: int = 0 minute_tokens: int = 0 day_requests: int = 0 day_tokens: int = 0 last_minute_reset: float = field(default_factory=time.time) last_day_reset: float = field(default_factory=time.time) def _reset_if_needed(self): """Reset counters if time windows have passed.""" now = time.time() if now - self.last_minute_reset >= 60: self.minute_requests = 0 self.minute_tokens = 0 self.last_minute_reset = now if now - self.last_day_reset >= 86400: self.day_requests = 0 self.day_tokens = 0 self.last_day_reset = now def can_request(self, estimated_tokens: int = 100) -> tuple[bool, str]: """Check if a request can be made within rate limits.""" self._reset_if_needed() if self.minute_requests >= self.config.rpm: return False, f"Rate limit: {self.model_name} exceeded {self.config.rpm} RPM" if self.day_requests >= self.config.rpd: return False, f"Rate limit: {self.model_name} exceeded {self.config.rpd} RPD" if self.minute_tokens + estimated_tokens > self.config.tpm: return False, f"Rate limit: {self.model_name} would exceed {self.config.tpm} TPM" if self.day_tokens + estimated_tokens > self.config.tpd: return False, f"Rate limit: {self.model_name} would exceed {self.config.tpd} TPD" return True, "" def record_request(self, tokens_used: int): """Record a completed request.""" self._reset_if_needed() self.minute_requests += 1 self.day_requests += 1 self.minute_tokens += tokens_used self.day_tokens += tokens_used class ModelManager: """Manages model instances and rate limiting.""" def __init__(self): self.trackers: Dict[str, ModelRateLimitTracker] = {} self._api_key = os.getenv("GROQ_API_KEY") def _get_tracker(self, model_name: str) -> ModelRateLimitTracker: """Get or create a rate limit tracker for a model.""" if model_name not in self.trackers: config = MODEL_CONFIGS.get(model_name) if not config: raise ValueError(f"Unknown model: {model_name}") self.trackers[model_name] = ModelRateLimitTracker(model_name, config) return self.trackers[model_name] def get_model(self, model_name: str) -> ChatGroq: """Get a ChatGroq instance for the specified model.""" config = MODEL_CONFIGS.get(model_name) if not config: raise ValueError(f"Unknown model: {model_name}") return ChatGroq( api_key=self._api_key, model=config.id, temperature=config.temperature, max_tokens=config.max_tokens, streaming=config.streaming, max_retries=3, # Retry network errors ) def check_rate_limit(self, model_name: str, estimated_tokens: int = 100) -> tuple[bool, str]: """Check if a model can handle a request.""" tracker = self._get_tracker(model_name) return tracker.can_request(estimated_tokens) def record_usage(self, model_name: str, tokens_used: int): """Record token usage for a model.""" tracker = self._get_tracker(model_name) tracker.record_request(tokens_used) async def invoke_with_fallback( self, primary_model: str, fallback_model: Optional[str], messages: list, estimated_tokens: int = 100 ) -> tuple[str, str, int]: """ Invoke a model with optional fallback on rate limit or error. Returns: (response_content, model_used, tokens_used) """ # Try primary model can_use, error = self.check_rate_limit(primary_model, estimated_tokens) if can_use: try: llm = self.get_model(primary_model) response = await llm.ainvoke(messages) tokens = len(response.content) // 4 # Rough estimate self.record_usage(primary_model, tokens) return response.content, primary_model, tokens except Exception as e: if fallback_model: pass # Try fallback else: raise e # Try fallback if available if fallback_model: can_use, error = self.check_rate_limit(fallback_model, estimated_tokens) if can_use: llm = self.get_model(fallback_model) response = await llm.ainvoke(messages) tokens = len(response.content) // 4 self.record_usage(fallback_model, tokens) return response.content, fallback_model, tokens raise Exception(error or "All models rate limited") # Global model manager instance model_manager = ModelManager() def get_model(model_name: str) -> ChatGroq: """Convenience function to get a model instance.""" return model_manager.get_model(model_name)