| """ |
| Subscription service. |
| |
| This module provides functions for managing subscriptions. |
| """ |
| import os |
| import logging |
| from datetime import datetime, timedelta |
| from typing import List, Dict, Any, Optional, Tuple, Union |
|
|
| import stripe |
| from sqlalchemy.ext.asyncio import AsyncSession |
| from sqlalchemy import select, update, delete |
| from sqlalchemy.orm import joinedload |
|
|
| from src.models.subscription import ( |
| SubscriptionPlan, UserSubscription, PaymentHistory, |
| SubscriptionTier, BillingPeriod, SubscriptionStatus, PaymentStatus |
| ) |
| from src.models.user import User |
|
|
| |
| stripe.api_key = os.environ.get("STRIPE_SECRET_KEY") |
| STRIPE_PUBLISHABLE_KEY = os.environ.get("STRIPE_PUBLISHABLE_KEY") |
|
|
| |
| logger = logging.getLogger(__name__) |
|
|
|
|
| async def get_subscription_plans( |
| db: AsyncSession, |
| active_only: bool = True |
| ) -> List[SubscriptionPlan]: |
| """ |
| Get all subscription plans. |
| |
| Args: |
| db: Database session |
| active_only: If True, only return active plans |
| |
| Returns: |
| List of subscription plans |
| """ |
| query = select(SubscriptionPlan) |
| |
| if active_only: |
| query = query.where(SubscriptionPlan.is_active == True) |
| |
| result = await db.execute(query) |
| plans = result.scalars().all() |
| |
| return plans |
|
|
|
|
| async def get_subscription_plan_by_id( |
| db: AsyncSession, |
| plan_id: int |
| ) -> Optional[SubscriptionPlan]: |
| """ |
| Get a subscription plan by ID. |
| |
| Args: |
| db: Database session |
| plan_id: ID of the plan to get |
| |
| Returns: |
| Subscription plan or None if not found |
| """ |
| query = select(SubscriptionPlan).where(SubscriptionPlan.id == plan_id) |
| result = await db.execute(query) |
| plan = result.scalars().first() |
| |
| return plan |
|
|
|
|
| async def get_subscription_plan_by_tier( |
| db: AsyncSession, |
| tier: SubscriptionTier |
| ) -> Optional[SubscriptionPlan]: |
| """ |
| Get a subscription plan by tier. |
| |
| Args: |
| db: Database session |
| tier: Tier of the plan to get |
| |
| Returns: |
| Subscription plan or None if not found |
| """ |
| query = select(SubscriptionPlan).where(SubscriptionPlan.tier == tier) |
| result = await db.execute(query) |
| plan = result.scalars().first() |
| |
| return plan |
|
|
|
|
| async def create_subscription_plan( |
| db: AsyncSession, |
| name: str, |
| tier: SubscriptionTier, |
| description: str, |
| price_monthly: float, |
| price_annually: float, |
| max_alerts: int = 10, |
| max_reports: int = 5, |
| max_searches_per_day: int = 20, |
| max_monitoring_keywords: int = 10, |
| max_data_retention_days: int = 30, |
| supports_api_access: bool = False, |
| supports_live_feed: bool = False, |
| supports_dark_web_monitoring: bool = False, |
| supports_export: bool = False, |
| supports_advanced_analytics: bool = False, |
| create_stripe_product: bool = True |
| ) -> Optional[SubscriptionPlan]: |
| """ |
| Create a new subscription plan. |
| |
| Args: |
| db: Database session |
| name: Name of the plan |
| tier: Tier of the plan |
| description: Description of the plan |
| price_monthly: Monthly price of the plan |
| price_annually: Annual price of the plan |
| max_alerts: Maximum number of alerts allowed |
| max_reports: Maximum number of reports allowed |
| max_searches_per_day: Maximum number of searches per day |
| max_monitoring_keywords: Maximum number of monitoring keywords |
| max_data_retention_days: Maximum number of days to retain data |
| supports_api_access: Whether the plan supports API access |
| supports_live_feed: Whether the plan supports live feed |
| supports_dark_web_monitoring: Whether the plan supports dark web monitoring |
| supports_export: Whether the plan supports data export |
| supports_advanced_analytics: Whether the plan supports advanced analytics |
| create_stripe_product: Whether to create a Stripe product for this plan |
| |
| Returns: |
| Created subscription plan or None if creation failed |
| """ |
| |
| existing_plan = await get_subscription_plan_by_tier(db, tier) |
| |
| if existing_plan: |
| logger.warning(f"Subscription plan with tier {tier} already exists") |
| return None |
| |
| |
| stripe_product_id = None |
| stripe_monthly_price_id = None |
| stripe_annual_price_id = None |
| |
| if create_stripe_product and stripe.api_key: |
| try: |
| |
| product = stripe.Product.create( |
| name=name, |
| description=description, |
| metadata={ |
| "tier": tier.value, |
| "max_alerts": max_alerts, |
| "max_reports": max_reports, |
| "max_searches_per_day": max_searches_per_day, |
| "max_monitoring_keywords": max_monitoring_keywords, |
| "max_data_retention_days": max_data_retention_days, |
| "supports_api_access": "yes" if supports_api_access else "no", |
| "supports_live_feed": "yes" if supports_live_feed else "no", |
| "supports_dark_web_monitoring": "yes" if supports_dark_web_monitoring else "no", |
| "supports_export": "yes" if supports_export else "no", |
| "supports_advanced_analytics": "yes" if supports_advanced_analytics else "no" |
| } |
| ) |
| |
| stripe_product_id = product.id |
| |
| |
| monthly_price = stripe.Price.create( |
| product=product.id, |
| unit_amount=int(price_monthly * 100), |
| currency="usd", |
| recurring={"interval": "month"}, |
| metadata={"billing_period": "monthly"} |
| ) |
| |
| stripe_monthly_price_id = monthly_price.id |
| |
| |
| annual_price = stripe.Price.create( |
| product=product.id, |
| unit_amount=int(price_annually * 100), |
| currency="usd", |
| recurring={"interval": "year"}, |
| metadata={"billing_period": "annually"} |
| ) |
| |
| stripe_annual_price_id = annual_price.id |
| |
| logger.info(f"Created Stripe product {product.id} for plan {name}") |
| except Exception as e: |
| logger.error(f"Failed to create Stripe product for plan {name}: {e}") |
| |
| |
| plan = SubscriptionPlan( |
| name=name, |
| tier=tier, |
| description=description, |
| price_monthly=price_monthly, |
| price_annually=price_annually, |
| max_alerts=max_alerts, |
| max_reports=max_reports, |
| max_searches_per_day=max_searches_per_day, |
| max_monitoring_keywords=max_monitoring_keywords, |
| max_data_retention_days=max_data_retention_days, |
| supports_api_access=supports_api_access, |
| supports_live_feed=supports_live_feed, |
| supports_dark_web_monitoring=supports_dark_web_monitoring, |
| supports_export=supports_export, |
| supports_advanced_analytics=supports_advanced_analytics, |
| stripe_product_id=stripe_product_id, |
| stripe_monthly_price_id=stripe_monthly_price_id, |
| stripe_annual_price_id=stripe_annual_price_id |
| ) |
| |
| db.add(plan) |
| await db.commit() |
| await db.refresh(plan) |
| |
| return plan |
|
|
|
|
| async def update_subscription_plan( |
| db: AsyncSession, |
| plan_id: int, |
| name: Optional[str] = None, |
| description: Optional[str] = None, |
| price_monthly: Optional[float] = None, |
| price_annually: Optional[float] = None, |
| is_active: Optional[bool] = None, |
| max_alerts: Optional[int] = None, |
| max_reports: Optional[int] = None, |
| max_searches_per_day: Optional[int] = None, |
| max_monitoring_keywords: Optional[int] = None, |
| max_data_retention_days: Optional[int] = None, |
| supports_api_access: Optional[bool] = None, |
| supports_live_feed: Optional[bool] = None, |
| supports_dark_web_monitoring: Optional[bool] = None, |
| supports_export: Optional[bool] = None, |
| supports_advanced_analytics: Optional[bool] = None, |
| update_stripe_product: bool = True |
| ) -> Optional[SubscriptionPlan]: |
| """ |
| Update a subscription plan. |
| |
| Args: |
| db: Database session |
| plan_id: ID of the plan to update |
| name: New name of the plan |
| description: New description of the plan |
| price_monthly: New monthly price of the plan |
| price_annually: New annual price of the plan |
| is_active: New active status of the plan |
| max_alerts: New maximum number of alerts allowed |
| max_reports: New maximum number of reports allowed |
| max_searches_per_day: New maximum number of searches per day |
| max_monitoring_keywords: New maximum number of monitoring keywords |
| max_data_retention_days: New maximum number of days to retain data |
| supports_api_access: New API access support status |
| supports_live_feed: New live feed support status |
| supports_dark_web_monitoring: New dark web monitoring support status |
| supports_export: New data export support status |
| supports_advanced_analytics: New advanced analytics support status |
| update_stripe_product: Whether to update the Stripe product for this plan |
| |
| Returns: |
| Updated subscription plan or None if update failed |
| """ |
| |
| plan = await get_subscription_plan_by_id(db, plan_id) |
| |
| if not plan: |
| logger.warning(f"Subscription plan with ID {plan_id} not found") |
| return None |
| |
| |
| update_data = {} |
| |
| if name is not None: |
| update_data["name"] = name |
| |
| if description is not None: |
| update_data["description"] = description |
| |
| if price_monthly is not None: |
| update_data["price_monthly"] = price_monthly |
| |
| if price_annually is not None: |
| update_data["price_annually"] = price_annually |
| |
| if is_active is not None: |
| update_data["is_active"] = is_active |
| |
| if max_alerts is not None: |
| update_data["max_alerts"] = max_alerts |
| |
| if max_reports is not None: |
| update_data["max_reports"] = max_reports |
| |
| if max_searches_per_day is not None: |
| update_data["max_searches_per_day"] = max_searches_per_day |
| |
| if max_monitoring_keywords is not None: |
| update_data["max_monitoring_keywords"] = max_monitoring_keywords |
| |
| if max_data_retention_days is not None: |
| update_data["max_data_retention_days"] = max_data_retention_days |
| |
| if supports_api_access is not None: |
| update_data["supports_api_access"] = supports_api_access |
| |
| if supports_live_feed is not None: |
| update_data["supports_live_feed"] = supports_live_feed |
| |
| if supports_dark_web_monitoring is not None: |
| update_data["supports_dark_web_monitoring"] = supports_dark_web_monitoring |
| |
| if supports_export is not None: |
| update_data["supports_export"] = supports_export |
| |
| if supports_advanced_analytics is not None: |
| update_data["supports_advanced_analytics"] = supports_advanced_analytics |
| |
| |
| if update_stripe_product and plan.stripe_product_id and stripe.api_key: |
| try: |
| |
| product_update_data = {} |
| |
| if name is not None: |
| product_update_data["name"] = name |
| |
| if description is not None: |
| product_update_data["description"] = description |
| |
| metadata_update = {} |
| |
| if max_alerts is not None: |
| metadata_update["max_alerts"] = max_alerts |
| |
| if max_reports is not None: |
| metadata_update["max_reports"] = max_reports |
| |
| if max_searches_per_day is not None: |
| metadata_update["max_searches_per_day"] = max_searches_per_day |
| |
| if max_monitoring_keywords is not None: |
| metadata_update["max_monitoring_keywords"] = max_monitoring_keywords |
| |
| if max_data_retention_days is not None: |
| metadata_update["max_data_retention_days"] = max_data_retention_days |
| |
| if supports_api_access is not None: |
| metadata_update["supports_api_access"] = "yes" if supports_api_access else "no" |
| |
| if supports_live_feed is not None: |
| metadata_update["supports_live_feed"] = "yes" if supports_live_feed else "no" |
| |
| if supports_dark_web_monitoring is not None: |
| metadata_update["supports_dark_web_monitoring"] = "yes" if supports_dark_web_monitoring else "no" |
| |
| if supports_export is not None: |
| metadata_update["supports_export"] = "yes" if supports_export else "no" |
| |
| if supports_advanced_analytics is not None: |
| metadata_update["supports_advanced_analytics"] = "yes" if supports_advanced_analytics else "no" |
| |
| if metadata_update: |
| product_update_data["metadata"] = metadata_update |
| |
| if product_update_data: |
| stripe.Product.modify(plan.stripe_product_id, **product_update_data) |
| |
| |
| if price_monthly is not None and plan.stripe_monthly_price_id: |
| |
| new_monthly_price = stripe.Price.create( |
| product=plan.stripe_product_id, |
| unit_amount=int(price_monthly * 100), |
| currency="usd", |
| recurring={"interval": "month"}, |
| metadata={"billing_period": "monthly"} |
| ) |
| |
| update_data["stripe_monthly_price_id"] = new_monthly_price.id |
| |
| if price_annually is not None and plan.stripe_annual_price_id: |
| |
| new_annual_price = stripe.Price.create( |
| product=plan.stripe_product_id, |
| unit_amount=int(price_annually * 100), |
| currency="usd", |
| recurring={"interval": "year"}, |
| metadata={"billing_period": "annually"} |
| ) |
| |
| update_data["stripe_annual_price_id"] = new_annual_price.id |
| |
| logger.info(f"Updated Stripe product {plan.stripe_product_id} for plan {plan.name}") |
| except Exception as e: |
| logger.error(f"Failed to update Stripe product for plan {plan.name}: {e}") |
| |
| |
| if update_data: |
| await db.execute( |
| update(SubscriptionPlan) |
| .where(SubscriptionPlan.id == plan_id) |
| .values(**update_data) |
| ) |
| |
| await db.commit() |
| |
| |
| plan = await get_subscription_plan_by_id(db, plan_id) |
| |
| return plan |
|
|
|
|
| async def get_user_subscription( |
| db: AsyncSession, |
| user_id: int |
| ) -> Optional[UserSubscription]: |
| """ |
| Get a user's active subscription. |
| |
| Args: |
| db: Database session |
| user_id: ID of the user |
| |
| Returns: |
| User subscription or None if not found |
| """ |
| query = ( |
| select(UserSubscription) |
| .where(UserSubscription.user_id == user_id) |
| .where(UserSubscription.status != SubscriptionStatus.CANCELED) |
| .options(joinedload(UserSubscription.plan)) |
| ) |
| |
| result = await db.execute(query) |
| subscription = result.scalars().first() |
| |
| return subscription |
|
|
|
|
| async def get_user_subscription_by_id( |
| db: AsyncSession, |
| subscription_id: int |
| ) -> Optional[UserSubscription]: |
| """ |
| Get a user subscription by ID. |
| |
| Args: |
| db: Database session |
| subscription_id: ID of the subscription |
| |
| Returns: |
| User subscription or None if not found |
| """ |
| query = ( |
| select(UserSubscription) |
| .where(UserSubscription.id == subscription_id) |
| .options(joinedload(UserSubscription.plan)) |
| ) |
| |
| result = await db.execute(query) |
| subscription = result.scalars().first() |
| |
| return subscription |
|
|
|
|
| async def create_user_subscription( |
| db: AsyncSession, |
| user_id: int, |
| plan_id: int, |
| billing_period: BillingPeriod = BillingPeriod.MONTHLY, |
| create_stripe_subscription: bool = True, |
| payment_method_id: Optional[str] = None |
| ) -> Optional[UserSubscription]: |
| """ |
| Create a new user subscription. |
| |
| Args: |
| db: Database session |
| user_id: ID of the user |
| plan_id: ID of the subscription plan |
| billing_period: Billing period (monthly or annually) |
| create_stripe_subscription: Whether to create a Stripe subscription |
| payment_method_id: ID of the payment method to use (required if create_stripe_subscription is True) |
| |
| Returns: |
| Created user subscription or None if creation failed |
| """ |
| |
| query = select(User).where(User.id == user_id) |
| result = await db.execute(query) |
| user = result.scalars().first() |
| |
| if not user: |
| logger.warning(f"User with ID {user_id} not found") |
| return None |
| |
| |
| plan = await get_subscription_plan_by_id(db, plan_id) |
| |
| if not plan: |
| logger.warning(f"Subscription plan with ID {plan_id} not found") |
| return None |
| |
| |
| existing_subscription = await get_user_subscription(db, user_id) |
| |
| if existing_subscription: |
| logger.warning(f"User with ID {user_id} already has an active subscription") |
| return None |
| |
| |
| now = datetime.utcnow() |
| |
| if billing_period == BillingPeriod.MONTHLY: |
| current_period_end = now + timedelta(days=30) |
| price = plan.price_monthly |
| stripe_price_id = plan.stripe_monthly_price_id |
| elif billing_period == BillingPeriod.ANNUALLY: |
| current_period_end = now + timedelta(days=365) |
| price = plan.price_annually |
| stripe_price_id = plan.stripe_annual_price_id |
| else: |
| logger.warning(f"Invalid billing period: {billing_period}") |
| return None |
| |
| |
| stripe_subscription_id = None |
| stripe_customer_id = None |
| |
| if create_stripe_subscription and stripe.api_key and plan.stripe_product_id: |
| if not payment_method_id: |
| logger.warning("Payment method ID is required to create a Stripe subscription") |
| return None |
| |
| try: |
| |
| customers = stripe.Customer.list(email=user.email) |
| |
| if customers.data: |
| customer = customers.data[0] |
| stripe_customer_id = customer.id |
| else: |
| customer = stripe.Customer.create( |
| email=user.email, |
| name=user.full_name, |
| metadata={"user_id": user_id} |
| ) |
| |
| stripe_customer_id = customer.id |
| |
| |
| stripe.PaymentMethod.attach( |
| payment_method_id, |
| customer=stripe_customer_id |
| ) |
| |
| |
| stripe.Customer.modify( |
| stripe_customer_id, |
| invoice_settings={ |
| "default_payment_method": payment_method_id |
| } |
| ) |
| |
| |
| subscription = stripe.Subscription.create( |
| customer=stripe_customer_id, |
| items=[ |
| {"price": stripe_price_id} |
| ], |
| expand=["latest_invoice.payment_intent"] |
| ) |
| |
| stripe_subscription_id = subscription.id |
| |
| logger.info(f"Created Stripe subscription {subscription.id} for user {user_id}") |
| except Exception as e: |
| logger.error(f"Failed to create Stripe subscription for user {user_id}: {e}") |
| return None |
| |
| |
| subscription = UserSubscription( |
| user_id=user_id, |
| plan_id=plan_id, |
| status=SubscriptionStatus.ACTIVE, |
| billing_period=billing_period, |
| current_period_start=now, |
| current_period_end=current_period_end, |
| stripe_subscription_id=stripe_subscription_id, |
| stripe_customer_id=stripe_customer_id |
| ) |
| |
| db.add(subscription) |
| await db.commit() |
| await db.refresh(subscription) |
| |
| |
| if subscription.id: |
| payment_status = PaymentStatus.SUCCEEDED if stripe_subscription_id else PaymentStatus.PENDING |
| |
| payment = PaymentHistory( |
| user_id=user_id, |
| subscription_id=subscription.id, |
| amount=price, |
| currency="USD", |
| status=payment_status |
| ) |
| |
| db.add(payment) |
| await db.commit() |
| |
| return subscription |
|
|
|
|
| async def cancel_user_subscription( |
| db: AsyncSession, |
| subscription_id: int, |
| cancel_stripe_subscription: bool = True |
| ) -> Optional[UserSubscription]: |
| """ |
| Cancel a user subscription. |
| |
| Args: |
| db: Database session |
| subscription_id: ID of the subscription to cancel |
| cancel_stripe_subscription: Whether to cancel the Stripe subscription |
| |
| Returns: |
| Canceled user subscription or None if cancellation failed |
| """ |
| |
| subscription = await get_user_subscription_by_id(db, subscription_id) |
| |
| if not subscription: |
| logger.warning(f"Subscription with ID {subscription_id} not found") |
| return None |
| |
| |
| if cancel_stripe_subscription and subscription.stripe_subscription_id and stripe.api_key: |
| try: |
| stripe.Subscription.modify( |
| subscription.stripe_subscription_id, |
| cancel_at_period_end=True |
| ) |
| |
| logger.info(f"Canceled Stripe subscription {subscription.stripe_subscription_id} at period end") |
| except Exception as e: |
| logger.error(f"Failed to cancel Stripe subscription {subscription.stripe_subscription_id}: {e}") |
| |
| |
| now = datetime.utcnow() |
| |
| await db.execute( |
| update(UserSubscription) |
| .where(UserSubscription.id == subscription_id) |
| .values( |
| status=SubscriptionStatus.CANCELED, |
| canceled_at=now |
| ) |
| ) |
| |
| await db.commit() |
| |
| |
| subscription = await get_user_subscription_by_id(db, subscription_id) |
| |
| return subscription |