Spaces:
Sleeping
Sleeping
| import stripe | |
| import os | |
| import json | |
| from fastapi import Request, HTTPException | |
| from datetime import datetime, timedelta | |
| import datasets_registry | |
| import subscriptions_ledger | |
| from fastapi.responses import JSONResponse | |
| async def handle_stripe_webhook(request: Request): | |
| payload = await request.body() | |
| sig_header = request.headers.get("stripe-signature") | |
| webhook_secret = os.getenv("STRIPE_WEBHOOK_SECRET") | |
| try: | |
| event = stripe.Webhook.construct_event( | |
| payload, sig_header, webhook_secret | |
| ) | |
| except ValueError as e: | |
| raise HTTPException(status_code=400, detail="Invalid payload") | |
| except stripe.error.SignatureVerificationError as e: | |
| raise HTTPException(status_code=400, detail="Invalid signature") | |
| event_type = event["type"] | |
| data_object = event["data"]["object"] | |
| if event_type == "checkout.session.completed": | |
| await handle_checkout_session(data_object, event["id"]) | |
| elif event_type == "invoice.paid": | |
| await handle_invoice_paid(data_object, event["id"]) | |
| return JSONResponse({"status": "success"}) | |
| async def handle_checkout_session(session, event_id): | |
| # Metadata should contain hf_user and dataset_id | |
| metadata = session.get("metadata", {}) | |
| hf_user = metadata.get("hf_user") | |
| dataset_id = metadata.get("dataset_id") | |
| if not hf_user or not dataset_id: | |
| print(f"Missing metadata in session {session['id']}") | |
| return | |
| # Determine plan and duration | |
| # In a real scenario, we might look up the line items to find the price ID | |
| # For simplicity, we assume the first line item's price ID matches our registry | |
| # Or we can pass the plan_id in metadata too. | |
| # Let's try to get price_id from the session if possible, or rely on metadata | |
| # If subscription mode, we might need to fetch the subscription details | |
| subscription_id = session.get("subscription") | |
| if subscription_id: | |
| # It's a subscription | |
| sub = stripe.Subscription.retrieve(subscription_id) | |
| price_id = sub["items"]["data"][0]["price"]["id"] | |
| current_period_end = sub["current_period_end"] | |
| end_date = datetime.fromtimestamp(current_period_end) | |
| else: | |
| # One-time payment? | |
| # For this hackathon, let's assume subscriptions. | |
| print("Non-subscription checkout not fully supported yet.") | |
| return | |
| plan_info = datasets_registry.get_plan_by_price_id(price_id) | |
| if not plan_info: | |
| print(f"Unknown price_id {price_id}") | |
| return | |
| plan_id = plan_info["plan"]["plan_id"] | |
| # Generate unique access token for this subscription | |
| access_token = subscriptions_ledger.generate_access_token() | |
| ledger_entry = { | |
| "event_id": event_id, | |
| "hf_user": hf_user, | |
| "dataset_id": dataset_id, | |
| "plan_id": plan_id, | |
| "subscription_start": datetime.utcnow().isoformat() + "Z", | |
| "subscription_end": end_date.isoformat() + "Z", | |
| "source": "stripe", | |
| "access_token": access_token, | |
| "created_at": datetime.utcnow().isoformat() + "Z", | |
| "stripe_customer_id": session.get("customer"), | |
| "stripe_subscription_id": subscription_id | |
| } | |
| subscriptions_ledger.append_subscription_event(ledger_entry) | |
| async def handle_invoice_paid(invoice, event_id): | |
| subscription_id = invoice.get("subscription") | |
| if not subscription_id: | |
| return | |
| # We need to find the user and dataset associated with this subscription | |
| # We can query Stripe or look up in our ledger if we stored subscription_id | |
| # For now, let's assume we can get it from the subscription metadata in Stripe | |
| sub = stripe.Subscription.retrieve(subscription_id) | |
| metadata = sub.get("metadata", {}) | |
| hf_user = metadata.get("hf_user") | |
| dataset_id = metadata.get("dataset_id") | |
| if not hf_user or not dataset_id: | |
| # Try to find from previous ledger entries? | |
| # For simplicity, we assume metadata is preserved on the subscription object in Stripe | |
| print(f"Missing metadata in subscription {subscription_id}") | |
| return | |
| price_id = sub["items"]["data"][0]["price"]["id"] | |
| current_period_end = sub["current_period_end"] | |
| end_date = datetime.fromtimestamp(current_period_end) | |
| plan_info = datasets_registry.get_plan_by_price_id(price_id) | |
| if not plan_info: | |
| return | |
| plan_id = plan_info["plan"]["plan_id"] | |
| # For renewals, try to preserve the existing access token | |
| existing_sub = subscriptions_ledger.get_active_subscription(hf_user, dataset_id) | |
| if existing_sub and existing_sub.get("access_token"): | |
| access_token = existing_sub["access_token"] | |
| else: | |
| access_token = subscriptions_ledger.generate_access_token() | |
| ledger_entry = { | |
| "event_id": event_id, | |
| "hf_user": hf_user, | |
| "dataset_id": dataset_id, | |
| "plan_id": plan_id, | |
| "subscription_start": datetime.utcnow().isoformat() + "Z", # Or period start | |
| "subscription_end": end_date.isoformat() + "Z", | |
| "source": "stripe", | |
| "access_token": access_token, | |
| "created_at": datetime.utcnow().isoformat() + "Z", | |
| "stripe_customer_id": invoice.get("customer"), | |
| "stripe_subscription_id": subscription_id | |
| } | |
| subscriptions_ledger.append_subscription_event(ledger_entry) | |