File size: 5,382 Bytes
f1b8a40
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
import stripe
import os
import json
from fastapi import Request, HTTPException
from datetime import datetime, timedelta
import datasets_registry
import subscriptions_ledger

from fastapi.responses import JSONResponse

async def handle_stripe_webhook(request: Request):
    payload = await request.body()
    sig_header = request.headers.get("stripe-signature")
    webhook_secret = os.getenv("STRIPE_WEBHOOK_SECRET")
    
    try:
        event = stripe.Webhook.construct_event(
            payload, sig_header, webhook_secret
        )
    except ValueError as e:
        raise HTTPException(status_code=400, detail="Invalid payload")
    except stripe.error.SignatureVerificationError as e:
        raise HTTPException(status_code=400, detail="Invalid signature")
        
    event_type = event["type"]
    data_object = event["data"]["object"]
    
    if event_type == "checkout.session.completed":
        await handle_checkout_session(data_object, event["id"])
    elif event_type == "invoice.paid":
        await handle_invoice_paid(data_object, event["id"])
        
    return JSONResponse({"status": "success"})

async def handle_checkout_session(session, event_id):
    # Metadata should contain hf_user and dataset_id
    metadata = session.get("metadata", {})
    hf_user = metadata.get("hf_user")
    dataset_id = metadata.get("dataset_id")
    
    if not hf_user or not dataset_id:
        print(f"Missing metadata in session {session['id']}")
        return

    # Determine plan and duration
    # In a real scenario, we might look up the line items to find the price ID
    # For simplicity, we assume the first line item's price ID matches our registry
    # Or we can pass the plan_id in metadata too.
    
    # Let's try to get price_id from the session if possible, or rely on metadata
    # If subscription mode, we might need to fetch the subscription details
    subscription_id = session.get("subscription")
    if subscription_id:
        # It's a subscription
        sub = stripe.Subscription.retrieve(subscription_id)
        price_id = sub["items"]["data"][0]["price"]["id"]
        current_period_end = sub["current_period_end"]
        end_date = datetime.fromtimestamp(current_period_end)
    else:
        # One-time payment?
        # For this hackathon, let's assume subscriptions.
        print("Non-subscription checkout not fully supported yet.")
        return

    plan_info = datasets_registry.get_plan_by_price_id(price_id)
    if not plan_info:
        print(f"Unknown price_id {price_id}")
        return
        
    plan_id = plan_info["plan"]["plan_id"]

    # Generate unique access token for this subscription
    access_token = subscriptions_ledger.generate_access_token()

    ledger_entry = {
        "event_id": event_id,
        "hf_user": hf_user,
        "dataset_id": dataset_id,
        "plan_id": plan_id,
        "subscription_start": datetime.utcnow().isoformat() + "Z",
        "subscription_end": end_date.isoformat() + "Z",
        "source": "stripe",
        "access_token": access_token,
        "created_at": datetime.utcnow().isoformat() + "Z",
        "stripe_customer_id": session.get("customer"),
        "stripe_subscription_id": subscription_id
    }

    subscriptions_ledger.append_subscription_event(ledger_entry)

async def handle_invoice_paid(invoice, event_id):
    subscription_id = invoice.get("subscription")
    if not subscription_id:
        return
        
    # We need to find the user and dataset associated with this subscription
    # We can query Stripe or look up in our ledger if we stored subscription_id
    # For now, let's assume we can get it from the subscription metadata in Stripe
    
    sub = stripe.Subscription.retrieve(subscription_id)
    metadata = sub.get("metadata", {})
    hf_user = metadata.get("hf_user")
    dataset_id = metadata.get("dataset_id")
    
    if not hf_user or not dataset_id:
        # Try to find from previous ledger entries? 
        # For simplicity, we assume metadata is preserved on the subscription object in Stripe
        print(f"Missing metadata in subscription {subscription_id}")
        return

    price_id = sub["items"]["data"][0]["price"]["id"]
    current_period_end = sub["current_period_end"]
    end_date = datetime.fromtimestamp(current_period_end)
    
    plan_info = datasets_registry.get_plan_by_price_id(price_id)
    if not plan_info:
        return

    plan_id = plan_info["plan"]["plan_id"]

    # For renewals, try to preserve the existing access token
    existing_sub = subscriptions_ledger.get_active_subscription(hf_user, dataset_id)
    if existing_sub and existing_sub.get("access_token"):
        access_token = existing_sub["access_token"]
    else:
        access_token = subscriptions_ledger.generate_access_token()

    ledger_entry = {
        "event_id": event_id,
        "hf_user": hf_user,
        "dataset_id": dataset_id,
        "plan_id": plan_id,
        "subscription_start": datetime.utcnow().isoformat() + "Z",  # Or period start
        "subscription_end": end_date.isoformat() + "Z",
        "source": "stripe",
        "access_token": access_token,
        "created_at": datetime.utcnow().isoformat() + "Z",
        "stripe_customer_id": invoice.get("customer"),
        "stripe_subscription_id": subscription_id
    }

    subscriptions_ledger.append_subscription_event(ledger_entry)