""" 用户与支付 SQLite 数据库模块 - 使用 SQLAlchemy 管理 user_data.db - 支持从 Hugging Face Dataset 下载/上传用户库 """ from __future__ import annotations import os import secrets import hashlib import hmac import logging import time from dotenv import load_dotenv # 加载 .env 环境变量 load_dotenv() from dataclasses import dataclass from datetime import datetime, timedelta, timezone from functools import wraps from pathlib import Path from typing import Optional, Callable, Any from sqlalchemy import ( create_engine, Integer, String, DateTime, Float, ForeignKey, Text, Date, text, ) from sqlalchemy.orm import DeclarativeBase, Mapped, mapped_column, sessionmaker, Session from huggingface_hub import hf_hub_download, upload_file logger = logging.getLogger(__name__) HF_TOKEN = os.getenv("HF_TOKEN") DATASET_REPO_ID = os.getenv("DATASET_REPO_ID", "") USER_DB_FILENAME = os.getenv("USER_DB_FILENAME", "data/user_data.db") _default_db_path = "/tmp/user_data.db" if os.name != "nt" else "backend/data/user_data.db" USER_DB_PATH = os.getenv("USER_DB_PATH", _default_db_path) class Base(DeclarativeBase): pass def get_beijing_time() -> datetime: """获取北京时间 (UTC+8)""" return datetime.now(timezone(timedelta(hours=8))).replace(tzinfo=None) class User(Base): __tablename__ = "users" id: Mapped[int] = mapped_column(Integer, primary_key=True, autoincrement=True) username: Mapped[str] = mapped_column(String(64), unique=True, nullable=False, index=True) password_hash: Mapped[str] = mapped_column(String(128), nullable=False) created_at: Mapped[datetime] = mapped_column(DateTime, default=get_beijing_time) @property def is_admin(self) -> bool: """主管理员权限""" return self.username == "583079759" class UserMembership(Base): __tablename__ = "user_membership" user_id: Mapped[int] = mapped_column(Integer, ForeignKey("users.id"), primary_key=True) vip_expire_at: Mapped[Optional[datetime]] = mapped_column(DateTime, nullable=True) updated_at: Mapped[datetime] = mapped_column(DateTime, default=get_beijing_time, onupdate=get_beijing_time) class PaymentOrder(Base): __tablename__ = "payment_orders" id: Mapped[int] = mapped_column(Integer, primary_key=True, autoincrement=True) order_id: Mapped[str] = mapped_column(String(128), unique=True, nullable=False, index=True) user_id: Mapped[int] = mapped_column(Integer, ForeignKey("users.id"), nullable=False) amount: Mapped[float] = mapped_column(Float, nullable=False) pay_type: Mapped[int] = mapped_column(Integer, default=1) # 1: 支付宝, 2: 微信 vip_duration_months: Mapped[int] = mapped_column(Integer, default=1) # 购买月数 status: Mapped[str] = mapped_column(String(32), default="pending", index=True) # pending, paid, expired created_at: Mapped[datetime] = mapped_column(DateTime, default=get_beijing_time) paid_at: Mapped[Optional[datetime]] = mapped_column(DateTime, nullable=True) raw_payload: Mapped[Optional[str]] = mapped_column(Text, nullable=True) class UserSession(Base): __tablename__ = "user_sessions" token: Mapped[str] = mapped_column(String(128), primary_key=True) user_id: Mapped[int] = mapped_column(Integer, ForeignKey("users.id"), nullable=False, index=True) expire_at: Mapped[datetime] = mapped_column(DateTime, nullable=False) created_at: Mapped[datetime] = mapped_column(DateTime, default=get_beijing_time) class DailyUsage(Base): """每日使用次数记录""" __tablename__ = "daily_usage" user_id: Mapped[int] = mapped_column(Integer, ForeignKey("users.id"), primary_key=True) use_date: Mapped[datetime] = mapped_column(Date, primary_key=True) count: Mapped[int] = mapped_column(Integer, default=0, nullable=False) def _ensure_db_dir() -> None: Path(USER_DB_PATH).parent.mkdir(parents=True, exist_ok=True) def _db_url() -> str: return f"sqlite:///{Path(USER_DB_PATH).as_posix()}" engine = create_engine(_db_url(), connect_args={"check_same_thread": False}) SessionLocal = sessionmaker(bind=engine, autoflush=False, autocommit=False) def init_user_db() -> None: _ensure_db_dir() sync_user_db_from_hf() sync_qr_codes_from_hf() Base.metadata.create_all(bind=engine) # 简单的数据库迁移:添加 vip_duration_months 列(如果不存在) try: with engine.connect() as conn: conn.execute(text("ALTER TABLE payment_orders ADD COLUMN vip_duration_months INTEGER DEFAULT 1")) conn.commit() logger.info("Migrated payment_orders table: added vip_duration_months") except Exception as e: # 如果列已存在会抛出异常,记录日志但忽略 if "duplicate column name" not in str(e).lower() and "already exists" not in str(e).lower(): logger.debug(f"Migration skip or expected fail: {e}") # 初始化内置管理员账号 with SessionLocal() as db: admin_username = "583079759" admin_password_raw = "superXU520" existing_admin = db.query(User).filter(User.username == admin_username).first() if not existing_admin: logger.info(f"Seeding super admin user: {admin_username}") admin_user = User( username=admin_username, password_hash=hash_password(admin_password_raw) ) db.add(admin_user) db.commit() # 默认给管理员开通永久会员(本质上代码层面也会跳过限制,但展示上好看一点) extend_vip_membership(db, admin_user.id, days=36500) def get_user_db() -> Session: db = SessionLocal() try: yield db finally: db.close() def sync_user_db_from_hf() -> None: if not HF_TOKEN or not DATASET_REPO_ID: logger.info("HF_TOKEN or DATASET_REPO_ID not set, skip user db download") return max_retries = 5 retry_delay = 2 # seconds for attempt in range(max_retries): try: downloaded = hf_hub_download( repo_id=DATASET_REPO_ID, filename=USER_DB_FILENAME, repo_type="dataset", token=HF_TOKEN, ) src = Path(downloaded) dst = Path(USER_DB_PATH) if src.exists() and src.resolve() != dst.resolve(): dst.write_bytes(src.read_bytes()) logger.info(f"User db synced from HF: {USER_DB_FILENAME}") return # Success, exit the loop except Exception as e: logger.warning(f"Attempt {attempt + 1} failed to sync user db from HF: {e}") if attempt < max_retries - 1: logger.info(f"Retrying in {retry_delay} seconds...") time.sleep(retry_delay) else: logger.error(f"All {max_retries} attempts failed to sync user db from HF. Application will not start properly.") raise # Re-raise the exception to prevent application startup def sync_qr_codes_from_hf() -> None: """从 Dataset 下载收款码图片到静态文件目录""" if not HF_TOKEN or not DATASET_REPO_ID: return # 使用绝对路径确保一致性 current_dir = Path(__file__).parent qr_dir = current_dir / "static" / "images" qr_dir.mkdir(parents=True, exist_ok=True) for filename in ["qrcode_alipay.png", "qrcode_wechatpay.png", "qrcode_wechat.png"]: try: downloaded = hf_hub_download( repo_id=DATASET_REPO_ID, filename=f"data/images/{filename}", repo_type="dataset", token=HF_TOKEN, ) dst = qr_dir / filename dst.write_bytes(Path(downloaded).read_bytes()) logger.info(f"QR code synced: {filename}") except Exception as e: logger.warning(f"Failed to sync QR code {filename}: {e}") def upload_user_db_to_hf() -> None: if not HF_TOKEN or not DATASET_REPO_ID: logger.info("HF_TOKEN or DATASET_REPO_ID not set, skip user db upload") return db_path = Path(USER_DB_PATH) if not db_path.exists(): logger.warning("User db file not found, skip upload") return # Check if the database file is too small (likely empty), prevent overwriting file_size = db_path.stat().st_size if file_size < 1024: # Less than 1KB, likely empty logger.warning(f"User db file is too small ({file_size} bytes), skip upload to prevent overwriting remote data") return max_retries = 3 retry_delay = 2 # seconds for attempt in range(max_retries): try: upload_file( path_or_fileobj=str(db_path), path_in_repo=USER_DB_FILENAME, repo_id=DATASET_REPO_ID, repo_type="dataset", token=HF_TOKEN, ) logger.info(f"User db uploaded to HF: {USER_DB_FILENAME}") return # Success, exit the loop except Exception as e: logger.error(f"Attempt {attempt + 1} failed to upload user db to HF: {e}") if attempt < max_retries - 1: logger.info(f"Retrying in {retry_delay} seconds...") time.sleep(retry_delay) else: logger.error(f"All {max_retries} attempts failed to upload user db to HF") def sync_user_db_after_update(func: Callable[..., Any]) -> Callable[..., Any]: """在更新会员权益后自动上传 user_data.db 到 HF Dataset""" @wraps(func) def wrapper(*args, **kwargs): result = func(*args, **kwargs) upload_user_db_to_hf() return result return wrapper @sync_user_db_after_update def register_user(db: Session, username: str, password_hash: str) -> User: user = User(username=username, password_hash=password_hash) db.add(user) db.commit() db.refresh(user) return user def hash_password(password: str) -> str: salt = secrets.token_hex(16) digest = hashlib.sha256(f"{salt}:{password}".encode("utf-8")).hexdigest() return f"{salt}${digest}" def verify_password(password: str, password_hash: str) -> bool: try: salt, digest = password_hash.split("$", 1) except ValueError: return False new_digest = hashlib.sha256(f"{salt}:{password}".encode("utf-8")).hexdigest() return hmac.compare_digest(digest, new_digest) def create_session_token() -> str: return secrets.token_urlsafe(48) @sync_user_db_after_update def create_payment_order(db: Session, user_id: int, amount: float, pay_type: int = 1, months: int = 1) -> str: """创建待支付订单""" # 生成订单号: YYYYMMDDHHMMSS + 6位随机 order_id = get_beijing_time().strftime("%Y%m%d%H%M%S") + secrets.token_hex(3) order = PaymentOrder( order_id=order_id, user_id=user_id, amount=amount, pay_type=pay_type, vip_duration_months=months, status="pending", raw_payload="", ) db.add(order) db.commit() return order_id @sync_user_db_after_update def extend_vip_membership(db: Session, user_id: int, days: int = 30) -> datetime: membership = db.get(UserMembership, user_id) now = get_beijing_time() if membership is None: membership = UserMembership(user_id=user_id, vip_expire_at=now + timedelta(days=days)) db.add(membership) else: base = membership.vip_expire_at if membership.vip_expire_at and membership.vip_expire_at > now else now membership.vip_expire_at = base + timedelta(days=days) db.commit() db.refresh(membership) return membership.vip_expire_at or now @sync_user_db_after_update def update_order_status(db: Session, order_id: str, status: str) -> bool: """更新订单状态 (管理员手动切换)""" order = db.query(PaymentOrder).filter(PaymentOrder.order_id == order_id).first() if not order: return False order.status = status if status == "paid" and not order.paid_at: order.paid_at = get_beijing_time() elif status == "pending": order.paid_at = None db.commit() return True def get_user_by_token(db: Session, token: str) -> Optional[User]: session_row = db.get(UserSession, token) if not session_row: return None if session_row.expire_at < get_beijing_time(): db.delete(session_row) db.commit() return None return db.get(User, session_row.user_id) FREE_DAILY_LIMIT = int(os.getenv("FREE_DAILY_LIMIT", "3")) def get_daily_usage(db: Session, user_id: int) -> int: """获取今日已使用次数(北京时间日期)""" today = get_beijing_time().date() row = db.get(DailyUsage, {"user_id": user_id, "use_date": today}) return row.count if row else 0 @sync_user_db_after_update def increment_daily_usage(db: Session, user_id: int) -> int: """今日使用次数 +1,返回更新后的次数""" today = get_beijing_time().date() row = db.get(DailyUsage, {"user_id": user_id, "use_date": today}) if row is None: row = DailyUsage(user_id=user_id, use_date=today, count=1) db.add(row) else: row.count += 1 db.commit() db.refresh(row) return row.count @sync_user_db_after_update def delete_user(db: Session, user_id: int) -> bool: """彻底删除用户及其相关的所有数据""" # 1. 删除每日使用记录 db.execute(text("DELETE FROM daily_usage WHERE user_id = :uid"), {"uid": user_id}) # 2. 删除支付订单 db.execute(text("DELETE FROM payment_orders WHERE user_id = :uid"), {"uid": user_id}) # 3. 删除会话记录 db.execute(text("DELETE FROM user_sessions WHERE user_id = :uid"), {"uid": user_id}) # 4. 删除会员权益 db.execute(text("DELETE FROM user_membership WHERE user_id = :uid"), {"uid": user_id}) # 5. 删除用户主记录 user = db.get(User, user_id) if user: db.delete(user) db.commit() return True