Spaces:
Running
Running
| import os | |
| import base64 | |
| import random | |
| import httpx | |
| from urllib.parse import quote | |
| from fastapi import APIRouter, Request, HTTPException, Header | |
| from fastapi.responses import Response, JSONResponse, StreamingResponse | |
| import re | |
| from typing import Optional | |
| import json | |
| from helper.assets import ( | |
| save_base64_image, | |
| cleanup_image, | |
| is_base64_image, | |
| ) | |
| import asyncio | |
| from helper.ratelimit import ( | |
| enforce_rate_limit, | |
| resolve_rate_limit_identity, | |
| check_audio_rate_limit, | |
| check_video_rate_limit, | |
| check_image_rate_limit, | |
| MAX_CHAT_PROMPT_BYTES, | |
| MAX_CHAT_PROMPT_CHARS, | |
| MAX_GROQ_PROMPT_BYTES, | |
| MAX_GROQ_PROMPT_CHARS, | |
| MAX_MEDIA_PROMPT_BYTES, | |
| MAX_MEDIA_PROMPT_CHARS, | |
| extract_user_text, | |
| calculate_messages_size, | |
| normalize_prompt_value, | |
| enforce_prompt_size, | |
| resolve_bound_subject, | |
| get_usage_snapshot_for_subject, | |
| ) | |
| from helper.keywords import * | |
| router = APIRouter(prefix="/gen") | |
| PKEY = os.getenv("POLLINATIONS_KEY", "") | |
| PKEY2 = os.getenv("POLLINATIONS2_KEY", "") | |
| PKEY3 = os.getenv("POLLINATIONS3_KEY", "") | |
| AIRFORCE_KEY = os.getenv("AIRFORCE") | |
| AIRFORCE_VIDEO_MODEL = "grok-imagine-video" | |
| AIRFORCE_API_URL = "https://api.airforce/v1/images/generations" | |
| valid_ratios = {"3:2", "2:3", "1:1", "", None} | |
| ratios = {"3:2", "2:3", "1:1"} | |
| valid_modes = {"normal", "fun", "", None} | |
| modes = {"normal", "fun"} | |
| MAX_VIDEO_RETRIES = 6 | |
| def is_cinematic_image_prompt(prompt: str) -> bool: | |
| for kw in CREATIVE_KEYWORDS: | |
| if kw in prompt.lower(): | |
| return True | |
| return False | |
| def is_complex_reasoning(prompt: str) -> bool: | |
| if len(prompt) > 800: | |
| return True | |
| for kw in REASONING_KEYWORDS: | |
| if kw in prompt: | |
| return True | |
| if re.search(r"\b(if|therefore|assume|let x|given that)\b", prompt): | |
| return True | |
| return False | |
| def is_lightweight(prompt: str) -> bool: | |
| if len(prompt) < 100: | |
| for kw in LIGHTWEIGHT_KEYWORDS: | |
| if kw in prompt: | |
| return True | |
| return False | |
| async def check_chat_rate_limit( | |
| request: Request, | |
| authorization: Optional[str], | |
| client_id: Optional[str] = None, | |
| ): | |
| return await enforce_rate_limit(request, authorization, "cloudChatDaily", client_id) | |
| # ----------------------------- | |
| # IMAGE GENERATION | |
| # ----------------------------- | |
| async def generate_image( | |
| request: Request, | |
| prompt: str = None, | |
| authorization: str = Header(None), | |
| x_client_id: str = Header(None), | |
| ): | |
| """ | |
| Image generation endpoint. | |
| -------------------------------------------------------------- | |
| • Accepts a plain‑text prompt (GET or JSON body). | |
| • Optional JSON fields: | |
| - mode: "fantasy" | "realistic" (keeps current behaviour) | |
| - image_urls: list of up to 2 image URLs or base‑64 strings | |
| • If *any* image is supplied we always use the Pollinations | |
| model **flux-klein** (the “editing” model). Otherwise the | |
| original heuristic (flux / zimage) is retained. | |
| • Base‑64 images are saved temporarily with the helper | |
| `save_base64_image` and served from the asset CDN exactly | |
| like the video endpoint does. | |
| --------------------------------------------------------------f | |
| """ | |
| timeout = httpx.Timeout(300.0, read=300.0) | |
| if prompt is None: | |
| payload = await request.json() | |
| prompt = payload.get("prompt") | |
| mode = payload.get("mode") | |
| image_urls = payload.get("image_urls") | |
| else: | |
| mode = request.query_params.get("mode") | |
| image_urls = request.query_params.getlist("image_urls") | |
| prompt = normalize_prompt_value(prompt, "prompt") | |
| enforce_prompt_size(prompt, MAX_MEDIA_PROMPT_CHARS, MAX_MEDIA_PROMPT_BYTES, "Image prompt") | |
| await check_image_rate_limit(request, authorization, x_client_id) | |
| chosen_model = "zimage" | |
| if is_cinematic_image_prompt(prompt): | |
| chosen_model = "flux" | |
| if isinstance(mode, str): | |
| m = mode.strip().lower() | |
| if m == "fantasy": | |
| chosen_model = "flux" | |
| elif m == "realistic": | |
| chosen_model = "zimage" | |
| has_input_image = bool(image_urls) | |
| temp_assets = [] | |
| if has_input_image: | |
| chosen_model = "klein" | |
| params = {"model": chosen_model, "key": PKEY2} | |
| if has_input_image: | |
| processed = [] | |
| for img in image_urls[:2]: | |
| if is_base64_image(img): | |
| image_id = save_base64_image(img) | |
| temp_assets.append(image_id) | |
| served = f"{request.base_url}asset-cdn/assets/{image_id}" | |
| processed.append(served) | |
| else: | |
| processed.append(img) | |
| params["image"] = "|".join(processed) | |
| encoded_prompt = quote(prompt, safe="") | |
| query = "&".join(f"{k}={quote(str(v), safe='')}" for k, v in params.items()) | |
| url = f"https://gen.pollinations.ai/image/{encoded_prompt}?{query}" | |
| try: | |
| async with httpx.AsyncClient(timeout=timeout) as client: | |
| resp = await client.get(url) | |
| finally: | |
| for aid in temp_assets: | |
| cleanup_image(aid) | |
| if resp.status_code != 200: | |
| raise HTTPException(500, f"Pollinations error: {resp.status_code}") | |
| return Response(content=resp.content, media_type="image/jpeg") | |
| # ----------------------------- | |
| # SFX GENERATION | |
| # ----------------------------- | |
| async def gensfx( | |
| request: Request, | |
| prompt: str = None, | |
| authorization: str = Header(None), | |
| x_client_id: str = Header(None), | |
| ): | |
| if prompt is None: | |
| payload = await request.json() | |
| prompt = payload.get("prompt") | |
| prompt = normalize_prompt_value(prompt, "prompt") | |
| enforce_prompt_size(prompt, MAX_MEDIA_PROMPT_CHARS, MAX_MEDIA_PROMPT_BYTES, "Audio prompt") | |
| await check_audio_rate_limit(request, authorization, x_client_id) | |
| url = f"https://gen.pollinations.ai/audio/{prompt}?model=acestep&key={PKEY}" | |
| async with httpx.AsyncClient(timeout=None) as client: | |
| resp = await client.get(url) | |
| if resp.status_code != 200: | |
| return JSONResponse( | |
| status_code=resp.status_code, | |
| content={"success": False, "error": "Upstream music/sfx generation failed"}, | |
| ) | |
| return Response(resp.content, media_type="audio/mpeg") | |
| # ----------------------------- | |
| # TTS GENERATION | |
| # ----------------------------- | |
| async def gentts( | |
| request: Request, | |
| prompt: str = None, | |
| authorization: str = Header(None), | |
| x_client_id: str = Header(None), | |
| ): | |
| if prompt is None: | |
| payload = await request.json() | |
| prompt = payload.get("prompt") | |
| prompt = normalize_prompt_value(prompt, "prompt") | |
| enforce_prompt_size(prompt, MAX_MEDIA_PROMPT_CHARS, MAX_MEDIA_PROMPT_BYTES, "Audio prompt") | |
| await check_audio_rate_limit(request, authorization, x_client_id) | |
| url = f"https://gen.pollinations.ai/audio/{prompt}?key={PKEY3}" | |
| async with httpx.AsyncClient(timeout=None) as client: | |
| resp = await client.get(url) | |
| if resp.status_code != 200: | |
| return JSONResponse( | |
| status_code=resp.status_code, | |
| content={"success": False, "error": "Upstream audio generation failed"}, | |
| ) | |
| return Response(resp.content, media_type="audio/mpeg") | |
| # ----------------------------- | |
| # VIDEO GENERATION (Pollinations) | |
| # ----------------------------- | |
| async def genvideo(request: Request, prompt: str = None, authorization: str = Header(None), x_client_id: str = Header(None)): | |
| if request.method == "HEAD": | |
| return Response( | |
| status_code=200, | |
| headers={ | |
| "Y-prompt": "string — required. The text prompt used to generate the video.", | |
| "Y-ratio": "string — optional. Aspect ratio of the output video.", | |
| "Y-ratio-values": "3:2,2:3,1:1", | |
| "Y-ratio-default": "3:2", | |
| "Y-mode": "string — optional. Controls generation style.", | |
| "Y-mode-values": "normal,fun", | |
| "Y-mode-default": "normal", | |
| "Y-duration": "integer — optional. Duration in seconds (1–10).", | |
| "Y-duration-default": "5", | |
| "Y-image_urls": "array<string> — optional. Up to 2 image URLs for conditioning.", | |
| "Y-image_urls-max": "2", | |
| "Y-response_format": "video/mp4", | |
| "Y-model": "grok-video", | |
| }, | |
| ) | |
| aspectRatio = "3:2" | |
| inputMode = "normal" | |
| duration = 5 | |
| image_urls = None | |
| ratio = None | |
| mode = None | |
| if prompt is None: | |
| user_body = await request.json() | |
| prompt = user_body.get("prompt") | |
| ratio = user_body.get("ratio") | |
| mode = user_body.get("mode") | |
| image_urls = user_body.get("image_urls") | |
| duration = user_body.get("duration", 5) | |
| if ratio not in valid_ratios: | |
| raise HTTPException( | |
| status_code=400, | |
| detail=f"Invalid aspect ratio '{ratio}'. Must be one of 3:2, 2:3, or 1:1.", | |
| ) | |
| if ratio in ratios: | |
| aspectRatio = ratio | |
| if mode not in valid_modes: | |
| raise HTTPException( | |
| status_code=400, | |
| detail=f"Invalid mode '{mode}'. Must be 'normal' or 'fun'.", | |
| ) | |
| if mode in modes: | |
| inputMode = mode | |
| if image_urls: | |
| if not isinstance(image_urls, list): | |
| raise HTTPException(400, "image_urls must be a list") | |
| if len(image_urls) > 2: | |
| raise HTTPException(400, "You may provide at most two image URLs") | |
| # Clamp duration | |
| try: | |
| duration = max(1, min(10, int(duration))) | |
| except (TypeError, ValueError): | |
| duration = 5 | |
| prompt = normalize_prompt_value(prompt, "prompt") | |
| enforce_prompt_size( | |
| prompt, MAX_MEDIA_PROMPT_CHARS, MAX_MEDIA_PROMPT_BYTES, "Video prompt" | |
| ) | |
| await check_video_rate_limit(request, authorization, x_client_id) | |
| RATIO_MAP = { | |
| "3:2": "16:9", | |
| "2:3": "9:16", | |
| "1:1": "9:16", | |
| } | |
| pollinations_ratio = RATIO_MAP.get(aspectRatio, "16:9") | |
| encoded_prompt = quote(prompt, safe="") | |
| params = { | |
| "model": "ltx-2", | |
| "duration": duration, | |
| "aspectRatio": pollinations_ratio, | |
| "seed": -1, | |
| } | |
| temp_assets = [] | |
| if image_urls: | |
| processed_urls = [] | |
| for img in image_urls[:2]: | |
| if is_base64_image(img): | |
| image_id = save_base64_image(img) | |
| temp_assets.append(image_id) | |
| served_url = f"{request.base_url}asset-cdn/assets/{image_id}" | |
| processed_urls.append(served_url) | |
| else: | |
| processed_urls.append(img) | |
| params["image"] = "|".join(processed_urls) | |
| if inputMode == "fun": | |
| params["enhance"] = "true" | |
| query_string = "&".join(f"{k}={quote(str(v), safe='')}" for k, v in params.items()) | |
| url = f"https://gen.pollinations.ai/video/{encoded_prompt}?{query_string}" | |
| print(f"[VIDEO GEN] Pollinations URL: {url}") | |
| url = url + f"&key={PKEY}" | |
| resp = None | |
| try: | |
| async with httpx.AsyncClient(timeout=600) as client: | |
| resp = await client.get(url) | |
| finally: | |
| for aid in temp_assets: | |
| cleanup_image(aid) | |
| if resp is None: | |
| raise HTTPException(502, "Video generation request failed") | |
| if resp.status_code != 200: | |
| body_text = "" | |
| try: | |
| body_text = resp.text | |
| except Exception: | |
| pass | |
| return JSONResponse( | |
| status_code=resp.status_code, | |
| content={ | |
| "success": False, | |
| "error": "Upstream video generation failed", | |
| "status_code": resp.status_code, | |
| "message": body_text[:1000], | |
| }, | |
| ) | |
| if not resp.content: | |
| raise HTTPException(502, "Pollinations returned empty response") | |
| return Response( | |
| content=resp.content, | |
| media_type="video/mp4", | |
| headers={ | |
| "Content-Length": str(len(resp.content)), | |
| "Accept-Ranges": "bytes", | |
| }, | |
| ) | |
| async def genvideo_airforce( | |
| request: Request, | |
| prompt: str = None, | |
| authorization: str = Header(None), | |
| x_client_id: str = Header(None), | |
| ): | |
| if request.method == "HEAD": | |
| return Response( | |
| status_code=200, | |
| headers={ | |
| # Required field | |
| "Y-prompt": "string — required. The text prompt used to generate the video.", | |
| # Optional fields | |
| "Y-ratio": "string — optional. Aspect ratio of the output video.", | |
| "Y-ratio-values": "3:2,2:3,1:1", | |
| "Y-ratio-default": "3:2", | |
| "Y-mode": "string — optional. Controls generation style.", | |
| "Y-mode-values": "normal,fun", | |
| "Y-mode-default": "normal", | |
| "Y-duration": "integer — optional. Duration in seconds.", | |
| "Y-duration-default": "5", | |
| "Y-image_urls": "array<string> — optional. Up to 2 image URLs for conditioning.", | |
| "Y-image_urls-max": "2", | |
| # Response format | |
| "Y-response_format": "video/mp4", | |
| # Model info | |
| "Y-model": "grok-imagine-video", | |
| }, | |
| ) | |
| aspectRatio = "3:2" | |
| inputMode = "normal" | |
| image_urls = None | |
| ratio = None | |
| mode = None | |
| user_body = {} | |
| if prompt is None: | |
| user_body = await request.json() | |
| prompt = user_body.get("prompt") | |
| ratio = user_body.get("ratio") | |
| mode = user_body.get("mode") | |
| image_urls = user_body.get("image_urls") | |
| if ratio not in valid_ratios: | |
| raise HTTPException( | |
| status_code=400, | |
| detail=f"Invalid aspect ratio {ratio}. Must be one of 3:2, 2:3, or 1:1. Default is 3:2", | |
| ) | |
| if ratio in ratios: | |
| aspectRatio = ratio | |
| if mode not in valid_modes: | |
| raise HTTPException( | |
| status_code=400, | |
| detail=f"Invalid mode {mode}. Must be 'normal' or 'fun'. Default is normal", | |
| ) | |
| if mode in modes: | |
| inputMode = mode | |
| if image_urls: | |
| if not isinstance(image_urls, list): | |
| raise HTTPException(400, "image_urls must be a list") | |
| if len(image_urls) > 2: | |
| raise HTTPException(400, "You may provide at most two image URLs") | |
| prompt = normalize_prompt_value(prompt, "prompt") | |
| enforce_prompt_size( | |
| prompt, MAX_MEDIA_PROMPT_CHARS, MAX_MEDIA_PROMPT_BYTES, "Video prompt" | |
| ) | |
| await check_video_rate_limit(request, authorization, x_client_id) | |
| payload = { | |
| "model": AIRFORCE_VIDEO_MODEL, | |
| "prompt": prompt, | |
| "n": 1, | |
| "size": "1024x1024", | |
| "response_format": "b64_json", | |
| "sse": False, | |
| "mode": inputMode, | |
| "aspectRatio": aspectRatio, | |
| } | |
| if image_urls: | |
| payload["image_urls"] = image_urls | |
| async with httpx.AsyncClient(timeout=600) as client: | |
| resp = await client.post( | |
| AIRFORCE_API_URL, | |
| headers={ | |
| "Authorization": f"Bearer {AIRFORCE_KEY}", | |
| "Content-Type": "application/json", | |
| }, | |
| json=payload, | |
| ) | |
| if resp.status_code != 200: | |
| return JSONResponse(status_code=resp.status_code, content=resp.json()) | |
| if not resp.content: | |
| raise HTTPException(502, "api.airforce returned empty response") | |
| try: | |
| result = resp.json() | |
| b64_video = result["data"][0]["b64_json"] | |
| except Exception: | |
| raise HTTPException(502, f"Invalid api.airforce response: {resp.text[:500]}") | |
| if not b64_video: | |
| raise HTTPException(502, "Airforce returned empty b64_json") | |
| video_bytes = base64.b64decode(b64_video) | |
| return Response( | |
| content=video_bytes, | |
| media_type="video/mp4", | |
| headers={ | |
| "Content-Length": str(len(video_bytes)), | |
| "Accept-Ranges": "bytes", | |
| }, | |
| ) | |
| MODEL_MAP = { | |
| "llama-3.1-8b-instant": "Meta Llama 3.1 8B Instant", | |
| "gpt-4o-mini": "OpenAI GPT 4o Mini", | |
| "gpt-4.1": "OpenAI GPT 4.1", | |
| "nemotron-3-super": "NVIDIA Nemotron 3 Super", | |
| "openai/gpt-oss-120b": "OpenAI GPT-OSS 120B", | |
| "openai/gpt-oss-20b": "OpenAI GPT-OSS 20B", | |
| "qwen-3-235b-a22b-instruct-2507": "Qwen3 Instruct", | |
| "llama-3.3-70b-versatile": "Meta Llama 3.3 70B Versatile", | |
| "meta-llama/llama-4-scout-17b-16e-instruct": "Meta Llama 4 Scout" | |
| } | |
| async def generate_text( | |
| request: Request, | |
| authorization: Optional[str] = Header(None), | |
| x_client_id: Optional[str] = Header(None), | |
| ): | |
| body = await request.json() | |
| messages = body.get("messages", []) | |
| if not isinstance(messages, list) or len(messages) == 0: | |
| raise HTTPException(400, "messages[] is required") | |
| total_chars, total_bytes = calculate_messages_size(messages) | |
| # if total_chars > MAX_CHAT_PROMPT_CHARS or total_bytes > MAX_CHAT_PROMPT_BYTES: | |
| # raise HTTPException( | |
| # status_code=413, | |
| # detail=( | |
| # f"Prompt context too large ({total_chars} chars, {total_bytes} bytes). " | |
| # f"Max allowed is {MAX_CHAT_PROMPT_CHARS} chars or {MAX_CHAT_PROMPT_BYTES} bytes." | |
| # ), | |
| # ) | |
| prompt_text = extract_user_text(messages) | |
| uses_tools = ( | |
| "tools" in body and isinstance(body["tools"], list) and len(body["tools"]) > 0 | |
| ) or ("tool_choice" in body and body["tool_choice"] not in [None, "none"]) | |
| long_context = is_long_context(messages) | |
| code_present = contains_code(prompt_text) | |
| math_heavy = is_math_heavy(prompt_text) | |
| structured_task = is_structured_task(prompt_text) | |
| multi_q = multiple_questions(prompt_text) | |
| code_heavy = is_code_heavy(prompt_text, code_present, long_context) | |
| score = 0 | |
| if long_context: | |
| score += 3 | |
| if math_heavy: | |
| score += 3 | |
| if structured_task: | |
| score += 2 | |
| if code_present: | |
| score += 2 | |
| if multi_q: | |
| score += 1 | |
| for kw in REASONING_KEYWORDS: | |
| if kw in prompt_text: | |
| score += 1 | |
| chosen_model = "llama-3.3-70b-versatile" | |
| provider = "groq" | |
| has_images = contains_images(messages) | |
| if has_images: | |
| chosen_model = "gpt-4.1" | |
| provider = "navy vision" | |
| else: | |
| if score > 10: | |
| score = 10 | |
| if uses_tools: | |
| if score >= 6: | |
| chosen_model = "nemotron-3-super" | |
| provider = "navy" | |
| elif score >= 4: | |
| chosen_model = "openai/gpt-oss-120b" | |
| provider = "groq" | |
| else: | |
| chosen_model = "openai/gpt-oss-20b" | |
| provider = "groq" | |
| elif code_present: | |
| if code_heavy and score >= 6: | |
| chosen_model = "o3-mini" | |
| provider = "navy" | |
| elif score >= 4: | |
| chosen_model = "llama-3.3-70b-versatile" | |
| provider = "groq" | |
| elif score >= 4: | |
| chosen_model = "meta-llama/llama-4-scout-17b-16e-instruct" | |
| provider = "groq" | |
| elif score >= 6: | |
| chosen_model = "sonar" | |
| provider = "navy" | |
| if provider == "groq" and ( | |
| total_chars > MAX_GROQ_PROMPT_CHARS or total_bytes > MAX_GROQ_PROMPT_BYTES | |
| ): | |
| provider = "navy" | |
| chosen_model = "gpt-4o-mini" | |
| await check_chat_rate_limit(request, authorization, x_client_id) | |
| body["model"] = chosen_model | |
| print( | |
| f""" | |
| [ADVANCED ROUTER] | |
| Score: {score} | |
| Uses tools: {uses_tools} | |
| Long context: {long_context} | |
| Code present: {code_present} | |
| Math heavy: {math_heavy} | |
| Structured: {structured_task} | |
| Multi-question: {multi_q} | |
| MULTIMODAL REQUIRED: {has_images} | |
| → Selected: {chosen_model} ({provider}) | |
| """ | |
| ) | |
| stream = body.get("stream", False) | |
| fallback_model = "meta-llama/llama-4-scout-17b-16e-instruct" | |
| fallback_provider = "groq" | |
| if provider == "groq": | |
| groq_keys = os.getenv("GROQ_KEY", "") | |
| print(f"ENV VAR: {groq_keys}") | |
| groq_keys_list = [k.strip() for k in groq_keys.split(",") if k.strip()] | |
| print(f"PARSED ENV VAR LIST: {groq_keys_list}") | |
| if not groq_keys_list: | |
| raise HTTPException(500, "Missing GROQ_KEY(s)") | |
| API_KEY = random.choice(groq_keys_list) | |
| print(f"SELECTED API KEY: {API_KEY}") | |
| url = "https://api.groq.com/openai/v1/chat/completions" | |
| elif provider == "cerebras": | |
| cer_keys = os.getenv("CER_KEY", "") | |
| cer_keys_list = [k.strip() for k in cer_keys.split(",") if k.strip()] | |
| if not cer_keys_list: | |
| raise HTTPException(500, "Missing CER_KEY(s)") | |
| API_KEY = random.choice(cer_keys_list) | |
| url = "https://api.cerebras.ai/v1/chat/completions" | |
| elif provider == "navy vision": | |
| navy_keys = os.getenv("NAVY_KEY", "") | |
| navy_keys_list = [k.strip() for k in navy_keys.split(",") if k.strip()] | |
| if not navy_keys_list: | |
| raise HTTPException(500, "Missing NAVY Keys(s)") | |
| API_KEY = random.choice(navy_keys_list) | |
| url = "https://api.navy/v1/chat/completions" | |
| elif provider == "navy": | |
| navy_keys = os.getenv("NAVY_TEXT_ONLY", "") | |
| navy_keys_list = [k.strip() for k in navy_keys.split(",") if k.strip()] | |
| if not navy_keys_list: | |
| raise HTTPException(500, "Missing NAVY TEXT ONLY Keys(s)") | |
| API_KEY = random.choice(navy_keys_list) | |
| url = "https://api.navy/v1/chat/completions" | |
| else: | |
| raise HTTPException(500, "Unknown provider routing error") | |
| headers = {"Authorization": f"Bearer {API_KEY}"} | |
| if stream: | |
| body["stream"] = True | |
| async def stream_primary(client, url, body, headers): | |
| """ | |
| Handles the primary provider stream (Navy Vision, Groq, Cerebras, etc.) | |
| Returns either: | |
| - a StreamingResponse generator, OR | |
| - triggers fallback if provider fails | |
| """ | |
| try: | |
| async with client.stream("POST", url, json=body, headers=headers) as r: | |
| if r.status_code >= 400: | |
| print("[STREAM FALLBACK] Primary provider failed → switching to Groq fallback") | |
| async for chunk in stream_fallback(client, body): | |
| yield chunk | |
| return | |
| async for line in r.aiter_lines(): | |
| if not line: | |
| yield "\n" | |
| continue | |
| if line.startswith("event: error"): | |
| fallback() | |
| if line.startswith("data:"): | |
| try: | |
| obj = json.loads(line[5:].strip()) | |
| if isinstance(obj, dict) and "error" in obj and isinstance(obj["error"], dict): | |
| fallback() | |
| except: | |
| pass | |
| yield line + "\n" | |
| except Exception as e: | |
| print(f"[STREAM ERROR] {e}") | |
| async for chunk in stream_fallback(client, body): | |
| yield chunk | |
| async def stream_fallback(client, body): | |
| """ | |
| Clean fallback stream to Groq 17B. | |
| This MUST NOT be nested inside another stream. | |
| """ | |
| fallback_body = { | |
| "model": fallback_model, | |
| "messages": body["messages"], | |
| "stream": True, | |
| } | |
| groq_keys = os.getenv("GROQ_KEY", "") | |
| groq_keys_list = [k.strip() for k in groq_keys.split(",") if k.strip()] | |
| fallback_headers = {"Authorization": f"Bearer {random.choice(groq_keys_list)}"} | |
| print("[FALLBACK] Starting Groq fallback stream") | |
| async with client.stream( | |
| "POST", | |
| "https://api.groq.com/openai/v1/chat/completions", | |
| json=fallback_body, | |
| headers=fallback_headers, | |
| ) as r: | |
| if r.status_code >= 400: | |
| err = (await r.aread()).decode("utf-8", errors="replace") | |
| yield f'data: {{"error": "Fallback provider failed: {err[:500]}"}}\n\n' | |
| return | |
| async for line in r.aiter_lines(): | |
| if not line: | |
| yield "\n" | |
| continue | |
| if not line.startswith("data:"): | |
| yield f"data: {line}\n\n" | |
| else: | |
| yield line + "\n" | |
| async def event_generator(): | |
| sent_metadata = False | |
| async with httpx.AsyncClient(timeout=None) as client: | |
| async for chunk in stream_primary(client, url, body, headers): | |
| if not sent_metadata: | |
| meta = { | |
| "router_metadata": { | |
| "model_name": MODEL_MAP.get(chosen_model, chosen_model) | |
| } | |
| } | |
| yield f"data: {json.dumps(meta)}\n\n" | |
| sent_metadata = True | |
| yield chunk | |
| return StreamingResponse( | |
| event_generator(), | |
| media_type="text/event-stream", | |
| headers={ | |
| "Cache-Control": "no-cache", | |
| "Connection": "keep-alive", | |
| "X-Accel-Buffering": "no", | |
| }, | |
| ) | |
| else: | |
| async with httpx.AsyncClient(timeout=None) as client: | |
| r = await client.post(url, json=body, headers=headers) | |
| if provider == "navy vision" and r.status_code >= 400: | |
| print("[FALLBACK] Navy vision failed — switching to 17B Groq") | |
| groq_keys = os.getenv("GROQ_KEY", "") | |
| groq_keys_list = [k.strip() for k in groq_keys.split(",") if k.strip()] | |
| if not groq_keys_list: | |
| raise HTTPException(500, "Missing GROQ_KEY(s) for fallback") | |
| API_KEY = random.choice(groq_keys_list) | |
| fallback_headers = {"Authorization": f"Bearer {API_KEY}"} | |
| fallback_body = dict(body) | |
| fallback_body["model"] = fallback_model | |
| r = await client.post( | |
| "https://api.groq.com/openai/v1/chat/completions", | |
| json=fallback_body, | |
| headers=fallback_headers, | |
| ) | |
| content_type = (r.headers.get("content-type") or "").lower() | |
| if "application/json" in content_type: | |
| try: | |
| payload = r.json() | |
| except Exception: | |
| payload = {"error": "Upstream returned invalid JSON"} | |
| else: | |
| payload = { | |
| "error": "Upstream returned non-JSON response", | |
| "status_code": r.status_code, | |
| "message": r.text[:1000], | |
| } | |
| return JSONResponse(status_code=r.status_code, content=payload) | |
| raise HTTPException(500, "Unknown provider routing error") | |
| async def analyze_prompt( | |
| request: Request | |
| ): | |
| body = await request.json() | |
| messages = body.get("prompt", []) | |
| if not isinstance(messages, list) or len(messages) == 0: | |
| raise HTTPException(400, "messages[] is required") | |
| total_chars, total_bytes = calculate_messages_size(messages) | |
| prompt_text = extract_user_text(messages) | |
| uses_tools = ( | |
| "tools" in body and isinstance(body["tools"], list) and len(body["tools"]) > 0 | |
| ) or ("tool_choice" in body and body["tool_choice"] not in [None, "none"]) | |
| long_context = is_long_context(messages) | |
| code_present = contains_code(prompt_text) | |
| math_heavy = is_math_heavy(prompt_text) | |
| structured_task = is_structured_task(prompt_text) | |
| multi_q = multiple_questions(prompt_text) | |
| code_heavy = is_code_heavy(prompt_text, code_present, long_context) | |
| score = 0 | |
| if long_context: | |
| score += 3 | |
| if math_heavy: | |
| score += 3 | |
| if structured_task: | |
| score += 2 | |
| if code_present: | |
| score += 2 | |
| if multi_q: | |
| score += 1 | |
| for kw in REASONING_KEYWORDS: | |
| if kw in prompt_text: | |
| score += 1 | |
| chosen_model = "llama-3.3-70b-versatil" | |
| provider = "groq" | |
| has_images = contains_images(messages) | |
| if has_images: | |
| chosen_model = "gpt-4o-mini" | |
| provider = "navy vision" | |
| else: | |
| if score > 10: | |
| score = 10 | |
| if uses_tools: | |
| if score >= 6: | |
| chosen_model = "nemotron-3-super" | |
| provider = "navy" | |
| elif score >= 4: | |
| chosen_model = "openai/gpt-oss-120b" | |
| provider = "groq" | |
| else: | |
| chosen_model = "openai/gpt-oss-20b" | |
| provider = "groq" | |
| elif code_present: | |
| if code_heavy and score >= 6: | |
| chosen_model = "o3-mini" | |
| provider = "navy" | |
| elif score >= 4: | |
| chosen_model = "llama-3.3-70b-versatile" | |
| provider = "groq" | |
| elif score >= 4: | |
| chosen_model = "meta-llama/llama-4-scout-17b-16e-instruct" | |
| provider = "groq" | |
| elif score >= 6: | |
| chosen_model = "sonar" | |
| provider = "navy" | |
| if provider == "groq" and ( | |
| total_chars > MAX_GROQ_PROMPT_CHARS or total_bytes > MAX_GROQ_PROMPT_BYTES | |
| ): | |
| provider = "navy" | |
| chosen_model = "gpt-4o-mini" | |
| return { MODEL_MAP[chosen_model] } | |
| def return_models_openai(): | |
| return { | |
| "object": "list", | |
| "data": [ | |
| { | |
| "id": "lightning", | |
| "object": "model", | |
| "created": 1767225600, | |
| "owned_by": "inferenceport-ai" | |
| } | |
| ] | |
| } |