"""Shared Gemini model tier configuration for Cepheus.""" from __future__ import annotations import os import re import time from typing import Any, TypeVar T = TypeVar("T") # Recommended core models (Google Gemini API slugs) DEFAULT_MODEL = "gemini-3.5-flash" PRO_MODEL = "gemini-3.1-pro-preview" LITE_MODEL = "gemini-3.1-flash-lite" # Map friendly names / legacy env values to current API slugs. MODEL_ALIASES: dict[str, str] = { "gemini-flash-latest": DEFAULT_MODEL, "gemini-3.1-pro": PRO_MODEL, "gemini-3.1-pro-latest": PRO_MODEL, } def get_model(tier: str = "default") -> str: """Return the configured model slug for default, pro, or lite tier.""" if tier == "pro": raw = os.getenv("GEMINI_MODEL_PRO", PRO_MODEL) elif tier == "lite": raw = os.getenv("GEMINI_MODEL_LITE", LITE_MODEL) else: raw = os.getenv("GEMINI_MODEL", DEFAULT_MODEL) return MODEL_ALIASES.get(raw, raw) def fallback_chain(tier: str = "default") -> list[str]: """Primary model for the tier, then the remaining tiers as fallbacks. Ordering: requested tier first, then the other working tiers (so a quota-exhausted Pro/Flash automatically degrades to Flash-Lite, which has the most generous quota). Duplicates are removed while preserving order. """ primary = get_model(tier) default = get_model("default") lite = get_model("lite") ordered = [primary, default, lite] seen: set[str] = set() chain: list[str] = [] for model in ordered: if model and model not in seen: seen.add(model) chain.append(model) return chain def parse_retry_delay(message: str) -> float: """Extract the server-suggested retry delay (seconds) from a 429 error message.""" match = re.search(r"retryDelay['\":\s]+(\d+(?:\.\d+)?)s", message) if match: try: return float(match.group(1)) + 1 except ValueError: pass return 5.0 def is_rate_limit(exc: Exception) -> bool: text = str(exc) return "429" in text or "RESOURCE_EXHAUSTED" in text def api_key_configured() -> bool: """True when GEMINI_API_KEY is set (non-empty) in the environment.""" return bool(os.getenv("GEMINI_API_KEY", "").strip()) def is_not_found(exc: Exception) -> bool: text = str(exc) return "404" in text or "NOT_FOUND" in text _last_api_call_time = 0.0 _MIN_SPACING_SECONDS = 4.0 def generate_with_fallback( client: Any, *, tier: str = "default", contents: Any, config: Any, rounds: int = 2, ) -> Any: """Call generate_content, degrading across the model chain on quota errors. Strategy (fast + bulletproof): - Try every model in the chain once (Pro → Flash → Flash-Lite). The first success returns immediately. A model that is quota-exhausted (429) or unavailable (404) is skipped instantly — no blocking sleeps — so a valid key always lands on a model with available quota. - If the entire chain is rate-limited, wait briefly and retry the chain once more (`rounds`) to ride out a transient spike. Raises the last exception only when every model in every round fails. """ global _last_api_call_time # Enforce minimum spacing between calls to prevent 429 rate limits now = time.time() elapsed = now - _last_api_call_time if elapsed < _MIN_SPACING_SECONDS: sleep_needed = _MIN_SPACING_SECONDS - elapsed time.sleep(sleep_needed) chain = fallback_chain(tier) last_exc: Exception | None = None for round_idx in range(max(1, rounds)): all_rate_limited = True for model in chain: try: _last_api_call_time = time.time() return client.models.generate_content( model=model, contents=contents, config=config, ) except Exception as exc: last_exc = exc if not is_rate_limit(exc): all_rate_limited = False # 404/not-found or other transient: move to next model immediately. continue # Only retry the whole chain if everything was a genuine rate-limit. if not all_rate_limited or round_idx == rounds - 1: break delay = parse_retry_delay(str(last_exc)) if last_exc else 5.0 time.sleep(min(delay, 8.0)) if last_exc is None: raise RuntimeError("generate_with_fallback called without attempting a model") raise last_exc