sharktide commited on
Commit
6aeec47
·
verified ·
1 Parent(s): 317b4a5

Update helper/subscriptions.py

Browse files
Files changed (1) hide show
  1. helper/subscriptions.py +175 -94
helper/subscriptions.py CHANGED
@@ -1,11 +1,14 @@
1
- import os
2
- import psycopg
3
- from psycopg.rows import dict_row
4
- from supabase import create_client, Client
5
  import asyncio
 
6
  import json
 
 
7
  from pathlib import Path
8
 
 
 
 
 
9
  ADMINS_PATH = Path("assets/admins.json")
10
 
11
  try:
@@ -111,6 +114,39 @@ async def get_conn():
111
  return conn
112
 
113
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
114
  def normalize_plan_key(plan_name: str | None) -> str:
115
  if not plan_name:
116
  return "free"
@@ -127,24 +163,17 @@ def normalize_plan_key(plan_name: str | None) -> str:
127
 
128
 
129
  async def fetch_subscription(jwt: str):
130
- auth_res = supabase.auth.get_user(jwt)
131
- if auth_res.user is None:
132
- return {"error": "Invalid or expired session"}
133
-
134
- user = auth_res.user
135
- email = user.email
136
-
137
- auth_res = supabase.auth.get_user(jwt)
138
- if auth_res.user is None:
139
- return {"error": "Invalid or expired session"}
140
-
141
- user = auth_res.user
142
- email = user.email.lower()
143
-
144
  if email in ADMIN_EMAIL_MAP:
145
  admin = ADMIN_EMAIL_MAP[email]
146
  print(f"[ADMIN OVERRIDE] {email} → forcing plan '{admin['plan_key']}'")
147
-
148
  subscription_obj = {
149
  "subscription_id": admin.get("subscription_id", "admin-override"),
150
  "status": admin.get("status", "active"),
@@ -154,92 +183,56 @@ async def fetch_subscription(jwt: str):
154
  "nickname": admin.get("nickname"),
155
  "plan_key": admin.get("plan_key"),
156
  }
157
-
158
  return {
159
  "email": email,
160
- "signed_up": user.created_at.isoformat(),
161
  "subscription": [subscription_obj],
162
  "plan_key": admin["plan_key"],
 
163
  }
164
 
165
- async with conn_lock:
166
- connection = await get_conn()
167
-
168
- try:
169
- with connection.cursor() as cur:
170
- cur.execute("""
171
- with cust as (
172
- select id
173
- from stripe.customers
174
- where email = %s
175
- ),
176
- subs as (
177
- select
178
- s.id as subscription_id,
179
- s.status,
180
- s.current_period_end,
181
- s.items->'data'->0->'price'->>'id' as price_id
182
- from stripe.subscriptions s
183
- join cust on s.customer = cust.id
184
- where s.status in ('active', 'trialing', 'past_due')
185
- )
186
- select
187
- subs.subscription_id,
188
- subs.status,
189
- subs.current_period_end,
190
- subs.price_id,
191
- prices.nickname,
192
- prices.product as product_id,
193
- products.name as product_name
194
- from subs
195
- left join stripe.prices prices
196
- on prices.id = subs.price_id
197
- left join stripe.products products
198
- on prices.product = products.id;
199
- """, (email,))
200
- rows = cur.fetchall()
201
-
202
- except psycopg.OperationalError:
203
- connection = psycopg.connect(POSTGRE_SECRET, row_factory=dict_row)
204
- with connection.cursor() as cur:
205
- cur.execute("""
206
- with cust as (
207
- select id
208
- from stripe.customers
209
- where email = %s
210
- ),
211
- subs as (
212
- select
213
- s.id as subscription_id,
214
- s.status,
215
- s.current_period_end,
216
- s.items->'data'->0->'price'->>'id' as price_id
217
- from stripe.subscriptions s
218
- join cust on s.customer = cust.id
219
- where s.status in ('active', 'trialing', 'past_due')
220
- )
221
- select
222
- subs.subscription_id,
223
- subs.status,
224
- subs.current_period_end,
225
- subs.price_id,
226
- prices.nickname,
227
- prices.product as product_id,
228
- products.name as product_name
229
- from subs
230
- left join stripe.prices prices
231
- on prices.id = subs.price_id
232
- left join stripe.products products
233
- on prices.product = products.id;
234
- """, (email,))
235
- rows = cur.fetchall()
236
 
237
  if not rows:
238
  return {
239
  "email": email,
240
- "signed_up": user.created_at.isoformat(),
241
  "subscription": None,
242
  "plan_key": "free",
 
243
  }
244
 
245
  subscriptions = []
@@ -263,7 +256,95 @@ async def fetch_subscription(jwt: str):
263
 
264
  return {
265
  "email": email,
266
- "signed_up": user.created_at.isoformat(),
267
  "subscription": subscriptions,
268
  "plan_key": preferred_plan_key,
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
269
  }
 
 
 
 
 
1
  import asyncio
2
+ import hashlib
3
  import json
4
+ import os
5
+ from datetime import datetime, timezone
6
  from pathlib import Path
7
 
8
+ import psycopg
9
+ from psycopg.rows import dict_row
10
+ from supabase import Client, create_client
11
+
12
  ADMINS_PATH = Path("assets/admins.json")
13
 
14
  try:
 
114
  return conn
115
 
116
 
117
+ async def execute_query(query: str, params=(), *, fetchone: bool = False, commit: bool = False):
118
+ async with conn_lock:
119
+ connection = await get_conn()
120
+
121
+ try:
122
+ with connection.cursor() as cur:
123
+ cur.execute(query, params)
124
+ if cur.description is None:
125
+ result = None if fetchone else []
126
+ else:
127
+ result = cur.fetchone() if fetchone else cur.fetchall()
128
+ if commit:
129
+ connection.commit()
130
+ return result
131
+ except psycopg.OperationalError:
132
+ connection = psycopg.connect(
133
+ POSTGRE_SECRET,
134
+ row_factory=dict_row,
135
+ sslmode="verify-full",
136
+ sslrootcert="prod-ca-2021.crt",
137
+ )
138
+ globals()["conn"] = connection
139
+ with connection.cursor() as cur:
140
+ cur.execute(query, params)
141
+ if cur.description is None:
142
+ result = None if fetchone else []
143
+ else:
144
+ result = cur.fetchone() if fetchone else cur.fetchall()
145
+ if commit:
146
+ connection.commit()
147
+ return result
148
+
149
+
150
  def normalize_plan_key(plan_name: str | None) -> str:
151
  if not plan_name:
152
  return "free"
 
163
 
164
 
165
  async def fetch_subscription(jwt: str):
166
+ identity = await resolve_token_identity(jwt)
167
+ if "error" in identity:
168
+ return identity
169
+
170
+ email = identity["email"]
171
+ signed_up = identity["signed_up"]
172
+
 
 
 
 
 
 
 
173
  if email in ADMIN_EMAIL_MAP:
174
  admin = ADMIN_EMAIL_MAP[email]
175
  print(f"[ADMIN OVERRIDE] {email} → forcing plan '{admin['plan_key']}'")
176
+
177
  subscription_obj = {
178
  "subscription_id": admin.get("subscription_id", "admin-override"),
179
  "status": admin.get("status", "active"),
 
183
  "nickname": admin.get("nickname"),
184
  "plan_key": admin.get("plan_key"),
185
  }
186
+
187
  return {
188
  "email": email,
189
+ "signed_up": signed_up,
190
  "subscription": [subscription_obj],
191
  "plan_key": admin["plan_key"],
192
+ "auth_type": identity["auth_type"],
193
  }
194
 
195
+ rows = await execute_query(
196
+ """
197
+ with cust as (
198
+ select id
199
+ from stripe.customers
200
+ where email = %s
201
+ ),
202
+ subs as (
203
+ select
204
+ s.id as subscription_id,
205
+ s.status,
206
+ s.current_period_end,
207
+ s.items->'data'->0->'price'->>'id' as price_id
208
+ from stripe.subscriptions s
209
+ join cust on s.customer = cust.id
210
+ where s.status in ('active', 'trialing', 'past_due')
211
+ )
212
+ select
213
+ subs.subscription_id,
214
+ subs.status,
215
+ subs.current_period_end,
216
+ subs.price_id,
217
+ prices.nickname,
218
+ prices.product as product_id,
219
+ products.name as product_name
220
+ from subs
221
+ left join stripe.prices prices
222
+ on prices.id = subs.price_id
223
+ left join stripe.products products
224
+ on prices.product = products.id;
225
+ """,
226
+ (email,),
227
+ )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
228
 
229
  if not rows:
230
  return {
231
  "email": email,
232
+ "signed_up": signed_up,
233
  "subscription": None,
234
  "plan_key": "free",
235
+ "auth_type": identity["auth_type"],
236
  }
237
 
238
  subscriptions = []
 
256
 
257
  return {
258
  "email": email,
259
+ "signed_up": signed_up,
260
  "subscription": subscriptions,
261
  "plan_key": preferred_plan_key,
262
+ "auth_type": identity["auth_type"],
263
+ }
264
+
265
+
266
+ async def resolve_token_identity(token: str):
267
+ jwt_identity = resolve_jwt_identity(token)
268
+ if jwt_identity is not None:
269
+ return jwt_identity
270
+
271
+ api_key_identity = await resolve_api_key_identity(token)
272
+ if api_key_identity is not None:
273
+ return api_key_identity
274
+
275
+ return {"error": "Invalid or expired credentials"}
276
+
277
+
278
+ def resolve_jwt_identity(token: str):
279
+ try:
280
+ auth_res = supabase.auth.get_user(token)
281
+ except Exception:
282
+ return None
283
+
284
+ user = getattr(auth_res, "user", None)
285
+ email = getattr(user, "email", None)
286
+ created_at = getattr(user, "created_at", None)
287
+ user_id = getattr(user, "id", None)
288
+ if user is None or not isinstance(email, str) or not email.strip():
289
+ return None
290
+
291
+ return {
292
+ "auth_type": "jwt",
293
+ "user_id": user_id,
294
+ "email": email.strip().lower(),
295
+ "signed_up": created_at.isoformat() if hasattr(created_at, "isoformat") else None,
296
+ }
297
+
298
+
299
+ async def resolve_api_key_identity(token: str):
300
+ token_hash = hashlib.sha256(token.encode("utf-8")).hexdigest()
301
+ row = await execute_query(
302
+ """
303
+ select
304
+ k.id,
305
+ k.user_id,
306
+ k.name,
307
+ k.expires_at,
308
+ k.revoked_at,
309
+ u.email,
310
+ u.created_at as user_created_at
311
+ from public.lightning_api_keys k
312
+ join auth.users u on u.id = k.user_id
313
+ where k.key_hash = %s
314
+ limit 1;
315
+ """,
316
+ (token_hash,),
317
+ fetchone=True,
318
+ )
319
+ if not row:
320
+ return None
321
+
322
+ if row.get("revoked_at") is not None:
323
+ return {"error": "API key has been revoked"}
324
+
325
+ expires_at = row.get("expires_at")
326
+ if expires_at is not None and expires_at <= datetime.now(timezone.utc):
327
+ return {"error": "API key has expired"}
328
+
329
+ email = row.get("email")
330
+ created_at = row.get("user_created_at")
331
+ user_id = row.get("user_id")
332
+ if not isinstance(email, str) or not email.strip():
333
+ return {"error": "API key is not linked to a valid user"}
334
+
335
+ await execute_query(
336
+ """
337
+ update public.lightning_api_keys
338
+ set last_used_at = timezone('utc', now())
339
+ where id = %s;
340
+ """,
341
+ (row["id"],),
342
+ commit=True,
343
+ )
344
+
345
+ return {
346
+ "auth_type": "api_key",
347
+ "user_id": str(user_id) if user_id is not None else None,
348
+ "email": email.strip().lower(),
349
+ "signed_up": created_at.isoformat() if hasattr(created_at, "isoformat") else None,
350
  }