#!/usr/bin/env python3 # coding: utf-8 import logging import math import os from contextlib import contextmanager from typing import Literal from sqlalchemy import ( BigInteger, Column, Enum, Float, ForeignKey, Integer, String, create_engine, ) from sqlalchemy.types import TEXT from sqlalchemy.ext.declarative import declarative_base from sqlalchemy.orm import relationship, sessionmaker from config import ENABLE_VIP, FREE_DOWNLOAD class PaymentStatus: PENDING = "pending" COMPLETED = "completed" FAILED = "failed" REFUNDED = "refunded" Base = declarative_base() class User(Base): __tablename__ = "users" id = Column(Integer, primary_key=True, autoincrement=True) user_id = Column(BigInteger, unique=True, nullable=False) # telegram user id free = Column(Integer, default=FREE_DOWNLOAD) paid = Column(Integer, default=0) config = Column(TEXT) settings = relationship("Setting", back_populates="user", cascade="all, delete-orphan", uselist=False) payments = relationship("Payment", back_populates="user", cascade="all, delete-orphan") class Setting(Base): __tablename__ = "settings" id = Column(Integer, primary_key=True, autoincrement=True) quality = Column(Enum("high", "medium", "low", "audio", "custom"), nullable=False, default="high") format = Column(Enum("video", "audio", "document"), nullable=False, default="video") user_id = Column(Integer, ForeignKey("users.id"), nullable=False) user = relationship("User", back_populates="settings") class Payment(Base): __tablename__ = "payments" id = Column(Integer, primary_key=True, autoincrement=True) method = Column(String(50), nullable=False) amount = Column(Float, nullable=False) status = Column( Enum( PaymentStatus.PENDING, PaymentStatus.COMPLETED, PaymentStatus.FAILED, PaymentStatus.REFUNDED, ), nullable=False, ) transaction_id = Column(String(100)) user_id = Column(Integer, ForeignKey("users.id"), nullable=False) user = relationship("User", back_populates="payments") def create_session(): # Get SSL certificate path from environment variable # ssl_ca_path = os.getenv("SSL_CA_PATH", "/path/to/ca.pem") # Default path if not set # Create engine with SSL configuration engine = create_engine( os.getenv("DB_DSN"), pool_size=50, max_overflow=100, pool_timeout=30, pool_recycle=1800, connect_args={ "ssl": { "ca": "/ca.pem" # "ssl_mode": "REQUIRED" } } ) # Verify connection works try: with engine.connect() as conn: conn.execute("SELECT 1") except Exception as e: logging.error("Failed to connect to database: %s", e) raise Base.metadata.create_all(engine) return sessionmaker(bind=engine) SessionFactory = create_session() @contextmanager def session_manager(): s = SessionFactory() try: yield s s.commit() except Exception as e: s.rollback() raise finally: s.close() def get_quality_settings(tgid) -> Literal["high", "medium", "low", "audio", "custom"]: with session_manager() as session: user = session.query(User).filter(User.user_id == tgid).first() if user and user.settings: return user.settings.quality return "high" def get_format_settings(tgid) -> Literal["video", "audio", "document"]: with session_manager() as session: user = session.query(User).filter(User.user_id == tgid).first() if user and user.settings: return user.settings.format return "video" def set_user_settings(tgid: int, key: str, value: str): # set quality or format settings with session_manager() as session: # find user first user = session.query(User).filter(User.user_id == tgid).first() # upsert setting = session.query(Setting).filter(Setting.user_id == user.id).first() if setting: setattr(setting, key, value) else: session.add(Setting(user_id=user.id, **{key: value})) def get_free_quota(uid: int): if not ENABLE_VIP: return math.inf with session_manager() as session: data = session.query(User).filter(User.user_id == uid).first() if data: return data.free return FREE_DOWNLOAD def get_paid_quota(uid: int): if ENABLE_VIP: with session_manager() as session: data = session.query(User).filter(User.user_id == uid).first() if data: return data.paid return 0 return math.inf def reset_free_quota(uid: int): with session_manager() as session: data = session.query(User).filter(User.user_id == uid).first() if data: data.free = 5 def add_paid_quota(uid: int, amount: int): with session_manager() as session: data = session.query(User).filter(User.user_id == uid).first() if data: data.paid += amount def check_quota(uid: int): if not ENABLE_VIP: return with session_manager() as session: data = session.query(User).filter(User.user_id == uid).first() if data and (data.free + data.paid) <= 0: raise Exception("Quota exhausted. Please /buy or wait until free quota is reset") def use_quota(uid: int): # use free first, then paid if not ENABLE_VIP: return with session_manager() as session: user = session.query(User).filter(User.user_id == uid).first() if user: if user.free > 0: user.free -= 1 elif user.paid > 0: user.paid -= 1 else: raise Exception("Quota exhausted. Please /buy or wait until free quota is reset") def init_user(uid: int): with session_manager() as session: user = session.query(User).filter(User.user_id == uid).first() if not user: session.add(User(user_id=uid)) def reset_free(): with session_manager() as session: users = session.query(User).all() for user in users: user.free = FREE_DOWNLOAD session.commit() def credit_account(who, total_amount: int, quota: int, transaction, method="stripe"): with session_manager() as session: user = session.query(User).filter(User.user_id == who).first() if user: dollar = total_amount / 100 user.paid += quota logging.info("user %d credited with %d tokens, payment:$%.2f", who, user.paid, dollar) session.add( Payment( method=method, amount=total_amount, status=PaymentStatus.COMPLETED, transaction_id=transaction, user_id=user.id, ) ) session.commit() return user.free, user.paid return None, None