Spaces:
Running
Running
| """ | |
| FastAPI 路由定义 | |
| """ | |
| import os | |
| import json | |
| import logging | |
| import threading | |
| from datetime import datetime, timedelta | |
| from typing import Optional | |
| from fastapi import APIRouter, HTTPException, Header, Query, Depends, Request | |
| from fastapi.responses import HTMLResponse, PlainTextResponse | |
| from pydantic import BaseModel | |
| from sqlalchemy.orm import Session | |
| from .core import start_game, get_kline_by_code, GameStartResponse, KLineData | |
| from .database import get_db | |
| from .database_user import ( | |
| get_user_db, | |
| User, | |
| UserMembership, | |
| UserSession, | |
| PaymentOrder, | |
| hash_password, | |
| verify_password, | |
| create_session_token, | |
| get_user_by_token, | |
| extend_vip_membership, | |
| get_daily_usage, | |
| increment_daily_usage, | |
| FREE_DAILY_LIMIT, | |
| create_payment_order, | |
| register_user, | |
| delete_user as db_delete_user, | |
| sync_user_db_after_update, | |
| get_beijing_time, | |
| ) | |
| from .limiter import limiter | |
| # 导入同步函数 | |
| import sys | |
| sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__)))) | |
| from scripts.sync_data import main as run_sync_task | |
| logger = logging.getLogger(__name__) | |
| router = APIRouter(prefix="/api") | |
| # 环境变量 | |
| ADMIN_SECRET = os.getenv("ADMIN_SECRET", "your-secret-key") | |
| SESSION_EXPIRE_DAYS = int(os.getenv("SESSION_EXPIRE_DAYS", "30")) | |
| RATE_LIMIT_GLOBAL_RAW = os.getenv("RATE_LIMIT_GLOBAL", "50/second") | |
| RATE_LIMIT_IP_RAW = os.getenv("RATE_LIMIT_IP", "5/second") | |
| # 确保频率限制字符串包含时间单位 (例如 /second) | |
| def _ensure_rate_limit_format(limit_str: str) -> str: | |
| if "/" not in limit_str: | |
| return f"{limit_str}/second" | |
| return limit_str | |
| RATE_LIMIT_GLOBAL = _ensure_rate_limit_format(RATE_LIMIT_GLOBAL_RAW) | |
| RATE_LIMIT_IP = _ensure_rate_limit_format(RATE_LIMIT_IP_RAW) | |
| class HealthResponse(BaseModel): | |
| """健康检查响应""" | |
| status: str | |
| database: str | |
| stocks_count: int | |
| class ErrorResponse(BaseModel): | |
| """错误响应""" | |
| error: str | |
| class RegisterRequest(BaseModel): | |
| username: str | |
| password: str | |
| class LoginRequest(BaseModel): | |
| username: str | |
| password: str | |
| class AuthResponse(BaseModel): | |
| token: str | |
| user_id: int | |
| username: str | |
| class UserInfoResponse(BaseModel): | |
| user_id: int | |
| username: str | |
| vip_expire_at: Optional[str] | |
| is_vip: bool | |
| class VipStatusResponse(BaseModel): | |
| is_vip: bool | |
| vip_expire_at: Optional[str] | |
| class CreatePaymentRequest(BaseModel): | |
| type: int = 2 # 1: 支付宝, 2: 微信 (默认为2) | |
| months: int = 1 # 购买月数 | |
| class CreatePaymentResponse(BaseModel): | |
| order_id: str | |
| price: float | |
| type: int | |
| def _get_token_from_header(authorization: Optional[str]) -> Optional[str]: | |
| if not authorization: | |
| return None | |
| if authorization.lower().startswith("bearer "): | |
| return authorization[7:] | |
| return authorization | |
| def check_vip_status(user_id: int, db: Session) -> dict: | |
| membership = db.get(UserMembership, user_id) | |
| now = get_beijing_time() | |
| if membership is None or membership.vip_expire_at is None: | |
| return {"is_vip": False, "vip_expire_at": None} | |
| expire_at = membership.vip_expire_at | |
| is_vip = expire_at > now | |
| return { | |
| "is_vip": is_vip, | |
| "vip_expire_at": expire_at.strftime('%Y-%m-%d %H:%M:%S') | |
| } | |
| def get_current_user( | |
| authorization: Optional[str] = Header(None, alias="Authorization"), | |
| db: Session = Depends(get_user_db), | |
| ) -> User: | |
| token = _get_token_from_header(authorization) | |
| if not token: | |
| raise HTTPException(status_code=401, detail="Missing authorization token") | |
| user = get_user_by_token(db, token) | |
| if not user: | |
| raise HTTPException(status_code=401, detail="Invalid or expired token") | |
| return user | |
| async def health_check(): | |
| """健康检查接口""" | |
| try: | |
| db = get_db() | |
| count = db.conn.execute("SELECT COUNT(*) FROM stock_list").fetchone()[0] | |
| return HealthResponse( | |
| status="healthy", | |
| database="connected", | |
| stocks_count=count | |
| ) | |
| except Exception as e: | |
| logger.error(f"Health check failed: {e}") | |
| return HealthResponse( | |
| status="unhealthy", | |
| database="disconnected", | |
| stocks_count=0 | |
| ) | |
| async def register(request: Request, payload: RegisterRequest, db: Session = Depends(get_user_db)): | |
| username = payload.username.strip() | |
| if len(username) < 3 or len(payload.password) < 6: | |
| raise HTTPException(status_code=400, detail="用户名至少3位,密码至少6位") | |
| exists = db.query(User).filter(User.username == username).first() | |
| if exists: | |
| raise HTTPException(status_code=400, detail="用户名已存在") | |
| user = register_user(db, username, hash_password(payload.password)) | |
| token = create_session_token() | |
| session = UserSession( | |
| token=token, | |
| user_id=user.id, | |
| expire_at=get_beijing_time() + timedelta(days=SESSION_EXPIRE_DAYS) | |
| ) | |
| db.add(session) | |
| db.commit() | |
| return AuthResponse(token=token, user_id=user.id, username=user.username) | |
| async def login(request: Request, payload: LoginRequest, db: Session = Depends(get_user_db)): | |
| user = db.query(User).filter(User.username == payload.username.strip()).first() | |
| if not user or not verify_password(payload.password, user.password_hash): | |
| raise HTTPException(status_code=401, detail="用户名或密码错误") | |
| # 账号唯一登录限制:踢掉旧会话 | |
| db.query(UserSession).filter(UserSession.user_id == user.id).delete() | |
| token = create_session_token() | |
| session = UserSession( | |
| token=token, | |
| user_id=user.id, | |
| expire_at=get_beijing_time() + timedelta(days=SESSION_EXPIRE_DAYS) | |
| ) | |
| db.add(session) | |
| db.commit() | |
| return AuthResponse(token=token, user_id=user.id, username=user.username) | |
| async def me(current_user: User = Depends(get_current_user), db: Session = Depends(get_user_db)): | |
| vip = check_vip_status(current_user.id, db) | |
| return UserInfoResponse( | |
| user_id=current_user.id, | |
| username=current_user.username, | |
| vip_expire_at=vip["vip_expire_at"], | |
| is_vip=vip["is_vip"], | |
| ) | |
| async def vip_status(current_user: User = Depends(get_current_user), db: Session = Depends(get_user_db)): | |
| vip = check_vip_status(current_user.id, db) | |
| return VipStatusResponse(is_vip=vip["is_vip"], vip_expire_at=vip["vip_expire_at"]) | |
| async def logout( | |
| current_user: User = Depends(get_current_user), | |
| authorization: Optional[str] = Header(None, alias="Authorization"), | |
| db: Session = Depends(get_user_db), | |
| ): | |
| """退出登录 - 删除当前 session token""" | |
| token = _get_token_from_header(authorization) | |
| if token: | |
| session_row = db.get(UserSession, token) | |
| if session_row: | |
| db.delete(session_row) | |
| db.commit() | |
| return {"status": "ok"} | |
| async def payment_callback(request: Request, db: Session = Depends(get_user_db)): | |
| """ | |
| 支付手动回调占位 | |
| - 未来可用于接收第三方手动转账通知或管理员后台触发 | |
| """ | |
| return {"status": "ok"} | |
| async def create_payment( | |
| payload: CreatePaymentRequest, | |
| request: Request, | |
| current_user: User = Depends(get_current_user), | |
| db: Session = Depends(get_user_db) | |
| ): | |
| """创建支付订单""" | |
| # 价格配置: 每月 10 元,包年 100 元 | |
| months = max(1, payload.months) | |
| if months == 12: | |
| price = 100.0 | |
| elif months > 12: | |
| years = months // 12 | |
| extra_months = months % 12 | |
| price = years * 100.0 + extra_months * 10.0 | |
| else: | |
| price = float(months * 10) | |
| # 1. 创建本地订单 | |
| order_id = create_payment_order(db, current_user.id, price, payload.type, months) | |
| # 2. 返回结果 | |
| return CreatePaymentResponse( | |
| order_id=order_id, | |
| price=price, | |
| type=payload.type | |
| ) | |
| async def check_payment_status( | |
| order_id: str, | |
| db: Session = Depends(get_user_db) | |
| ): | |
| """查询订单状态(手动模式下主要用于占位或未来人工确认)""" | |
| order = db.query(PaymentOrder).filter(PaymentOrder.order_id == order_id).first() | |
| if not order: | |
| return {"code": -1, "msg": "order not found"} | |
| return { | |
| "code": 1, | |
| "data": { | |
| "status": 2 if order.status == "paid" else 1 | |
| } | |
| } | |
| async def game_start( | |
| mode: str = Query(default="random", description="游戏模式"), | |
| market: Optional[str] = Query(default=None, description="板块类型"), | |
| current_user: User = Depends(get_current_user), | |
| db: Session = Depends(get_user_db), | |
| ): | |
| """ | |
| 开始游戏 - 获取盲盒数据 | |
| - 必须登录(未登录返回 401) | |
| - 非 VIP 每天限 FREE_DAILY_LIMIT 次 | |
| - VIP 无限制 | |
| """ | |
| vip_info = check_vip_status(current_user.id, db) | |
| is_vip = vip_info["is_vip"] | |
| # 非 VIP:检查每日次数 (管理员跳过) | |
| if not is_vip and not current_user.is_admin: | |
| used_today = get_daily_usage(db, current_user.id) | |
| if used_today >= FREE_DAILY_LIMIT: | |
| raise HTTPException( | |
| status_code=403, | |
| detail={ | |
| "code": "DAILY_LIMIT_EXCEEDED", | |
| "message": f"今日免费次数({FREE_DAILY_LIMIT}次)已用完,开通会员可无限使用", | |
| "used": used_today, | |
| "limit": FREE_DAILY_LIMIT, | |
| } | |
| ) | |
| # VIP 受限板块校验 | |
| if not is_vip and market in ['科创板', '北交所', '可转债']: | |
| raise HTTPException(status_code=403, detail="该板块仅限 VIP 会员使用") | |
| try: | |
| result = start_game(mode=mode, mask=True, market_type=market) | |
| # 成功后记录使用次数(VIP 不计) | |
| if not is_vip: | |
| increment_daily_usage(db, current_user.id) | |
| return result | |
| except ValueError as e: | |
| raise HTTPException(status_code=400, detail=str(e)) | |
| except Exception as e: | |
| logger.error(f"Game start failed: {e}") | |
| raise HTTPException(status_code=500, detail="Internal server error") | |
| async def get_usage_today( | |
| current_user: User = Depends(get_current_user), | |
| db: Session = Depends(get_user_db), | |
| ): | |
| """获取今日使用情况""" | |
| vip_info = check_vip_status(current_user.id, db) | |
| is_vip = vip_info["is_vip"] | |
| used = get_daily_usage(db, current_user.id) | |
| return { | |
| "used": used, | |
| "limit": None if is_vip else FREE_DAILY_LIMIT, | |
| "remaining": None if is_vip else max(0, FREE_DAILY_LIMIT - used), | |
| "is_vip": is_vip, | |
| } | |
| async def get_kline( | |
| code: str = Query(..., description="股票代码"), | |
| start: Optional[str] = Query(None, description="起始日期 (YYYY-MM-DD)"), | |
| end: Optional[str] = Query(None, description="结束日期 (YYYY-MM-DD)") | |
| ): | |
| """ | |
| 获取指定股票的K线数据 | |
| """ | |
| try: | |
| klines = get_kline_by_code(code, start_date=start, end_date=end) | |
| if not klines: | |
| raise HTTPException(status_code=404, detail=f"No data found for stock {code}") | |
| return klines | |
| except HTTPException: | |
| raise | |
| except Exception as e: | |
| logger.error(f"Get kline failed: {e}") | |
| raise HTTPException(status_code=500, detail="Internal server error") | |
| async def trigger_sync( | |
| x_secret_key: Optional[str] = Header(None, alias="X-Secret-Key") | |
| ): | |
| """ | |
| 触发数据同步(需要管理员权限) | |
| """ | |
| if x_secret_key != ADMIN_SECRET: | |
| raise HTTPException(status_code=403, detail="Invalid secret key") | |
| try: | |
| # 在后台线程运行同步任务,避免阻塞 API | |
| thread = threading.Thread(target=run_sync_task) | |
| thread.start() | |
| logger.info("Manual sync started in background thread") | |
| return {"status": "sync_triggered", "message": "Data sync has been started in background"} | |
| except Exception as e: | |
| logger.error(f"Sync trigger failed: {e}") | |
| raise HTTPException(status_code=500, detail="Failed to start sync task") | |
| async def get_hs300_index( | |
| start: Optional[str] = Query(None, description="起始日期 (YYYY-MM-DD)"), | |
| end: Optional[str] = Query(None, description="结束日期 (YYYY-MM-DD)") | |
| ): | |
| """ | |
| 获取沪深300指数数据 | |
| """ | |
| try: | |
| db = get_db() | |
| query = "SELECT trade_date, close FROM stock_daily WHERE code = '000300'" | |
| params = [] | |
| if start: | |
| query += " AND trade_date >= ?" | |
| params.append(start) | |
| if end: | |
| query += " AND trade_date <= ?" | |
| params.append(end) | |
| query += " ORDER BY trade_date" | |
| result = db.conn.execute(query, params).fetchall() | |
| if not result: | |
| return [] | |
| return [ | |
| { | |
| "date": row[0].strftime('%Y-%m-%d') if hasattr(row[0], 'strftime') else str(row[0]), | |
| "close": float(row[1]) | |
| } | |
| for row in result | |
| ] | |
| except Exception as e: | |
| logger.error(f"Get HS300 index failed: {e}") | |
| raise HTTPException(status_code=500, detail="Internal server error") | |
| # --- Admin Management Endpoints --- | |
| class AdminUserItem(BaseModel): | |
| user_id: int | |
| username: str | |
| is_vip: bool | |
| vip_expire_at: Optional[str] | |
| created_at: str | |
| has_payment: bool # 是否有支付记录 | |
| last_login: Optional[str] # 最近登录时间 | |
| class AdminOrderItem(BaseModel): | |
| order_id: str | |
| amount: float | |
| pay_type: int | |
| status: str | |
| created_at: str | |
| paid_at: Optional[str] | |
| months: int | |
| class AdminUpdatePasswordRequest(BaseModel): | |
| new_password: str | |
| class AdminUpdateVipRequest(BaseModel): | |
| months: int | |
| class AdminUpdateOrderStatusRequest(BaseModel): | |
| status: str | |
| def get_current_user_admin(current_user: User = Depends(get_current_user)) -> User: | |
| """确认当前用户是管理员""" | |
| if not current_user.is_admin: | |
| raise HTTPException(status_code=403, detail="您没有管理员权限") | |
| return current_user | |
| async def admin_get_users( | |
| db: Session = Depends(get_user_db), | |
| admin: User = Depends(get_current_user_admin) | |
| ): | |
| """获取所有用户列表 (管理员专用)""" | |
| users = db.query(User).all() | |
| result = [] | |
| for u in users: | |
| vip_info = check_vip_status(u.id, db) | |
| # 查询是否有支付记录 | |
| has_payment = db.query(PaymentOrder).filter(PaymentOrder.user_id == u.id).count() > 0 | |
| # 查询最近登录时间 (从 UserSession 中获取最新的 created_at) | |
| last_session = db.query(UserSession).filter(UserSession.user_id == u.id).order_by(UserSession.created_at.desc()).first() | |
| last_login = last_session.created_at.strftime('%Y/%m/%d %H:%M:%S') if last_session and last_session.created_at else None | |
| # 格式化 VIP 到期时间 | |
| vip_expire_formatted = None | |
| if vip_info["vip_expire_at"]: | |
| try: | |
| # 将字符串解析为 datetime 对象,然后重新格式化为所需格式 | |
| from datetime import datetime | |
| dt = datetime.strptime(vip_info["vip_expire_at"], '%Y-%m-%d %H:%M:%S') | |
| vip_expire_formatted = dt.strftime('%Y/%m/%d %H:%M:%S') | |
| except ValueError: | |
| # 如果解析失败,使用原始值 | |
| vip_expire_formatted = vip_info["vip_expire_at"].replace('-', '/') | |
| result.append(AdminUserItem( | |
| user_id=u.id, | |
| username=u.username, | |
| is_vip=vip_info["is_vip"], | |
| vip_expire_at=vip_expire_formatted, | |
| created_at=u.created_at.strftime('%Y/%m/%d %H:%M:%S') if u.created_at else "", | |
| has_payment=has_payment, | |
| last_login=last_login | |
| )) | |
| return result | |
| async def admin_update_password( | |
| user_id: int, | |
| payload: AdminUpdatePasswordRequest, | |
| db: Session = Depends(get_user_db), | |
| admin: User = Depends(get_current_user_admin) | |
| ): | |
| """修改指定用户密码 (管理员专用)""" | |
| user = db.query(User).filter(User.id == user_id).first() | |
| if not user: | |
| raise HTTPException(status_code=404, detail="用户不存在") | |
| from .database_user import hash_password | |
| user.password_hash = hash_password(payload.new_password) | |
| db.commit() | |
| return {"status": "ok"} | |
| async def admin_update_vip( | |
| user_id: int, | |
| payload: AdminUpdateVipRequest, | |
| db: Session = Depends(get_user_db), | |
| admin: User = Depends(get_current_user_admin) | |
| ): | |
| """手动延长用户会员 (管理员专用)""" | |
| user = db.query(User).filter(User.id == user_id).first() | |
| if not user: | |
| raise HTTPException(status_code=404, detail="用户不存在") | |
| # 计算天数: 1个月按30天计 | |
| extension_days = max(1, payload.months) * 30 | |
| extend_vip_membership(db, user_id, days=extension_days) | |
| return {"status": "ok"} | |
| async def admin_get_user_orders( | |
| user_id: int, | |
| db: Session = Depends(get_user_db), | |
| admin: User = Depends(get_current_user_admin) | |
| ): | |
| """获取指定用户的订单历史 (管理员专用)""" | |
| orders = db.query(PaymentOrder).filter(PaymentOrder.user_id == user_id).order_by(PaymentOrder.created_at.desc()).all() | |
| return [ | |
| AdminOrderItem( | |
| order_id=o.order_id, | |
| amount=o.amount, | |
| pay_type=o.pay_type, | |
| status=o.status, | |
| created_at=o.created_at.strftime('%Y-%m-%d %H:%M:%S') if o.created_at else "", | |
| paid_at=o.paid_at.strftime('%Y-%m-%d %H:%M:%S') if o.paid_at else None, | |
| months=o.vip_duration_months | |
| ) | |
| for o in orders | |
| ] | |
| async def admin_update_order_status( | |
| order_id: str, | |
| payload: AdminUpdateOrderStatusRequest, | |
| db: Session = Depends(get_user_db), | |
| admin: User = Depends(get_current_user_admin) | |
| ): | |
| """手动修改订单状态 (管理员专用)""" | |
| from .database_user import update_order_status | |
| success = update_order_status(db, order_id, payload.status) | |
| if not success: | |
| raise HTTPException(status_code=404, detail="订单不存在") | |
| return {"status": "ok"} | |
| async def admin_delete_user( | |
| user_id: int, | |
| db: Session = Depends(get_user_db), | |
| admin: User = Depends(get_current_user_admin) | |
| ): | |
| """彻底删除用户及其数据 (管理员专用)""" | |
| # 允许删除非超级管理员自己(如果需要的话,但通常不允许删除特定ID,这里简单处理:不允许删除超级管理员) | |
| user_to_delete = db.get(User, user_id) | |
| if not user_to_delete: | |
| raise HTTPException(status_code=404, detail="用户不存在") | |
| if user_to_delete.username == "583079759": | |
| raise HTTPException(status_code=403, detail="系统保护:无法删除超级管理员账号") | |
| db_delete_user(db, user_id) | |
| return {"status": "ok"} | |