Spaces:
Running
Running
| import os | |
| import time | |
| import hashlib | |
| from fastapi import FastAPI, Request, HTTPException, status, Header | |
| from fastapi.middleware.cors import CORSMiddleware | |
| from fastapi.responses import ( | |
| Response, | |
| JSONResponse, | |
| StreamingResponse, | |
| RedirectResponse, | |
| ) | |
| import httpx | |
| from bs4 import BeautifulSoup | |
| from typing import List, Dict, Any | |
| import asyncio | |
| import re | |
| import random | |
| from urllib.parse import quote | |
| import base64 | |
| from helper.subscriptions import ( | |
| fetch_subscription, | |
| normalize_plan_key, | |
| TIER_CONFIG, | |
| PLAN_ORDER, | |
| ) | |
| from typing import Optional | |
| from helper.keywords import * | |
| from helper.assets import ( | |
| save_base64_image, | |
| cleanup_image, | |
| is_base64_image, | |
| asset_router, | |
| ) | |
| from helper.ratelimit import ( | |
| enforce_rate_limit, | |
| resolve_rate_limit_identity, | |
| check_audio_rate_limit, | |
| check_video_rate_limit, | |
| check_image_rate_limit, | |
| MAX_CHAT_PROMPT_BYTES, | |
| MAX_CHAT_PROMPT_CHARS, | |
| MAX_GROQ_PROMPT_BYTES, | |
| MAX_GROQ_PROMPT_CHARS, | |
| MAX_MEDIA_PROMPT_BYTES, | |
| MAX_MEDIA_PROMPT_CHARS, | |
| extract_user_text, | |
| calculate_messages_size, | |
| normalize_prompt_value, | |
| enforce_prompt_size, | |
| resolve_bound_subject, | |
| get_usage_snapshot_for_subject, | |
| ) | |
| app = FastAPI() | |
| app.add_middleware( | |
| CORSMiddleware, | |
| allow_origins=["*"], | |
| allow_methods=["GET", "POST", "HEAD"], | |
| allow_headers=["*"], | |
| ) | |
| app.include_router(asset_router) | |
| async def reroute_to_home(): | |
| return RedirectResponse( | |
| url="https://inference.js.org", status_code=status.HTTP_308_PERMANENT_REDIRECT | |
| ) | |
| OLLAMA_LIBRARY_URL = "https://ollama.com/library" | |
| def is_complex_reasoning(prompt: str) -> bool: | |
| if len(prompt) > 800: | |
| return True | |
| for kw in REASONING_KEYWORDS: | |
| if kw in prompt: | |
| return True | |
| if re.search(r"\b(if|therefore|assume|let x|given that)\b", prompt): | |
| return True | |
| return False | |
| def is_lightweight(prompt: str) -> bool: | |
| if len(prompt) < 100: | |
| for kw in LIGHTWEIGHT_KEYWORDS: | |
| if kw in prompt: | |
| return True | |
| return False | |
| def is_cinematic_image_prompt(prompt: str) -> bool: | |
| for kw in CREATIVE_KEYWORDS: | |
| if kw in prompt.lower(): | |
| return True | |
| return False | |
| PKEY = os.getenv("POLLINATIONS_KEY", "") | |
| PKEY2 = os.getenv("POLLINATIONS2_KEY", "") | |
| PKEY3 = os.getenv("POLLINATIONS3_KEY", "") | |
| GROQ_TOOL_MODELS = [ | |
| "openai/gpt-oss-120b", | |
| "openai/gpt-oss-20b", | |
| "meta-llama/llama-4-scout-17b-16e-instruct", | |
| "qwen/qwen3-32b", | |
| "moonshotai/kimi-k2-instruct", | |
| ] | |
| GROQ_NORMAL_MODELS = [ | |
| "llama-3.1-8b-instant", | |
| "llama-3.3-70b-versatile", | |
| "meta-llama/llama-4-maverick-17b-128e-instruct", | |
| "meta-llama/llama-guard-4-12b", | |
| "openai/gpt-oss-safeguard-20b", | |
| "qwen/qwen3-32b", | |
| ] | |
| CEREBRAS_MODELS = [ | |
| "gpt-oss-120b", | |
| "llama3.1-8b", | |
| "qwen-3-235b-a22b-instruct-2507", | |
| "zai-glm-4.7", | |
| ] | |
| async def check_chat_rate_limit( | |
| request: Request, | |
| authorization: Optional[str], | |
| client_id: Optional[str] = None, | |
| ): | |
| return await enforce_rate_limit(request, authorization, "cloudChatDaily", client_id) | |
| async def head_sfx(): | |
| return Response( | |
| status_code=200, | |
| headers={ | |
| "Content-Type": "audio/mpeg", | |
| "Accept-Ranges": "bytes", | |
| }, | |
| ) | |
| async def head_image(): | |
| return Response( | |
| status_code=200, | |
| headers={ | |
| "Content-Type": "image/jpeg", | |
| "Accept-Ranges": "bytes", | |
| }, | |
| ) | |
| async def head_video(): | |
| return Response( | |
| status_code=200, | |
| headers={ | |
| "Content-Type": "video/mp4", | |
| "Accept-Ranges": "bytes", | |
| }, | |
| ) | |
| async def head_text(): | |
| return Response( | |
| status_code=200, | |
| headers={ | |
| "Content-Type": "application/json", | |
| "Accept-Ranges": "bytes", | |
| }, | |
| ) | |
| async def get_status(): | |
| notify = "" | |
| services = { | |
| "Video Generation": {"code": 200, "state": "ok", "message": "Running normally"}, | |
| "Image Generation": {"code": 200, "state": "ok", "message": "Running normally"}, | |
| "Lightning-Text v2": { | |
| "code": 200, | |
| "state": "ok", | |
| "message": "Running normally", | |
| }, | |
| "Music/SFX Generation": { | |
| "code": 200, | |
| "state": "ok", | |
| "message": "Running normally", | |
| }, | |
| } | |
| overall_state = ( | |
| "ok" if all(s["state"] == "ok" for s in services.values()) else "degraded" | |
| ) | |
| return JSONResponse( | |
| status_code=200, | |
| content={ | |
| "state": overall_state, | |
| "services": services, | |
| "notifications": notify, | |
| "latest": "2.4.0", | |
| }, | |
| ) | |
| async def generate_image( | |
| request: Request, | |
| prompt: str = None, | |
| authorization: Optional[str] = Header(None), | |
| x_client_id: Optional[str] = Header(None), | |
| ): | |
| timeout = httpx.Timeout(300.0, read=300.0) | |
| payload: Dict[str, Any] = {} | |
| if prompt is None: | |
| payload = await request.json() | |
| prompt = payload.get("prompt") | |
| prompt = normalize_prompt_value(prompt, "prompt") | |
| enforce_prompt_size( | |
| prompt, MAX_MEDIA_PROMPT_CHARS, MAX_MEDIA_PROMPT_BYTES, "Image prompt" | |
| ) | |
| await check_image_rate_limit(request, authorization, x_client_id) | |
| mode = payload.get("mode") if isinstance(payload, dict) else None | |
| if is_cinematic_image_prompt(prompt): | |
| chosen_model = "flux" | |
| else: | |
| chosen_model = "zimage" | |
| if isinstance(mode, str): | |
| normalized_mode = mode.strip().lower() | |
| if normalized_mode == "fantasy": | |
| chosen_model = "flux" | |
| elif normalized_mode == "realistic": | |
| chosen_model = "zimage" | |
| print(f"[IMAGE GEN] Routing to model: {chosen_model}") | |
| url = f"https://gen.pollinations.ai/image/{quote(prompt, safe='')}?model={chosen_model}&key={PKEY2}" | |
| async with httpx.AsyncClient(timeout=timeout) as client: | |
| response = await client.get(url) | |
| if response.status_code != 200: | |
| raise HTTPException( | |
| status_code=500, detail=f"Pollinations error: {response.status_code}" | |
| ) | |
| return Response(content=response.content, media_type="image/jpeg") | |
| async def get_models() -> List[Dict]: | |
| async with httpx.AsyncClient() as client: | |
| response = await client.get(OLLAMA_LIBRARY_URL) | |
| html = response.text | |
| soup = BeautifulSoup(html, "html.parser") | |
| items = soup.select("li[x-test-model]") | |
| models = [] | |
| for item in items: | |
| name = item.select_one("[x-test-model-title] span") | |
| description = item.select_one("p.max-w-lg") | |
| sizes = [el.get_text(strip=True) for el in item.select("[x-test-size]")] | |
| pulls = item.select_one("[x-test-pull-count]") | |
| tags = [ | |
| t.get_text(strip=True) for t in item.select('span[class*="text-blue-600"]') | |
| ] | |
| updated = item.select_one("[x-test-updated]") | |
| link = item.select_one("a") | |
| models.append( | |
| { | |
| "name": name.get_text(strip=True) if name else "", | |
| "description": ( | |
| description.get_text(strip=True) | |
| if description | |
| else "No description" | |
| ), | |
| "sizes": sizes, | |
| "pulls": pulls.get_text(strip=True) if pulls else "Unknown", | |
| "tags": tags, | |
| "updated": updated.get_text(strip=True) if updated else "Unknown", | |
| "link": link.get("href") if link else None, | |
| } | |
| ) | |
| return models | |
| async def generate_text( | |
| request: Request, | |
| authorization: Optional[str] = Header(None), | |
| x_client_id: Optional[str] = Header(None), | |
| ): | |
| body = await request.json() | |
| messages = body.get("messages", []) | |
| if not isinstance(messages, list) or len(messages) == 0: | |
| raise HTTPException(400, "messages[] is required") | |
| total_chars, total_bytes = calculate_messages_size(messages) | |
| if total_chars > MAX_CHAT_PROMPT_CHARS or total_bytes > MAX_CHAT_PROMPT_BYTES: | |
| raise HTTPException( | |
| status_code=413, | |
| detail=( | |
| f"Prompt context too large ({total_chars} chars, {total_bytes} bytes). " | |
| f"Max allowed is {MAX_CHAT_PROMPT_CHARS} chars or {MAX_CHAT_PROMPT_BYTES} bytes." | |
| ), | |
| ) | |
| prompt_text = extract_user_text(messages) | |
| uses_tools = ( | |
| "tools" in body and isinstance(body["tools"], list) and len(body["tools"]) > 0 | |
| ) or ("tool_choice" in body and body["tool_choice"] not in [None, "none"]) | |
| long_context = is_long_context(messages) | |
| code_present = contains_code(prompt_text) | |
| math_heavy = is_math_heavy(prompt_text) | |
| structured_task = is_structured_task(prompt_text) | |
| multi_q = multiple_questions(prompt_text) | |
| code_heavy = is_code_heavy(prompt_text, code_present, long_context) | |
| score = 0 | |
| if long_context: | |
| score += 3 | |
| if math_heavy: | |
| score += 3 | |
| if structured_task: | |
| score += 2 | |
| if code_present: | |
| score += 2 | |
| if multi_q: | |
| score += 1 | |
| for kw in REASONING_KEYWORDS: | |
| if kw in prompt_text: | |
| score += 1 | |
| chosen_model = "llama-3.1-8b-instant" | |
| provider = "groq" | |
| has_images = contains_images(messages) | |
| if has_images: | |
| chosen_model = "meta-llama/llama-4-scout-17b-16e-instruct" | |
| provider = "groq" | |
| else: | |
| if score > 10: | |
| score = 10 | |
| if uses_tools: | |
| if score >= 4: | |
| chosen_model = "openai/gpt-oss-120b" | |
| else: | |
| chosen_model = "openai/gpt-oss-20b" | |
| provider = "groq" | |
| elif code_present: | |
| if code_heavy and score >= 6: | |
| chosen_model = "gpt-oss-120b" | |
| provider = "cerebras" | |
| elif score >= 4: | |
| chosen_model = "llama-3.3-70b-versatile" | |
| provider = "groq" | |
| elif score >= 4: | |
| chosen_model = "meta-llama/llama-4-scout-17b-16e-instruct" | |
| provider = "groq" | |
| if provider == "groq" and ( | |
| total_chars > MAX_GROQ_PROMPT_CHARS or total_bytes > MAX_GROQ_PROMPT_BYTES | |
| ): | |
| raise HTTPException( | |
| status_code=413, | |
| detail=( | |
| f"Prompt exceeds Groq-safe size ({total_chars} chars, {total_bytes} bytes). " | |
| f"Max Groq-safe size is {MAX_GROQ_PROMPT_CHARS} chars or {MAX_GROQ_PROMPT_BYTES} bytes." | |
| ), | |
| ) | |
| await check_chat_rate_limit(request, authorization, x_client_id) | |
| body["model"] = chosen_model | |
| print( | |
| f""" | |
| [ADVANCED ROUTER] | |
| Score: {score} | |
| Uses tools: {uses_tools} | |
| Long context: {long_context} | |
| Code present: {code_present} | |
| Math heavy: {math_heavy} | |
| Structured: {structured_task} | |
| Multi-question: {multi_q} | |
| MULTIMODAL REQUIRED: {has_images} | |
| → Selected: {chosen_model} ({provider}) | |
| """ | |
| ) | |
| stream = body.get("stream", False) | |
| if provider == "groq": | |
| groq_keys = os.getenv("GROQ_KEY", "") | |
| groq_keys_list = [k.strip() for k in groq_keys.split(",") if k.strip()] | |
| if not groq_keys_list: | |
| raise HTTPException(500, "Missing GROQ_KEY(s)") | |
| API_KEY = random.choice(groq_keys_list) | |
| url = "https://api.groq.com/openai/v1/chat/completions" | |
| elif provider == "cerebras": | |
| cer_keys = os.getenv("CER_KEY", "") | |
| cer_keys_list = [k.strip() for k in cer_keys.split(",") if k.strip()] | |
| if not cer_keys_list: | |
| raise HTTPException(500, "Missing CER_KEY(s)") | |
| API_KEY = random.choice(cer_keys_list) | |
| url = "https://api.cerebras.ai/v1/chat/completions" | |
| else: | |
| raise HTTPException(500, "Unknown provider routing error") | |
| headers = {"Authorization": f"Bearer {API_KEY}"} | |
| if stream: | |
| body["stream"] = True | |
| async def event_generator(): | |
| try: | |
| async with httpx.AsyncClient(timeout=None) as client: | |
| async with client.stream( | |
| "POST", | |
| url, | |
| json=body, | |
| headers=headers, | |
| ) as r: | |
| if r.status_code >= 400: | |
| error_payload = "" | |
| try: | |
| error_payload = ( | |
| (await r.aread()).decode("utf-8", errors="replace") | |
| )[:800] | |
| except Exception: | |
| error_payload = "" | |
| safe_error_payload = ( | |
| error_payload.replace("\\", "\\\\") | |
| .replace('"', '\\"') | |
| .replace("\n", " ") | |
| .replace("\r", " ") | |
| ) | |
| yield ( | |
| 'data: {"error": ' | |
| f'"Upstream provider error ({r.status_code}): {safe_error_payload}"' | |
| "}\n\n" | |
| ) | |
| return | |
| async for line in r.aiter_lines(): | |
| if line == "": | |
| yield "\n" | |
| continue | |
| yield line + "\n" | |
| except asyncio.CancelledError: | |
| return | |
| except Exception as e: | |
| yield f'data: {{"error": "{str(e)}"}}\n\n' | |
| return StreamingResponse( | |
| event_generator(), | |
| media_type="text/event-stream", | |
| headers={ | |
| "Cache-Control": "no-cache", | |
| "Connection": "keep-alive", | |
| "X-Accel-Buffering": "no", # critical for nginx | |
| }, | |
| ) | |
| else: | |
| async with httpx.AsyncClient(timeout=None) as client: | |
| r = await client.post(url, json=body, headers=headers) | |
| content_type = (r.headers.get("content-type") or "").lower() | |
| if "application/json" in content_type: | |
| try: | |
| payload = r.json() | |
| except Exception: | |
| payload = {"error": "Upstream returned invalid JSON"} | |
| else: | |
| payload = { | |
| "error": "Upstream returned non-JSON response", | |
| "status_code": r.status_code, | |
| "message": r.text[:1000], | |
| } | |
| return JSONResponse(status_code=r.status_code, content=payload) | |
| raise HTTPException(500, "Unknown provider routing error") | |
| async def gensfx( | |
| request: Request, | |
| prompt: str = None, | |
| authorization: Optional[str] = Header(None), | |
| x_client_id: Optional[str] = Header(None), | |
| ): | |
| payload: Dict[str, Any] = {} | |
| if prompt is None: | |
| payload = await request.json() | |
| prompt = payload.get("prompt") | |
| prompt = normalize_prompt_value(prompt, "prompt") | |
| enforce_prompt_size( | |
| prompt, MAX_MEDIA_PROMPT_CHARS, MAX_MEDIA_PROMPT_BYTES, "Audio prompt" | |
| ) | |
| await check_audio_rate_limit(request, authorization, x_client_id) | |
| url = f"https://gen.pollinations.ai/audio/{prompt}?model=elevenmusic&key={PKEY}" | |
| async with httpx.AsyncClient(timeout=None) as client: | |
| response = await client.get(url) | |
| body_text = "" | |
| try: | |
| body_text = response.text | |
| except Exception: | |
| pass | |
| if response.status_code != 200: | |
| return JSONResponse( | |
| status_code=response.status_code, | |
| content={ | |
| "success": False, | |
| "error": "Upstream music/sfx generation failed", | |
| "status_code": response.status_code, | |
| "message": body_text[:1000], | |
| }, | |
| ) | |
| return Response(response.content, media_type="audio/mpeg") | |
| async def gensfx( | |
| request: Request, | |
| prompt: str = None, | |
| authorization: Optional[str] = Header(None), | |
| x_client_id: Optional[str] = Header(None), | |
| ): | |
| payload: Dict[str, Any] = {} | |
| if prompt is None: | |
| payload = await request.json() | |
| prompt = payload.get("prompt") | |
| prompt = normalize_prompt_value(prompt, "prompt") | |
| enforce_prompt_size( | |
| prompt, MAX_MEDIA_PROMPT_CHARS, MAX_MEDIA_PROMPT_BYTES, "Audio prompt" | |
| ) | |
| await check_audio_rate_limit(request, authorization, x_client_id) | |
| url = f"https://gen.pollinations.ai/audio/{prompt}?key={PKEY3}" | |
| async with httpx.AsyncClient(timeout=None) as client: | |
| response = await client.get(url) | |
| body_text = "" | |
| try: | |
| body_text = response.text | |
| except Exception: | |
| pass | |
| if response.status_code != 200: | |
| return JSONResponse( | |
| status_code=response.status_code, | |
| content={ | |
| "success": False, | |
| "error": "Upstream audio generation failed", | |
| "status_code": response.status_code, | |
| "message": body_text[:1000], | |
| }, | |
| ) | |
| return Response(response.content, media_type="audio/mpeg") | |
| async def genvideo_airforce( | |
| request: Request, | |
| prompt: str = None, | |
| authorization: Optional[str] = Header(None), | |
| x_client_id: Optional[str] = Header(None), | |
| ): | |
| if request.method == "HEAD": | |
| return Response( | |
| status_code=200, | |
| headers={ | |
| "Y-prompt": "string — required. The text prompt used to generate the video.", | |
| "Y-ratio": "string — optional. Aspect ratio of the output video.", | |
| "Y-ratio-values": "3:2,2:3,1:1", | |
| "Y-ratio-default": "3:2", | |
| "Y-mode": "string — optional. Controls generation style.", | |
| "Y-mode-values": "normal,fun", | |
| "Y-mode-default": "normal", | |
| "Y-duration": "integer — optional. Duration in seconds (1–10).", | |
| "Y-duration-default": "5", | |
| "Y-image_urls": "array<string> — optional. Up to 2 image URLs for conditioning.", | |
| "Y-image_urls-max": "2", | |
| "Y-response_format": "video/mp4", | |
| "Y-model": "grok-video", | |
| }, | |
| ) | |
| aspectRatio = "3:2" | |
| inputMode = "normal" | |
| duration = 5 | |
| image_urls = None | |
| ratio = None | |
| mode = None | |
| if prompt is None: | |
| user_body = await request.json() | |
| prompt = user_body.get("prompt") | |
| ratio = user_body.get("ratio") | |
| mode = user_body.get("mode") | |
| image_urls = user_body.get("image_urls") | |
| duration = user_body.get("duration", 5) | |
| if ratio not in valid_ratios: | |
| raise HTTPException( | |
| status_code=400, | |
| detail=f"Invalid aspect ratio '{ratio}'. Must be one of 3:2, 2:3, or 1:1.", | |
| ) | |
| if ratio in ratios: | |
| aspectRatio = ratio | |
| if mode not in valid_modes: | |
| raise HTTPException( | |
| status_code=400, | |
| detail=f"Invalid mode '{mode}'. Must be 'normal' or 'fun'.", | |
| ) | |
| if mode in modes: | |
| inputMode = mode | |
| if image_urls: | |
| if not isinstance(image_urls, list): | |
| raise HTTPException(400, "image_urls must be a list") | |
| if len(image_urls) > 2: | |
| raise HTTPException(400, "You may provide at most two image URLs") | |
| # Clamp duration | |
| try: | |
| duration = max(1, min(10, int(duration))) | |
| except (TypeError, ValueError): | |
| duration = 5 | |
| prompt = normalize_prompt_value(prompt, "prompt") | |
| enforce_prompt_size( | |
| prompt, MAX_MEDIA_PROMPT_CHARS, MAX_MEDIA_PROMPT_BYTES, "Video prompt" | |
| ) | |
| await check_video_rate_limit(request, authorization, x_client_id) | |
| RATIO_MAP = { | |
| "3:2": "16:9", | |
| "2:3": "9:16", | |
| "1:1": "1:1", | |
| } | |
| pollinations_ratio = RATIO_MAP.get(aspectRatio, "16:9") | |
| encoded_prompt = quote(prompt, safe="") | |
| params = { | |
| "model": "grok-video", | |
| "duration": duration, | |
| "aspectRatio": pollinations_ratio, | |
| "seed": -1, | |
| } | |
| temp_assets = [] | |
| if image_urls: | |
| processed_urls = [] | |
| for img in image_urls[:2]: | |
| if is_base64_image(img): | |
| image_id = save_base64_image(img) | |
| temp_assets.append(image_id) | |
| served_url = f"{request.base_url}asset-cdn/assets/{image_id}" | |
| processed_urls.append(served_url) | |
| else: | |
| processed_urls.append(img) | |
| params["image"] = "|".join(processed_urls) | |
| if inputMode == "fun": | |
| params["enhance"] = "true" | |
| query_string = "&".join(f"{k}={quote(str(v), safe='')}" for k, v in params.items()) | |
| url = f"https://gen.pollinations.ai/image/{encoded_prompt}?{query_string}" | |
| print(f"[VIDEO GEN] Pollinations URL: {url}") | |
| url = url + f"&key={PKEY}" | |
| resp = None | |
| try: | |
| async with httpx.AsyncClient(timeout=600) as client: | |
| resp = await client.get(url) | |
| finally: | |
| for aid in temp_assets: | |
| cleanup_image(aid) | |
| if resp is None: | |
| raise HTTPException(502, "Video generation request failed") | |
| if resp.status_code != 200: | |
| body_text = "" | |
| try: | |
| body_text = resp.text | |
| except Exception: | |
| pass | |
| return JSONResponse( | |
| status_code=resp.status_code, | |
| content={ | |
| "success": False, | |
| "error": "Upstream video generation failed", | |
| "status_code": resp.status_code, | |
| "message": body_text[:1000], | |
| }, | |
| ) | |
| if not resp.content: | |
| raise HTTPException(502, "Pollinations returned empty response") | |
| return Response( | |
| content=resp.content, | |
| media_type="video/mp4", | |
| headers={ | |
| "Content-Length": str(len(resp.content)), | |
| "Accept-Ranges": "bytes", | |
| }, | |
| ) | |
| AIRFORCE_KEY = os.getenv("AIRFORCE") | |
| AIRFORCE_VIDEO_MODEL = "grok-imagine-video" | |
| AIRFORCE_API_URL = "https://api.airforce/v1/images/generations" | |
| valid_ratios = {"3:2", "2:3", "1:1", "", None} | |
| ratios = {"3:2", "2:3", "1:1"} | |
| valid_modes = {"normal", "fun", "", None} | |
| modes = {"normal", "fun"} | |
| MAX_VIDEO_RETRIES = 6 | |
| async def genvideo_airforce( | |
| request: Request, | |
| prompt: str = None, | |
| authorization: Optional[str] = Header(None), | |
| x_client_id: Optional[str] = Header(None), | |
| ): | |
| if request.method == "HEAD": | |
| return Response( | |
| status_code=200, | |
| headers={ | |
| # Required field | |
| "Y-prompt": "string — required. The text prompt used to generate the video.", | |
| # Optional fields | |
| "Y-ratio": "string — optional. Aspect ratio of the output video.", | |
| "Y-ratio-values": "3:2,2:3,1:1", | |
| "Y-ratio-default": "3:2", | |
| "Y-mode": "string — optional. Controls generation style.", | |
| "Y-mode-values": "normal,fun", | |
| "Y-mode-default": "normal", | |
| "Y-duration": "integer — optional. Duration in seconds.", | |
| "Y-duration-default": "5", | |
| "Y-image_urls": "array<string> — optional. Up to 2 image URLs for conditioning.", | |
| "Y-image_urls-max": "2", | |
| # Response format | |
| "Y-response_format": "video/mp4", | |
| # Model info | |
| "Y-model": "grok-imagine-video", | |
| }, | |
| ) | |
| aspectRatio = "3:2" | |
| inputMode = "normal" | |
| image_urls = None | |
| ratio = None | |
| mode = None | |
| user_body = {} | |
| if prompt is None: | |
| user_body = await request.json() | |
| prompt = user_body.get("prompt") | |
| ratio = user_body.get("ratio") | |
| mode = user_body.get("mode") | |
| image_urls = user_body.get("image_urls") | |
| if ratio not in valid_ratios: | |
| raise HTTPException( | |
| status_code=400, | |
| detail=f"Invalid aspect ratio {ratio}. Must be one of 3:2, 2:3, or 1:1. Default is 3:2", | |
| ) | |
| if ratio in ratios: | |
| aspectRatio = ratio | |
| if mode not in valid_modes: | |
| raise HTTPException( | |
| status_code=400, | |
| detail=f"Invalid mode {mode}. Must be 'normal' or 'fun'. Default is normal", | |
| ) | |
| if mode in modes: | |
| inputMode = mode | |
| if image_urls: | |
| if not isinstance(image_urls, list): | |
| raise HTTPException(400, "image_urls must be a list") | |
| if len(image_urls) > 2: | |
| raise HTTPException(400, "You may provide at most two image URLs") | |
| prompt = normalize_prompt_value(prompt, "prompt") | |
| enforce_prompt_size( | |
| prompt, MAX_MEDIA_PROMPT_CHARS, MAX_MEDIA_PROMPT_BYTES, "Video prompt" | |
| ) | |
| await check_video_rate_limit(request, authorization, x_client_id) | |
| payload = { | |
| "model": AIRFORCE_VIDEO_MODEL, | |
| "prompt": prompt, | |
| "n": 1, | |
| "size": "1024x1024", | |
| "response_format": "b64_json", | |
| "sse": False, | |
| "mode": inputMode, | |
| "aspectRatio": aspectRatio, | |
| } | |
| if image_urls: | |
| payload["image_urls"] = image_urls | |
| async with httpx.AsyncClient(timeout=600) as client: | |
| resp = await client.post( | |
| AIRFORCE_API_URL, | |
| headers={ | |
| "Authorization": f"Bearer {AIRFORCE_KEY}", | |
| "Content-Type": "application/json", | |
| }, | |
| json=payload, | |
| ) | |
| if resp.status_code != 200: | |
| return JSONResponse(status_code=resp.status_code, content=resp.json()) | |
| if not resp.content: | |
| raise HTTPException(502, "api.airforce returned empty response") | |
| try: | |
| result = resp.json() | |
| b64_video = result["data"][0]["b64_json"] | |
| except Exception: | |
| raise HTTPException(502, f"Invalid api.airforce response: {resp.text[:500]}") | |
| if not b64_video: | |
| raise HTTPException(502, "Airforce returned empty b64_json") | |
| video_bytes = base64.b64decode(b64_video) | |
| return Response( | |
| content=video_bytes, | |
| media_type="video/mp4", | |
| headers={ | |
| "Content-Length": str(len(video_bytes)), | |
| "Accept-Ranges": "bytes", | |
| }, | |
| ) | |
| async def get_subscription(authorization: Optional[str] = Header(None)): | |
| if not authorization or not authorization.startswith("Bearer "): | |
| raise HTTPException(401, "Missing or invalid Authorization header") | |
| jwt = authorization.split(" ", 1)[1] | |
| result = await fetch_subscription(jwt) | |
| if "error" in result: | |
| raise HTTPException(401, result["error"]) | |
| plan_key = normalize_plan_key(result.get("plan_key")) | |
| result["plan_key"] = plan_key | |
| result["plan_name"] = (TIER_CONFIG.get(plan_key) or TIER_CONFIG["free"])["name"] | |
| return result | |
| async def get_usage( | |
| request: Request, | |
| authorization: Optional[str] = Header(None), | |
| x_client_id: Optional[str] = Header(None), | |
| ): | |
| plan_key, subject = await resolve_rate_limit_identity( | |
| request, authorization, x_client_id | |
| ) | |
| plan = TIER_CONFIG.get(plan_key) or TIER_CONFIG["free"] | |
| usage = get_usage_snapshot_for_subject(plan_key, subject) | |
| return JSONResponse( | |
| status_code=200, | |
| content={ | |
| "plan_key": plan_key, | |
| "plan_name": plan.get("name", "Free Tier"), | |
| "usage": usage, | |
| "generated_at": time.strftime("%Y-%m-%dT%H:%M:%SZ", time.gmtime()), | |
| }, | |
| ) | |
| async def tier_config(): | |
| plans = [] | |
| for idx, key in enumerate(PLAN_ORDER): | |
| plan = TIER_CONFIG.get(key) | |
| if not plan: | |
| continue | |
| plans.append( | |
| { | |
| "key": key, | |
| "name": plan["name"], | |
| "url": plan["url"], | |
| "price": plan["price"], | |
| "limits": plan["limits"], | |
| "order": idx, | |
| } | |
| ) | |
| return JSONResponse( | |
| status_code=200, | |
| content={ | |
| "defaultPlanKey": "free", | |
| "plans": plans, | |
| }, | |
| ) | |
| async def tiers(): | |
| paid_plans = [] | |
| for key in PLAN_ORDER: | |
| if key == "free": | |
| continue | |
| plan = TIER_CONFIG.get(key) | |
| if not plan: | |
| continue | |
| paid_plans.append( | |
| { | |
| "key": key, | |
| "name": plan["name"], | |
| "url": plan["url"], | |
| "price": plan["price"], | |
| "limits": plan["limits"], | |
| } | |
| ) | |
| return JSONResponse( | |
| status_code=200, | |
| content=paid_plans, | |
| ) | |
| async def redirect_to_protal(request: Request): | |
| email = None | |
| if request.method == "POST": | |
| try: | |
| body = await request.json() | |
| email = body.get("email") | |
| except: | |
| email = None | |
| base_url = "https://billing.stripe.com/p/login/5kQdR9aIM3ts4steyabbG00" | |
| if not email: | |
| return RedirectResponse(url=base_url, status_code=status.HTTP_302_FOUND) | |
| if request.method != "POST": | |
| return RedirectResponse( | |
| url=f"{base_url}?prefilled_email={email}", status_code=status.HTTP_302_FOUND | |
| ) | |
| else: | |
| return JSONResponse({"redirect_url": (base_url + "?prefilled_email=" + email)}) | |