lightning / helper /subscriptions.py
R.C.M.
Add extra usage overrides
8f16f13
import asyncio
import hashlib
import json
import os
from datetime import datetime, timezone
from pathlib import Path
import base64
import psycopg
from psycopg.rows import dict_row
from supabase import Client, create_client
import hmac
ADMINS_PATH = Path("assets/admins.json")
EXTRA_USAGE_PATH = Path("assets/extra_usage.json")
try:
ADMIN_EMAIL_MAP = json.loads(ADMINS_PATH.read_text())
except Exception as e:
print(f"Failed to load admins.json: {e}")
ADMIN_EMAIL_MAP = {}
try:
EXTRA_USAGE_EMAIL_MAP = json.loads(EXTRA_USAGE_PATH.read_text())
except Exception as e:
print(f"Failed to load extra_usage.json: {e}")
EXTRA_USAGE_EMAIL_MAP = {}
PLAN_ORDER = ["free", "light", "core", "creator", "professional"]
TIER_CONFIG = {
"free": {
"name": "Free Tier",
"url": "",
"price": "0.00",
"limits": {
"cloudChatDaily": 50,
"imagesDaily": 10,
"videosDaily": 3,
"audioWeekly": 1,
"verifyTokenWithEmailDaily": 100,
},
},
"light": {
"name": "InferencePort AI Light",
"url": "https://buy.stripe.com/bJedR93gk4xw7EFeyabbG08",
"price": "9.99",
"limits": {
"cloudChatDaily": None,
"imagesDaily": 50,
"videosDaily": 10,
"audioWeekly": 5,
"verifyTokenWithEmailDaily": 100,
},
},
"core": {
"name": "InferencePort AI Core",
"url": "https://buy.stripe.com/28E4gzg365BA1gh61EbbG09",
"price": "15.99",
"limits": {
"cloudChatDaily": None,
"imagesDaily": 150,
"videosDaily": 30,
"audioWeekly": 25,
"verifyTokenWithEmailDaily": 100,
},
},
"creator": {
"name": "InferencePort AI Creator",
"url": "https://buy.stripe.com/dRmdR98AE5BA8IJ89MbbG0a",
"price": "29.99",
"limits": {
"cloudChatDaily": None,
"imagesDaily": 300,
"videosDaily": 50,
"audioWeekly": 45,
"verifyTokenWithEmailDaily": 100,
},
},
"professional": {
"name": "InferencePort AI Professional",
"url": "https://buy.stripe.com/14AaEX7wA3ts7EF2PsbbG0b",
"price": "99.99",
"limits": {
"cloudChatDaily": None,
"imagesDaily": None,
"videosDaily": None,
"audioWeekly": 75,
"verifyTokenWithEmailDaily": 100,
},
},
}
USAGE_PERIODS = {
"cloudChatDaily": "daily",
"imagesDaily": "daily",
"videosDaily": "daily",
"audioWeekly": "weekly",
"verifyTokenWithEmailDaily": "daily",
}
usage_store = {
"cloudChatDaily": {},
"imagesDaily": {},
"videosDaily": {},
"audioWeekly": {},
"verifyTokenWithEmailDaily": {},
}
usage_locks = {
"cloudChatDaily": {},
"imagesDaily": {},
"videosDaily": {},
"audioWeekly": {},
"verifyTokenWithEmailDaily": {},
}
conn = None
conn_lock = asyncio.Lock()
SUPABASE_URL = os.getenv("SUPABASE_URL")
SUPABASE_KEY = os.getenv("SUPABASE_SERVICE_ROLE_KEY")
supabase: Client = create_client(SUPABASE_URL, SUPABASE_KEY)
POSTGRE_SECRET = os.getenv("POSTGRE_SECRET")
conn = psycopg.connect(POSTGRE_SECRET, row_factory=dict_row, sslmode="verify-full", sslrootcert="prod-ca-2021.crt")
async def get_conn():
global conn
if conn is None or conn.closed:
conn = psycopg.connect(POSTGRE_SECRET, row_factory=dict_row, sslmode="verify-full", sslrootcert="prod-ca-2021.crt")
return conn
async def execute_query(query: str, params=(), *, fetchone: bool = False, commit: bool = False):
async with conn_lock:
connection = await get_conn()
try:
with connection.cursor() as cur:
cur.execute(query, params)
if cur.description is None:
result = None if fetchone else []
else:
result = cur.fetchone() if fetchone else cur.fetchall()
if commit:
connection.commit()
return result
except psycopg.OperationalError:
connection = psycopg.connect(
POSTGRE_SECRET,
row_factory=dict_row,
sslmode="verify-full",
sslrootcert="prod-ca-2021.crt",
)
globals()["conn"] = connection
with connection.cursor() as cur:
cur.execute(query, params)
if cur.description is None:
result = None if fetchone else []
else:
result = cur.fetchone() if fetchone else cur.fetchall()
if commit:
connection.commit()
return result
def normalize_plan_key(plan_name: str | None) -> str:
if not plan_name:
return "free"
normalized = "".join(ch for ch in str(plan_name).lower() if ch.isalpha())
if "professional" in normalized:
return "professional"
if "creator" in normalized:
return "creator"
if "core" in normalized:
return "core"
if "light" in normalized:
return "light"
return "free"
def get_effective_limit_for_email(
plan_key: str,
metric: str,
email: str | None = None,
):
plan = TIER_CONFIG.get(plan_key) or TIER_CONFIG["free"]
default_limit = plan.get("limits", {}).get(metric)
if not isinstance(email, str) or not email.strip():
return default_limit
entry = EXTRA_USAGE_EMAIL_MAP.get(email.strip().lower())
if not isinstance(entry, dict):
return default_limit
raw_limits = entry.get("limits") if isinstance(entry.get("limits"), dict) else entry
override = raw_limits.get(metric) if isinstance(raw_limits, dict) else None
if override is None and metric not in raw_limits:
return default_limit
if override is None:
return None
if isinstance(override, (int, float)) and override >= 0:
return int(override)
return default_limit
async def fetch_subscription(jwt: str):
identity = await resolve_token_identity(jwt)
if "error" in identity:
return identity
email = identity["email"]
signed_up = identity["signed_up"]
if email in ADMIN_EMAIL_MAP:
admin = ADMIN_EMAIL_MAP[email]
print(f"[ADMIN OVERRIDE] {email} → forcing plan '{admin['plan_key']}'")
subscription_obj = {
"subscription_id": admin.get("subscription_id", "admin-override"),
"status": admin.get("status", "active"),
"current_period_end": None,
"price_id": None,
"product_name": admin.get("product_name"),
"nickname": admin.get("nickname"),
"plan_key": admin.get("plan_key"),
}
return {
"email": email,
"signed_up": signed_up,
"subscription": [subscription_obj],
"plan_key": admin["plan_key"],
"auth_type": identity["auth_type"],
}
rows = await execute_query(
"""
with cust as (
select id
from stripe.customers
where email = %s
),
subs as (
select
s.id as subscription_id,
s.status,
s.current_period_end,
s.items->'data'->0->'price'->>'id' as price_id
from stripe.subscriptions s
join cust on s.customer = cust.id
where s.status in ('active', 'trialing', 'past_due')
)
select
subs.subscription_id,
subs.status,
subs.current_period_end,
subs.price_id,
prices.nickname,
prices.product as product_id,
products.name as product_name
from subs
left join stripe.prices prices
on prices.id = subs.price_id
left join stripe.products products
on prices.product = products.id;
""",
(email,),
)
if not rows:
return {
"email": email,
"signed_up": signed_up,
"subscription": None,
"plan_key": "free",
"auth_type": identity["auth_type"],
}
subscriptions = []
preferred_plan_key = "free"
for row in rows:
plan_key = normalize_plan_key(row["product_name"] or row["nickname"])
subscriptions.append({
"subscription_id": row["subscription_id"],
"status": row["status"],
"current_period_end": row["current_period_end"],
"price_id": row["price_id"],
"product_name": row["product_name"],
"nickname": row["nickname"],
"plan_key": plan_key,
})
if row["status"] in ("active", "trialing"):
preferred_plan_key = plan_key
return {
"email": email,
"signed_up": signed_up,
"subscription": subscriptions,
"plan_key": preferred_plan_key,
"auth_type": identity["auth_type"],
}
from fastapi import HTTPException
async def resolve_token_identity(token: str | None):
if not token or (token == "public"):
return None
jwt_identity = resolve_jwt_identity(token)
if jwt_identity is not None:
return jwt_identity
api_identity = await resolve_api_key_identity(token)
if api_identity is not None:
if "error" in api_identity:
raise HTTPException(status_code=401, detail=api_identity["error"])
return api_identity
raise HTTPException(status_code=401, detail="Invalid API key or token")
def resolve_jwt_identity(token: str):
try:
auth_res = supabase.auth.get_user(token)
except Exception:
return None
user = getattr(auth_res, "user", None)
email = getattr(user, "email", None)
created_at = getattr(user, "created_at", None)
user_id = getattr(user, "id", None)
if user is None or not isinstance(email, str) or not email.strip():
return None
return {
"auth_type": "jwt",
"user_id": user_id,
"email": email.strip().lower(),
"signed_up": created_at.isoformat() if hasattr(created_at, "isoformat") else None,
}
def verify_lightning_api_key(secret: str, stored: str) -> bool:
_, salt_b64, key_b64 = stored.split("$")
salt = base64.urlsafe_b64decode(salt_b64 + "==")
stored_key = base64.urlsafe_b64decode(key_b64 + "==")
derived_key = hashlib.scrypt(
secret.encode("utf-8"),
salt=salt,
n=16384, r=8, p=1,
dklen=32
)
return hmac.compare_digest(derived_key, stored_key)
async def resolve_api_key_identity(token: str | None):
if token is None or token.strip() == "":
return None
prefix = token[:16]
row = await execute_query(
"""
select
k.id,
k.user_id,
k.name,
k.expires_at,
k.revoked_at,
k.key_hash,
u.email,
u.created_at as user_created_at
from public.lightning_api_keys k
join auth.users u on u.id = k.user_id
where k.key_prefix = %s
limit 1;
""",
(prefix,),
fetchone=True,
)
if not row:
raise HTTPException(status_code=401, detail="Invalid API key or token")
if not verify_lightning_api_key(token, row["key_hash"]):
raise HTTPException(status_code=401, detail="Invalid API key or token")
if row.get("revoked_at") is not None:
raise HTTPException(status_code=401, detail="Invalid API key or token")
expires_at = row.get("expires_at")
if expires_at is not None and expires_at <= datetime.now(timezone.utc):
raise HTTPException(status_code=401, detail="Invalid API key or token")
email = row.get("email")
if not isinstance(email, str) or not email.strip():
raise HTTPException(status_code=401, detail="Invalid API key or token")
await execute_query(
"""
update public.lightning_api_keys
set last_used_at = timezone('utc', now())
where id = %s;
""",
(row["id"],),
commit=True,
)
return {
"auth_type": "api_key",
"user_id": str(row["user_id"]) if row["user_id"] is not None else None,
"email": email.strip().lower(),
"signed_up": row["user_created_at"].isoformat()
if hasattr(row["user_created_at"], "isoformat")
else None,
}