solution_challenge_backend / backend /gemini_config.py
github-actions
Deploy to Hugging Face
c794b6b
Raw
History Blame Contribute Delete
4.6 kB
"""Shared Gemini model tier configuration for Cepheus."""
from __future__ import annotations
import os
import re
import time
from typing import Any, TypeVar
T = TypeVar("T")
# Recommended core models (Google Gemini API slugs)
DEFAULT_MODEL = "gemini-3.5-flash"
PRO_MODEL = "gemini-3.1-pro-preview"
LITE_MODEL = "gemini-3.1-flash-lite"
# Map friendly names / legacy env values to current API slugs.
MODEL_ALIASES: dict[str, str] = {
"gemini-flash-latest": DEFAULT_MODEL,
"gemini-3.1-pro": PRO_MODEL,
"gemini-3.1-pro-latest": PRO_MODEL,
}
def get_model(tier: str = "default") -> str:
"""Return the configured model slug for default, pro, or lite tier."""
if tier == "pro":
raw = os.getenv("GEMINI_MODEL_PRO", PRO_MODEL)
elif tier == "lite":
raw = os.getenv("GEMINI_MODEL_LITE", LITE_MODEL)
else:
raw = os.getenv("GEMINI_MODEL", DEFAULT_MODEL)
return MODEL_ALIASES.get(raw, raw)
def fallback_chain(tier: str = "default") -> list[str]:
"""Primary model for the tier, then the remaining tiers as fallbacks.
Ordering: requested tier first, then the other working tiers (so a
quota-exhausted Pro/Flash automatically degrades to Flash-Lite, which
has the most generous quota). Duplicates are removed while preserving order.
"""
primary = get_model(tier)
default = get_model("default")
lite = get_model("lite")
ordered = [primary, default, lite]
seen: set[str] = set()
chain: list[str] = []
for model in ordered:
if model and model not in seen:
seen.add(model)
chain.append(model)
return chain
def parse_retry_delay(message: str) -> float:
"""Extract the server-suggested retry delay (seconds) from a 429 error message."""
match = re.search(r"retryDelay['\":\s]+(\d+(?:\.\d+)?)s", message)
if match:
try:
return float(match.group(1)) + 1
except ValueError:
pass
return 5.0
def is_rate_limit(exc: Exception) -> bool:
text = str(exc)
return "429" in text or "RESOURCE_EXHAUSTED" in text
def api_key_configured() -> bool:
"""True when GEMINI_API_KEY is set (non-empty) in the environment."""
return bool(os.getenv("GEMINI_API_KEY", "").strip())
def is_not_found(exc: Exception) -> bool:
text = str(exc)
return "404" in text or "NOT_FOUND" in text
_last_api_call_time = 0.0
_MIN_SPACING_SECONDS = 4.0
def generate_with_fallback(
client: Any,
*,
tier: str = "default",
contents: Any,
config: Any,
rounds: int = 2,
) -> Any:
"""Call generate_content, degrading across the model chain on quota errors.
Strategy (fast + bulletproof):
- Try every model in the chain once (Pro → Flash → Flash-Lite). The first
success returns immediately. A model that is quota-exhausted (429) or
unavailable (404) is skipped instantly — no blocking sleeps — so a valid
key always lands on a model with available quota.
- If the entire chain is rate-limited, wait briefly and retry the chain
once more (`rounds`) to ride out a transient spike.
Raises the last exception only when every model in every round fails.
"""
global _last_api_call_time
# Enforce minimum spacing between calls to prevent 429 rate limits
now = time.time()
elapsed = now - _last_api_call_time
if elapsed < _MIN_SPACING_SECONDS:
sleep_needed = _MIN_SPACING_SECONDS - elapsed
time.sleep(sleep_needed)
chain = fallback_chain(tier)
last_exc: Exception | None = None
for round_idx in range(max(1, rounds)):
all_rate_limited = True
for model in chain:
try:
_last_api_call_time = time.time()
return client.models.generate_content(
model=model,
contents=contents,
config=config,
)
except Exception as exc:
last_exc = exc
if not is_rate_limit(exc):
all_rate_limited = False
# 404/not-found or other transient: move to next model immediately.
continue
# Only retry the whole chain if everything was a genuine rate-limit.
if not all_rate_limited or round_idx == rounds - 1:
break
delay = parse_retry_delay(str(last_exc)) if last_exc else 5.0
time.sleep(min(delay, 8.0))
if last_exc is None:
raise RuntimeError("generate_with_fallback called without attempting a model")
raise last_exc