lightning / app.py
incognitolm's picture
Update app.py
ba195af verified
raw
history blame
34.4 kB
import os
import time
import hashlib
from fastapi import WebSocket, WebSocketDisconnect
from collections import defaultdict, deque
import json
from fastapi import FastAPI, Request, HTTPException, status, Header
from fastapi.middleware.cors import CORSMiddleware
from fastapi.responses import (
Response,
JSONResponse,
StreamingResponse,
RedirectResponse,
)
import httpx
from bs4 import BeautifulSoup
from typing import List, Dict, Any
import asyncio
import re
import random
from urllib.parse import quote
import base64
from helper.subscriptions import (
fetch_subscription,
normalize_plan_key,
TIER_CONFIG,
PLAN_ORDER,
)
from typing import Optional
from helper.keywords import *
from helper.assets import (
save_base64_image,
cleanup_image,
is_base64_image,
asset_router,
)
from helper.ratelimit import (
enforce_rate_limit,
resolve_rate_limit_identity,
check_audio_rate_limit,
check_video_rate_limit,
check_image_rate_limit,
MAX_CHAT_PROMPT_BYTES,
MAX_CHAT_PROMPT_CHARS,
MAX_GROQ_PROMPT_BYTES,
MAX_GROQ_PROMPT_CHARS,
MAX_MEDIA_PROMPT_BYTES,
MAX_MEDIA_PROMPT_CHARS,
extract_user_text,
calculate_messages_size,
normalize_prompt_value,
enforce_prompt_size,
resolve_bound_subject,
get_usage_snapshot_for_subject,
)
app = FastAPI()
WEBSOCKET_KEY = os.getenv("WEBSOCKET_KEY")
# authentication attempt tracking
AUTH_ATTEMPTS = defaultdict(lambda: deque())
AUTH_WINDOW_SECONDS = 60
AUTH_MAX_ATTEMPTS = 10
app.add_middleware(
CORSMiddleware,
allow_origins=["*"],
allow_methods=["GET", "POST", "HEAD"],
allow_headers=["*"],
)
app.include_router(asset_router)
def check_ws_auth_rate_limit(ip: str):
now = time.time()
q = AUTH_ATTEMPTS[ip]
# purge old attempts
while q and now - q[0] > AUTH_WINDOW_SECONDS:
q.popleft()
if len(q) >= AUTH_MAX_ATTEMPTS:
return False
q.append(now)
return True
@app.get("/")
async def reroute_to_home():
return RedirectResponse(
url="https://inference.js.org", status_code=status.HTTP_308_PERMANENT_REDIRECT
)
OLLAMA_LIBRARY_URL = "https://ollama.com/library"
def is_complex_reasoning(prompt: str) -> bool:
if len(prompt) > 800:
return True
for kw in REASONING_KEYWORDS:
if kw in prompt:
return True
if re.search(r"\b(if|therefore|assume|let x|given that)\b", prompt):
return True
return False
def is_lightweight(prompt: str) -> bool:
if len(prompt) < 100:
for kw in LIGHTWEIGHT_KEYWORDS:
if kw in prompt:
return True
return False
def is_cinematic_image_prompt(prompt: str) -> bool:
for kw in CREATIVE_KEYWORDS:
if kw in prompt.lower():
return True
return False
PKEY = os.getenv("POLLINATIONS_KEY", "")
PKEY2 = os.getenv("POLLINATIONS2_KEY", "")
PKEY3 = os.getenv("POLLINATIONS3_KEY", "")
GROQ_TOOL_MODELS = [
"openai/gpt-oss-120b",
"openai/gpt-oss-20b",
"meta-llama/llama-4-scout-17b-16e-instruct",
"qwen/qwen3-32b",
"moonshotai/kimi-k2-instruct",
]
GROQ_NORMAL_MODELS = [
"llama-3.1-8b-instant",
"llama-3.3-70b-versatile",
"meta-llama/llama-4-maverick-17b-128e-instruct",
"meta-llama/llama-guard-4-12b",
"openai/gpt-oss-safeguard-20b",
"qwen/qwen3-32b",
]
CEREBRAS_MODELS = [
"gpt-oss-120b",
"llama3.1-8b",
"qwen-3-235b-a22b-instruct-2507",
"zai-glm-4.7",
]
async def check_chat_rate_limit(
request: Request,
authorization: Optional[str],
client_id: Optional[str] = None,
):
return await enforce_rate_limit(request, authorization, "cloudChatDaily", client_id)
@app.head("/status/sfx")
async def head_sfx():
return Response(
status_code=200,
headers={
"Content-Type": "audio/mpeg",
"Accept-Ranges": "bytes",
},
)
@app.head("/status/image")
async def head_image():
return Response(
status_code=200,
headers={
"Content-Type": "image/jpeg",
"Accept-Ranges": "bytes",
},
)
@app.head("/status/video")
async def head_video():
return Response(
status_code=200,
headers={
"Content-Type": "video/mp4",
"Accept-Ranges": "bytes",
},
)
@app.head("/status/text")
async def head_text():
return Response(
status_code=200,
headers={
"Content-Type": "application/json",
"Accept-Ranges": "bytes",
},
)
@app.get("/status")
async def get_status():
notify = "Added the studio for professional media generation. Available in v2.8.0"
services = {
"Video Generation": {"code": 200, "state": "ok", "message": "Running Normally"},
"Image Generation": {"code": 200, "state": "ok", "message": "Running Normally"},
"Lightning-Text v2": {
"code": 200,
"state": "ok",
"message": "Running normally",
},
"Music/SFX Generation": {
"code": 200,
"state": "ok",
"message": "Running normally",
},
}
overall_state = (
"ok" if all(s["state"] == "ok" for s in services.values()) else "degraded"
)
return JSONResponse(
status_code=200,
content={
"state": overall_state,
"services": services,
"notifications": notify,
"latest": "2.8.0",
},
)
@app.post("/gen/image")
@app.get("/genimg/{prompt}")
async def generate_image(
request: Request,
prompt: str = None,
authorization: Optional[str] = Header(None),
x_client_id: Optional[str] = Header(None),
):
"""
Image generation endpoint.
--------------------------------------------------------------
• Accepts a plain‑text prompt (GET or JSON body).
• Optional JSON fields:
- mode: "fantasy" | "realistic" (keeps current behaviour)
- image_urls: list of up to 2 image URLs or base‑64 strings
• If *any* image is supplied we always use the Pollinations
model **flux-2-dev** (the “editing” model). Otherwise the
original heuristic (flux / zimage) is retained.
• Base‑64 images are saved temporarily with the helper
`save_base64_image` and served from the asset CDN exactly
like the video endpoint does.
--------------------------------------------------------------
"""
timeout = httpx.Timeout(300.0, read=300.0)
payload: Dict[str, Any] = {}
if prompt is None:
payload = await request.json()
prompt = payload.get("prompt")
mode = payload.get("mode")
image_urls = payload.get("image_urls")
else:
mode = request.query_params.get("mode")
image_urls = request.query_params.getlist("image_urls")
payload = {}
prompt = normalize_prompt_value(prompt, "prompt")
enforce_prompt_size(
prompt, MAX_MEDIA_PROMPT_CHARS, MAX_MEDIA_PROMPT_BYTES, "Image prompt"
)
await check_image_rate_limit(request, authorization, x_client_id)
chosen_model = "zimage"
if is_cinematic_image_prompt(prompt):
chosen_model = "flux"
if isinstance(mode, str):
normalized_mode = mode.strip().lower()
if normalized_mode == "fantasy":
chosen_model = "flux"
elif normalized_mode == "realistic":
chosen_model = "zimage"
has_input_image = False
temp_assets: List[str] = []
if image_urls:
if not isinstance(image_urls, list):
raise HTTPException(400, "image_urls must be a list")
if len(image_urls) > 4:
raise HTTPException(400, "Maximum of four image URLs allowed")
has_input_image = True
if has_input_image:
chosen_model = "klein"
params = {
"model": chosen_model,
"key": PKEY2,
}
if has_input_image:
processed_urls: List[str] = []
for img in image_urls[:2]:
if is_base64_image(img):
image_id = save_base64_image(img)
temp_assets.append(image_id)
served_url = f"{request.base_url}asset-cdn/assets/{image_id}"
processed_urls.append(served_url)
else:
processed_urls.append(img)
params["image"] = "|".join(processed_urls)
encoded_prompt = quote(prompt, safe="")
query_string = "&".join(f"{k}={quote(str(v), safe='')}" for k, v in params.items())
url = f"https://gen.pollinations.ai/image/{encoded_prompt}?{query_string}"
try:
async with httpx.AsyncClient(timeout=timeout) as client:
response = await client.get(url)
finally:
for aid in temp_assets:
cleanup_image(aid)
if response.status_code != 200:
raise HTTPException(
status_code=500,
detail=f"Pollinations error: {response.status_code}",
)
return Response(content=response.content, media_type="image/jpeg")
@app.head("/models")
@app.get("/models")
async def get_models() -> List[Dict]:
async with httpx.AsyncClient() as client:
response = await client.get(OLLAMA_LIBRARY_URL)
html = response.text
soup = BeautifulSoup(html, "html.parser")
items = soup.select("li[x-test-model]")
models = []
for item in items:
name = item.select_one("[x-test-model-title] span")
description = item.select_one("p.max-w-lg")
sizes = [el.get_text(strip=True) for el in item.select("[x-test-size]")]
pulls = item.select_one("[x-test-pull-count]")
tags = [
t.get_text(strip=True) for t in item.select('span[class*="text-blue-600"]')
]
updated = item.select_one("[x-test-updated]")
link = item.select_one("a")
models.append(
{
"name": name.get_text(strip=True) if name else "",
"description": (
description.get_text(strip=True)
if description
else "No description"
),
"sizes": sizes,
"pulls": pulls.get_text(strip=True) if pulls else "Unknown",
"tags": tags,
"updated": updated.get_text(strip=True) if updated else "Unknown",
"link": link.get("href") if link else None,
}
)
return models
@app.post("/gen/chat/completions")
async def generate_text(
request: Request,
authorization: Optional[str] = Header(None),
x_client_id: Optional[str] = Header(None),
):
body = await request.json()
messages = body.get("messages", [])
if not isinstance(messages, list) or len(messages) == 0:
raise HTTPException(400, "messages[] is required")
total_chars, total_bytes = calculate_messages_size(messages)
# if total_chars > MAX_CHAT_PROMPT_CHARS or total_bytes > MAX_CHAT_PROMPT_BYTES:
# raise HTTPException(
# status_code=413,
# detail=(
# f"Prompt context too large ({total_chars} chars, {total_bytes} bytes). "
# f"Max allowed is {MAX_CHAT_PROMPT_CHARS} chars or {MAX_CHAT_PROMPT_BYTES} bytes."
# ),
# )
prompt_text = extract_user_text(messages)
uses_tools = (
"tools" in body and isinstance(body["tools"], list) and len(body["tools"]) > 0
) or ("tool_choice" in body and body["tool_choice"] not in [None, "none"])
long_context = is_long_context(messages)
code_present = contains_code(prompt_text)
math_heavy = is_math_heavy(prompt_text)
structured_task = is_structured_task(prompt_text)
multi_q = multiple_questions(prompt_text)
code_heavy = is_code_heavy(prompt_text, code_present, long_context)
score = 0
if long_context:
score += 3
if math_heavy:
score += 3
if structured_task:
score += 2
if code_present:
score += 2
if multi_q:
score += 1
for kw in REASONING_KEYWORDS:
if kw in prompt_text:
score += 1
chosen_model = "llama-3.1-8b-instant"
provider = "groq"
has_images = contains_images(messages)
if has_images:
chosen_model = "meta-llama/llama-4-scout-17b-16e-instruct"
provider = "groq"
else:
if score > 10:
score = 10
if uses_tools:
if score >= 4:
chosen_model = "openai/gpt-oss-120b"
else:
chosen_model = "openai/gpt-oss-20b"
provider = "groq"
elif code_present:
if code_heavy and score >= 6:
chosen_model = "qwen-3-235b-a22b-instruct-2507"
provider = "cerebras"
elif score >= 4:
chosen_model = "llama-3.3-70b-versatile"
provider = "groq"
elif score >= 4:
chosen_model = "meta-llama/llama-4-scout-17b-16e-instruct"
provider = "groq"
if provider == "groq" and (
total_chars > MAX_GROQ_PROMPT_CHARS or total_bytes > MAX_GROQ_PROMPT_BYTES
):
provider = "cerebras"
chosen_model = "qwen-3-235b-a22b-instruct-2507"
await check_chat_rate_limit(request, authorization, x_client_id)
body["model"] = chosen_model
print(
f"""
[ADVANCED ROUTER]
Score: {score}
Uses tools: {uses_tools}
Long context: {long_context}
Code present: {code_present}
Math heavy: {math_heavy}
Structured: {structured_task}
Multi-question: {multi_q}
MULTIMODAL REQUIRED: {has_images}
→ Selected: {chosen_model} ({provider})
"""
)
stream = body.get("stream", False)
if provider == "groq":
groq_keys = os.getenv("GROQ_KEY", "")
print(f"ENV VAR: {groq_keys}")
groq_keys_list = [k.strip() for k in groq_keys.split(",") if k.strip()]
print(f"PARSED ENV VAR LIST: {groq_keys_list}")
if not groq_keys_list:
raise HTTPException(500, "Missing GROQ_KEY(s)")
API_KEY = random.choice(groq_keys_list)
print(f"SELECTED API KEY: {API_KEY}")
url = "https://api.groq.com/openai/v1/chat/completions"
elif provider == "cerebras":
cer_keys = os.getenv("CER_KEY", "")
cer_keys_list = [k.strip() for k in cer_keys.split(",") if k.strip()]
if not cer_keys_list:
raise HTTPException(500, "Missing CER_KEY(s)")
API_KEY = random.choice(cer_keys_list)
url = "https://api.cerebras.ai/v1/chat/completions"
else:
raise HTTPException(500, "Unknown provider routing error")
headers = {"Authorization": f"Bearer {API_KEY}"}
if stream:
body["stream"] = True
async def event_generator():
try:
async with httpx.AsyncClient(timeout=None) as client:
async with client.stream(
"POST",
url,
json=body,
headers=headers,
) as r:
if r.status_code >= 400:
error_payload = ""
try:
error_payload = (
(await r.aread()).decode("utf-8", errors="replace")
)[:800]
except Exception:
error_payload = ""
safe_error_payload = (
error_payload.replace("\\", "\\\\")
.replace('"', '\\"')
.replace("\n", " ")
.replace("\r", " ")
)
yield (
'data: {"error": '
f'"Upstream provider error ({r.status_code}): {safe_error_payload}"'
"}\n\n"
)
return
async for line in r.aiter_lines():
if line == "":
yield "\n"
continue
yield line + "\n"
except asyncio.CancelledError:
return
except Exception as e:
yield f'data: {{"error": "{str(e)}"}}\n\n'
return StreamingResponse(
event_generator(),
media_type="text/event-stream",
headers={
"Cache-Control": "no-cache",
"Connection": "keep-alive",
"X-Accel-Buffering": "no", # critical for nginx
},
)
else:
async with httpx.AsyncClient(timeout=None) as client:
r = await client.post(url, json=body, headers=headers)
content_type = (r.headers.get("content-type") or "").lower()
if "application/json" in content_type:
try:
payload = r.json()
except Exception:
payload = {"error": "Upstream returned invalid JSON"}
else:
payload = {
"error": "Upstream returned non-JSON response",
"status_code": r.status_code,
"message": r.text[:1000],
}
return JSONResponse(status_code=r.status_code, content=payload)
raise HTTPException(500, "Unknown provider routing error")
@app.get("/gen/sfx/{prompt}")
@app.post("/gen/sfx")
async def gensfx(
request: Request,
prompt: str = None,
authorization: Optional[str] = Header(None),
x_client_id: Optional[str] = Header(None),
):
payload: Dict[str, Any] = {}
if prompt is None:
payload = await request.json()
prompt = payload.get("prompt")
prompt = normalize_prompt_value(prompt, "prompt")
enforce_prompt_size(
prompt, MAX_MEDIA_PROMPT_CHARS, MAX_MEDIA_PROMPT_BYTES, "Audio prompt"
)
await check_audio_rate_limit(request, authorization, x_client_id)
url = f"https://gen.pollinations.ai/audio/{prompt}?model=acestep&key={PKEY}"
async with httpx.AsyncClient(timeout=None) as client:
response = await client.get(url)
body_text = ""
try:
body_text = response.text
except Exception:
pass
if response.status_code != 200:
return JSONResponse(
status_code=response.status_code,
content={
"success": False,
"error": "Upstream music/sfx generation failed",
"status_code": response.status_code,
"message": body_text[:1000],
},
)
return Response(response.content, media_type="audio/mpeg")
@app.get("/gen/tts/{prompt}")
@app.post("/gen/tts")
async def gensfx(
request: Request,
prompt: str = None,
authorization: Optional[str] = Header(None),
x_client_id: Optional[str] = Header(None),
):
payload: Dict[str, Any] = {}
if prompt is None:
payload = await request.json()
prompt = payload.get("prompt")
prompt = normalize_prompt_value(prompt, "prompt")
enforce_prompt_size(
prompt, MAX_MEDIA_PROMPT_CHARS, MAX_MEDIA_PROMPT_BYTES, "Audio prompt"
)
await check_audio_rate_limit(request, authorization, x_client_id)
url = f"https://gen.pollinations.ai/audio/{prompt}?key={PKEY3}"
async with httpx.AsyncClient(timeout=None) as client:
response = await client.get(url)
body_text = ""
try:
body_text = response.text
except Exception:
pass
if response.status_code != 200:
return JSONResponse(
status_code=response.status_code,
content={
"success": False,
"error": "Upstream audio generation failed",
"status_code": response.status_code,
"message": body_text[:1000],
},
)
return Response(response.content, media_type="audio/mpeg")
@app.get("/gen/video/{prompt}")
@app.post("/gen/video")
@app.head("/gen/video")
async def genvideo_airforce(
request: Request,
prompt: str = None,
authorization: Optional[str] = Header(None),
x_client_id: Optional[str] = Header(None),
):
if request.method == "HEAD":
return Response(
status_code=200,
headers={
"Y-prompt": "string — required. The text prompt used to generate the video.",
"Y-ratio": "string — optional. Aspect ratio of the output video.",
"Y-ratio-values": "3:2,2:3,1:1",
"Y-ratio-default": "3:2",
"Y-mode": "string — optional. Controls generation style.",
"Y-mode-values": "normal,fun",
"Y-mode-default": "normal",
"Y-duration": "integer — optional. Duration in seconds (1–10).",
"Y-duration-default": "5",
"Y-image_urls": "array<string> — optional. Up to 2 image URLs for conditioning.",
"Y-image_urls-max": "2",
"Y-response_format": "video/mp4",
"Y-model": "grok-video",
},
)
aspectRatio = "3:2"
inputMode = "normal"
duration = 5
image_urls = None
ratio = None
mode = None
if prompt is None:
user_body = await request.json()
prompt = user_body.get("prompt")
ratio = user_body.get("ratio")
mode = user_body.get("mode")
image_urls = user_body.get("image_urls")
duration = user_body.get("duration", 5)
if ratio not in valid_ratios:
raise HTTPException(
status_code=400,
detail=f"Invalid aspect ratio '{ratio}'. Must be one of 3:2, 2:3, or 1:1.",
)
if ratio in ratios:
aspectRatio = ratio
if mode not in valid_modes:
raise HTTPException(
status_code=400,
detail=f"Invalid mode '{mode}'. Must be 'normal' or 'fun'.",
)
if mode in modes:
inputMode = mode
if image_urls:
if not isinstance(image_urls, list):
raise HTTPException(400, "image_urls must be a list")
if len(image_urls) > 2:
raise HTTPException(400, "You may provide at most two image URLs")
# Clamp duration
try:
duration = max(1, min(10, int(duration)))
except (TypeError, ValueError):
duration = 5
prompt = normalize_prompt_value(prompt, "prompt")
enforce_prompt_size(
prompt, MAX_MEDIA_PROMPT_CHARS, MAX_MEDIA_PROMPT_BYTES, "Video prompt"
)
await check_video_rate_limit(request, authorization, x_client_id)
RATIO_MAP = {
"3:2": "16:9",
"2:3": "9:16",
"1:1": "9:16",
}
pollinations_ratio = RATIO_MAP.get(aspectRatio, "16:9")
encoded_prompt = quote(prompt, safe="")
params = {
"model": "ltx-2",
"duration": duration,
"aspectRatio": pollinations_ratio,
"seed": -1,
}
temp_assets = []
if image_urls:
processed_urls = []
for img in image_urls[:2]:
if is_base64_image(img):
image_id = save_base64_image(img)
temp_assets.append(image_id)
served_url = f"{request.base_url}asset-cdn/assets/{image_id}"
processed_urls.append(served_url)
else:
processed_urls.append(img)
params["image"] = "|".join(processed_urls)
if inputMode == "fun":
params["enhance"] = "true"
query_string = "&".join(f"{k}={quote(str(v), safe='')}" for k, v in params.items())
url = f"https://gen.pollinations.ai/image/{encoded_prompt}?{query_string}"
print(f"[VIDEO GEN] Pollinations URL: {url}")
url = url + f"&key={PKEY}"
resp = None
try:
async with httpx.AsyncClient(timeout=600) as client:
resp = await client.get(url)
finally:
for aid in temp_assets:
cleanup_image(aid)
if resp is None:
raise HTTPException(502, "Video generation request failed")
if resp.status_code != 200:
body_text = ""
try:
body_text = resp.text
except Exception:
pass
return JSONResponse(
status_code=resp.status_code,
content={
"success": False,
"error": "Upstream video generation failed",
"status_code": resp.status_code,
"message": body_text[:1000],
},
)
if not resp.content:
raise HTTPException(502, "Pollinations returned empty response")
return Response(
content=resp.content,
media_type="video/mp4",
headers={
"Content-Length": str(len(resp.content)),
"Accept-Ranges": "bytes",
},
)
AIRFORCE_KEY = os.getenv("AIRFORCE")
AIRFORCE_VIDEO_MODEL = "grok-imagine-video"
AIRFORCE_API_URL = "https://api.airforce/v1/images/generations"
valid_ratios = {"3:2", "2:3", "1:1", "", None}
ratios = {"3:2", "2:3", "1:1"}
valid_modes = {"normal", "fun", "", None}
modes = {"normal", "fun"}
MAX_VIDEO_RETRIES = 6
@app.get("/gen/video/airforce/{prompt}")
@app.post("/gen/video/airforce")
@app.head("/gen/video/airforce")
async def genvideo_airforce(
request: Request,
prompt: str = None,
authorization: Optional[str] = Header(None),
x_client_id: Optional[str] = Header(None),
):
if request.method == "HEAD":
return Response(
status_code=200,
headers={
# Required field
"Y-prompt": "string — required. The text prompt used to generate the video.",
# Optional fields
"Y-ratio": "string — optional. Aspect ratio of the output video.",
"Y-ratio-values": "3:2,2:3,1:1",
"Y-ratio-default": "3:2",
"Y-mode": "string — optional. Controls generation style.",
"Y-mode-values": "normal,fun",
"Y-mode-default": "normal",
"Y-duration": "integer — optional. Duration in seconds.",
"Y-duration-default": "5",
"Y-image_urls": "array<string> — optional. Up to 2 image URLs for conditioning.",
"Y-image_urls-max": "2",
# Response format
"Y-response_format": "video/mp4",
# Model info
"Y-model": "grok-imagine-video",
},
)
aspectRatio = "3:2"
inputMode = "normal"
image_urls = None
ratio = None
mode = None
user_body = {}
if prompt is None:
user_body = await request.json()
prompt = user_body.get("prompt")
ratio = user_body.get("ratio")
mode = user_body.get("mode")
image_urls = user_body.get("image_urls")
if ratio not in valid_ratios:
raise HTTPException(
status_code=400,
detail=f"Invalid aspect ratio {ratio}. Must be one of 3:2, 2:3, or 1:1. Default is 3:2",
)
if ratio in ratios:
aspectRatio = ratio
if mode not in valid_modes:
raise HTTPException(
status_code=400,
detail=f"Invalid mode {mode}. Must be 'normal' or 'fun'. Default is normal",
)
if mode in modes:
inputMode = mode
if image_urls:
if not isinstance(image_urls, list):
raise HTTPException(400, "image_urls must be a list")
if len(image_urls) > 2:
raise HTTPException(400, "You may provide at most two image URLs")
prompt = normalize_prompt_value(prompt, "prompt")
enforce_prompt_size(
prompt, MAX_MEDIA_PROMPT_CHARS, MAX_MEDIA_PROMPT_BYTES, "Video prompt"
)
await check_video_rate_limit(request, authorization, x_client_id)
payload = {
"model": AIRFORCE_VIDEO_MODEL,
"prompt": prompt,
"n": 1,
"size": "1024x1024",
"response_format": "b64_json",
"sse": False,
"mode": inputMode,
"aspectRatio": aspectRatio,
}
if image_urls:
payload["image_urls"] = image_urls
async with httpx.AsyncClient(timeout=600) as client:
resp = await client.post(
AIRFORCE_API_URL,
headers={
"Authorization": f"Bearer {AIRFORCE_KEY}",
"Content-Type": "application/json",
},
json=payload,
)
if resp.status_code != 200:
return JSONResponse(status_code=resp.status_code, content=resp.json())
if not resp.content:
raise HTTPException(502, "api.airforce returned empty response")
try:
result = resp.json()
b64_video = result["data"][0]["b64_json"]
except Exception:
raise HTTPException(502, f"Invalid api.airforce response: {resp.text[:500]}")
if not b64_video:
raise HTTPException(502, "Airforce returned empty b64_json")
video_bytes = base64.b64decode(b64_video)
return Response(
content=video_bytes,
media_type="video/mp4",
headers={
"Content-Length": str(len(video_bytes)),
"Accept-Ranges": "bytes",
},
)
@app.get("/subscription")
async def get_subscription(authorization: Optional[str] = Header(None)):
if not authorization or not authorization.startswith("Bearer "):
raise HTTPException(401, "Missing or invalid Authorization header")
jwt = authorization.split(" ", 1)[1]
result = await fetch_subscription(jwt)
if "error" in result:
raise HTTPException(401, result["error"])
plan_key = normalize_plan_key(result.get("plan_key"))
result["plan_key"] = plan_key
result["plan_name"] = (TIER_CONFIG.get(plan_key) or TIER_CONFIG["free"])["name"]
return result
@app.get("/usage")
async def get_usage(
request: Request,
authorization: Optional[str] = Header(None),
x_client_id: Optional[str] = Header(None),
):
plan_key, subject = await resolve_rate_limit_identity(
request, authorization, x_client_id
)
plan = TIER_CONFIG.get(plan_key) or TIER_CONFIG["free"]
usage = get_usage_snapshot_for_subject(plan_key, subject)
return JSONResponse(
status_code=200,
content={
"plan_key": plan_key,
"plan_name": plan.get("name", "Free Tier"),
"usage": usage,
"generated_at": time.strftime("%Y-%m-%dT%H:%M:%SZ", time.gmtime()),
},
)
@app.get("/tier-config")
async def tier_config():
plans = []
for idx, key in enumerate(PLAN_ORDER):
plan = TIER_CONFIG.get(key)
if not plan:
continue
plans.append(
{
"key": key,
"name": plan["name"],
"url": plan["url"],
"price": plan["price"],
"limits": plan["limits"],
"order": idx,
}
)
return JSONResponse(
status_code=200,
content={
"defaultPlanKey": "free",
"plans": plans,
},
)
@app.get("/tiers")
async def tiers():
paid_plans = []
for key in PLAN_ORDER:
if key == "free":
continue
plan = TIER_CONFIG.get(key)
if not plan:
continue
paid_plans.append(
{
"key": key,
"name": plan["name"],
"url": plan["url"],
"price": plan["price"],
"limits": plan["limits"],
}
)
return JSONResponse(
status_code=200,
content=paid_plans,
)
@app.websocket("/ws/chat")
async def websocket_chat(ws: WebSocket):
ip = ws.client.host
await ws.accept()
# rate limit auth attempts
if not check_ws_auth_rate_limit(ip):
await ws.close(code=4408)
return
try:
auth_msg = await ws.receive_text()
auth_data = json.loads(auth_msg)
provided_key = auth_data.get("key")
if not WEBSOCKET_KEY or provided_key != WEBSOCKET_KEY:
await ws.close(code=4403)
return
# authenticated
await ws.send_json({"type": "auth", "status": "ok"})
while True:
msg = await ws.receive_text()
data = json.loads(msg)
body = data.get("body")
headers = data.get("headers", {})
if not body:
await ws.send_json({"error": "Missing body"})
continue
url = str(ws.url).replace("ws://", "http://").replace("wss://", "https://")
url = url.split("/ws/chat")[0] + "/gen/chat/completions"
async with httpx.AsyncClient(timeout=None) as client:
async with client.stream(
"POST",
url,
json=body,
headers=headers,
) as r:
async for line in r.aiter_lines():
if not line:
continue
await ws.send_text(line)
except WebSocketDisconnect:
return
except Exception as e:
try:
await ws.send_json({"error": str(e)})
except:
pass
await ws.close()
@app.get("/portal")
@app.post("/portal")
async def redirect_to_protal(request: Request):
email = None
if request.method == "POST":
try:
body = await request.json()
email = body.get("email")
except:
email = None
base_url = "https://billing.stripe.com/p/login/5kQdR9aIM3ts4steyabbG00"
if not email:
return RedirectResponse(url=base_url, status_code=status.HTTP_302_FOUND)
if request.method != "POST":
return RedirectResponse(
url=f"{base_url}?prefilled_email={email}", status_code=status.HTTP_302_FOUND
)
else:
return JSONResponse({"redirect_url": (base_url + "?prefilled_email=" + email)})