import asyncio import hashlib import json import os from datetime import datetime, timezone from pathlib import Path import base64 import psycopg from psycopg.rows import dict_row from supabase import Client, create_client import hmac ADMINS_PATH = Path("assets/admins.json") EXTRA_USAGE_PATH = Path("assets/extra_usage.json") try: ADMIN_EMAIL_MAP = json.loads(ADMINS_PATH.read_text()) except Exception as e: print(f"Failed to load admins.json: {e}") ADMIN_EMAIL_MAP = {} try: EXTRA_USAGE_EMAIL_MAP = json.loads(EXTRA_USAGE_PATH.read_text()) except Exception as e: print(f"Failed to load extra_usage.json: {e}") EXTRA_USAGE_EMAIL_MAP = {} PLAN_ORDER = ["free", "light", "core", "creator", "professional"] TIER_CONFIG = { "free": { "name": "Free Tier", "url": "", "price": "0.00", "limits": { "cloudChatDaily": 50, "imagesDaily": 10, "videosDaily": 3, "audioWeekly": 1, "verifyTokenWithEmailDaily": 100, }, }, "light": { "name": "InferencePort AI Light", "url": "https://buy.stripe.com/bJedR93gk4xw7EFeyabbG08", "price": "9.99", "limits": { "cloudChatDaily": None, "imagesDaily": 50, "videosDaily": 10, "audioWeekly": 5, "verifyTokenWithEmailDaily": 100, }, }, "core": { "name": "InferencePort AI Core", "url": "https://buy.stripe.com/28E4gzg365BA1gh61EbbG09", "price": "15.99", "limits": { "cloudChatDaily": None, "imagesDaily": 150, "videosDaily": 30, "audioWeekly": 25, "verifyTokenWithEmailDaily": 100, }, }, "creator": { "name": "InferencePort AI Creator", "url": "https://buy.stripe.com/dRmdR98AE5BA8IJ89MbbG0a", "price": "29.99", "limits": { "cloudChatDaily": None, "imagesDaily": 300, "videosDaily": 50, "audioWeekly": 45, "verifyTokenWithEmailDaily": 100, }, }, "professional": { "name": "InferencePort AI Professional", "url": "https://buy.stripe.com/14AaEX7wA3ts7EF2PsbbG0b", "price": "99.99", "limits": { "cloudChatDaily": None, "imagesDaily": None, "videosDaily": None, "audioWeekly": 75, "verifyTokenWithEmailDaily": 100, }, }, } USAGE_PERIODS = { "cloudChatDaily": "daily", "imagesDaily": "daily", "videosDaily": "daily", "audioWeekly": "weekly", "verifyTokenWithEmailDaily": "daily", } usage_store = { "cloudChatDaily": {}, "imagesDaily": {}, "videosDaily": {}, "audioWeekly": {}, "verifyTokenWithEmailDaily": {}, } usage_locks = { "cloudChatDaily": {}, "imagesDaily": {}, "videosDaily": {}, "audioWeekly": {}, "verifyTokenWithEmailDaily": {}, } conn = None conn_lock = asyncio.Lock() SUPABASE_URL = os.getenv("SUPABASE_URL") SUPABASE_KEY = os.getenv("SUPABASE_SERVICE_ROLE_KEY") supabase: Client = create_client(SUPABASE_URL, SUPABASE_KEY) POSTGRE_SECRET = os.getenv("POSTGRE_SECRET") conn = psycopg.connect(POSTGRE_SECRET, row_factory=dict_row, sslmode="verify-full", sslrootcert="prod-ca-2021.crt") async def get_conn(): global conn if conn is None or conn.closed: conn = psycopg.connect(POSTGRE_SECRET, row_factory=dict_row, sslmode="verify-full", sslrootcert="prod-ca-2021.crt") return conn async def execute_query(query: str, params=(), *, fetchone: bool = False, commit: bool = False): async with conn_lock: connection = await get_conn() try: with connection.cursor() as cur: cur.execute(query, params) if cur.description is None: result = None if fetchone else [] else: result = cur.fetchone() if fetchone else cur.fetchall() if commit: connection.commit() return result except psycopg.OperationalError: connection = psycopg.connect( POSTGRE_SECRET, row_factory=dict_row, sslmode="verify-full", sslrootcert="prod-ca-2021.crt", ) globals()["conn"] = connection with connection.cursor() as cur: cur.execute(query, params) if cur.description is None: result = None if fetchone else [] else: result = cur.fetchone() if fetchone else cur.fetchall() if commit: connection.commit() return result def normalize_plan_key(plan_name: str | None) -> str: if not plan_name: return "free" normalized = "".join(ch for ch in str(plan_name).lower() if ch.isalpha()) if "professional" in normalized: return "professional" if "creator" in normalized: return "creator" if "core" in normalized: return "core" if "light" in normalized: return "light" return "free" def get_effective_limit_for_email( plan_key: str, metric: str, email: str | None = None, ): plan = TIER_CONFIG.get(plan_key) or TIER_CONFIG["free"] default_limit = plan.get("limits", {}).get(metric) if not isinstance(email, str) or not email.strip(): return default_limit entry = EXTRA_USAGE_EMAIL_MAP.get(email.strip().lower()) if not isinstance(entry, dict): return default_limit raw_limits = entry.get("limits") if isinstance(entry.get("limits"), dict) else entry override = raw_limits.get(metric) if isinstance(raw_limits, dict) else None if override is None and metric not in raw_limits: return default_limit if override is None: return None if isinstance(override, (int, float)) and override >= 0: return int(override) return default_limit async def fetch_subscription(jwt: str): identity = await resolve_token_identity(jwt) if "error" in identity: return identity email = identity["email"] signed_up = identity["signed_up"] if email in ADMIN_EMAIL_MAP: admin = ADMIN_EMAIL_MAP[email] print(f"[ADMIN OVERRIDE] {email} → forcing plan '{admin['plan_key']}'") subscription_obj = { "subscription_id": admin.get("subscription_id", "admin-override"), "status": admin.get("status", "active"), "current_period_end": None, "price_id": None, "product_name": admin.get("product_name"), "nickname": admin.get("nickname"), "plan_key": admin.get("plan_key"), } return { "email": email, "signed_up": signed_up, "subscription": [subscription_obj], "plan_key": admin["plan_key"], "auth_type": identity["auth_type"], } rows = await execute_query( """ with cust as ( select id from stripe.customers where email = %s ), subs as ( select s.id as subscription_id, s.status, s.current_period_end, s.items->'data'->0->'price'->>'id' as price_id from stripe.subscriptions s join cust on s.customer = cust.id where s.status in ('active', 'trialing', 'past_due') ) select subs.subscription_id, subs.status, subs.current_period_end, subs.price_id, prices.nickname, prices.product as product_id, products.name as product_name from subs left join stripe.prices prices on prices.id = subs.price_id left join stripe.products products on prices.product = products.id; """, (email,), ) if not rows: return { "email": email, "signed_up": signed_up, "subscription": None, "plan_key": "free", "auth_type": identity["auth_type"], } subscriptions = [] preferred_plan_key = "free" for row in rows: plan_key = normalize_plan_key(row["product_name"] or row["nickname"]) subscriptions.append({ "subscription_id": row["subscription_id"], "status": row["status"], "current_period_end": row["current_period_end"], "price_id": row["price_id"], "product_name": row["product_name"], "nickname": row["nickname"], "plan_key": plan_key, }) if row["status"] in ("active", "trialing"): preferred_plan_key = plan_key return { "email": email, "signed_up": signed_up, "subscription": subscriptions, "plan_key": preferred_plan_key, "auth_type": identity["auth_type"], } from fastapi import HTTPException async def resolve_token_identity(token: str | None): if not token or (token == "public"): return None jwt_identity = resolve_jwt_identity(token) if jwt_identity is not None: return jwt_identity api_identity = await resolve_api_key_identity(token) if api_identity is not None: if "error" in api_identity: raise HTTPException(status_code=401, detail=api_identity["error"]) return api_identity raise HTTPException(status_code=401, detail="Invalid API key or token") def resolve_jwt_identity(token: str): try: auth_res = supabase.auth.get_user(token) except Exception: return None user = getattr(auth_res, "user", None) email = getattr(user, "email", None) created_at = getattr(user, "created_at", None) user_id = getattr(user, "id", None) if user is None or not isinstance(email, str) or not email.strip(): return None return { "auth_type": "jwt", "user_id": user_id, "email": email.strip().lower(), "signed_up": created_at.isoformat() if hasattr(created_at, "isoformat") else None, } def verify_lightning_api_key(secret: str, stored: str) -> bool: _, salt_b64, key_b64 = stored.split("$") salt = base64.urlsafe_b64decode(salt_b64 + "==") stored_key = base64.urlsafe_b64decode(key_b64 + "==") derived_key = hashlib.scrypt( secret.encode("utf-8"), salt=salt, n=16384, r=8, p=1, dklen=32 ) return hmac.compare_digest(derived_key, stored_key) async def resolve_api_key_identity(token: str | None): if token is None or token.strip() == "": return None prefix = token[:16] row = await execute_query( """ select k.id, k.user_id, k.name, k.expires_at, k.revoked_at, k.key_hash, u.email, u.created_at as user_created_at from public.lightning_api_keys k join auth.users u on u.id = k.user_id where k.key_prefix = %s limit 1; """, (prefix,), fetchone=True, ) if not row: raise HTTPException(status_code=401, detail="Invalid API key or token") if not verify_lightning_api_key(token, row["key_hash"]): raise HTTPException(status_code=401, detail="Invalid API key or token") if row.get("revoked_at") is not None: raise HTTPException(status_code=401, detail="Invalid API key or token") expires_at = row.get("expires_at") if expires_at is not None and expires_at <= datetime.now(timezone.utc): raise HTTPException(status_code=401, detail="Invalid API key or token") email = row.get("email") if not isinstance(email, str) or not email.strip(): raise HTTPException(status_code=401, detail="Invalid API key or token") await execute_query( """ update public.lightning_api_keys set last_used_at = timezone('utc', now()) where id = %s; """, (row["id"],), commit=True, ) return { "auth_type": "api_key", "user_id": str(row["user_id"]) if row["user_id"] is not None else None, "email": email.strip().lower(), "signed_up": row["user_created_at"].isoformat() if hasattr(row["user_created_at"], "isoformat") else None, }