import os import time import hashlib from fastapi import WebSocket, WebSocketDisconnect from collections import defaultdict, deque import json from fastapi import FastAPI, Request, HTTPException, status, Header from fastapi.middleware.cors import CORSMiddleware from fastapi.responses import ( Response, JSONResponse, StreamingResponse, RedirectResponse, ) import httpx from bs4 import BeautifulSoup from typing import List, Dict, Any import asyncio import re import random from urllib.parse import quote import base64 from helper.subscriptions import ( fetch_subscription, normalize_plan_key, TIER_CONFIG, PLAN_ORDER, ) from typing import Optional from helper.keywords import * from helper.assets import ( save_base64_image, cleanup_image, is_base64_image, asset_router, ) from helper.ratelimit import ( enforce_rate_limit, resolve_rate_limit_identity, check_audio_rate_limit, check_video_rate_limit, check_image_rate_limit, MAX_CHAT_PROMPT_BYTES, MAX_CHAT_PROMPT_CHARS, MAX_GROQ_PROMPT_BYTES, MAX_GROQ_PROMPT_CHARS, MAX_MEDIA_PROMPT_BYTES, MAX_MEDIA_PROMPT_CHARS, extract_user_text, calculate_messages_size, normalize_prompt_value, enforce_prompt_size, resolve_bound_subject, get_usage_snapshot_for_subject, ) app = FastAPI() WEBSOCKET_KEY = os.getenv("WEBSOCKET_KEY") # authentication attempt tracking AUTH_ATTEMPTS = defaultdict(lambda: deque()) AUTH_WINDOW_SECONDS = 60 AUTH_MAX_ATTEMPTS = 10 app.add_middleware( CORSMiddleware, allow_origins=["*"], allow_methods=["GET", "POST", "HEAD"], allow_headers=["*"], ) app.include_router(asset_router) def check_ws_auth_rate_limit(ip: str): now = time.time() q = AUTH_ATTEMPTS[ip] # purge old attempts while q and now - q[0] > AUTH_WINDOW_SECONDS: q.popleft() if len(q) >= AUTH_MAX_ATTEMPTS: return False q.append(now) return True @app.get("/") async def reroute_to_home(): return RedirectResponse( url="https://inference.js.org", status_code=status.HTTP_308_PERMANENT_REDIRECT ) OLLAMA_LIBRARY_URL = "https://ollama.com/library" def is_complex_reasoning(prompt: str) -> bool: if len(prompt) > 800: return True for kw in REASONING_KEYWORDS: if kw in prompt: return True if re.search(r"\b(if|therefore|assume|let x|given that)\b", prompt): return True return False def is_lightweight(prompt: str) -> bool: if len(prompt) < 100: for kw in LIGHTWEIGHT_KEYWORDS: if kw in prompt: return True return False def is_cinematic_image_prompt(prompt: str) -> bool: for kw in CREATIVE_KEYWORDS: if kw in prompt.lower(): return True return False PKEY = os.getenv("POLLINATIONS_KEY", "") PKEY2 = os.getenv("POLLINATIONS2_KEY", "") PKEY3 = os.getenv("POLLINATIONS3_KEY", "") GROQ_TOOL_MODELS = [ "openai/gpt-oss-120b", "openai/gpt-oss-20b", "meta-llama/llama-4-scout-17b-16e-instruct", "qwen/qwen3-32b", "moonshotai/kimi-k2-instruct", ] GROQ_NORMAL_MODELS = [ "llama-3.1-8b-instant", "llama-3.3-70b-versatile", "meta-llama/llama-4-maverick-17b-128e-instruct", "meta-llama/llama-guard-4-12b", "openai/gpt-oss-safeguard-20b", "qwen/qwen3-32b", ] CEREBRAS_MODELS = [ "gpt-oss-120b", "llama3.1-8b", "qwen-3-235b-a22b-instruct-2507", "zai-glm-4.7", ] async def check_chat_rate_limit( request: Request, authorization: Optional[str], client_id: Optional[str] = None, ): return await enforce_rate_limit(request, authorization, "cloudChatDaily", client_id) @app.head("/status/sfx") async def head_sfx(): return Response( status_code=200, headers={ "Content-Type": "audio/mpeg", "Accept-Ranges": "bytes", }, ) @app.head("/status/image") async def head_image(): return Response( status_code=200, headers={ "Content-Type": "image/jpeg", "Accept-Ranges": "bytes", }, ) @app.head("/status/video") async def head_video(): return Response( status_code=200, headers={ "Content-Type": "video/mp4", "Accept-Ranges": "bytes", }, ) @app.head("/status/text") async def head_text(): return Response( status_code=200, headers={ "Content-Type": "application/json", "Accept-Ranges": "bytes", }, ) @app.get("/status") async def get_status(): notify = "Added the studio for professional media generation. Available in v2.8.0" services = { "Video Generation": {"code": 200, "state": "ok", "message": "Running Normally"}, "Image Generation": {"code": 200, "state": "ok", "message": "Running Normally"}, "Lightning-Text v2": { "code": 200, "state": "ok", "message": "Running normally", }, "Music/SFX Generation": { "code": 200, "state": "ok", "message": "Running normally", }, } overall_state = ( "ok" if all(s["state"] == "ok" for s in services.values()) else "degraded" ) return JSONResponse( status_code=200, content={ "state": overall_state, "services": services, "notifications": notify, "latest": "2.8.0", }, ) @app.post("/gen/image") @app.get("/genimg/{prompt}") async def generate_image( request: Request, prompt: str = None, authorization: Optional[str] = Header(None), x_client_id: Optional[str] = Header(None), ): """ Image generation endpoint. -------------------------------------------------------------- • Accepts a plain‑text prompt (GET or JSON body). • Optional JSON fields: - mode: "fantasy" | "realistic" (keeps current behaviour) - image_urls: list of up to 2 image URLs or base‑64 strings • If *any* image is supplied we always use the Pollinations model **flux-2-dev** (the “editing” model). Otherwise the original heuristic (flux / zimage) is retained. • Base‑64 images are saved temporarily with the helper `save_base64_image` and served from the asset CDN exactly like the video endpoint does. -------------------------------------------------------------- """ timeout = httpx.Timeout(300.0, read=300.0) payload: Dict[str, Any] = {} if prompt is None: payload = await request.json() prompt = payload.get("prompt") mode = payload.get("mode") image_urls = payload.get("image_urls") else: mode = request.query_params.get("mode") image_urls = request.query_params.getlist("image_urls") payload = {} prompt = normalize_prompt_value(prompt, "prompt") enforce_prompt_size( prompt, MAX_MEDIA_PROMPT_CHARS, MAX_MEDIA_PROMPT_BYTES, "Image prompt" ) await check_image_rate_limit(request, authorization, x_client_id) chosen_model = "zimage" if is_cinematic_image_prompt(prompt): chosen_model = "flux" if isinstance(mode, str): normalized_mode = mode.strip().lower() if normalized_mode == "fantasy": chosen_model = "flux" elif normalized_mode == "realistic": chosen_model = "zimage" has_input_image = False temp_assets: List[str] = [] if image_urls: if not isinstance(image_urls, list): raise HTTPException(400, "image_urls must be a list") if len(image_urls) > 4: raise HTTPException(400, "Maximum of four image URLs allowed") has_input_image = True if has_input_image: chosen_model = "klein" params = { "model": chosen_model, "key": PKEY2, } if has_input_image: processed_urls: List[str] = [] for img in image_urls[:2]: if is_base64_image(img): image_id = save_base64_image(img) temp_assets.append(image_id) served_url = f"{request.base_url}asset-cdn/assets/{image_id}" processed_urls.append(served_url) else: processed_urls.append(img) params["image"] = "|".join(processed_urls) encoded_prompt = quote(prompt, safe="") query_string = "&".join(f"{k}={quote(str(v), safe='')}" for k, v in params.items()) url = f"https://gen.pollinations.ai/image/{encoded_prompt}?{query_string}" try: async with httpx.AsyncClient(timeout=timeout) as client: response = await client.get(url) finally: for aid in temp_assets: cleanup_image(aid) if response.status_code != 200: raise HTTPException( status_code=500, detail=f"Pollinations error: {response.status_code}", ) return Response(content=response.content, media_type="image/jpeg") @app.head("/models") @app.get("/models") async def get_models() -> List[Dict]: async with httpx.AsyncClient() as client: response = await client.get(OLLAMA_LIBRARY_URL) html = response.text soup = BeautifulSoup(html, "html.parser") items = soup.select("li[x-test-model]") models = [] for item in items: name = item.select_one("[x-test-model-title] span") description = item.select_one("p.max-w-lg") sizes = [el.get_text(strip=True) for el in item.select("[x-test-size]")] pulls = item.select_one("[x-test-pull-count]") tags = [ t.get_text(strip=True) for t in item.select('span[class*="text-blue-600"]') ] updated = item.select_one("[x-test-updated]") link = item.select_one("a") models.append( { "name": name.get_text(strip=True) if name else "", "description": ( description.get_text(strip=True) if description else "No description" ), "sizes": sizes, "pulls": pulls.get_text(strip=True) if pulls else "Unknown", "tags": tags, "updated": updated.get_text(strip=True) if updated else "Unknown", "link": link.get("href") if link else None, } ) return models @app.post("/gen/chat/completions") async def generate_text( request: Request, authorization: Optional[str] = Header(None), x_client_id: Optional[str] = Header(None), ): body = await request.json() messages = body.get("messages", []) if not isinstance(messages, list) or len(messages) == 0: raise HTTPException(400, "messages[] is required") total_chars, total_bytes = calculate_messages_size(messages) # if total_chars > MAX_CHAT_PROMPT_CHARS or total_bytes > MAX_CHAT_PROMPT_BYTES: # raise HTTPException( # status_code=413, # detail=( # f"Prompt context too large ({total_chars} chars, {total_bytes} bytes). " # f"Max allowed is {MAX_CHAT_PROMPT_CHARS} chars or {MAX_CHAT_PROMPT_BYTES} bytes." # ), # ) prompt_text = extract_user_text(messages) uses_tools = ( "tools" in body and isinstance(body["tools"], list) and len(body["tools"]) > 0 ) or ("tool_choice" in body and body["tool_choice"] not in [None, "none"]) long_context = is_long_context(messages) code_present = contains_code(prompt_text) math_heavy = is_math_heavy(prompt_text) structured_task = is_structured_task(prompt_text) multi_q = multiple_questions(prompt_text) code_heavy = is_code_heavy(prompt_text, code_present, long_context) score = 0 if long_context: score += 3 if math_heavy: score += 3 if structured_task: score += 2 if code_present: score += 2 if multi_q: score += 1 for kw in REASONING_KEYWORDS: if kw in prompt_text: score += 1 chosen_model = "llama-3.1-8b-instant" provider = "groq" has_images = contains_images(messages) if has_images: chosen_model = "meta-llama/llama-4-scout-17b-16e-instruct" provider = "groq" else: if score > 10: score = 10 if uses_tools: if score >= 4: chosen_model = "openai/gpt-oss-120b" else: chosen_model = "openai/gpt-oss-20b" provider = "groq" elif code_present: if code_heavy and score >= 6: chosen_model = "qwen-3-235b-a22b-instruct-2507" provider = "cerebras" elif score >= 4: chosen_model = "llama-3.3-70b-versatile" provider = "groq" elif score >= 4: chosen_model = "meta-llama/llama-4-scout-17b-16e-instruct" provider = "groq" if provider == "groq" and ( total_chars > MAX_GROQ_PROMPT_CHARS or total_bytes > MAX_GROQ_PROMPT_BYTES ): provider = "cerebras" chosen_model = "qwen-3-235b-a22b-instruct-2507" await check_chat_rate_limit(request, authorization, x_client_id) body["model"] = chosen_model print( f""" [ADVANCED ROUTER] Score: {score} Uses tools: {uses_tools} Long context: {long_context} Code present: {code_present} Math heavy: {math_heavy} Structured: {structured_task} Multi-question: {multi_q} MULTIMODAL REQUIRED: {has_images} → Selected: {chosen_model} ({provider}) """ ) stream = body.get("stream", False) if provider == "groq": groq_keys = os.getenv("GROQ_KEY", "") print(f"ENV VAR: {groq_keys}") groq_keys_list = [k.strip() for k in groq_keys.split(",") if k.strip()] print(f"PARSED ENV VAR LIST: {groq_keys_list}") if not groq_keys_list: raise HTTPException(500, "Missing GROQ_KEY(s)") API_KEY = random.choice(groq_keys_list) print(f"SELECTED API KEY: {API_KEY}") url = "https://api.groq.com/openai/v1/chat/completions" elif provider == "cerebras": cer_keys = os.getenv("CER_KEY", "") cer_keys_list = [k.strip() for k in cer_keys.split(",") if k.strip()] if not cer_keys_list: raise HTTPException(500, "Missing CER_KEY(s)") API_KEY = random.choice(cer_keys_list) url = "https://api.cerebras.ai/v1/chat/completions" else: raise HTTPException(500, "Unknown provider routing error") headers = {"Authorization": f"Bearer {API_KEY}"} if stream: body["stream"] = True async def event_generator(): try: async with httpx.AsyncClient(timeout=None) as client: async with client.stream( "POST", url, json=body, headers=headers, ) as r: if r.status_code >= 400: error_payload = "" try: error_payload = ( (await r.aread()).decode("utf-8", errors="replace") )[:800] except Exception: error_payload = "" safe_error_payload = ( error_payload.replace("\\", "\\\\") .replace('"', '\\"') .replace("\n", " ") .replace("\r", " ") ) yield ( 'data: {"error": ' f'"Upstream provider error ({r.status_code}): {safe_error_payload}"' "}\n\n" ) return async for line in r.aiter_lines(): if line == "": yield "\n" continue yield line + "\n" except asyncio.CancelledError: return except Exception as e: yield f'data: {{"error": "{str(e)}"}}\n\n' return StreamingResponse( event_generator(), media_type="text/event-stream", headers={ "Cache-Control": "no-cache", "Connection": "keep-alive", "X-Accel-Buffering": "no", # critical for nginx }, ) else: async with httpx.AsyncClient(timeout=None) as client: r = await client.post(url, json=body, headers=headers) content_type = (r.headers.get("content-type") or "").lower() if "application/json" in content_type: try: payload = r.json() except Exception: payload = {"error": "Upstream returned invalid JSON"} else: payload = { "error": "Upstream returned non-JSON response", "status_code": r.status_code, "message": r.text[:1000], } return JSONResponse(status_code=r.status_code, content=payload) raise HTTPException(500, "Unknown provider routing error") @app.get("/gen/sfx/{prompt}") @app.post("/gen/sfx") async def gensfx( request: Request, prompt: str = None, authorization: Optional[str] = Header(None), x_client_id: Optional[str] = Header(None), ): payload: Dict[str, Any] = {} if prompt is None: payload = await request.json() prompt = payload.get("prompt") prompt = normalize_prompt_value(prompt, "prompt") enforce_prompt_size( prompt, MAX_MEDIA_PROMPT_CHARS, MAX_MEDIA_PROMPT_BYTES, "Audio prompt" ) await check_audio_rate_limit(request, authorization, x_client_id) url = f"https://gen.pollinations.ai/audio/{prompt}?model=acestep&key={PKEY}" async with httpx.AsyncClient(timeout=None) as client: response = await client.get(url) body_text = "" try: body_text = response.text except Exception: pass if response.status_code != 200: return JSONResponse( status_code=response.status_code, content={ "success": False, "error": "Upstream music/sfx generation failed", "status_code": response.status_code, "message": body_text[:1000], }, ) return Response(response.content, media_type="audio/mpeg") @app.get("/gen/tts/{prompt}") @app.post("/gen/tts") async def gensfx( request: Request, prompt: str = None, authorization: Optional[str] = Header(None), x_client_id: Optional[str] = Header(None), ): payload: Dict[str, Any] = {} if prompt is None: payload = await request.json() prompt = payload.get("prompt") prompt = normalize_prompt_value(prompt, "prompt") enforce_prompt_size( prompt, MAX_MEDIA_PROMPT_CHARS, MAX_MEDIA_PROMPT_BYTES, "Audio prompt" ) await check_audio_rate_limit(request, authorization, x_client_id) url = f"https://gen.pollinations.ai/audio/{prompt}?key={PKEY3}" async with httpx.AsyncClient(timeout=None) as client: response = await client.get(url) body_text = "" try: body_text = response.text except Exception: pass if response.status_code != 200: return JSONResponse( status_code=response.status_code, content={ "success": False, "error": "Upstream audio generation failed", "status_code": response.status_code, "message": body_text[:1000], }, ) return Response(response.content, media_type="audio/mpeg") @app.get("/gen/video/{prompt}") @app.post("/gen/video") @app.head("/gen/video") async def genvideo_airforce( request: Request, prompt: str = None, authorization: Optional[str] = Header(None), x_client_id: Optional[str] = Header(None), ): if request.method == "HEAD": return Response( status_code=200, headers={ "Y-prompt": "string — required. The text prompt used to generate the video.", "Y-ratio": "string — optional. Aspect ratio of the output video.", "Y-ratio-values": "3:2,2:3,1:1", "Y-ratio-default": "3:2", "Y-mode": "string — optional. Controls generation style.", "Y-mode-values": "normal,fun", "Y-mode-default": "normal", "Y-duration": "integer — optional. Duration in seconds (1–10).", "Y-duration-default": "5", "Y-image_urls": "array — optional. Up to 2 image URLs for conditioning.", "Y-image_urls-max": "2", "Y-response_format": "video/mp4", "Y-model": "grok-video", }, ) aspectRatio = "3:2" inputMode = "normal" duration = 5 image_urls = None ratio = None mode = None if prompt is None: user_body = await request.json() prompt = user_body.get("prompt") ratio = user_body.get("ratio") mode = user_body.get("mode") image_urls = user_body.get("image_urls") duration = user_body.get("duration", 5) if ratio not in valid_ratios: raise HTTPException( status_code=400, detail=f"Invalid aspect ratio '{ratio}'. Must be one of 3:2, 2:3, or 1:1.", ) if ratio in ratios: aspectRatio = ratio if mode not in valid_modes: raise HTTPException( status_code=400, detail=f"Invalid mode '{mode}'. Must be 'normal' or 'fun'.", ) if mode in modes: inputMode = mode if image_urls: if not isinstance(image_urls, list): raise HTTPException(400, "image_urls must be a list") if len(image_urls) > 2: raise HTTPException(400, "You may provide at most two image URLs") # Clamp duration try: duration = max(1, min(10, int(duration))) except (TypeError, ValueError): duration = 5 prompt = normalize_prompt_value(prompt, "prompt") enforce_prompt_size( prompt, MAX_MEDIA_PROMPT_CHARS, MAX_MEDIA_PROMPT_BYTES, "Video prompt" ) await check_video_rate_limit(request, authorization, x_client_id) RATIO_MAP = { "3:2": "16:9", "2:3": "9:16", "1:1": "9:16", } pollinations_ratio = RATIO_MAP.get(aspectRatio, "16:9") encoded_prompt = quote(prompt, safe="") params = { "model": "ltx-2", "duration": duration, "aspectRatio": pollinations_ratio, "seed": -1, } temp_assets = [] if image_urls: processed_urls = [] for img in image_urls[:2]: if is_base64_image(img): image_id = save_base64_image(img) temp_assets.append(image_id) served_url = f"{request.base_url}asset-cdn/assets/{image_id}" processed_urls.append(served_url) else: processed_urls.append(img) params["image"] = "|".join(processed_urls) if inputMode == "fun": params["enhance"] = "true" query_string = "&".join(f"{k}={quote(str(v), safe='')}" for k, v in params.items()) url = f"https://gen.pollinations.ai/image/{encoded_prompt}?{query_string}" print(f"[VIDEO GEN] Pollinations URL: {url}") url = url + f"&key={PKEY}" resp = None try: async with httpx.AsyncClient(timeout=600) as client: resp = await client.get(url) finally: for aid in temp_assets: cleanup_image(aid) if resp is None: raise HTTPException(502, "Video generation request failed") if resp.status_code != 200: body_text = "" try: body_text = resp.text except Exception: pass return JSONResponse( status_code=resp.status_code, content={ "success": False, "error": "Upstream video generation failed", "status_code": resp.status_code, "message": body_text[:1000], }, ) if not resp.content: raise HTTPException(502, "Pollinations returned empty response") return Response( content=resp.content, media_type="video/mp4", headers={ "Content-Length": str(len(resp.content)), "Accept-Ranges": "bytes", }, ) AIRFORCE_KEY = os.getenv("AIRFORCE") AIRFORCE_VIDEO_MODEL = "grok-imagine-video" AIRFORCE_API_URL = "https://api.airforce/v1/images/generations" valid_ratios = {"3:2", "2:3", "1:1", "", None} ratios = {"3:2", "2:3", "1:1"} valid_modes = {"normal", "fun", "", None} modes = {"normal", "fun"} MAX_VIDEO_RETRIES = 6 @app.get("/gen/video/airforce/{prompt}") @app.post("/gen/video/airforce") @app.head("/gen/video/airforce") async def genvideo_airforce( request: Request, prompt: str = None, authorization: Optional[str] = Header(None), x_client_id: Optional[str] = Header(None), ): if request.method == "HEAD": return Response( status_code=200, headers={ # Required field "Y-prompt": "string — required. The text prompt used to generate the video.", # Optional fields "Y-ratio": "string — optional. Aspect ratio of the output video.", "Y-ratio-values": "3:2,2:3,1:1", "Y-ratio-default": "3:2", "Y-mode": "string — optional. Controls generation style.", "Y-mode-values": "normal,fun", "Y-mode-default": "normal", "Y-duration": "integer — optional. Duration in seconds.", "Y-duration-default": "5", "Y-image_urls": "array — optional. Up to 2 image URLs for conditioning.", "Y-image_urls-max": "2", # Response format "Y-response_format": "video/mp4", # Model info "Y-model": "grok-imagine-video", }, ) aspectRatio = "3:2" inputMode = "normal" image_urls = None ratio = None mode = None user_body = {} if prompt is None: user_body = await request.json() prompt = user_body.get("prompt") ratio = user_body.get("ratio") mode = user_body.get("mode") image_urls = user_body.get("image_urls") if ratio not in valid_ratios: raise HTTPException( status_code=400, detail=f"Invalid aspect ratio {ratio}. Must be one of 3:2, 2:3, or 1:1. Default is 3:2", ) if ratio in ratios: aspectRatio = ratio if mode not in valid_modes: raise HTTPException( status_code=400, detail=f"Invalid mode {mode}. Must be 'normal' or 'fun'. Default is normal", ) if mode in modes: inputMode = mode if image_urls: if not isinstance(image_urls, list): raise HTTPException(400, "image_urls must be a list") if len(image_urls) > 2: raise HTTPException(400, "You may provide at most two image URLs") prompt = normalize_prompt_value(prompt, "prompt") enforce_prompt_size( prompt, MAX_MEDIA_PROMPT_CHARS, MAX_MEDIA_PROMPT_BYTES, "Video prompt" ) await check_video_rate_limit(request, authorization, x_client_id) payload = { "model": AIRFORCE_VIDEO_MODEL, "prompt": prompt, "n": 1, "size": "1024x1024", "response_format": "b64_json", "sse": False, "mode": inputMode, "aspectRatio": aspectRatio, } if image_urls: payload["image_urls"] = image_urls async with httpx.AsyncClient(timeout=600) as client: resp = await client.post( AIRFORCE_API_URL, headers={ "Authorization": f"Bearer {AIRFORCE_KEY}", "Content-Type": "application/json", }, json=payload, ) if resp.status_code != 200: return JSONResponse(status_code=resp.status_code, content=resp.json()) if not resp.content: raise HTTPException(502, "api.airforce returned empty response") try: result = resp.json() b64_video = result["data"][0]["b64_json"] except Exception: raise HTTPException(502, f"Invalid api.airforce response: {resp.text[:500]}") if not b64_video: raise HTTPException(502, "Airforce returned empty b64_json") video_bytes = base64.b64decode(b64_video) return Response( content=video_bytes, media_type="video/mp4", headers={ "Content-Length": str(len(video_bytes)), "Accept-Ranges": "bytes", }, ) @app.get("/subscription") async def get_subscription(authorization: Optional[str] = Header(None)): if not authorization or not authorization.startswith("Bearer "): raise HTTPException(401, "Missing or invalid Authorization header") jwt = authorization.split(" ", 1)[1] result = await fetch_subscription(jwt) if "error" in result: raise HTTPException(401, result["error"]) plan_key = normalize_plan_key(result.get("plan_key")) result["plan_key"] = plan_key result["plan_name"] = (TIER_CONFIG.get(plan_key) or TIER_CONFIG["free"])["name"] return result @app.get("/usage") async def get_usage( request: Request, authorization: Optional[str] = Header(None), x_client_id: Optional[str] = Header(None), ): plan_key, subject = await resolve_rate_limit_identity( request, authorization, x_client_id ) plan = TIER_CONFIG.get(plan_key) or TIER_CONFIG["free"] usage = get_usage_snapshot_for_subject(plan_key, subject) return JSONResponse( status_code=200, content={ "plan_key": plan_key, "plan_name": plan.get("name", "Free Tier"), "usage": usage, "generated_at": time.strftime("%Y-%m-%dT%H:%M:%SZ", time.gmtime()), }, ) @app.get("/tier-config") async def tier_config(): plans = [] for idx, key in enumerate(PLAN_ORDER): plan = TIER_CONFIG.get(key) if not plan: continue plans.append( { "key": key, "name": plan["name"], "url": plan["url"], "price": plan["price"], "limits": plan["limits"], "order": idx, } ) return JSONResponse( status_code=200, content={ "defaultPlanKey": "free", "plans": plans, }, ) @app.get("/tiers") async def tiers(): paid_plans = [] for key in PLAN_ORDER: if key == "free": continue plan = TIER_CONFIG.get(key) if not plan: continue paid_plans.append( { "key": key, "name": plan["name"], "url": plan["url"], "price": plan["price"], "limits": plan["limits"], } ) return JSONResponse( status_code=200, content=paid_plans, ) @app.websocket("/ws/chat") async def websocket_chat(ws: WebSocket): ip = ws.client.host await ws.accept() # rate limit auth attempts if not check_ws_auth_rate_limit(ip): await ws.close(code=4408) return try: auth_msg = await ws.receive_text() auth_data = json.loads(auth_msg) provided_key = auth_data.get("key") if not WEBSOCKET_KEY or provided_key != WEBSOCKET_KEY: await ws.close(code=4403) return # authenticated await ws.send_json({"type": "auth", "status": "ok"}) while True: msg = await ws.receive_text() data = json.loads(msg) body = data.get("body") headers = data.get("headers", {}) if not body: await ws.send_json({"error": "Missing body"}) continue url = str(ws.url).replace("ws://", "http://").replace("wss://", "https://") url = url.split("/ws/chat")[0] + "/gen/chat/completions" async with httpx.AsyncClient(timeout=None) as client: async with client.stream( "POST", url, json=body, headers=headers, ) as r: async for line in r.aiter_lines(): if not line: continue await ws.send_text(line) except WebSocketDisconnect: return except Exception as e: try: await ws.send_json({"error": str(e)}) except: pass await ws.close() @app.get("/portal") @app.post("/portal") async def redirect_to_protal(request: Request): email = None if request.method == "POST": try: body = await request.json() email = body.get("email") except: email = None base_url = "https://billing.stripe.com/p/login/5kQdR9aIM3ts4steyabbG00" if not email: return RedirectResponse(url=base_url, status_code=status.HTTP_302_FOUND) if request.method != "POST": return RedirectResponse( url=f"{base_url}?prefilled_email={email}", status_code=status.HTTP_302_FOUND ) else: return JSONResponse({"redirect_url": (base_url + "?prefilled_email=" + email)})