Spaces:
Running
Running
File size: 6,298 Bytes
5b89d45 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 | """
Smart Rate Limiter with Adaptive Delays and Caching
Helps maximize chat usage within free tier limits
"""
import time
import logging
from typing import Optional, Dict, Any
from datetime import datetime, timedelta
from functools import lru_cache
import hashlib
logger = logging.getLogger(__name__)
class RateLimiter:
"""
Adaptive rate limiter that:
1. Tracks API usage per provider
2. Implements smart delays
3. Caches responses for repeated queries
4. Provides usage statistics
"""
def __init__(self, provider: str = "gemini"):
self.provider = provider
self.request_times = []
self.token_usage = {"input": 0, "output": 0, "total": 0}
self.last_request_time = None
# Load configuration (with fallbacks if config file missing)
try:
import rate_limit_config as config
except ImportError:
# Use defaults if config not found
class config:
GEMINI_RPM = 15
GEMINI_MIN_DELAY = 2.0
GEMINI_BURST_DELAY = 8.0
GROQ_RPM = 30
GROQ_MIN_DELAY = 1.0
GROQ_BURST_DELAY = 10.0
ENABLE_CACHE = True
CACHE_TTL = 300
# Provider-specific limits
self.limits = {
"gemini": {
"rpm": config.GEMINI_RPM,
"min_delay": config.GEMINI_MIN_DELAY,
"burst_delay": config.GEMINI_BURST_DELAY,
},
"groq": {
"rpm": config.GROQ_RPM,
"min_delay": config.GROQ_MIN_DELAY,
"burst_delay": config.GROQ_BURST_DELAY,
}
}
self.response_cache = {} if config.ENABLE_CACHE else None
self.cache_ttl = config.CACHE_TTL
def get_cache_key(self, query: str, context_hash: str = "") -> str:
"""Generate cache key for a query"""
combined = f"{query}:{context_hash}"
return hashlib.md5(combined.encode()).hexdigest()
def get_cached_response(self, cache_key: str) -> Optional[Dict[str, Any]]:
"""Check if we have a cached response"""
if self.response_cache is None:
return None
if cache_key in self.response_cache:
cached_data, timestamp = self.response_cache[cache_key]
if time.time() - timestamp < self.cache_ttl:
logger.info(f"🎯 Cache hit! Saved an API call.")
return cached_data
else:
# Expired, remove it
del self.response_cache[cache_key]
return None
def cache_response(self, cache_key: str, response: Dict[str, Any]):
"""Cache a response"""
if self.response_cache is None:
return
self.response_cache[cache_key] = (response, time.time())
# Keep cache size manageable
if len(self.response_cache) > 100:
# Remove oldest entries
sorted_items = sorted(self.response_cache.items(), key=lambda x: x[1][1])
for key, _ in sorted_items[:20]: # Remove 20 oldest
del self.response_cache[key]
def calculate_smart_delay(self) -> float:
"""
Calculate optimal delay based on recent usage.
Returns delay in seconds.
"""
config = self.limits.get(self.provider, self.limits["gemini"])
# Clean old request times (older than 1 minute)
cutoff = time.time() - 60
self.request_times = [t for t in self.request_times if t > cutoff]
# Check if we're approaching the rate limit
requests_last_minute = len(self.request_times)
if requests_last_minute >= config["rpm"] * 0.9: # 90% of limit
logger.warning(f"⚠️ Approaching rate limit ({requests_last_minute}/{config['rpm']} RPM)")
return config["burst_delay"]
elif requests_last_minute >= config["rpm"] * 0.7: # 70% of limit
return config["min_delay"] * 1.5
else:
return config["min_delay"]
def wait_if_needed(self):
"""
Smart wait that adapts to usage patterns.
Only waits when necessary to avoid rate limits.
"""
if self.last_request_time is None:
self.last_request_time = time.time()
self.request_times.append(time.time())
return
delay = self.calculate_smart_delay()
elapsed = time.time() - self.last_request_time
if elapsed < delay:
wait_time = delay - elapsed
logger.info(f"⏱️ Smart delay: waiting {wait_time:.1f}s to avoid rate limit...")
time.sleep(wait_time)
self.last_request_time = time.time()
self.request_times.append(time.time())
def record_usage(self, input_tokens: int = 0, output_tokens: int = 0):
"""Track token usage for statistics"""
self.token_usage["input"] += input_tokens
self.token_usage["output"] += output_tokens
self.token_usage["total"] += (input_tokens + output_tokens)
def get_usage_stats(self) -> Dict[str, Any]:
"""Get current usage statistics"""
cutoff = time.time() - 60
recent_requests = len([t for t in self.request_times if t > cutoff])
return {
"provider": self.provider,
"requests_last_minute": recent_requests,
"total_tokens": self.token_usage["total"],
"input_tokens": self.token_usage["input"],
"output_tokens": self.token_usage["output"],
"cache_size": len(self.response_cache) if self.response_cache else 0
}
def reset_stats(self):
"""Reset usage statistics"""
self.token_usage = {"input": 0, "output": 0, "total": 0}
self.request_times = []
logger.info("📊 Usage statistics reset")
# Global rate limiters (one per provider)
_rate_limiters: Dict[str, RateLimiter] = {}
def get_rate_limiter(provider: str) -> RateLimiter:
"""Get or create rate limiter for a provider"""
if provider not in _rate_limiters:
_rate_limiters[provider] = RateLimiter(provider)
return _rate_limiters[provider]
|