Spaces:
Running
Running
| import asyncio | |
| import hashlib | |
| import json | |
| import os | |
| from datetime import datetime, timezone | |
| from pathlib import Path | |
| import base64 | |
| import psycopg | |
| from psycopg.rows import dict_row | |
| from supabase import Client, create_client | |
| import hmac | |
| ADMINS_PATH = Path("assets/admins.json") | |
| EXTRA_USAGE_PATH = Path("assets/extra_usage.json") | |
| try: | |
| ADMIN_EMAIL_MAP = json.loads(ADMINS_PATH.read_text()) | |
| except Exception as e: | |
| print(f"Failed to load admins.json: {e}") | |
| ADMIN_EMAIL_MAP = {} | |
| try: | |
| EXTRA_USAGE_EMAIL_MAP = json.loads(EXTRA_USAGE_PATH.read_text()) | |
| except Exception as e: | |
| print(f"Failed to load extra_usage.json: {e}") | |
| EXTRA_USAGE_EMAIL_MAP = {} | |
| PLAN_ORDER = ["free", "light", "core", "creator", "professional"] | |
| TIER_CONFIG = { | |
| "free": { | |
| "name": "Free Tier", | |
| "url": "", | |
| "price": "0.00", | |
| "limits": { | |
| "cloudChatDaily": 50, | |
| "imagesDaily": 10, | |
| "videosDaily": 3, | |
| "audioWeekly": 1, | |
| "verifyTokenWithEmailDaily": 100, | |
| }, | |
| }, | |
| "light": { | |
| "name": "InferencePort AI Light", | |
| "url": "https://buy.stripe.com/bJedR93gk4xw7EFeyabbG08", | |
| "price": "9.99", | |
| "limits": { | |
| "cloudChatDaily": None, | |
| "imagesDaily": 50, | |
| "videosDaily": 10, | |
| "audioWeekly": 5, | |
| "verifyTokenWithEmailDaily": 100, | |
| }, | |
| }, | |
| "core": { | |
| "name": "InferencePort AI Core", | |
| "url": "https://buy.stripe.com/28E4gzg365BA1gh61EbbG09", | |
| "price": "15.99", | |
| "limits": { | |
| "cloudChatDaily": None, | |
| "imagesDaily": 150, | |
| "videosDaily": 30, | |
| "audioWeekly": 25, | |
| "verifyTokenWithEmailDaily": 100, | |
| }, | |
| }, | |
| "creator": { | |
| "name": "InferencePort AI Creator", | |
| "url": "https://buy.stripe.com/dRmdR98AE5BA8IJ89MbbG0a", | |
| "price": "29.99", | |
| "limits": { | |
| "cloudChatDaily": None, | |
| "imagesDaily": 300, | |
| "videosDaily": 50, | |
| "audioWeekly": 45, | |
| "verifyTokenWithEmailDaily": 100, | |
| }, | |
| }, | |
| "professional": { | |
| "name": "InferencePort AI Professional", | |
| "url": "https://buy.stripe.com/14AaEX7wA3ts7EF2PsbbG0b", | |
| "price": "99.99", | |
| "limits": { | |
| "cloudChatDaily": None, | |
| "imagesDaily": None, | |
| "videosDaily": None, | |
| "audioWeekly": 75, | |
| "verifyTokenWithEmailDaily": 100, | |
| }, | |
| }, | |
| } | |
| USAGE_PERIODS = { | |
| "cloudChatDaily": "daily", | |
| "imagesDaily": "daily", | |
| "videosDaily": "daily", | |
| "audioWeekly": "weekly", | |
| "verifyTokenWithEmailDaily": "daily", | |
| } | |
| usage_store = { | |
| "cloudChatDaily": {}, | |
| "imagesDaily": {}, | |
| "videosDaily": {}, | |
| "audioWeekly": {}, | |
| "verifyTokenWithEmailDaily": {}, | |
| } | |
| usage_locks = { | |
| "cloudChatDaily": {}, | |
| "imagesDaily": {}, | |
| "videosDaily": {}, | |
| "audioWeekly": {}, | |
| "verifyTokenWithEmailDaily": {}, | |
| } | |
| conn = None | |
| conn_lock = asyncio.Lock() | |
| SUPABASE_URL = os.getenv("SUPABASE_URL") | |
| SUPABASE_KEY = os.getenv("SUPABASE_SERVICE_ROLE_KEY") | |
| supabase: Client = create_client(SUPABASE_URL, SUPABASE_KEY) | |
| POSTGRE_SECRET = os.getenv("POSTGRE_SECRET") | |
| conn = psycopg.connect(POSTGRE_SECRET, row_factory=dict_row, sslmode="verify-full", sslrootcert="prod-ca-2021.crt") | |
| async def get_conn(): | |
| global conn | |
| if conn is None or conn.closed: | |
| conn = psycopg.connect(POSTGRE_SECRET, row_factory=dict_row, sslmode="verify-full", sslrootcert="prod-ca-2021.crt") | |
| return conn | |
| async def execute_query(query: str, params=(), *, fetchone: bool = False, commit: bool = False): | |
| async with conn_lock: | |
| connection = await get_conn() | |
| try: | |
| with connection.cursor() as cur: | |
| cur.execute(query, params) | |
| if cur.description is None: | |
| result = None if fetchone else [] | |
| else: | |
| result = cur.fetchone() if fetchone else cur.fetchall() | |
| if commit: | |
| connection.commit() | |
| return result | |
| except psycopg.OperationalError: | |
| connection = psycopg.connect( | |
| POSTGRE_SECRET, | |
| row_factory=dict_row, | |
| sslmode="verify-full", | |
| sslrootcert="prod-ca-2021.crt", | |
| ) | |
| globals()["conn"] = connection | |
| with connection.cursor() as cur: | |
| cur.execute(query, params) | |
| if cur.description is None: | |
| result = None if fetchone else [] | |
| else: | |
| result = cur.fetchone() if fetchone else cur.fetchall() | |
| if commit: | |
| connection.commit() | |
| return result | |
| def normalize_plan_key(plan_name: str | None) -> str: | |
| if not plan_name: | |
| return "free" | |
| normalized = "".join(ch for ch in str(plan_name).lower() if ch.isalpha()) | |
| if "professional" in normalized: | |
| return "professional" | |
| if "creator" in normalized: | |
| return "creator" | |
| if "core" in normalized: | |
| return "core" | |
| if "light" in normalized: | |
| return "light" | |
| return "free" | |
| def get_effective_limit_for_email( | |
| plan_key: str, | |
| metric: str, | |
| email: str | None = None, | |
| ): | |
| plan = TIER_CONFIG.get(plan_key) or TIER_CONFIG["free"] | |
| default_limit = plan.get("limits", {}).get(metric) | |
| if not isinstance(email, str) or not email.strip(): | |
| return default_limit | |
| entry = EXTRA_USAGE_EMAIL_MAP.get(email.strip().lower()) | |
| if not isinstance(entry, dict): | |
| return default_limit | |
| raw_limits = entry.get("limits") if isinstance(entry.get("limits"), dict) else entry | |
| override = raw_limits.get(metric) if isinstance(raw_limits, dict) else None | |
| if override is None and metric not in raw_limits: | |
| return default_limit | |
| if override is None: | |
| return None | |
| if isinstance(override, (int, float)) and override >= 0: | |
| return int(override) | |
| return default_limit | |
| async def fetch_subscription(jwt: str): | |
| identity = await resolve_token_identity(jwt) | |
| if "error" in identity: | |
| return identity | |
| email = identity["email"] | |
| signed_up = identity["signed_up"] | |
| if email in ADMIN_EMAIL_MAP: | |
| admin = ADMIN_EMAIL_MAP[email] | |
| print(f"[ADMIN OVERRIDE] {email} → forcing plan '{admin['plan_key']}'") | |
| subscription_obj = { | |
| "subscription_id": admin.get("subscription_id", "admin-override"), | |
| "status": admin.get("status", "active"), | |
| "current_period_end": None, | |
| "price_id": None, | |
| "product_name": admin.get("product_name"), | |
| "nickname": admin.get("nickname"), | |
| "plan_key": admin.get("plan_key"), | |
| } | |
| return { | |
| "email": email, | |
| "signed_up": signed_up, | |
| "subscription": [subscription_obj], | |
| "plan_key": admin["plan_key"], | |
| "auth_type": identity["auth_type"], | |
| } | |
| rows = await execute_query( | |
| """ | |
| with cust as ( | |
| select id | |
| from stripe.customers | |
| where email = %s | |
| ), | |
| subs as ( | |
| select | |
| s.id as subscription_id, | |
| s.status, | |
| s.current_period_end, | |
| s.items->'data'->0->'price'->>'id' as price_id | |
| from stripe.subscriptions s | |
| join cust on s.customer = cust.id | |
| where s.status in ('active', 'trialing', 'past_due') | |
| ) | |
| select | |
| subs.subscription_id, | |
| subs.status, | |
| subs.current_period_end, | |
| subs.price_id, | |
| prices.nickname, | |
| prices.product as product_id, | |
| products.name as product_name | |
| from subs | |
| left join stripe.prices prices | |
| on prices.id = subs.price_id | |
| left join stripe.products products | |
| on prices.product = products.id; | |
| """, | |
| (email,), | |
| ) | |
| if not rows: | |
| return { | |
| "email": email, | |
| "signed_up": signed_up, | |
| "subscription": None, | |
| "plan_key": "free", | |
| "auth_type": identity["auth_type"], | |
| } | |
| subscriptions = [] | |
| preferred_plan_key = "free" | |
| for row in rows: | |
| plan_key = normalize_plan_key(row["product_name"] or row["nickname"]) | |
| subscriptions.append({ | |
| "subscription_id": row["subscription_id"], | |
| "status": row["status"], | |
| "current_period_end": row["current_period_end"], | |
| "price_id": row["price_id"], | |
| "product_name": row["product_name"], | |
| "nickname": row["nickname"], | |
| "plan_key": plan_key, | |
| }) | |
| if row["status"] in ("active", "trialing"): | |
| preferred_plan_key = plan_key | |
| return { | |
| "email": email, | |
| "signed_up": signed_up, | |
| "subscription": subscriptions, | |
| "plan_key": preferred_plan_key, | |
| "auth_type": identity["auth_type"], | |
| } | |
| from fastapi import HTTPException | |
| async def resolve_token_identity(token: str | None): | |
| if not token or (token == "public"): | |
| return None | |
| jwt_identity = resolve_jwt_identity(token) | |
| if jwt_identity is not None: | |
| return jwt_identity | |
| api_identity = await resolve_api_key_identity(token) | |
| if api_identity is not None: | |
| if "error" in api_identity: | |
| raise HTTPException(status_code=401, detail=api_identity["error"]) | |
| return api_identity | |
| raise HTTPException(status_code=401, detail="Invalid API key or token") | |
| def resolve_jwt_identity(token: str): | |
| try: | |
| auth_res = supabase.auth.get_user(token) | |
| except Exception: | |
| return None | |
| user = getattr(auth_res, "user", None) | |
| email = getattr(user, "email", None) | |
| created_at = getattr(user, "created_at", None) | |
| user_id = getattr(user, "id", None) | |
| if user is None or not isinstance(email, str) or not email.strip(): | |
| return None | |
| return { | |
| "auth_type": "jwt", | |
| "user_id": user_id, | |
| "email": email.strip().lower(), | |
| "signed_up": created_at.isoformat() if hasattr(created_at, "isoformat") else None, | |
| } | |
| def verify_lightning_api_key(secret: str, stored: str) -> bool: | |
| _, salt_b64, key_b64 = stored.split("$") | |
| salt = base64.urlsafe_b64decode(salt_b64 + "==") | |
| stored_key = base64.urlsafe_b64decode(key_b64 + "==") | |
| derived_key = hashlib.scrypt( | |
| secret.encode("utf-8"), | |
| salt=salt, | |
| n=16384, r=8, p=1, | |
| dklen=32 | |
| ) | |
| return hmac.compare_digest(derived_key, stored_key) | |
| async def resolve_api_key_identity(token: str | None): | |
| if token is None or token.strip() == "": | |
| return None | |
| prefix = token[:16] | |
| row = await execute_query( | |
| """ | |
| select | |
| k.id, | |
| k.user_id, | |
| k.name, | |
| k.expires_at, | |
| k.revoked_at, | |
| k.key_hash, | |
| u.email, | |
| u.created_at as user_created_at | |
| from public.lightning_api_keys k | |
| join auth.users u on u.id = k.user_id | |
| where k.key_prefix = %s | |
| limit 1; | |
| """, | |
| (prefix,), | |
| fetchone=True, | |
| ) | |
| if not row: | |
| raise HTTPException(status_code=401, detail="Invalid API key or token") | |
| if not verify_lightning_api_key(token, row["key_hash"]): | |
| raise HTTPException(status_code=401, detail="Invalid API key or token") | |
| if row.get("revoked_at") is not None: | |
| raise HTTPException(status_code=401, detail="Invalid API key or token") | |
| expires_at = row.get("expires_at") | |
| if expires_at is not None and expires_at <= datetime.now(timezone.utc): | |
| raise HTTPException(status_code=401, detail="Invalid API key or token") | |
| email = row.get("email") | |
| if not isinstance(email, str) or not email.strip(): | |
| raise HTTPException(status_code=401, detail="Invalid API key or token") | |
| await execute_query( | |
| """ | |
| update public.lightning_api_keys | |
| set last_used_at = timezone('utc', now()) | |
| where id = %s; | |
| """, | |
| (row["id"],), | |
| commit=True, | |
| ) | |
| return { | |
| "auth_type": "api_key", | |
| "user_id": str(row["user_id"]) if row["user_id"] is not None else None, | |
| "email": email.strip().lower(), | |
| "signed_up": row["user_created_at"].isoformat() | |
| if hasattr(row["user_created_at"], "isoformat") | |
| else None, | |
| } | |