Spaces:
Sleeping
Sleeping
File size: 7,436 Bytes
ba5110e |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 |
"""
Model configurations for the multi-agent algebra chatbot.
Includes rate limits, model parameters, and factory functions.
"""
import os
import time
import asyncio
from typing import Optional, Dict, Any, Callable, TypeVar
from functools import wraps
from dataclasses import dataclass, field
from langchain_groq import ChatGroq
@dataclass
class ModelConfig:
"""Configuration for a specific model."""
id: str
temperature: float = 0.6
max_tokens: int = 4096
context_length: int = 128000 # Default context window
top_p: float = 1.0
streaming: bool = True
# Rate limits
rpm: int = 30 # Requests per minute
rpd: int = 1000 # Requests per day
tpm: int = 10000 # Tokens per minute
tpd: int = 300000 # Tokens per day
# Model configurations based on rate limit table
MODEL_CONFIGS: Dict[str, ModelConfig] = {
"kimi-k2": ModelConfig(
id="moonshotai/kimi-k2-instruct-0905",
temperature=0.0,
max_tokens=16384,
context_length=262144, # 256K tokens
top_p=1.0,
rpm=60, rpd=1000, tpm=10000, tpd=300000
),
"llama-4-maverick": ModelConfig(
id="meta-llama/llama-4-maverick-17b-128e-instruct",
temperature=0.0,
max_tokens=8192,
context_length=128000,
rpm=30, rpd=1000, tpm=6000, tpd=500000
),
"llama-4-scout": ModelConfig(
id="meta-llama/llama-4-scout-17b-16e-instruct",
temperature=0.0,
max_tokens=8192,
context_length=128000,
rpm=30, rpd=1000, tpm=30000, tpd=500000
),
"qwen3-32b": ModelConfig(
id="qwen/qwen3-32b",
temperature=0.0,
max_tokens=8192,
context_length=32768, # 32K tokens
rpm=60, rpd=1000, tpm=6000, tpd=500000
),
"gpt-oss-120b": ModelConfig(
id="openai/gpt-oss-120b",
temperature=0.0,
max_tokens=8192,
context_length=128000,
rpm=30, rpd=1000, tpm=8000, tpd=200000
),
"wolfram": ModelConfig(
id="wolfram-alpha-api",
temperature=0.0,
max_tokens=0,
context_length=0,
rpm=30, rpd=2000, tpm=100000, tpd=1000000
),
}
@dataclass
class ModelRateLimitTracker:
"""Track rate limits for a specific model."""
model_name: str
config: ModelConfig
minute_requests: int = 0
minute_tokens: int = 0
day_requests: int = 0
day_tokens: int = 0
last_minute_reset: float = field(default_factory=time.time)
last_day_reset: float = field(default_factory=time.time)
def _reset_if_needed(self):
"""Reset counters if time windows have passed."""
now = time.time()
if now - self.last_minute_reset >= 60:
self.minute_requests = 0
self.minute_tokens = 0
self.last_minute_reset = now
if now - self.last_day_reset >= 86400:
self.day_requests = 0
self.day_tokens = 0
self.last_day_reset = now
def can_request(self, estimated_tokens: int = 100) -> tuple[bool, str]:
"""Check if a request can be made within rate limits."""
self._reset_if_needed()
if self.minute_requests >= self.config.rpm:
return False, f"Rate limit: {self.model_name} exceeded {self.config.rpm} RPM"
if self.day_requests >= self.config.rpd:
return False, f"Rate limit: {self.model_name} exceeded {self.config.rpd} RPD"
if self.minute_tokens + estimated_tokens > self.config.tpm:
return False, f"Rate limit: {self.model_name} would exceed {self.config.tpm} TPM"
if self.day_tokens + estimated_tokens > self.config.tpd:
return False, f"Rate limit: {self.model_name} would exceed {self.config.tpd} TPD"
return True, ""
def record_request(self, tokens_used: int):
"""Record a completed request."""
self._reset_if_needed()
self.minute_requests += 1
self.day_requests += 1
self.minute_tokens += tokens_used
self.day_tokens += tokens_used
class ModelManager:
"""Manages model instances and rate limiting."""
def __init__(self):
self.trackers: Dict[str, ModelRateLimitTracker] = {}
self._api_key = os.getenv("GROQ_API_KEY")
def _get_tracker(self, model_name: str) -> ModelRateLimitTracker:
"""Get or create a rate limit tracker for a model."""
if model_name not in self.trackers:
config = MODEL_CONFIGS.get(model_name)
if not config:
raise ValueError(f"Unknown model: {model_name}")
self.trackers[model_name] = ModelRateLimitTracker(model_name, config)
return self.trackers[model_name]
def get_model(self, model_name: str) -> ChatGroq:
"""Get a ChatGroq instance for the specified model."""
config = MODEL_CONFIGS.get(model_name)
if not config:
raise ValueError(f"Unknown model: {model_name}")
return ChatGroq(
api_key=self._api_key,
model=config.id,
temperature=config.temperature,
max_tokens=config.max_tokens,
streaming=config.streaming,
max_retries=3, # Retry network errors
)
def check_rate_limit(self, model_name: str, estimated_tokens: int = 100) -> tuple[bool, str]:
"""Check if a model can handle a request."""
tracker = self._get_tracker(model_name)
return tracker.can_request(estimated_tokens)
def record_usage(self, model_name: str, tokens_used: int):
"""Record token usage for a model."""
tracker = self._get_tracker(model_name)
tracker.record_request(tokens_used)
async def invoke_with_fallback(
self,
primary_model: str,
fallback_model: Optional[str],
messages: list,
estimated_tokens: int = 100
) -> tuple[str, str, int]:
"""
Invoke a model with optional fallback on rate limit or error.
Returns: (response_content, model_used, tokens_used)
"""
# Try primary model
can_use, error = self.check_rate_limit(primary_model, estimated_tokens)
if can_use:
try:
llm = self.get_model(primary_model)
response = await llm.ainvoke(messages)
tokens = len(response.content) // 4 # Rough estimate
self.record_usage(primary_model, tokens)
return response.content, primary_model, tokens
except Exception as e:
if fallback_model:
pass # Try fallback
else:
raise e
# Try fallback if available
if fallback_model:
can_use, error = self.check_rate_limit(fallback_model, estimated_tokens)
if can_use:
llm = self.get_model(fallback_model)
response = await llm.ainvoke(messages)
tokens = len(response.content) // 4
self.record_usage(fallback_model, tokens)
return response.content, fallback_model, tokens
raise Exception(error or "All models rate limited")
# Global model manager instance
model_manager = ModelManager()
def get_model(model_name: str) -> ChatGroq:
"""Convenience function to get a model instance."""
return model_manager.get_model(model_name)
|