""" FastAPI 路由定义 """ import os import json import logging import threading from datetime import datetime, timedelta from typing import Optional from fastapi import APIRouter, HTTPException, Header, Query, Depends, Request from fastapi.responses import HTMLResponse, PlainTextResponse from pydantic import BaseModel from sqlalchemy.orm import Session from .core import start_game, get_kline_by_code, GameStartResponse, KLineData from .database import get_db from .database_user import ( get_user_db, User, UserMembership, UserSession, PaymentOrder, hash_password, verify_password, create_session_token, get_user_by_token, extend_vip_membership, get_daily_usage, increment_daily_usage, FREE_DAILY_LIMIT, create_payment_order, register_user, delete_user as db_delete_user, sync_user_db_after_update, get_beijing_time, ) from .limiter import limiter # 导入同步函数 import sys sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__)))) from scripts.sync_data import main as run_sync_task logger = logging.getLogger(__name__) router = APIRouter(prefix="/api") # 环境变量 ADMIN_SECRET = os.getenv("ADMIN_SECRET", "your-secret-key") SESSION_EXPIRE_DAYS = int(os.getenv("SESSION_EXPIRE_DAYS", "30")) RATE_LIMIT_GLOBAL_RAW = os.getenv("RATE_LIMIT_GLOBAL", "50/second") RATE_LIMIT_IP_RAW = os.getenv("RATE_LIMIT_IP", "5/second") # 确保频率限制字符串包含时间单位 (例如 /second) def _ensure_rate_limit_format(limit_str: str) -> str: if "/" not in limit_str: return f"{limit_str}/second" return limit_str RATE_LIMIT_GLOBAL = _ensure_rate_limit_format(RATE_LIMIT_GLOBAL_RAW) RATE_LIMIT_IP = _ensure_rate_limit_format(RATE_LIMIT_IP_RAW) class HealthResponse(BaseModel): """健康检查响应""" status: str database: str stocks_count: int class ErrorResponse(BaseModel): """错误响应""" error: str class RegisterRequest(BaseModel): username: str password: str class LoginRequest(BaseModel): username: str password: str class AuthResponse(BaseModel): token: str user_id: int username: str class UserInfoResponse(BaseModel): user_id: int username: str vip_expire_at: Optional[str] is_vip: bool class VipStatusResponse(BaseModel): is_vip: bool vip_expire_at: Optional[str] class CreatePaymentRequest(BaseModel): type: int = 2 # 1: 支付宝, 2: 微信 (默认为2) months: int = 1 # 购买月数 class CreatePaymentResponse(BaseModel): order_id: str price: float type: int def _get_token_from_header(authorization: Optional[str]) -> Optional[str]: if not authorization: return None if authorization.lower().startswith("bearer "): return authorization[7:] return authorization def check_vip_status(user_id: int, db: Session) -> dict: membership = db.get(UserMembership, user_id) now = get_beijing_time() if membership is None or membership.vip_expire_at is None: return {"is_vip": False, "vip_expire_at": None} expire_at = membership.vip_expire_at is_vip = expire_at > now return { "is_vip": is_vip, "vip_expire_at": expire_at.strftime('%Y-%m-%d %H:%M:%S') } def get_current_user( authorization: Optional[str] = Header(None, alias="Authorization"), db: Session = Depends(get_user_db), ) -> User: token = _get_token_from_header(authorization) if not token: raise HTTPException(status_code=401, detail="Missing authorization token") user = get_user_by_token(db, token) if not user: raise HTTPException(status_code=401, detail="Invalid or expired token") return user @router.get("/health", response_model=HealthResponse) async def health_check(): """健康检查接口""" try: db = get_db() count = db.conn.execute("SELECT COUNT(*) FROM stock_list").fetchone()[0] return HealthResponse( status="healthy", database="connected", stocks_count=count ) except Exception as e: logger.error(f"Health check failed: {e}") return HealthResponse( status="unhealthy", database="disconnected", stocks_count=0 ) @router.post("/v1/auth/register", response_model=AuthResponse) @limiter.limit(RATE_LIMIT_GLOBAL, key_func=lambda: "global") @limiter.limit(RATE_LIMIT_IP) async def register(request: Request, payload: RegisterRequest, db: Session = Depends(get_user_db)): username = payload.username.strip() if len(username) < 3 or len(payload.password) < 6: raise HTTPException(status_code=400, detail="用户名至少3位,密码至少6位") exists = db.query(User).filter(User.username == username).first() if exists: raise HTTPException(status_code=400, detail="用户名已存在") user = register_user(db, username, hash_password(payload.password)) token = create_session_token() session = UserSession( token=token, user_id=user.id, expire_at=get_beijing_time() + timedelta(days=SESSION_EXPIRE_DAYS) ) db.add(session) db.commit() return AuthResponse(token=token, user_id=user.id, username=user.username) @router.post("/v1/auth/login", response_model=AuthResponse) @limiter.limit(RATE_LIMIT_GLOBAL, key_func=lambda: "global") @limiter.limit(RATE_LIMIT_IP) async def login(request: Request, payload: LoginRequest, db: Session = Depends(get_user_db)): user = db.query(User).filter(User.username == payload.username.strip()).first() if not user or not verify_password(payload.password, user.password_hash): raise HTTPException(status_code=401, detail="用户名或密码错误") # 账号唯一登录限制:踢掉旧会话 db.query(UserSession).filter(UserSession.user_id == user.id).delete() token = create_session_token() session = UserSession( token=token, user_id=user.id, expire_at=get_beijing_time() + timedelta(days=SESSION_EXPIRE_DAYS) ) db.add(session) db.commit() return AuthResponse(token=token, user_id=user.id, username=user.username) @router.get("/v1/auth/me", response_model=UserInfoResponse) async def me(current_user: User = Depends(get_current_user), db: Session = Depends(get_user_db)): vip = check_vip_status(current_user.id, db) return UserInfoResponse( user_id=current_user.id, username=current_user.username, vip_expire_at=vip["vip_expire_at"], is_vip=vip["is_vip"], ) @router.get("/v1/vip/status", response_model=VipStatusResponse) async def vip_status(current_user: User = Depends(get_current_user), db: Session = Depends(get_user_db)): vip = check_vip_status(current_user.id, db) return VipStatusResponse(is_vip=vip["is_vip"], vip_expire_at=vip["vip_expire_at"]) @router.post("/v1/auth/logout") async def logout( current_user: User = Depends(get_current_user), authorization: Optional[str] = Header(None, alias="Authorization"), db: Session = Depends(get_user_db), ): """退出登录 - 删除当前 session token""" token = _get_token_from_header(authorization) if token: session_row = db.get(UserSession, token) if session_row: db.delete(session_row) db.commit() return {"status": "ok"} @router.post("/v1/payment/callback") async def payment_callback(request: Request, db: Session = Depends(get_user_db)): """ 支付手动回调占位 - 未来可用于接收第三方手动转账通知或管理员后台触发 """ return {"status": "ok"} @router.post("/v1/payment/create", response_model=CreatePaymentResponse) async def create_payment( payload: CreatePaymentRequest, request: Request, current_user: User = Depends(get_current_user), db: Session = Depends(get_user_db) ): """创建支付订单""" # 价格配置: 每月 10 元,包年 100 元 months = max(1, payload.months) if months == 12: price = 100.0 elif months > 12: years = months // 12 extra_months = months % 12 price = years * 100.0 + extra_months * 10.0 else: price = float(months * 10) # 1. 创建本地订单 order_id = create_payment_order(db, current_user.id, price, payload.type, months) # 2. 返回结果 return CreatePaymentResponse( order_id=order_id, price=price, type=payload.type ) @router.get("/v1/payment/check/{order_id}") async def check_payment_status( order_id: str, db: Session = Depends(get_user_db) ): """查询订单状态(手动模式下主要用于占位或未来人工确认)""" order = db.query(PaymentOrder).filter(PaymentOrder.order_id == order_id).first() if not order: return {"code": -1, "msg": "order not found"} return { "code": 1, "data": { "status": 2 if order.status == "paid" else 1 } } @router.get("/game/start", response_model=GameStartResponse) async def game_start( mode: str = Query(default="random", description="游戏模式"), market: Optional[str] = Query(default=None, description="板块类型"), current_user: User = Depends(get_current_user), db: Session = Depends(get_user_db), ): """ 开始游戏 - 获取盲盒数据 - 必须登录(未登录返回 401) - 非 VIP 每天限 FREE_DAILY_LIMIT 次 - VIP 无限制 """ vip_info = check_vip_status(current_user.id, db) is_vip = vip_info["is_vip"] # 非 VIP:检查每日次数 (管理员跳过) if not is_vip and not current_user.is_admin: used_today = get_daily_usage(db, current_user.id) if used_today >= FREE_DAILY_LIMIT: raise HTTPException( status_code=403, detail={ "code": "DAILY_LIMIT_EXCEEDED", "message": f"今日免费次数({FREE_DAILY_LIMIT}次)已用完,开通会员可无限使用", "used": used_today, "limit": FREE_DAILY_LIMIT, } ) # VIP 受限板块校验 if not is_vip and market in ['科创板', '北交所', '可转债']: raise HTTPException(status_code=403, detail="该板块仅限 VIP 会员使用") try: result = start_game(mode=mode, mask=True, market_type=market) # 成功后记录使用次数(VIP 不计) if not is_vip: increment_daily_usage(db, current_user.id) return result except ValueError as e: raise HTTPException(status_code=400, detail=str(e)) except Exception as e: logger.error(f"Game start failed: {e}") raise HTTPException(status_code=500, detail="Internal server error") @router.get("/v1/usage/today") async def get_usage_today( current_user: User = Depends(get_current_user), db: Session = Depends(get_user_db), ): """获取今日使用情况""" vip_info = check_vip_status(current_user.id, db) is_vip = vip_info["is_vip"] used = get_daily_usage(db, current_user.id) return { "used": used, "limit": None if is_vip else FREE_DAILY_LIMIT, "remaining": None if is_vip else max(0, FREE_DAILY_LIMIT - used), "is_vip": is_vip, } @router.get("/kline", response_model=list[KLineData]) async def get_kline( code: str = Query(..., description="股票代码"), start: Optional[str] = Query(None, description="起始日期 (YYYY-MM-DD)"), end: Optional[str] = Query(None, description="结束日期 (YYYY-MM-DD)") ): """ 获取指定股票的K线数据 """ try: klines = get_kline_by_code(code, start_date=start, end_date=end) if not klines: raise HTTPException(status_code=404, detail=f"No data found for stock {code}") return klines except HTTPException: raise except Exception as e: logger.error(f"Get kline failed: {e}") raise HTTPException(status_code=500, detail="Internal server error") @router.post("/admin/sync") async def trigger_sync( x_secret_key: Optional[str] = Header(None, alias="X-Secret-Key") ): """ 触发数据同步(需要管理员权限) """ if x_secret_key != ADMIN_SECRET: raise HTTPException(status_code=403, detail="Invalid secret key") try: # 在后台线程运行同步任务,避免阻塞 API thread = threading.Thread(target=run_sync_task) thread.start() logger.info("Manual sync started in background thread") return {"status": "sync_triggered", "message": "Data sync has been started in background"} except Exception as e: logger.error(f"Sync trigger failed: {e}") raise HTTPException(status_code=500, detail="Failed to start sync task") @router.get("/index/hs300") async def get_hs300_index( start: Optional[str] = Query(None, description="起始日期 (YYYY-MM-DD)"), end: Optional[str] = Query(None, description="结束日期 (YYYY-MM-DD)") ): """ 获取沪深300指数数据 """ try: db = get_db() query = "SELECT trade_date, close FROM stock_daily WHERE code = '000300'" params = [] if start: query += " AND trade_date >= ?" params.append(start) if end: query += " AND trade_date <= ?" params.append(end) query += " ORDER BY trade_date" result = db.conn.execute(query, params).fetchall() if not result: return [] return [ { "date": row[0].strftime('%Y-%m-%d') if hasattr(row[0], 'strftime') else str(row[0]), "close": float(row[1]) } for row in result ] except Exception as e: logger.error(f"Get HS300 index failed: {e}") raise HTTPException(status_code=500, detail="Internal server error") # --- Admin Management Endpoints --- class AdminUserItem(BaseModel): user_id: int username: str is_vip: bool vip_expire_at: Optional[str] created_at: str has_payment: bool # 是否有支付记录 last_login: Optional[str] # 最近登录时间 class AdminOrderItem(BaseModel): order_id: str amount: float pay_type: int status: str created_at: str paid_at: Optional[str] months: int class AdminUpdatePasswordRequest(BaseModel): new_password: str class AdminUpdateVipRequest(BaseModel): months: int class AdminUpdateOrderStatusRequest(BaseModel): status: str def get_current_user_admin(current_user: User = Depends(get_current_user)) -> User: """确认当前用户是管理员""" if not current_user.is_admin: raise HTTPException(status_code=403, detail="您没有管理员权限") return current_user @router.get("/v1/admin/users", response_model=list[AdminUserItem]) async def admin_get_users( db: Session = Depends(get_user_db), admin: User = Depends(get_current_user_admin) ): """获取所有用户列表 (管理员专用)""" users = db.query(User).all() result = [] for u in users: vip_info = check_vip_status(u.id, db) # 查询是否有支付记录 has_payment = db.query(PaymentOrder).filter(PaymentOrder.user_id == u.id).count() > 0 # 查询最近登录时间 (从 UserSession 中获取最新的 created_at) last_session = db.query(UserSession).filter(UserSession.user_id == u.id).order_by(UserSession.created_at.desc()).first() last_login = last_session.created_at.strftime('%Y/%m/%d %H:%M:%S') if last_session and last_session.created_at else None # 格式化 VIP 到期时间 vip_expire_formatted = None if vip_info["vip_expire_at"]: try: # 将字符串解析为 datetime 对象,然后重新格式化为所需格式 from datetime import datetime dt = datetime.strptime(vip_info["vip_expire_at"], '%Y-%m-%d %H:%M:%S') vip_expire_formatted = dt.strftime('%Y/%m/%d %H:%M:%S') except ValueError: # 如果解析失败,使用原始值 vip_expire_formatted = vip_info["vip_expire_at"].replace('-', '/') result.append(AdminUserItem( user_id=u.id, username=u.username, is_vip=vip_info["is_vip"], vip_expire_at=vip_expire_formatted, created_at=u.created_at.strftime('%Y/%m/%d %H:%M:%S') if u.created_at else "", has_payment=has_payment, last_login=last_login )) return result @router.post("/v1/admin/user/{user_id}/password") @sync_user_db_after_update async def admin_update_password( user_id: int, payload: AdminUpdatePasswordRequest, db: Session = Depends(get_user_db), admin: User = Depends(get_current_user_admin) ): """修改指定用户密码 (管理员专用)""" user = db.query(User).filter(User.id == user_id).first() if not user: raise HTTPException(status_code=404, detail="用户不存在") from .database_user import hash_password user.password_hash = hash_password(payload.new_password) db.commit() return {"status": "ok"} @router.post("/v1/admin/user/{user_id}/vip") @sync_user_db_after_update async def admin_update_vip( user_id: int, payload: AdminUpdateVipRequest, db: Session = Depends(get_user_db), admin: User = Depends(get_current_user_admin) ): """手动延长用户会员 (管理员专用)""" user = db.query(User).filter(User.id == user_id).first() if not user: raise HTTPException(status_code=404, detail="用户不存在") # 计算天数: 1个月按30天计 extension_days = max(1, payload.months) * 30 extend_vip_membership(db, user_id, days=extension_days) return {"status": "ok"} @router.get("/v1/admin/user/{user_id}/orders", response_model=list[AdminOrderItem]) async def admin_get_user_orders( user_id: int, db: Session = Depends(get_user_db), admin: User = Depends(get_current_user_admin) ): """获取指定用户的订单历史 (管理员专用)""" orders = db.query(PaymentOrder).filter(PaymentOrder.user_id == user_id).order_by(PaymentOrder.created_at.desc()).all() return [ AdminOrderItem( order_id=o.order_id, amount=o.amount, pay_type=o.pay_type, status=o.status, created_at=o.created_at.strftime('%Y-%m-%d %H:%M:%S') if o.created_at else "", paid_at=o.paid_at.strftime('%Y-%m-%d %H:%M:%S') if o.paid_at else None, months=o.vip_duration_months ) for o in orders ] @router.put("/v1/admin/order/{order_id}/status") async def admin_update_order_status( order_id: str, payload: AdminUpdateOrderStatusRequest, db: Session = Depends(get_user_db), admin: User = Depends(get_current_user_admin) ): """手动修改订单状态 (管理员专用)""" from .database_user import update_order_status success = update_order_status(db, order_id, payload.status) if not success: raise HTTPException(status_code=404, detail="订单不存在") return {"status": "ok"} @router.delete("/v1/admin/user/{user_id}") async def admin_delete_user( user_id: int, db: Session = Depends(get_user_db), admin: User = Depends(get_current_user_admin) ): """彻底删除用户及其数据 (管理员专用)""" # 允许删除非超级管理员自己(如果需要的话,但通常不允许删除特定ID,这里简单处理:不允许删除超级管理员) user_to_delete = db.get(User, user_id) if not user_to_delete: raise HTTPException(status_code=404, detail="用户不存在") if user_to_delete.username == "583079759": raise HTTPException(status_code=403, detail="系统保护:无法删除超级管理员账号") db_delete_user(db, user_id) return {"status": "ok"}