Spaces:
Running
Running
Fix connection closing
Browse files- subscriptions.py +55 -34
subscriptions.py
CHANGED
|
@@ -2,6 +2,10 @@ import os
|
|
| 2 |
import psycopg
|
| 3 |
from psycopg.rows import dict_row
|
| 4 |
from supabase import create_client, Client
|
|
|
|
|
|
|
|
|
|
|
|
|
| 5 |
|
| 6 |
SUPABASE_URL = os.getenv("SUPABASE_URL")
|
| 7 |
SUPABASE_KEY = os.getenv("SUPABASE_SERVICE_ROLE_KEY")
|
|
@@ -11,6 +15,15 @@ POSTGRE_SECRET = os.getenv("POSTGRE_SECRET")
|
|
| 11 |
conn = psycopg.connect(POSTGRE_SECRET, row_factory=dict_row)
|
| 12 |
|
| 13 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 14 |
def normalize_plan_key(plan_name: str | None) -> str:
|
| 15 |
if not plan_name:
|
| 16 |
return "free"
|
|
@@ -27,7 +40,6 @@ def normalize_plan_key(plan_name: str | None) -> str:
|
|
| 27 |
|
| 28 |
|
| 29 |
async def fetch_subscription(jwt: str):
|
| 30 |
-
# You still need Supabase auth to get the user email
|
| 31 |
auth_res = supabase.auth.get_user(jwt)
|
| 32 |
if auth_res.user is None:
|
| 33 |
return {"error": "Invalid or expired session"}
|
|
@@ -35,39 +47,48 @@ async def fetch_subscription(jwt: str):
|
|
| 35 |
user = auth_res.user
|
| 36 |
email = user.email
|
| 37 |
|
| 38 |
-
with
|
| 39 |
-
|
| 40 |
-
|
| 41 |
-
|
| 42 |
-
|
| 43 |
-
|
| 44 |
-
|
| 45 |
-
|
| 46 |
-
|
| 47 |
-
|
| 48 |
-
|
| 49 |
-
|
| 50 |
-
|
| 51 |
-
|
| 52 |
-
|
| 53 |
-
|
| 54 |
-
|
| 55 |
-
|
| 56 |
-
|
| 57 |
-
|
| 58 |
-
|
| 59 |
-
|
| 60 |
-
|
| 61 |
-
|
| 62 |
-
|
| 63 |
-
|
| 64 |
-
|
| 65 |
-
|
| 66 |
-
|
| 67 |
-
|
| 68 |
-
|
| 69 |
-
|
| 70 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 71 |
|
| 72 |
if not rows:
|
| 73 |
return {
|
|
|
|
| 2 |
import psycopg
|
| 3 |
from psycopg.rows import dict_row
|
| 4 |
from supabase import create_client, Client
|
| 5 |
+
import asyncio
|
| 6 |
+
|
| 7 |
+
conn = None
|
| 8 |
+
conn_lock = asyncio.Lock()
|
| 9 |
|
| 10 |
SUPABASE_URL = os.getenv("SUPABASE_URL")
|
| 11 |
SUPABASE_KEY = os.getenv("SUPABASE_SERVICE_ROLE_KEY")
|
|
|
|
| 15 |
conn = psycopg.connect(POSTGRE_SECRET, row_factory=dict_row)
|
| 16 |
|
| 17 |
|
| 18 |
+
async def get_conn():
|
| 19 |
+
global conn
|
| 20 |
+
|
| 21 |
+
if conn is None or conn.closed:
|
| 22 |
+
conn = psycopg.connect(POSTGRE_SECRET, row_factory=dict_row)
|
| 23 |
+
|
| 24 |
+
return conn
|
| 25 |
+
|
| 26 |
+
|
| 27 |
def normalize_plan_key(plan_name: str | None) -> str:
|
| 28 |
if not plan_name:
|
| 29 |
return "free"
|
|
|
|
| 40 |
|
| 41 |
|
| 42 |
async def fetch_subscription(jwt: str):
|
|
|
|
| 43 |
auth_res = supabase.auth.get_user(jwt)
|
| 44 |
if auth_res.user is None:
|
| 45 |
return {"error": "Invalid or expired session"}
|
|
|
|
| 47 |
user = auth_res.user
|
| 48 |
email = user.email
|
| 49 |
|
| 50 |
+
async with conn_lock:
|
| 51 |
+
connection = await get_conn()
|
| 52 |
+
|
| 53 |
+
try:
|
| 54 |
+
with connection.cursor() as cur:
|
| 55 |
+
cur.execute("""
|
| 56 |
+
with cust as (
|
| 57 |
+
select id
|
| 58 |
+
from stripe.customers
|
| 59 |
+
where email = %s
|
| 60 |
+
),
|
| 61 |
+
subs as (
|
| 62 |
+
select
|
| 63 |
+
s.id as subscription_id,
|
| 64 |
+
s.status,
|
| 65 |
+
s.current_period_end,
|
| 66 |
+
s.items->'data'->0->'price'->>'id' as price_id
|
| 67 |
+
from stripe.subscriptions s
|
| 68 |
+
join cust on s.customer = cust.id
|
| 69 |
+
where s.status in ('active', 'trialing', 'past_due')
|
| 70 |
+
)
|
| 71 |
+
select
|
| 72 |
+
subs.subscription_id,
|
| 73 |
+
subs.status,
|
| 74 |
+
subs.current_period_end,
|
| 75 |
+
subs.price_id,
|
| 76 |
+
prices.nickname,
|
| 77 |
+
prices.product as product_id,
|
| 78 |
+
products.name as product_name
|
| 79 |
+
from subs
|
| 80 |
+
left join stripe.prices prices
|
| 81 |
+
on prices.id = subs.price_id
|
| 82 |
+
left join stripe.products products
|
| 83 |
+
on prices.product = products.id;
|
| 84 |
+
""", (email,))
|
| 85 |
+
rows = cur.fetchall()
|
| 86 |
+
|
| 87 |
+
except psycopg.OperationalError:
|
| 88 |
+
connection = psycopg.connect(POSTGRE_SECRET, row_factory=dict_row)
|
| 89 |
+
with connection.cursor() as cur:
|
| 90 |
+
cur.execute(""" ... same SQL ... """, (email,))
|
| 91 |
+
rows = cur.fetchall()
|
| 92 |
|
| 93 |
if not rows:
|
| 94 |
return {
|