AswinMathew's picture
Upload folder using huggingface_hub
7190fd0 verified
"""LLM service — Gemini (primary) with OpenRouter fallback and retry logic."""
import asyncio
import base64
import json
import logging
import re
from pathlib import Path
import httpx
from app.config import get_settings
logger = logging.getLogger(__name__)
settings = get_settings()
GEMINI_BASE = "https://generativelanguage.googleapis.com/v1beta"
OPENROUTER_BASE = "https://openrouter.ai/api/v1"
# Default Gemini model + fallback
GEMINI_PRIMARY = "gemini-3.1-flash-lite-preview"
GEMINI_FALLBACK = "gemini-2.5-flash"
# Free models on OpenRouter that support JSON well
_OPENROUTER_FALLBACK_MODELS = [
"google/gemma-3-27b-it:free",
"meta-llama/llama-3.3-70b-instruct:free",
"deepseek/deepseek-chat-v3-0324:free",
]
def _extract_gemini_text(data: dict) -> str:
"""Extract text from Gemini response, handling thinking models (2.5+)."""
candidates = data.get("candidates", [])
if not candidates:
raise ValueError("No response from Gemini")
content = candidates[0].get("content", {})
parts = content.get("parts", [])
# Thinking models may have multiple parts — skip thinking parts, get output text only
texts = [p["text"] for p in parts if "text" in p and not p.get("thought")]
if not texts:
# Thinking model used all tokens on thought — no output
raise ValueError("Gemini returned empty response (thinking tokens exhausted)")
return "\n".join(texts)
async def call_gemini(
prompt: str,
system_instruction: str | None = None,
model: str = GEMINI_PRIMARY,
temperature: float = 0.7,
max_tokens: int = 8192,
response_json: bool = False,
) -> str:
"""Call Google Gemini API with retry.
Fallback chain: primary Gemini model → fallback Gemini model → OpenRouter.
"""
if not settings.gemini_api_key:
raise ValueError("GEMINI_API_KEY not configured")
# Build model fallback chain (skip duplicates)
models_to_try = [model]
if GEMINI_FALLBACK != model:
models_to_try.append(GEMINI_FALLBACK)
last_err = None
for current_model in models_to_try:
url = f"{GEMINI_BASE}/models/{current_model}:generateContent?key={settings.gemini_api_key}"
body: dict = {
"contents": [{"parts": [{"text": prompt}]}],
"generationConfig": {
"temperature": temperature,
},
}
# Only set maxOutputTokens if explicitly small (skip for large requests to let model output freely)
if max_tokens <= 16384:
body["generationConfig"]["maxOutputTokens"] = max_tokens
if system_instruction:
body["systemInstruction"] = {"parts": [{"text": system_instruction}]}
if response_json:
body["generationConfig"]["responseMimeType"] = "application/json"
for attempt in range(3):
try:
timeout = 300.0 if max_tokens > 16384 else 90.0
async with httpx.AsyncClient(timeout=timeout) as client:
resp = await client.post(url, json=body)
if resp.status_code == 429:
wait = min(2 ** attempt * 10, 30)
logger.warning("Gemini/%s 429 — retry in %ds (%d/3)", current_model, wait, attempt + 1)
await asyncio.sleep(wait)
last_err = httpx.HTTPStatusError(
"429", request=resp.request, response=resp
)
continue
if resp.status_code >= 500:
logger.warning("Gemini/%s %d — trying next fallback", current_model, resp.status_code)
last_err = httpx.HTTPStatusError(
str(resp.status_code), request=resp.request, response=resp
)
break # Skip retries, go to next model
resp.raise_for_status()
data = resp.json()
# Log finish reason and token usage for debugging
candidates = data.get("candidates", [])
if candidates:
finish_reason = candidates[0].get("finishReason", "unknown")
logger.info("Gemini/%s finish_reason=%s", current_model, finish_reason)
usage = data.get("usageMetadata", {})
if usage:
logger.info("Gemini/%s tokens: prompt=%s, output=%s, total=%s",
current_model,
usage.get("promptTokenCount", "?"),
usage.get("candidatesTokenCount", "?"),
usage.get("totalTokenCount", "?"))
return _extract_gemini_text(data)
except httpx.HTTPStatusError as e:
if e.response.status_code == 429:
last_err = e
wait = min(2 ** attempt * 10, 30)
logger.warning("Gemini/%s 429 — retry in %ds (%d/3)", current_model, wait, attempt + 1)
await asyncio.sleep(wait)
continue
if e.response.status_code >= 500:
last_err = e
logger.warning("Gemini/%s %d — trying next fallback", current_model, e.response.status_code)
break # Go to next model
raise
except httpx.TimeoutException as e:
last_err = e
logger.warning("Gemini/%s timeout (attempt %d/3)", current_model, attempt + 1)
continue
else:
# All 3 retries exhausted for this model — try next
logger.warning("Gemini/%s exhausted retries, trying next fallback", current_model)
# All Gemini models exhausted — try OpenRouter fallback
if settings.openrouter_api_key:
logger.warning("All Gemini models failed, falling back to OpenRouter")
try:
return await call_openrouter(
prompt,
system_prompt=system_instruction,
max_tokens=max_tokens,
temperature=temperature,
response_json=response_json,
)
except Exception as fallback_err:
logger.error("OpenRouter fallback also failed: %s", fallback_err)
raise last_err or ValueError("All LLM providers failed (Gemini + OpenRouter)")
def _strip_markdown_fences(text: str) -> str:
"""Strip markdown code fences (```json ... ```) from LLM output."""
text = text.strip()
text = re.sub(r"^```\w*\s*\n?", "", text) # opening fence with optional lang tag
text = re.sub(r"\n?```\s*$", "", text) # closing fence
return text.strip()
def _repair_json(text: str) -> str | None:
"""Attempt to fix common JSON issues from LLM output.
Handles trailing commas and unclosed brackets from truncated responses.
Returns repaired text, or None if the structure looks unsalvageable.
"""
# Remove trailing commas before ] or }
text = re.sub(r",\s*([}\]])", r"\1", text)
# Count bracket imbalance
open_braces = text.count("{") - text.count("}")
open_brackets = text.count("[") - text.count("]")
if open_braces < 0 or open_brackets < 0:
return None # more closers than openers — not fixable
if open_braces > 0 or open_brackets > 0:
# Truncated response — close open brackets in correct nesting order
text = text.rstrip().rstrip(",")
# Walk backwards through text to find last-opened bracket type
stack = []
for ch in text:
if ch in "{[":
stack.append(ch)
elif ch in "}]":
if stack:
stack.pop()
# Close remaining open brackets in reverse (LIFO) order
closers = {"[": "]", "{": "}"}
text += "".join(closers[b] for b in reversed(stack))
return text
# Balanced brackets — return (may just have had trailing commas)
return text
async def call_gemini_json(
prompt: str,
system_instruction: str | None = None,
model: str = GEMINI_PRIMARY,
max_tokens: int = 8192,
_retries: int = 2,
) -> dict | list:
"""Call Gemini and parse JSON response with retry and repair on failure."""
text = await call_gemini(
prompt, system_instruction, model, response_json=True, max_tokens=max_tokens
)
for attempt in range(_retries + 1):
cleaned = _strip_markdown_fences(text)
# Try direct parse
try:
return json.loads(cleaned)
except json.JSONDecodeError:
pass
# Try repair
repaired = _repair_json(cleaned)
if repaired is not None:
try:
return json.loads(repaired)
except json.JSONDecodeError:
pass
# Log for debugging
logger.warning(
"JSON parse failed (attempt %d/%d). Raw text (first 500): %s",
attempt + 1, _retries + 1, text[:500],
)
if attempt < _retries:
# Retry with stricter instructions
retry_prompt = (
prompt
+ "\n\nCRITICAL: Your previous response was not valid JSON. "
"Return ONLY a valid JSON object/array. No markdown fences, no commentary."
)
text = await call_gemini(
retry_prompt, system_instruction, model, response_json=True,
max_tokens=max_tokens,
)
raise ValueError(
f"Failed to parse JSON after {_retries + 1} attempts. "
f"Raw text starts with: {text[:200]}"
)
async def call_openrouter(
prompt: str,
system_prompt: str | None = None,
model: str | None = None,
temperature: float = 0.7,
max_tokens: int = 4096,
response_json: bool = False,
) -> str:
"""Call OpenRouter free LLM. Tries multiple free models on failure."""
if not settings.openrouter_api_key:
raise ValueError("OPENROUTER_API_KEY not configured")
messages = []
if system_prompt:
messages.append({"role": "system", "content": system_prompt})
if response_json:
prompt = prompt + "\n\nIMPORTANT: Respond ONLY with valid JSON, no markdown fences or extra text."
messages.append({"role": "user", "content": prompt})
models_to_try = [model] if model else _OPENROUTER_FALLBACK_MODELS
last_err = None
for m in models_to_try:
try:
async with httpx.AsyncClient(timeout=90.0) as client:
resp = await client.post(
f"{OPENROUTER_BASE}/chat/completions",
headers={
"Authorization": f"Bearer {settings.openrouter_api_key}",
"Content-Type": "application/json",
},
json={
"model": m,
"messages": messages,
"temperature": temperature,
"max_tokens": max_tokens,
},
)
resp.raise_for_status()
data = resp.json()
return data["choices"][0]["message"]["content"]
except Exception as e:
logger.warning("OpenRouter model %s failed: %s", m, e)
last_err = e
continue
raise last_err or ValueError("All OpenRouter models failed")
# Free vision models on OpenRouter (verified 2026-02)
_OPENROUTER_VISION_MODELS = [
"google/gemma-3-27b-it:free",
"mistralai/mistral-small-3.1-24b-instruct:free",
"nvidia/nemotron-nano-12b-v2-vl:free",
"google/gemma-3-12b-it:free",
]
_IMAGE_MIME = {
".png": "image/png",
".jpg": "image/jpeg",
".jpeg": "image/jpeg",
".webp": "image/webp",
}
def _encode_image(path: str) -> tuple[str, str]:
"""Read an image file and return (base64_data, mime_type)."""
p = Path(path)
mime = _IMAGE_MIME.get(p.suffix.lower(), "image/png")
data = p.read_bytes()
return base64.b64encode(data).decode("utf-8"), mime
async def call_openrouter_vision(
prompt: str,
image_paths: list[str],
system_prompt: str | None = None,
model: str | None = None,
max_tokens: int = 2048,
) -> str:
"""Send images + prompt to a free OpenRouter vision model."""
if not settings.openrouter_api_key:
raise ValueError("OPENROUTER_API_KEY not configured")
# Build content parts: images first, then text
content_parts = []
for img_path in image_paths:
b64, mime = _encode_image(img_path)
content_parts.append({
"type": "image_url",
"image_url": {"url": f"data:{mime};base64,{b64}"},
})
content_parts.append({"type": "text", "text": prompt})
messages = []
if system_prompt:
messages.append({"role": "system", "content": system_prompt})
messages.append({"role": "user", "content": content_parts})
models_to_try = [model] if model else _OPENROUTER_VISION_MODELS
last_err = None
for m in models_to_try:
try:
async with httpx.AsyncClient(timeout=120.0) as client:
resp = await client.post(
f"{OPENROUTER_BASE}/chat/completions",
headers={
"Authorization": f"Bearer {settings.openrouter_api_key}",
"Content-Type": "application/json",
},
json={
"model": m,
"messages": messages,
"max_tokens": max_tokens,
},
)
resp.raise_for_status()
data = resp.json()
return data["choices"][0]["message"]["content"]
except Exception as e:
logger.warning("OpenRouter vision model %s failed: %s", m, e)
last_err = e
continue
raise last_err or ValueError("All OpenRouter vision models failed")
async def call_openrouter_vision_json(
prompt: str,
image_paths: list[str],
system_prompt: str | None = None,
model: str | None = None,
max_tokens: int = 2048,
) -> dict | list:
"""Call OpenRouter vision and parse JSON response with repair logic."""
full_prompt = prompt + "\n\nIMPORTANT: Respond ONLY with valid JSON, no markdown fences or extra text."
text = await call_openrouter_vision(
full_prompt, image_paths, system_prompt, model, max_tokens
)
cleaned = _strip_markdown_fences(text)
# Try direct parse
try:
return json.loads(cleaned)
except json.JSONDecodeError:
pass
# Try repair
repaired = _repair_json(cleaned)
if repaired is not None:
try:
return json.loads(repaired)
except json.JSONDecodeError:
pass
logger.warning("Vision JSON parse failed. Raw (first 500): %s", text[:500])
raise ValueError(f"Failed to parse vision JSON. Raw starts with: {text[:200]}")