lightning / app.py
sharktide's picture
Update app.py
da410a9 verified
import os
import time
import hashlib
from fastapi import FastAPI, Request, HTTPException, status, Header
from fastapi.middleware.cors import CORSMiddleware
from fastapi.responses import (
Response,
JSONResponse,
StreamingResponse,
RedirectResponse,
)
import httpx
from bs4 import BeautifulSoup
from typing import List, Dict, Any
import asyncio
import re
import random
from urllib.parse import quote
import base64
from helper.subscriptions import (
fetch_subscription,
normalize_plan_key,
TIER_CONFIG,
PLAN_ORDER,
)
from typing import Optional
from helper.keywords import *
from helper.assets import (
save_base64_image,
cleanup_image,
is_base64_image,
asset_router,
)
from helper.ratelimit import (
enforce_rate_limit,
resolve_rate_limit_identity,
check_audio_rate_limit,
check_video_rate_limit,
check_image_rate_limit,
MAX_CHAT_PROMPT_BYTES,
MAX_CHAT_PROMPT_CHARS,
MAX_GROQ_PROMPT_BYTES,
MAX_GROQ_PROMPT_CHARS,
MAX_MEDIA_PROMPT_BYTES,
MAX_MEDIA_PROMPT_CHARS,
extract_user_text,
calculate_messages_size,
normalize_prompt_value,
enforce_prompt_size,
resolve_bound_subject,
get_usage_snapshot_for_subject,
)
app = FastAPI()
app.add_middleware(
CORSMiddleware,
allow_origins=["*"],
allow_methods=["GET", "POST", "HEAD"],
allow_headers=["*"],
)
app.include_router(asset_router)
@app.get("/")
async def reroute_to_home():
return RedirectResponse(
url="https://inference.js.org", status_code=status.HTTP_308_PERMANENT_REDIRECT
)
OLLAMA_LIBRARY_URL = "https://ollama.com/library"
def is_complex_reasoning(prompt: str) -> bool:
if len(prompt) > 800:
return True
for kw in REASONING_KEYWORDS:
if kw in prompt:
return True
if re.search(r"\b(if|therefore|assume|let x|given that)\b", prompt):
return True
return False
def is_lightweight(prompt: str) -> bool:
if len(prompt) < 100:
for kw in LIGHTWEIGHT_KEYWORDS:
if kw in prompt:
return True
return False
def is_cinematic_image_prompt(prompt: str) -> bool:
for kw in CREATIVE_KEYWORDS:
if kw in prompt.lower():
return True
return False
PKEY = os.getenv("POLLINATIONS_KEY", "")
PKEY2 = os.getenv("POLLINATIONS2_KEY", "")
PKEY3 = os.getenv("POLLINATIONS3_KEY", "")
GROQ_TOOL_MODELS = [
"openai/gpt-oss-120b",
"openai/gpt-oss-20b",
"meta-llama/llama-4-scout-17b-16e-instruct",
"qwen/qwen3-32b",
"moonshotai/kimi-k2-instruct",
]
GROQ_NORMAL_MODELS = [
"llama-3.1-8b-instant",
"llama-3.3-70b-versatile",
"meta-llama/llama-4-maverick-17b-128e-instruct",
"meta-llama/llama-guard-4-12b",
"openai/gpt-oss-safeguard-20b",
"qwen/qwen3-32b",
]
CEREBRAS_MODELS = [
"gpt-oss-120b",
"llama3.1-8b",
"qwen-3-235b-a22b-instruct-2507",
"zai-glm-4.7",
]
async def check_chat_rate_limit(
request: Request,
authorization: Optional[str],
client_id: Optional[str] = None,
):
return await enforce_rate_limit(request, authorization, "cloudChatDaily", client_id)
@app.head("/status/sfx")
async def head_sfx():
return Response(
status_code=200,
headers={
"Content-Type": "audio/mpeg",
"Accept-Ranges": "bytes",
},
)
@app.head("/status/image")
async def head_image():
return Response(
status_code=200,
headers={
"Content-Type": "image/jpeg",
"Accept-Ranges": "bytes",
},
)
@app.head("/status/video")
async def head_video():
return Response(
status_code=200,
headers={
"Content-Type": "video/mp4",
"Accept-Ranges": "bytes",
},
)
@app.head("/status/text")
async def head_text():
return Response(
status_code=200,
headers={
"Content-Type": "application/json",
"Accept-Ranges": "bytes",
},
)
@app.get("/status")
async def get_status():
notify = ""
services = {
"Video Generation": {"code": 200, "state": "ok", "message": "Running normally"},
"Image Generation": {"code": 200, "state": "ok", "message": "Running normally"},
"Lightning-Text v2": {
"code": 200,
"state": "ok",
"message": "Running normally",
},
"Music/SFX Generation": {
"code": 200,
"state": "ok",
"message": "Running normally",
},
}
overall_state = (
"ok" if all(s["state"] == "ok" for s in services.values()) else "degraded"
)
return JSONResponse(
status_code=200,
content={
"state": overall_state,
"services": services,
"notifications": notify,
"latest": "2.4.0",
},
)
@app.post("/gen/image")
@app.get("/genimg/{prompt}")
async def generate_image(
request: Request,
prompt: str = None,
authorization: Optional[str] = Header(None),
x_client_id: Optional[str] = Header(None),
):
timeout = httpx.Timeout(300.0, read=300.0)
payload: Dict[str, Any] = {}
if prompt is None:
payload = await request.json()
prompt = payload.get("prompt")
prompt = normalize_prompt_value(prompt, "prompt")
enforce_prompt_size(
prompt, MAX_MEDIA_PROMPT_CHARS, MAX_MEDIA_PROMPT_BYTES, "Image prompt"
)
await check_image_rate_limit(request, authorization, x_client_id)
mode = payload.get("mode") if isinstance(payload, dict) else None
if is_cinematic_image_prompt(prompt):
chosen_model = "flux"
else:
chosen_model = "zimage"
if isinstance(mode, str):
normalized_mode = mode.strip().lower()
if normalized_mode == "fantasy":
chosen_model = "flux"
elif normalized_mode == "realistic":
chosen_model = "zimage"
print(f"[IMAGE GEN] Routing to model: {chosen_model}")
url = f"https://gen.pollinations.ai/image/{quote(prompt, safe='')}?model={chosen_model}&key={PKEY2}"
async with httpx.AsyncClient(timeout=timeout) as client:
response = await client.get(url)
if response.status_code != 200:
raise HTTPException(
status_code=500, detail=f"Pollinations error: {response.status_code}"
)
return Response(content=response.content, media_type="image/jpeg")
@app.head("/models")
@app.get("/models")
async def get_models() -> List[Dict]:
async with httpx.AsyncClient() as client:
response = await client.get(OLLAMA_LIBRARY_URL)
html = response.text
soup = BeautifulSoup(html, "html.parser")
items = soup.select("li[x-test-model]")
models = []
for item in items:
name = item.select_one("[x-test-model-title] span")
description = item.select_one("p.max-w-lg")
sizes = [el.get_text(strip=True) for el in item.select("[x-test-size]")]
pulls = item.select_one("[x-test-pull-count]")
tags = [
t.get_text(strip=True) for t in item.select('span[class*="text-blue-600"]')
]
updated = item.select_one("[x-test-updated]")
link = item.select_one("a")
models.append(
{
"name": name.get_text(strip=True) if name else "",
"description": (
description.get_text(strip=True)
if description
else "No description"
),
"sizes": sizes,
"pulls": pulls.get_text(strip=True) if pulls else "Unknown",
"tags": tags,
"updated": updated.get_text(strip=True) if updated else "Unknown",
"link": link.get("href") if link else None,
}
)
return models
@app.post("/gen/chat/completions")
async def generate_text(
request: Request,
authorization: Optional[str] = Header(None),
x_client_id: Optional[str] = Header(None),
):
body = await request.json()
messages = body.get("messages", [])
if not isinstance(messages, list) or len(messages) == 0:
raise HTTPException(400, "messages[] is required")
total_chars, total_bytes = calculate_messages_size(messages)
if total_chars > MAX_CHAT_PROMPT_CHARS or total_bytes > MAX_CHAT_PROMPT_BYTES:
raise HTTPException(
status_code=413,
detail=(
f"Prompt context too large ({total_chars} chars, {total_bytes} bytes). "
f"Max allowed is {MAX_CHAT_PROMPT_CHARS} chars or {MAX_CHAT_PROMPT_BYTES} bytes."
),
)
prompt_text = extract_user_text(messages)
uses_tools = (
"tools" in body and isinstance(body["tools"], list) and len(body["tools"]) > 0
) or ("tool_choice" in body and body["tool_choice"] not in [None, "none"])
long_context = is_long_context(messages)
code_present = contains_code(prompt_text)
math_heavy = is_math_heavy(prompt_text)
structured_task = is_structured_task(prompt_text)
multi_q = multiple_questions(prompt_text)
code_heavy = is_code_heavy(prompt_text, code_present, long_context)
score = 0
if long_context:
score += 3
if math_heavy:
score += 3
if structured_task:
score += 2
if code_present:
score += 2
if multi_q:
score += 1
for kw in REASONING_KEYWORDS:
if kw in prompt_text:
score += 1
chosen_model = "llama-3.1-8b-instant"
provider = "groq"
has_images = contains_images(messages)
if has_images:
chosen_model = "meta-llama/llama-4-scout-17b-16e-instruct"
provider = "groq"
else:
if score > 10:
score = 10
if uses_tools:
if score >= 4:
chosen_model = "openai/gpt-oss-120b"
else:
chosen_model = "openai/gpt-oss-20b"
provider = "groq"
elif code_present:
if code_heavy and score >= 6:
chosen_model = "gpt-oss-120b"
provider = "cerebras"
elif score >= 4:
chosen_model = "llama-3.3-70b-versatile"
provider = "groq"
elif score >= 4:
chosen_model = "meta-llama/llama-4-scout-17b-16e-instruct"
provider = "groq"
if provider == "groq" and (
total_chars > MAX_GROQ_PROMPT_CHARS or total_bytes > MAX_GROQ_PROMPT_BYTES
):
raise HTTPException(
status_code=413,
detail=(
f"Prompt exceeds Groq-safe size ({total_chars} chars, {total_bytes} bytes). "
f"Max Groq-safe size is {MAX_GROQ_PROMPT_CHARS} chars or {MAX_GROQ_PROMPT_BYTES} bytes."
),
)
await check_chat_rate_limit(request, authorization, x_client_id)
body["model"] = chosen_model
print(
f"""
[ADVANCED ROUTER]
Score: {score}
Uses tools: {uses_tools}
Long context: {long_context}
Code present: {code_present}
Math heavy: {math_heavy}
Structured: {structured_task}
Multi-question: {multi_q}
MULTIMODAL REQUIRED: {has_images}
→ Selected: {chosen_model} ({provider})
"""
)
stream = body.get("stream", False)
if provider == "groq":
groq_keys = os.getenv("GROQ_KEY", "")
groq_keys_list = [k.strip() for k in groq_keys.split(",") if k.strip()]
if not groq_keys_list:
raise HTTPException(500, "Missing GROQ_KEY(s)")
API_KEY = random.choice(groq_keys_list)
url = "https://api.groq.com/openai/v1/chat/completions"
elif provider == "cerebras":
cer_keys = os.getenv("CER_KEY", "")
cer_keys_list = [k.strip() for k in cer_keys.split(",") if k.strip()]
if not cer_keys_list:
raise HTTPException(500, "Missing CER_KEY(s)")
API_KEY = random.choice(cer_keys_list)
url = "https://api.cerebras.ai/v1/chat/completions"
else:
raise HTTPException(500, "Unknown provider routing error")
headers = {"Authorization": f"Bearer {API_KEY}"}
if stream:
body["stream"] = True
async def event_generator():
try:
async with httpx.AsyncClient(timeout=None) as client:
async with client.stream(
"POST",
url,
json=body,
headers=headers,
) as r:
if r.status_code >= 400:
error_payload = ""
try:
error_payload = (
(await r.aread()).decode("utf-8", errors="replace")
)[:800]
except Exception:
error_payload = ""
safe_error_payload = (
error_payload.replace("\\", "\\\\")
.replace('"', '\\"')
.replace("\n", " ")
.replace("\r", " ")
)
yield (
'data: {"error": '
f'"Upstream provider error ({r.status_code}): {safe_error_payload}"'
"}\n\n"
)
return
async for line in r.aiter_lines():
if line == "":
yield "\n"
continue
yield line + "\n"
except asyncio.CancelledError:
return
except Exception as e:
yield f'data: {{"error": "{str(e)}"}}\n\n'
return StreamingResponse(
event_generator(),
media_type="text/event-stream",
headers={
"Cache-Control": "no-cache",
"Connection": "keep-alive",
"X-Accel-Buffering": "no", # critical for nginx
},
)
else:
async with httpx.AsyncClient(timeout=None) as client:
r = await client.post(url, json=body, headers=headers)
content_type = (r.headers.get("content-type") or "").lower()
if "application/json" in content_type:
try:
payload = r.json()
except Exception:
payload = {"error": "Upstream returned invalid JSON"}
else:
payload = {
"error": "Upstream returned non-JSON response",
"status_code": r.status_code,
"message": r.text[:1000],
}
return JSONResponse(status_code=r.status_code, content=payload)
raise HTTPException(500, "Unknown provider routing error")
@app.get("/gen/sfx/{prompt}")
@app.post("/gen/sfx")
async def gensfx(
request: Request,
prompt: str = None,
authorization: Optional[str] = Header(None),
x_client_id: Optional[str] = Header(None),
):
payload: Dict[str, Any] = {}
if prompt is None:
payload = await request.json()
prompt = payload.get("prompt")
prompt = normalize_prompt_value(prompt, "prompt")
enforce_prompt_size(
prompt, MAX_MEDIA_PROMPT_CHARS, MAX_MEDIA_PROMPT_BYTES, "Audio prompt"
)
await check_audio_rate_limit(request, authorization, x_client_id)
url = f"https://gen.pollinations.ai/audio/{prompt}?model=elevenmusic&key={PKEY}"
async with httpx.AsyncClient(timeout=None) as client:
response = await client.get(url)
body_text = ""
try:
body_text = response.text
except Exception:
pass
if response.status_code != 200:
return JSONResponse(
status_code=response.status_code,
content={
"success": False,
"error": "Upstream music/sfx generation failed",
"status_code": response.status_code,
"message": body_text[:1000],
},
)
return Response(response.content, media_type="audio/mpeg")
@app.get("/gen/tts/{prompt}")
@app.post("/gen/tts")
async def gensfx(
request: Request,
prompt: str = None,
authorization: Optional[str] = Header(None),
x_client_id: Optional[str] = Header(None),
):
payload: Dict[str, Any] = {}
if prompt is None:
payload = await request.json()
prompt = payload.get("prompt")
prompt = normalize_prompt_value(prompt, "prompt")
enforce_prompt_size(
prompt, MAX_MEDIA_PROMPT_CHARS, MAX_MEDIA_PROMPT_BYTES, "Audio prompt"
)
await check_audio_rate_limit(request, authorization, x_client_id)
url = f"https://gen.pollinations.ai/audio/{prompt}?key={PKEY3}"
async with httpx.AsyncClient(timeout=None) as client:
response = await client.get(url)
body_text = ""
try:
body_text = response.text
except Exception:
pass
if response.status_code != 200:
return JSONResponse(
status_code=response.status_code,
content={
"success": False,
"error": "Upstream audio generation failed",
"status_code": response.status_code,
"message": body_text[:1000],
},
)
return Response(response.content, media_type="audio/mpeg")
@app.get("/gen/video/{prompt}")
@app.post("/gen/video")
@app.head("/gen/video")
async def genvideo_airforce(
request: Request,
prompt: str = None,
authorization: Optional[str] = Header(None),
x_client_id: Optional[str] = Header(None),
):
if request.method == "HEAD":
return Response(
status_code=200,
headers={
"Y-prompt": "string — required. The text prompt used to generate the video.",
"Y-ratio": "string — optional. Aspect ratio of the output video.",
"Y-ratio-values": "3:2,2:3,1:1",
"Y-ratio-default": "3:2",
"Y-mode": "string — optional. Controls generation style.",
"Y-mode-values": "normal,fun",
"Y-mode-default": "normal",
"Y-duration": "integer — optional. Duration in seconds (1–10).",
"Y-duration-default": "5",
"Y-image_urls": "array<string> — optional. Up to 2 image URLs for conditioning.",
"Y-image_urls-max": "2",
"Y-response_format": "video/mp4",
"Y-model": "grok-video",
},
)
aspectRatio = "3:2"
inputMode = "normal"
duration = 5
image_urls = None
ratio = None
mode = None
if prompt is None:
user_body = await request.json()
prompt = user_body.get("prompt")
ratio = user_body.get("ratio")
mode = user_body.get("mode")
image_urls = user_body.get("image_urls")
duration = user_body.get("duration", 5)
if ratio not in valid_ratios:
raise HTTPException(
status_code=400,
detail=f"Invalid aspect ratio '{ratio}'. Must be one of 3:2, 2:3, or 1:1.",
)
if ratio in ratios:
aspectRatio = ratio
if mode not in valid_modes:
raise HTTPException(
status_code=400,
detail=f"Invalid mode '{mode}'. Must be 'normal' or 'fun'.",
)
if mode in modes:
inputMode = mode
if image_urls:
if not isinstance(image_urls, list):
raise HTTPException(400, "image_urls must be a list")
if len(image_urls) > 2:
raise HTTPException(400, "You may provide at most two image URLs")
# Clamp duration
try:
duration = max(1, min(10, int(duration)))
except (TypeError, ValueError):
duration = 5
prompt = normalize_prompt_value(prompt, "prompt")
enforce_prompt_size(
prompt, MAX_MEDIA_PROMPT_CHARS, MAX_MEDIA_PROMPT_BYTES, "Video prompt"
)
await check_video_rate_limit(request, authorization, x_client_id)
RATIO_MAP = {
"3:2": "16:9",
"2:3": "9:16",
"1:1": "1:1",
}
pollinations_ratio = RATIO_MAP.get(aspectRatio, "16:9")
encoded_prompt = quote(prompt, safe="")
params = {
"model": "grok-video",
"duration": duration,
"aspectRatio": pollinations_ratio,
"seed": -1,
}
temp_assets = []
if image_urls:
processed_urls = []
for img in image_urls[:2]:
if is_base64_image(img):
image_id = save_base64_image(img)
temp_assets.append(image_id)
served_url = f"{request.base_url}asset-cdn/assets/{image_id}"
processed_urls.append(served_url)
else:
processed_urls.append(img)
params["image"] = "|".join(processed_urls)
if inputMode == "fun":
params["enhance"] = "true"
query_string = "&".join(f"{k}={quote(str(v), safe='')}" for k, v in params.items())
url = f"https://gen.pollinations.ai/image/{encoded_prompt}?{query_string}"
print(f"[VIDEO GEN] Pollinations URL: {url}")
url = url + f"&key={PKEY}"
resp = None
try:
async with httpx.AsyncClient(timeout=600) as client:
resp = await client.get(url)
finally:
for aid in temp_assets:
cleanup_image(aid)
if resp is None:
raise HTTPException(502, "Video generation request failed")
if resp.status_code != 200:
body_text = ""
try:
body_text = resp.text
except Exception:
pass
return JSONResponse(
status_code=resp.status_code,
content={
"success": False,
"error": "Upstream video generation failed",
"status_code": resp.status_code,
"message": body_text[:1000],
},
)
if not resp.content:
raise HTTPException(502, "Pollinations returned empty response")
return Response(
content=resp.content,
media_type="video/mp4",
headers={
"Content-Length": str(len(resp.content)),
"Accept-Ranges": "bytes",
},
)
AIRFORCE_KEY = os.getenv("AIRFORCE")
AIRFORCE_VIDEO_MODEL = "grok-imagine-video"
AIRFORCE_API_URL = "https://api.airforce/v1/images/generations"
valid_ratios = {"3:2", "2:3", "1:1", "", None}
ratios = {"3:2", "2:3", "1:1"}
valid_modes = {"normal", "fun", "", None}
modes = {"normal", "fun"}
MAX_VIDEO_RETRIES = 6
@app.get("/gen/video/airforce/{prompt}")
@app.post("/gen/video/airforce")
@app.head("/gen/video/airforce")
async def genvideo_airforce(
request: Request,
prompt: str = None,
authorization: Optional[str] = Header(None),
x_client_id: Optional[str] = Header(None),
):
if request.method == "HEAD":
return Response(
status_code=200,
headers={
# Required field
"Y-prompt": "string — required. The text prompt used to generate the video.",
# Optional fields
"Y-ratio": "string — optional. Aspect ratio of the output video.",
"Y-ratio-values": "3:2,2:3,1:1",
"Y-ratio-default": "3:2",
"Y-mode": "string — optional. Controls generation style.",
"Y-mode-values": "normal,fun",
"Y-mode-default": "normal",
"Y-duration": "integer — optional. Duration in seconds.",
"Y-duration-default": "5",
"Y-image_urls": "array<string> — optional. Up to 2 image URLs for conditioning.",
"Y-image_urls-max": "2",
# Response format
"Y-response_format": "video/mp4",
# Model info
"Y-model": "grok-imagine-video",
},
)
aspectRatio = "3:2"
inputMode = "normal"
image_urls = None
ratio = None
mode = None
user_body = {}
if prompt is None:
user_body = await request.json()
prompt = user_body.get("prompt")
ratio = user_body.get("ratio")
mode = user_body.get("mode")
image_urls = user_body.get("image_urls")
if ratio not in valid_ratios:
raise HTTPException(
status_code=400,
detail=f"Invalid aspect ratio {ratio}. Must be one of 3:2, 2:3, or 1:1. Default is 3:2",
)
if ratio in ratios:
aspectRatio = ratio
if mode not in valid_modes:
raise HTTPException(
status_code=400,
detail=f"Invalid mode {mode}. Must be 'normal' or 'fun'. Default is normal",
)
if mode in modes:
inputMode = mode
if image_urls:
if not isinstance(image_urls, list):
raise HTTPException(400, "image_urls must be a list")
if len(image_urls) > 2:
raise HTTPException(400, "You may provide at most two image URLs")
prompt = normalize_prompt_value(prompt, "prompt")
enforce_prompt_size(
prompt, MAX_MEDIA_PROMPT_CHARS, MAX_MEDIA_PROMPT_BYTES, "Video prompt"
)
await check_video_rate_limit(request, authorization, x_client_id)
payload = {
"model": AIRFORCE_VIDEO_MODEL,
"prompt": prompt,
"n": 1,
"size": "1024x1024",
"response_format": "b64_json",
"sse": False,
"mode": inputMode,
"aspectRatio": aspectRatio,
}
if image_urls:
payload["image_urls"] = image_urls
async with httpx.AsyncClient(timeout=600) as client:
resp = await client.post(
AIRFORCE_API_URL,
headers={
"Authorization": f"Bearer {AIRFORCE_KEY}",
"Content-Type": "application/json",
},
json=payload,
)
if resp.status_code != 200:
return JSONResponse(status_code=resp.status_code, content=resp.json())
if not resp.content:
raise HTTPException(502, "api.airforce returned empty response")
try:
result = resp.json()
b64_video = result["data"][0]["b64_json"]
except Exception:
raise HTTPException(502, f"Invalid api.airforce response: {resp.text[:500]}")
if not b64_video:
raise HTTPException(502, "Airforce returned empty b64_json")
video_bytes = base64.b64decode(b64_video)
return Response(
content=video_bytes,
media_type="video/mp4",
headers={
"Content-Length": str(len(video_bytes)),
"Accept-Ranges": "bytes",
},
)
@app.get("/subscription")
async def get_subscription(authorization: Optional[str] = Header(None)):
if not authorization or not authorization.startswith("Bearer "):
raise HTTPException(401, "Missing or invalid Authorization header")
jwt = authorization.split(" ", 1)[1]
result = await fetch_subscription(jwt)
if "error" in result:
raise HTTPException(401, result["error"])
plan_key = normalize_plan_key(result.get("plan_key"))
result["plan_key"] = plan_key
result["plan_name"] = (TIER_CONFIG.get(plan_key) or TIER_CONFIG["free"])["name"]
return result
@app.get("/usage")
async def get_usage(
request: Request,
authorization: Optional[str] = Header(None),
x_client_id: Optional[str] = Header(None),
):
plan_key, subject = await resolve_rate_limit_identity(
request, authorization, x_client_id
)
plan = TIER_CONFIG.get(plan_key) or TIER_CONFIG["free"]
usage = get_usage_snapshot_for_subject(plan_key, subject)
return JSONResponse(
status_code=200,
content={
"plan_key": plan_key,
"plan_name": plan.get("name", "Free Tier"),
"usage": usage,
"generated_at": time.strftime("%Y-%m-%dT%H:%M:%SZ", time.gmtime()),
},
)
@app.get("/tier-config")
async def tier_config():
plans = []
for idx, key in enumerate(PLAN_ORDER):
plan = TIER_CONFIG.get(key)
if not plan:
continue
plans.append(
{
"key": key,
"name": plan["name"],
"url": plan["url"],
"price": plan["price"],
"limits": plan["limits"],
"order": idx,
}
)
return JSONResponse(
status_code=200,
content={
"defaultPlanKey": "free",
"plans": plans,
},
)
@app.get("/tiers")
async def tiers():
paid_plans = []
for key in PLAN_ORDER:
if key == "free":
continue
plan = TIER_CONFIG.get(key)
if not plan:
continue
paid_plans.append(
{
"key": key,
"name": plan["name"],
"url": plan["url"],
"price": plan["price"],
"limits": plan["limits"],
}
)
return JSONResponse(
status_code=200,
content=paid_plans,
)
@app.get("/portal")
@app.post("/portal")
async def redirect_to_protal(request: Request):
email = None
if request.method == "POST":
try:
body = await request.json()
email = body.get("email")
except:
email = None
base_url = "https://billing.stripe.com/p/login/5kQdR9aIM3ts4steyabbG00"
if not email:
return RedirectResponse(url=base_url, status_code=status.HTTP_302_FOUND)
if request.method != "POST":
return RedirectResponse(
url=f"{base_url}?prefilled_email={email}", status_code=status.HTTP_302_FOUND
)
else:
return JSONResponse({"redirect_url": (base_url + "?prefilled_email=" + email)})