Spaces:
Running
Running
| """ | |
| 用户与支付 SQLite 数据库模块 | |
| - 使用 SQLAlchemy 管理 user_data.db | |
| - 支持从 Hugging Face Dataset 下载/上传用户库 | |
| """ | |
| from __future__ import annotations | |
| import os | |
| import secrets | |
| import hashlib | |
| import hmac | |
| import logging | |
| import time | |
| from dotenv import load_dotenv | |
| # 加载 .env 环境变量 | |
| load_dotenv() | |
| from dataclasses import dataclass | |
| from datetime import datetime, timedelta, timezone | |
| from functools import wraps | |
| from pathlib import Path | |
| from typing import Optional, Callable, Any | |
| from sqlalchemy import ( | |
| create_engine, | |
| Integer, | |
| String, | |
| DateTime, | |
| Float, | |
| ForeignKey, | |
| Text, | |
| Date, | |
| text, | |
| ) | |
| from sqlalchemy.orm import DeclarativeBase, Mapped, mapped_column, sessionmaker, Session | |
| from huggingface_hub import hf_hub_download, upload_file | |
| logger = logging.getLogger(__name__) | |
| HF_TOKEN = os.getenv("HF_TOKEN") | |
| DATASET_REPO_ID = os.getenv("DATASET_REPO_ID", "") | |
| USER_DB_FILENAME = os.getenv("USER_DB_FILENAME", "data/user_data.db") | |
| _default_db_path = "/tmp/user_data.db" if os.name != "nt" else "backend/data/user_data.db" | |
| USER_DB_PATH = os.getenv("USER_DB_PATH", _default_db_path) | |
| class Base(DeclarativeBase): | |
| pass | |
| def get_beijing_time() -> datetime: | |
| """获取北京时间 (UTC+8)""" | |
| return datetime.now(timezone(timedelta(hours=8))).replace(tzinfo=None) | |
| class User(Base): | |
| __tablename__ = "users" | |
| id: Mapped[int] = mapped_column(Integer, primary_key=True, autoincrement=True) | |
| username: Mapped[str] = mapped_column(String(64), unique=True, nullable=False, index=True) | |
| password_hash: Mapped[str] = mapped_column(String(128), nullable=False) | |
| created_at: Mapped[datetime] = mapped_column(DateTime, default=get_beijing_time) | |
| def is_admin(self) -> bool: | |
| """主管理员权限""" | |
| return self.username == "583079759" | |
| class UserMembership(Base): | |
| __tablename__ = "user_membership" | |
| user_id: Mapped[int] = mapped_column(Integer, ForeignKey("users.id"), primary_key=True) | |
| vip_expire_at: Mapped[Optional[datetime]] = mapped_column(DateTime, nullable=True) | |
| updated_at: Mapped[datetime] = mapped_column(DateTime, default=get_beijing_time, onupdate=get_beijing_time) | |
| class PaymentOrder(Base): | |
| __tablename__ = "payment_orders" | |
| id: Mapped[int] = mapped_column(Integer, primary_key=True, autoincrement=True) | |
| order_id: Mapped[str] = mapped_column(String(128), unique=True, nullable=False, index=True) | |
| user_id: Mapped[int] = mapped_column(Integer, ForeignKey("users.id"), nullable=False) | |
| amount: Mapped[float] = mapped_column(Float, nullable=False) | |
| pay_type: Mapped[int] = mapped_column(Integer, default=1) # 1: 支付宝, 2: 微信 | |
| vip_duration_months: Mapped[int] = mapped_column(Integer, default=1) # 购买月数 | |
| status: Mapped[str] = mapped_column(String(32), default="pending", index=True) # pending, paid, expired | |
| created_at: Mapped[datetime] = mapped_column(DateTime, default=get_beijing_time) | |
| paid_at: Mapped[Optional[datetime]] = mapped_column(DateTime, nullable=True) | |
| raw_payload: Mapped[Optional[str]] = mapped_column(Text, nullable=True) | |
| class UserSession(Base): | |
| __tablename__ = "user_sessions" | |
| token: Mapped[str] = mapped_column(String(128), primary_key=True) | |
| user_id: Mapped[int] = mapped_column(Integer, ForeignKey("users.id"), nullable=False, index=True) | |
| expire_at: Mapped[datetime] = mapped_column(DateTime, nullable=False) | |
| created_at: Mapped[datetime] = mapped_column(DateTime, default=get_beijing_time) | |
| class DailyUsage(Base): | |
| """每日使用次数记录""" | |
| __tablename__ = "daily_usage" | |
| user_id: Mapped[int] = mapped_column(Integer, ForeignKey("users.id"), primary_key=True) | |
| use_date: Mapped[datetime] = mapped_column(Date, primary_key=True) | |
| count: Mapped[int] = mapped_column(Integer, default=0, nullable=False) | |
| def _ensure_db_dir() -> None: | |
| Path(USER_DB_PATH).parent.mkdir(parents=True, exist_ok=True) | |
| def _db_url() -> str: | |
| return f"sqlite:///{Path(USER_DB_PATH).as_posix()}" | |
| engine = create_engine(_db_url(), connect_args={"check_same_thread": False}) | |
| SessionLocal = sessionmaker(bind=engine, autoflush=False, autocommit=False) | |
| def init_user_db() -> None: | |
| _ensure_db_dir() | |
| sync_user_db_from_hf() | |
| sync_qr_codes_from_hf() | |
| Base.metadata.create_all(bind=engine) | |
| # 简单的数据库迁移:添加 vip_duration_months 列(如果不存在) | |
| try: | |
| with engine.connect() as conn: | |
| conn.execute(text("ALTER TABLE payment_orders ADD COLUMN vip_duration_months INTEGER DEFAULT 1")) | |
| conn.commit() | |
| logger.info("Migrated payment_orders table: added vip_duration_months") | |
| except Exception as e: | |
| # 如果列已存在会抛出异常,记录日志但忽略 | |
| if "duplicate column name" not in str(e).lower() and "already exists" not in str(e).lower(): | |
| logger.debug(f"Migration skip or expected fail: {e}") | |
| # 初始化内置管理员账号 | |
| with SessionLocal() as db: | |
| admin_username = "583079759" | |
| admin_password_raw = "superXU520" | |
| existing_admin = db.query(User).filter(User.username == admin_username).first() | |
| if not existing_admin: | |
| logger.info(f"Seeding super admin user: {admin_username}") | |
| admin_user = User( | |
| username=admin_username, | |
| password_hash=hash_password(admin_password_raw) | |
| ) | |
| db.add(admin_user) | |
| db.commit() | |
| # 默认给管理员开通永久会员(本质上代码层面也会跳过限制,但展示上好看一点) | |
| extend_vip_membership(db, admin_user.id, days=36500) | |
| def get_user_db() -> Session: | |
| db = SessionLocal() | |
| try: | |
| yield db | |
| finally: | |
| db.close() | |
| def sync_user_db_from_hf() -> None: | |
| if not HF_TOKEN or not DATASET_REPO_ID: | |
| logger.info("HF_TOKEN or DATASET_REPO_ID not set, skip user db download") | |
| return | |
| max_retries = 5 | |
| retry_delay = 2 # seconds | |
| for attempt in range(max_retries): | |
| try: | |
| downloaded = hf_hub_download( | |
| repo_id=DATASET_REPO_ID, | |
| filename=USER_DB_FILENAME, | |
| repo_type="dataset", | |
| token=HF_TOKEN, | |
| ) | |
| src = Path(downloaded) | |
| dst = Path(USER_DB_PATH) | |
| if src.exists() and src.resolve() != dst.resolve(): | |
| dst.write_bytes(src.read_bytes()) | |
| logger.info(f"User db synced from HF: {USER_DB_FILENAME}") | |
| return # Success, exit the loop | |
| except Exception as e: | |
| logger.warning(f"Attempt {attempt + 1} failed to sync user db from HF: {e}") | |
| if attempt < max_retries - 1: | |
| logger.info(f"Retrying in {retry_delay} seconds...") | |
| time.sleep(retry_delay) | |
| else: | |
| logger.error(f"All {max_retries} attempts failed to sync user db from HF. Application will not start properly.") | |
| raise # Re-raise the exception to prevent application startup | |
| def sync_qr_codes_from_hf() -> None: | |
| """从 Dataset 下载收款码图片到静态文件目录""" | |
| if not HF_TOKEN or not DATASET_REPO_ID: | |
| return | |
| # 使用绝对路径确保一致性 | |
| current_dir = Path(__file__).parent | |
| qr_dir = current_dir / "static" / "images" | |
| qr_dir.mkdir(parents=True, exist_ok=True) | |
| for filename in ["qrcode_alipay.png", "qrcode_wechatpay.png", "qrcode_wechat.png"]: | |
| try: | |
| downloaded = hf_hub_download( | |
| repo_id=DATASET_REPO_ID, | |
| filename=f"data/images/{filename}", | |
| repo_type="dataset", | |
| token=HF_TOKEN, | |
| ) | |
| dst = qr_dir / filename | |
| dst.write_bytes(Path(downloaded).read_bytes()) | |
| logger.info(f"QR code synced: {filename}") | |
| except Exception as e: | |
| logger.warning(f"Failed to sync QR code {filename}: {e}") | |
| def upload_user_db_to_hf() -> None: | |
| if not HF_TOKEN or not DATASET_REPO_ID: | |
| logger.info("HF_TOKEN or DATASET_REPO_ID not set, skip user db upload") | |
| return | |
| db_path = Path(USER_DB_PATH) | |
| if not db_path.exists(): | |
| logger.warning("User db file not found, skip upload") | |
| return | |
| # Check if the database file is too small (likely empty), prevent overwriting | |
| file_size = db_path.stat().st_size | |
| if file_size < 1024: # Less than 1KB, likely empty | |
| logger.warning(f"User db file is too small ({file_size} bytes), skip upload to prevent overwriting remote data") | |
| return | |
| max_retries = 3 | |
| retry_delay = 2 # seconds | |
| for attempt in range(max_retries): | |
| try: | |
| upload_file( | |
| path_or_fileobj=str(db_path), | |
| path_in_repo=USER_DB_FILENAME, | |
| repo_id=DATASET_REPO_ID, | |
| repo_type="dataset", | |
| token=HF_TOKEN, | |
| ) | |
| logger.info(f"User db uploaded to HF: {USER_DB_FILENAME}") | |
| return # Success, exit the loop | |
| except Exception as e: | |
| logger.error(f"Attempt {attempt + 1} failed to upload user db to HF: {e}") | |
| if attempt < max_retries - 1: | |
| logger.info(f"Retrying in {retry_delay} seconds...") | |
| time.sleep(retry_delay) | |
| else: | |
| logger.error(f"All {max_retries} attempts failed to upload user db to HF") | |
| def sync_user_db_after_update(func: Callable[..., Any]) -> Callable[..., Any]: | |
| """在更新会员权益后自动上传 user_data.db 到 HF Dataset""" | |
| def wrapper(*args, **kwargs): | |
| result = func(*args, **kwargs) | |
| upload_user_db_to_hf() | |
| return result | |
| return wrapper | |
| def register_user(db: Session, username: str, password_hash: str) -> User: | |
| user = User(username=username, password_hash=password_hash) | |
| db.add(user) | |
| db.commit() | |
| db.refresh(user) | |
| return user | |
| def hash_password(password: str) -> str: | |
| salt = secrets.token_hex(16) | |
| digest = hashlib.sha256(f"{salt}:{password}".encode("utf-8")).hexdigest() | |
| return f"{salt}${digest}" | |
| def verify_password(password: str, password_hash: str) -> bool: | |
| try: | |
| salt, digest = password_hash.split("$", 1) | |
| except ValueError: | |
| return False | |
| new_digest = hashlib.sha256(f"{salt}:{password}".encode("utf-8")).hexdigest() | |
| return hmac.compare_digest(digest, new_digest) | |
| def create_session_token() -> str: | |
| return secrets.token_urlsafe(48) | |
| def create_payment_order(db: Session, user_id: int, amount: float, pay_type: int = 1, months: int = 1) -> str: | |
| """创建待支付订单""" | |
| # 生成订单号: YYYYMMDDHHMMSS + 6位随机 | |
| order_id = get_beijing_time().strftime("%Y%m%d%H%M%S") + secrets.token_hex(3) | |
| order = PaymentOrder( | |
| order_id=order_id, | |
| user_id=user_id, | |
| amount=amount, | |
| pay_type=pay_type, | |
| vip_duration_months=months, | |
| status="pending", | |
| raw_payload="", | |
| ) | |
| db.add(order) | |
| db.commit() | |
| return order_id | |
| def extend_vip_membership(db: Session, user_id: int, days: int = 30) -> datetime: | |
| membership = db.get(UserMembership, user_id) | |
| now = get_beijing_time() | |
| if membership is None: | |
| membership = UserMembership(user_id=user_id, vip_expire_at=now + timedelta(days=days)) | |
| db.add(membership) | |
| else: | |
| base = membership.vip_expire_at if membership.vip_expire_at and membership.vip_expire_at > now else now | |
| membership.vip_expire_at = base + timedelta(days=days) | |
| db.commit() | |
| db.refresh(membership) | |
| return membership.vip_expire_at or now | |
| def update_order_status(db: Session, order_id: str, status: str) -> bool: | |
| """更新订单状态 (管理员手动切换)""" | |
| order = db.query(PaymentOrder).filter(PaymentOrder.order_id == order_id).first() | |
| if not order: | |
| return False | |
| order.status = status | |
| if status == "paid" and not order.paid_at: | |
| order.paid_at = get_beijing_time() | |
| elif status == "pending": | |
| order.paid_at = None | |
| db.commit() | |
| return True | |
| def get_user_by_token(db: Session, token: str) -> Optional[User]: | |
| session_row = db.get(UserSession, token) | |
| if not session_row: | |
| return None | |
| if session_row.expire_at < get_beijing_time(): | |
| db.delete(session_row) | |
| db.commit() | |
| return None | |
| return db.get(User, session_row.user_id) | |
| FREE_DAILY_LIMIT = int(os.getenv("FREE_DAILY_LIMIT", "3")) | |
| def get_daily_usage(db: Session, user_id: int) -> int: | |
| """获取今日已使用次数(北京时间日期)""" | |
| today = get_beijing_time().date() | |
| row = db.get(DailyUsage, {"user_id": user_id, "use_date": today}) | |
| return row.count if row else 0 | |
| def increment_daily_usage(db: Session, user_id: int) -> int: | |
| """今日使用次数 +1,返回更新后的次数""" | |
| today = get_beijing_time().date() | |
| row = db.get(DailyUsage, {"user_id": user_id, "use_date": today}) | |
| if row is None: | |
| row = DailyUsage(user_id=user_id, use_date=today, count=1) | |
| db.add(row) | |
| else: | |
| row.count += 1 | |
| db.commit() | |
| db.refresh(row) | |
| return row.count | |
| def delete_user(db: Session, user_id: int) -> bool: | |
| """彻底删除用户及其相关的所有数据""" | |
| # 1. 删除每日使用记录 | |
| db.execute(text("DELETE FROM daily_usage WHERE user_id = :uid"), {"uid": user_id}) | |
| # 2. 删除支付订单 | |
| db.execute(text("DELETE FROM payment_orders WHERE user_id = :uid"), {"uid": user_id}) | |
| # 3. 删除会话记录 | |
| db.execute(text("DELETE FROM user_sessions WHERE user_id = :uid"), {"uid": user_id}) | |
| # 4. 删除会员权益 | |
| db.execute(text("DELETE FROM user_membership WHERE user_id = :uid"), {"uid": user_id}) | |
| # 5. 删除用户主记录 | |
| user = db.get(User, user_id) | |
| if user: | |
| db.delete(user) | |
| db.commit() | |
| return True | |