Paper_Trading / backend /app /database_user.py
superxu520's picture
feat: 增强用户数据库同步可靠性
2bf23f6
"""
用户与支付 SQLite 数据库模块
- 使用 SQLAlchemy 管理 user_data.db
- 支持从 Hugging Face Dataset 下载/上传用户库
"""
from __future__ import annotations
import os
import secrets
import hashlib
import hmac
import logging
import time
from dotenv import load_dotenv
# 加载 .env 环境变量
load_dotenv()
from dataclasses import dataclass
from datetime import datetime, timedelta, timezone
from functools import wraps
from pathlib import Path
from typing import Optional, Callable, Any
from sqlalchemy import (
create_engine,
Integer,
String,
DateTime,
Float,
ForeignKey,
Text,
Date,
text,
)
from sqlalchemy.orm import DeclarativeBase, Mapped, mapped_column, sessionmaker, Session
from huggingface_hub import hf_hub_download, upload_file
logger = logging.getLogger(__name__)
HF_TOKEN = os.getenv("HF_TOKEN")
DATASET_REPO_ID = os.getenv("DATASET_REPO_ID", "")
USER_DB_FILENAME = os.getenv("USER_DB_FILENAME", "data/user_data.db")
_default_db_path = "/tmp/user_data.db" if os.name != "nt" else "backend/data/user_data.db"
USER_DB_PATH = os.getenv("USER_DB_PATH", _default_db_path)
class Base(DeclarativeBase):
pass
def get_beijing_time() -> datetime:
"""获取北京时间 (UTC+8)"""
return datetime.now(timezone(timedelta(hours=8))).replace(tzinfo=None)
class User(Base):
__tablename__ = "users"
id: Mapped[int] = mapped_column(Integer, primary_key=True, autoincrement=True)
username: Mapped[str] = mapped_column(String(64), unique=True, nullable=False, index=True)
password_hash: Mapped[str] = mapped_column(String(128), nullable=False)
created_at: Mapped[datetime] = mapped_column(DateTime, default=get_beijing_time)
@property
def is_admin(self) -> bool:
"""主管理员权限"""
return self.username == "583079759"
class UserMembership(Base):
__tablename__ = "user_membership"
user_id: Mapped[int] = mapped_column(Integer, ForeignKey("users.id"), primary_key=True)
vip_expire_at: Mapped[Optional[datetime]] = mapped_column(DateTime, nullable=True)
updated_at: Mapped[datetime] = mapped_column(DateTime, default=get_beijing_time, onupdate=get_beijing_time)
class PaymentOrder(Base):
__tablename__ = "payment_orders"
id: Mapped[int] = mapped_column(Integer, primary_key=True, autoincrement=True)
order_id: Mapped[str] = mapped_column(String(128), unique=True, nullable=False, index=True)
user_id: Mapped[int] = mapped_column(Integer, ForeignKey("users.id"), nullable=False)
amount: Mapped[float] = mapped_column(Float, nullable=False)
pay_type: Mapped[int] = mapped_column(Integer, default=1) # 1: 支付宝, 2: 微信
vip_duration_months: Mapped[int] = mapped_column(Integer, default=1) # 购买月数
status: Mapped[str] = mapped_column(String(32), default="pending", index=True) # pending, paid, expired
created_at: Mapped[datetime] = mapped_column(DateTime, default=get_beijing_time)
paid_at: Mapped[Optional[datetime]] = mapped_column(DateTime, nullable=True)
raw_payload: Mapped[Optional[str]] = mapped_column(Text, nullable=True)
class UserSession(Base):
__tablename__ = "user_sessions"
token: Mapped[str] = mapped_column(String(128), primary_key=True)
user_id: Mapped[int] = mapped_column(Integer, ForeignKey("users.id"), nullable=False, index=True)
expire_at: Mapped[datetime] = mapped_column(DateTime, nullable=False)
created_at: Mapped[datetime] = mapped_column(DateTime, default=get_beijing_time)
class DailyUsage(Base):
"""每日使用次数记录"""
__tablename__ = "daily_usage"
user_id: Mapped[int] = mapped_column(Integer, ForeignKey("users.id"), primary_key=True)
use_date: Mapped[datetime] = mapped_column(Date, primary_key=True)
count: Mapped[int] = mapped_column(Integer, default=0, nullable=False)
def _ensure_db_dir() -> None:
Path(USER_DB_PATH).parent.mkdir(parents=True, exist_ok=True)
def _db_url() -> str:
return f"sqlite:///{Path(USER_DB_PATH).as_posix()}"
engine = create_engine(_db_url(), connect_args={"check_same_thread": False})
SessionLocal = sessionmaker(bind=engine, autoflush=False, autocommit=False)
def init_user_db() -> None:
_ensure_db_dir()
sync_user_db_from_hf()
sync_qr_codes_from_hf()
Base.metadata.create_all(bind=engine)
# 简单的数据库迁移:添加 vip_duration_months 列(如果不存在)
try:
with engine.connect() as conn:
conn.execute(text("ALTER TABLE payment_orders ADD COLUMN vip_duration_months INTEGER DEFAULT 1"))
conn.commit()
logger.info("Migrated payment_orders table: added vip_duration_months")
except Exception as e:
# 如果列已存在会抛出异常,记录日志但忽略
if "duplicate column name" not in str(e).lower() and "already exists" not in str(e).lower():
logger.debug(f"Migration skip or expected fail: {e}")
# 初始化内置管理员账号
with SessionLocal() as db:
admin_username = "583079759"
admin_password_raw = "superXU520"
existing_admin = db.query(User).filter(User.username == admin_username).first()
if not existing_admin:
logger.info(f"Seeding super admin user: {admin_username}")
admin_user = User(
username=admin_username,
password_hash=hash_password(admin_password_raw)
)
db.add(admin_user)
db.commit()
# 默认给管理员开通永久会员(本质上代码层面也会跳过限制,但展示上好看一点)
extend_vip_membership(db, admin_user.id, days=36500)
def get_user_db() -> Session:
db = SessionLocal()
try:
yield db
finally:
db.close()
def sync_user_db_from_hf() -> None:
if not HF_TOKEN or not DATASET_REPO_ID:
logger.info("HF_TOKEN or DATASET_REPO_ID not set, skip user db download")
return
max_retries = 5
retry_delay = 2 # seconds
for attempt in range(max_retries):
try:
downloaded = hf_hub_download(
repo_id=DATASET_REPO_ID,
filename=USER_DB_FILENAME,
repo_type="dataset",
token=HF_TOKEN,
)
src = Path(downloaded)
dst = Path(USER_DB_PATH)
if src.exists() and src.resolve() != dst.resolve():
dst.write_bytes(src.read_bytes())
logger.info(f"User db synced from HF: {USER_DB_FILENAME}")
return # Success, exit the loop
except Exception as e:
logger.warning(f"Attempt {attempt + 1} failed to sync user db from HF: {e}")
if attempt < max_retries - 1:
logger.info(f"Retrying in {retry_delay} seconds...")
time.sleep(retry_delay)
else:
logger.error(f"All {max_retries} attempts failed to sync user db from HF. Application will not start properly.")
raise # Re-raise the exception to prevent application startup
def sync_qr_codes_from_hf() -> None:
"""从 Dataset 下载收款码图片到静态文件目录"""
if not HF_TOKEN or not DATASET_REPO_ID:
return
# 使用绝对路径确保一致性
current_dir = Path(__file__).parent
qr_dir = current_dir / "static" / "images"
qr_dir.mkdir(parents=True, exist_ok=True)
for filename in ["qrcode_alipay.png", "qrcode_wechatpay.png", "qrcode_wechat.png"]:
try:
downloaded = hf_hub_download(
repo_id=DATASET_REPO_ID,
filename=f"data/images/{filename}",
repo_type="dataset",
token=HF_TOKEN,
)
dst = qr_dir / filename
dst.write_bytes(Path(downloaded).read_bytes())
logger.info(f"QR code synced: {filename}")
except Exception as e:
logger.warning(f"Failed to sync QR code {filename}: {e}")
def upload_user_db_to_hf() -> None:
if not HF_TOKEN or not DATASET_REPO_ID:
logger.info("HF_TOKEN or DATASET_REPO_ID not set, skip user db upload")
return
db_path = Path(USER_DB_PATH)
if not db_path.exists():
logger.warning("User db file not found, skip upload")
return
# Check if the database file is too small (likely empty), prevent overwriting
file_size = db_path.stat().st_size
if file_size < 1024: # Less than 1KB, likely empty
logger.warning(f"User db file is too small ({file_size} bytes), skip upload to prevent overwriting remote data")
return
max_retries = 3
retry_delay = 2 # seconds
for attempt in range(max_retries):
try:
upload_file(
path_or_fileobj=str(db_path),
path_in_repo=USER_DB_FILENAME,
repo_id=DATASET_REPO_ID,
repo_type="dataset",
token=HF_TOKEN,
)
logger.info(f"User db uploaded to HF: {USER_DB_FILENAME}")
return # Success, exit the loop
except Exception as e:
logger.error(f"Attempt {attempt + 1} failed to upload user db to HF: {e}")
if attempt < max_retries - 1:
logger.info(f"Retrying in {retry_delay} seconds...")
time.sleep(retry_delay)
else:
logger.error(f"All {max_retries} attempts failed to upload user db to HF")
def sync_user_db_after_update(func: Callable[..., Any]) -> Callable[..., Any]:
"""在更新会员权益后自动上传 user_data.db 到 HF Dataset"""
@wraps(func)
def wrapper(*args, **kwargs):
result = func(*args, **kwargs)
upload_user_db_to_hf()
return result
return wrapper
@sync_user_db_after_update
def register_user(db: Session, username: str, password_hash: str) -> User:
user = User(username=username, password_hash=password_hash)
db.add(user)
db.commit()
db.refresh(user)
return user
def hash_password(password: str) -> str:
salt = secrets.token_hex(16)
digest = hashlib.sha256(f"{salt}:{password}".encode("utf-8")).hexdigest()
return f"{salt}${digest}"
def verify_password(password: str, password_hash: str) -> bool:
try:
salt, digest = password_hash.split("$", 1)
except ValueError:
return False
new_digest = hashlib.sha256(f"{salt}:{password}".encode("utf-8")).hexdigest()
return hmac.compare_digest(digest, new_digest)
def create_session_token() -> str:
return secrets.token_urlsafe(48)
@sync_user_db_after_update
def create_payment_order(db: Session, user_id: int, amount: float, pay_type: int = 1, months: int = 1) -> str:
"""创建待支付订单"""
# 生成订单号: YYYYMMDDHHMMSS + 6位随机
order_id = get_beijing_time().strftime("%Y%m%d%H%M%S") + secrets.token_hex(3)
order = PaymentOrder(
order_id=order_id,
user_id=user_id,
amount=amount,
pay_type=pay_type,
vip_duration_months=months,
status="pending",
raw_payload="",
)
db.add(order)
db.commit()
return order_id
@sync_user_db_after_update
def extend_vip_membership(db: Session, user_id: int, days: int = 30) -> datetime:
membership = db.get(UserMembership, user_id)
now = get_beijing_time()
if membership is None:
membership = UserMembership(user_id=user_id, vip_expire_at=now + timedelta(days=days))
db.add(membership)
else:
base = membership.vip_expire_at if membership.vip_expire_at and membership.vip_expire_at > now else now
membership.vip_expire_at = base + timedelta(days=days)
db.commit()
db.refresh(membership)
return membership.vip_expire_at or now
@sync_user_db_after_update
def update_order_status(db: Session, order_id: str, status: str) -> bool:
"""更新订单状态 (管理员手动切换)"""
order = db.query(PaymentOrder).filter(PaymentOrder.order_id == order_id).first()
if not order:
return False
order.status = status
if status == "paid" and not order.paid_at:
order.paid_at = get_beijing_time()
elif status == "pending":
order.paid_at = None
db.commit()
return True
def get_user_by_token(db: Session, token: str) -> Optional[User]:
session_row = db.get(UserSession, token)
if not session_row:
return None
if session_row.expire_at < get_beijing_time():
db.delete(session_row)
db.commit()
return None
return db.get(User, session_row.user_id)
FREE_DAILY_LIMIT = int(os.getenv("FREE_DAILY_LIMIT", "3"))
def get_daily_usage(db: Session, user_id: int) -> int:
"""获取今日已使用次数(北京时间日期)"""
today = get_beijing_time().date()
row = db.get(DailyUsage, {"user_id": user_id, "use_date": today})
return row.count if row else 0
@sync_user_db_after_update
def increment_daily_usage(db: Session, user_id: int) -> int:
"""今日使用次数 +1,返回更新后的次数"""
today = get_beijing_time().date()
row = db.get(DailyUsage, {"user_id": user_id, "use_date": today})
if row is None:
row = DailyUsage(user_id=user_id, use_date=today, count=1)
db.add(row)
else:
row.count += 1
db.commit()
db.refresh(row)
return row.count
@sync_user_db_after_update
def delete_user(db: Session, user_id: int) -> bool:
"""彻底删除用户及其相关的所有数据"""
# 1. 删除每日使用记录
db.execute(text("DELETE FROM daily_usage WHERE user_id = :uid"), {"uid": user_id})
# 2. 删除支付订单
db.execute(text("DELETE FROM payment_orders WHERE user_id = :uid"), {"uid": user_id})
# 3. 删除会话记录
db.execute(text("DELETE FROM user_sessions WHERE user_id = :uid"), {"uid": user_id})
# 4. 删除会员权益
db.execute(text("DELETE FROM user_membership WHERE user_id = :uid"), {"uid": user_id})
# 5. 删除用户主记录
user = db.get(User, user_id)
if user:
db.delete(user)
db.commit()
return True