Spaces:
Running
Running
File size: 5,980 Bytes
bc10fd9 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 | import time
import threading
from collections import deque
from typing import Optional
import google.generativeai as genai
# βββ Model Pool (only models with actual quota) βββββββββββββββββββββββββββββββ
# Ordered by preference: most quota first
MODEL_POOL = [
{
"key": "gemini-3.1-flash-lite",
"name": "Gemini 3.1 Flash Lite",
"rpm": 15,
"rpd": 500,
"tpm": 250_000,
},
{
"key": "gemini-2.5-flash-lite", # gemini-2.5-flash-lite-preview-06-17 if needed
"name": "Gemini 2.5 Flash Lite",
"rpm": 10,
"rpd": 20,
"tpm": 250_000,
},
{
"key": "gemini-2.5-flash",
"name": "Gemini 2.5 Flash",
"rpm": 5,
"rpd": 20,
"tpm": 250_000,
},
{
"key": "gemini-2.0-flash", # "Gemini 3 Flash" in the UI
"name": "Gemini 3 Flash",
"rpm": 5,
"rpd": 20,
"tpm": 250_000,
},
]
class ModelManager:
"""
Tracks per-model rate limits (RPM + RPD) and automatically shuffles
to the next available model when a limit is reached.
Resets minute/day windows with a sliding window approach.
"""
def __init__(self):
self._lock = threading.Lock()
# For each model key: deque of UTC timestamps for recent calls
self._minute_calls: dict[str, deque] = {m["key"]: deque() for m in MODEL_POOL}
self._day_calls: dict[str, deque] = {m["key"]: deque() for m in MODEL_POOL}
# Track which models are in a cooldown (rate-limited by the API itself)
self._cooldown_until: dict[str, float] = {m["key"]: 0.0 for m in MODEL_POOL}
def _prune(self, dq: deque, window_seconds: int) -> None:
"""Remove timestamps outside the rolling window."""
cutoff = time.time() - window_seconds
while dq and dq[0] < cutoff:
dq.popleft()
def _is_available(self, model: dict) -> bool:
key = model["key"]
now = time.time()
# Hard cooldown (e.g. after a 429)
if now < self._cooldown_until[key]:
return False
self._prune(self._minute_calls[key], 60)
self._prune(self._day_calls[key], 86_400)
rpm_ok = len(self._minute_calls[key]) < model["rpm"]
rpd_ok = len(self._day_calls[key]) < model["rpd"]
return rpm_ok and rpd_ok
def _record_call(self, key: str) -> None:
now = time.time()
self._minute_calls[key].append(now)
self._day_calls[key].append(now)
def _set_cooldown(self, key: str, seconds: int = 65) -> None:
"""Call this after receiving a 429 to pause that model."""
self._cooldown_until[key] = time.time() + seconds
print(f"[ModelManager] {key} in cooldown for {seconds}s")
def get_available_model(self) -> Optional[dict]:
"""Return the first model that has remaining quota, or None."""
with self._lock:
for model in MODEL_POOL:
if self._is_available(model):
return model
return None
def call_with_fallback(self, system_prompt: str) -> tuple[str, str]:
"""
Try each model in order. On success return (response_text, model_key).
On 429 / quota error, mark the model as cooled down and try the next.
Raises RuntimeError if all models are exhausted.
"""
import google.api_core.exceptions as gex
with self._lock:
candidates = [m for m in MODEL_POOL if self._is_available(m)]
if not candidates:
raise RuntimeError("All models are rate-limited. Try again later.")
for model_info in candidates:
key = model_info["key"]
try:
genai_model = genai.GenerativeModel(
key,
generation_config={"response_mime_type": "application/json"},
)
response = genai_model.generate_content(system_prompt)
with self._lock:
self._record_call(key)
print(f"[ModelManager] Used: {key}")
return response.text, key
except gex.ResourceExhausted as e:
print(f"[ModelManager] 429 on {key}: {e}")
with self._lock:
self._set_cooldown(key, seconds=65)
continue # try next model
except Exception as e:
print(f"[ModelManager] Error on {key}: {e}")
continue # skip broken model, try next
raise RuntimeError("All models failed or are rate-limited.")
def status(self) -> list[dict]:
"""Return current usage snapshot for all models (useful for /api/models endpoint)."""
now = time.time()
result = []
with self._lock:
for m in MODEL_POOL:
key = m["key"]
self._prune(self._minute_calls[key], 60)
self._prune(self._day_calls[key], 86_400)
cooldown_remaining = max(0, self._cooldown_until[key] - now)
result.append({
"key": key,
"name": m["name"],
"rpm_limit": m["rpm"],
"rpd_limit": m["rpd"],
"rpm_used": len(self._minute_calls[key]),
"rpd_used": len(self._day_calls[key]),
"available": self._is_available(m),
"cooldown_seconds": round(cooldown_remaining),
})
return result
# Singleton β import this in main.py
model_manager = ModelManager() |