Spaces:
Paused
Paused
| #!/usr/bin/env python3 | |
| # coding: utf-8 | |
| import logging | |
| import math | |
| import os | |
| from contextlib import contextmanager | |
| from typing import Literal | |
| from sqlalchemy import ( | |
| BigInteger, | |
| Column, | |
| Enum, | |
| Float, | |
| ForeignKey, | |
| Integer, | |
| String, | |
| create_engine, | |
| ) | |
| from sqlalchemy.types import TEXT | |
| from sqlalchemy.ext.declarative import declarative_base | |
| from sqlalchemy.orm import relationship, sessionmaker | |
| from config import ENABLE_VIP, FREE_DOWNLOAD | |
| class PaymentStatus: | |
| PENDING = "pending" | |
| COMPLETED = "completed" | |
| FAILED = "failed" | |
| REFUNDED = "refunded" | |
| Base = declarative_base() | |
| class User(Base): | |
| __tablename__ = "users" | |
| id = Column(Integer, primary_key=True, autoincrement=True) | |
| user_id = Column(BigInteger, unique=True, nullable=False) # telegram user id | |
| free = Column(Integer, default=FREE_DOWNLOAD) | |
| paid = Column(Integer, default=0) | |
| config = Column(TEXT) | |
| settings = relationship("Setting", back_populates="user", cascade="all, delete-orphan", uselist=False) | |
| payments = relationship("Payment", back_populates="user", cascade="all, delete-orphan") | |
| class Setting(Base): | |
| __tablename__ = "settings" | |
| id = Column(Integer, primary_key=True, autoincrement=True) | |
| quality = Column(Enum("high", "medium", "low", "audio", "custom"), nullable=False, default="high") | |
| format = Column(Enum("video", "audio", "document"), nullable=False, default="video") | |
| user_id = Column(Integer, ForeignKey("users.id"), nullable=False) | |
| user = relationship("User", back_populates="settings") | |
| class Payment(Base): | |
| __tablename__ = "payments" | |
| id = Column(Integer, primary_key=True, autoincrement=True) | |
| method = Column(String(50), nullable=False) | |
| amount = Column(Float, nullable=False) | |
| status = Column( | |
| Enum( | |
| PaymentStatus.PENDING, | |
| PaymentStatus.COMPLETED, | |
| PaymentStatus.FAILED, | |
| PaymentStatus.REFUNDED, | |
| ), | |
| nullable=False, | |
| ) | |
| transaction_id = Column(String(100)) | |
| user_id = Column(Integer, ForeignKey("users.id"), nullable=False) | |
| user = relationship("User", back_populates="payments") | |
| def create_session(): | |
| # Get SSL certificate path from environment variable | |
| # ssl_ca_path = os.getenv("SSL_CA_PATH", "/path/to/ca.pem") # Default path if not set | |
| # Create engine with SSL configuration | |
| engine = create_engine( | |
| os.getenv("DB_DSN"), | |
| pool_size=50, | |
| max_overflow=100, | |
| pool_timeout=30, | |
| pool_recycle=1800, | |
| connect_args={ | |
| "ssl": { | |
| "ca": "/ca.pem" | |
| # "ssl_mode": "REQUIRED" | |
| } | |
| } | |
| ) | |
| # Verify connection works | |
| try: | |
| with engine.connect() as conn: | |
| conn.execute("SELECT 1") | |
| except Exception as e: | |
| logging.error("Failed to connect to database: %s", e) | |
| raise | |
| Base.metadata.create_all(engine) | |
| return sessionmaker(bind=engine) | |
| SessionFactory = create_session() | |
| @contextmanager | |
| def session_manager(): | |
| s = SessionFactory() | |
| try: | |
| yield s | |
| s.commit() | |
| except Exception as e: | |
| s.rollback() | |
| raise | |
| finally: | |
| s.close() | |
| def get_quality_settings(tgid) -> Literal["high", "medium", "low", "audio", "custom"]: | |
| with session_manager() as session: | |
| user = session.query(User).filter(User.user_id == tgid).first() | |
| if user and user.settings: | |
| return user.settings.quality | |
| return "high" | |
| def get_format_settings(tgid) -> Literal["video", "audio", "document"]: | |
| with session_manager() as session: | |
| user = session.query(User).filter(User.user_id == tgid).first() | |
| if user and user.settings: | |
| return user.settings.format | |
| return "video" | |
| def set_user_settings(tgid: int, key: str, value: str): | |
| # set quality or format settings | |
| with session_manager() as session: | |
| # find user first | |
| user = session.query(User).filter(User.user_id == tgid).first() | |
| # upsert | |
| setting = session.query(Setting).filter(Setting.user_id == user.id).first() | |
| if setting: | |
| setattr(setting, key, value) | |
| else: | |
| session.add(Setting(user_id=user.id, **{key: value})) | |
| def get_free_quota(uid: int): | |
| if not ENABLE_VIP: | |
| return math.inf | |
| with session_manager() as session: | |
| data = session.query(User).filter(User.user_id == uid).first() | |
| if data: | |
| return data.free | |
| return FREE_DOWNLOAD | |
| def get_paid_quota(uid: int): | |
| if ENABLE_VIP: | |
| with session_manager() as session: | |
| data = session.query(User).filter(User.user_id == uid).first() | |
| if data: | |
| return data.paid | |
| return 0 | |
| return math.inf | |
| def reset_free_quota(uid: int): | |
| with session_manager() as session: | |
| data = session.query(User).filter(User.user_id == uid).first() | |
| if data: | |
| data.free = 5 | |
| def add_paid_quota(uid: int, amount: int): | |
| with session_manager() as session: | |
| data = session.query(User).filter(User.user_id == uid).first() | |
| if data: | |
| data.paid += amount | |
| def check_quota(uid: int): | |
| if not ENABLE_VIP: | |
| return | |
| with session_manager() as session: | |
| data = session.query(User).filter(User.user_id == uid).first() | |
| if data and (data.free + data.paid) <= 0: | |
| raise Exception("Quota exhausted. Please /buy or wait until free quota is reset") | |
| def use_quota(uid: int): | |
| # use free first, then paid | |
| if not ENABLE_VIP: | |
| return | |
| with session_manager() as session: | |
| user = session.query(User).filter(User.user_id == uid).first() | |
| if user: | |
| if user.free > 0: | |
| user.free -= 1 | |
| elif user.paid > 0: | |
| user.paid -= 1 | |
| else: | |
| raise Exception("Quota exhausted. Please /buy or wait until free quota is reset") | |
| def init_user(uid: int): | |
| with session_manager() as session: | |
| user = session.query(User).filter(User.user_id == uid).first() | |
| if not user: | |
| session.add(User(user_id=uid)) | |
| def reset_free(): | |
| with session_manager() as session: | |
| users = session.query(User).all() | |
| for user in users: | |
| user.free = FREE_DOWNLOAD | |
| session.commit() | |
| def credit_account(who, total_amount: int, quota: int, transaction, method="stripe"): | |
| with session_manager() as session: | |
| user = session.query(User).filter(User.user_id == who).first() | |
| if user: | |
| dollar = total_amount / 100 | |
| user.paid += quota | |
| logging.info("user %d credited with %d tokens, payment:$%.2f", who, user.paid, dollar) | |
| session.add( | |
| Payment( | |
| method=method, | |
| amount=total_amount, | |
| status=PaymentStatus.COMPLETED, | |
| transaction_id=transaction, | |
| user_id=user.id, | |
| ) | |
| ) | |
| session.commit() | |
| return user.free, user.paid | |
| return None, None |