# Hugging Face Spaces - MySQL API Server # 此文件专为 Hugging Face Spaces 部署优化 import os import time import uuid from contextlib import asynccontextmanager from typing import List, Optional import pymysql from dbutils.pooled_db import PooledDB from fastapi import FastAPI, HTTPException, Query from fastapi.middleware.cors import CORSMiddleware from fastapi.responses import HTMLResponse, JSONResponse from pydantic import BaseModel from pymysql.cursors import DictCursor # 数据库配置 - 从环境变量读取 DB_CONFIG = { "host": os.getenv("DB_HOST", "114.116.200.230"), "port": int(os.getenv("DB_PORT", 3306)), "user": os.getenv("DB_USER", "aistock_admin"), "password": os.getenv("DB_PASSWORD", ""), # 必须通过环境变量设置 "database": os.getenv("DB_NAME", "aistock"), "charset": "utf8mb4", "cursorclass": DictCursor, } # 连接池 pool: Optional[PooledDB] = None def init_pool(): """初始化数据库连接池""" global pool if not DB_CONFIG["password"]: print("⚠️ 警告: DB_PASSWORD 环境变量未设置!") return pool = PooledDB( creator=pymysql, maxconnections=5, # Hugging Face 免费版资源有限 mincached=1, maxcached=3, blocking=True, maxusage=None, setsession=[], ping=1, **DB_CONFIG, ) print( f"✅ MySQL 连接池已初始化: {DB_CONFIG['host']}:{DB_CONFIG['port']}/{DB_CONFIG['database']}" ) def get_connection(): """从连接池获取连接""" if not pool: raise HTTPException(status_code=500, detail="数据库未配置") return pool.connection() @asynccontextmanager async def lifespan(app: FastAPI): """应用生命周期管理""" init_pool() yield if pool: pool.close() # FastAPI 应用 app = FastAPI( title="Stock Analysis MySQL API", description="MySQL API Server for Stock Analysis Agent (Hugging Face Spaces)", version="1.0.0", lifespan=lifespan, ) # CORS 中间件 - 允许 Cloudflare Worker 访问 app.add_middleware( CORSMiddleware, allow_origins=["*"], allow_credentials=True, allow_methods=["*"], allow_headers=["*"], ) # ============ Pydantic 模型 ============ class SubmitRequest(BaseModel): stock_code: str stock_name: str market: Optional[str] = "" username: Optional[str] = None # 提交任务的用户 submitted_by: Optional[str] = None # 同上,兼容 Cloudflare Worker submitted_at: Optional[str] = None # 提交时间 is_public: Optional[int] = 1 # 报告是否公开: 1=公开(默认), 0=私有 is_vip: Optional[int] = 0 # VIP用户提交的任务: 1=VIP, 0=普通 class CompleteRequest(BaseModel): id: str stock_code: Optional[str] = None stock_name: Optional[str] = None html_content: Optional[str] = None status: Optional[str] = "completed" class ReorderRequest(BaseModel): id: str class RemoveDuplicatesRequest(BaseModel): ids: List[str] # ============ API 端点 ============ @app.get("/") async def root(): """健康检查""" return { "status": "ok", "service": "MySQL API Server", "platform": "Hugging Face Spaces", "db_configured": bool(DB_CONFIG["password"]), } @app.post("/api/submit") async def submit_task(req: SubmitRequest): """提交新任务""" conn = get_connection() # 获取 username(优先使用 username,其次 submitted_by) username = req.username or req.submitted_by or None try: with conn.cursor() as cursor: cursor.execute( "SELECT id, submit_count FROM requests WHERE stock_code = %s AND stock_name = %s AND status IN ('pending', 'processing') LIMIT 1", (req.stock_code, req.stock_name), ) existing = cursor.fetchone() if existing: new_count = (existing.get("submit_count") or 1) + 1 cursor.execute( "UPDATE requests SET submit_count = %s WHERE id = %s", (new_count, existing["id"]), ) conn.commit() return { "success": True, "request_id": existing["id"], "duplicate": True, } request_id = str(uuid.uuid4()) now = int(time.time() * 1000) # 处理 is_public 字段:确保值为 0 或 1 is_public = 1 if req.is_public is None or req.is_public else 0 if req.is_public == 1 or req.is_public == True: is_public = 1 elif req.is_public == 0 or req.is_public == False: is_public = 0 else: is_public = 1 # 默认公开 # 处理 is_vip 字段 is_vip = 1 if req.is_vip == 1 or req.is_vip == True else 0 cursor.execute( "INSERT INTO requests (id, stock_code, stock_name, market, status, created_at, submit_count, username, is_public, is_vip) VALUES (%s, %s, %s, %s, %s, %s, %s, %s, %s, %s)", ( request_id, req.stock_code, req.stock_name, req.market, "pending", now, 1, username, is_public, is_vip, ), ) conn.commit() return {"success": True, "request_id": request_id, "username": username} except Exception as e: return JSONResponse(status_code=500, content={"error": str(e)}) finally: conn.close() @app.get("/api/pending") async def get_pending_tasks(): """获取待处理任务""" conn = get_connection() try: with conn.cursor() as cursor: cursor.execute( "SELECT * FROM requests WHERE status = 'pending' ORDER BY created_at ASC" ) tasks = cursor.fetchall() return {"tasks": tasks} except Exception as e: return JSONResponse(status_code=500, content={"error": str(e)}) finally: conn.close() @app.get("/api/status") async def get_task_status(id: str = Query(...)): """获取任务状态""" conn = get_connection() try: with conn.cursor() as cursor: cursor.execute("SELECT * FROM requests WHERE id = %s", (id,)) result = cursor.fetchone() if not result: return JSONResponse(status_code=404, content={"error": "Not found"}) return result except Exception as e: return JSONResponse(status_code=500, content={"error": str(e)}) finally: conn.close() @app.post("/api/complete") async def complete_task(req: CompleteRequest): """完成任务并保存报告""" conn = get_connection() try: with conn.cursor() as cursor: now = int(time.time() * 1000) cursor.execute( "UPDATE requests SET status = %s, completed_at = %s WHERE id = %s", (req.status, now, req.id), ) if req.status == "completed" and req.html_content and req.stock_code: # 使用 stock_code 作为稳定 id,确保每支股票只保留最新一份报告 report_id = str(req.stock_code) # 历史遗留:reports 表可能存在同一 stock_code 的多条记录,先清理再写入 cursor.execute( "DELETE FROM reports WHERE stock_code = %s AND id <> %s", (req.stock_code, report_id), ) cursor.execute( "INSERT INTO reports (id, stock_code, stock_name, html_content, created_at) VALUES (%s, %s, %s, %s, %s) ON DUPLICATE KEY UPDATE stock_name = VALUES(stock_name), html_content = VALUES(html_content), created_at = VALUES(created_at)", (report_id, req.stock_code, req.stock_name, req.html_content, now), ) conn.commit() return {"success": True} except Exception as e: return JSONResponse(status_code=500, content={"error": str(e)}) finally: conn.close() @app.get("/api/report") async def get_report(code: str = Query(...)): """获取报告 HTML""" conn = get_connection() try: with conn.cursor() as cursor: # reports 表可能存在同一 stock_code 的多条历史记录(按 created_at 递增) # 这里必须取最新的一条,否则会出现“列表时间是最新、打开报告却是旧内容”的错配 cursor.execute( "SELECT html_content FROM reports WHERE stock_code = %s ORDER BY created_at DESC LIMIT 1", (code,), ) result = cursor.fetchone() if not result or not result.get("html_content"): return HTMLResponse( content=f"

Report not found for {code}

", status_code=404 ) return HTMLResponse(content=result["html_content"]) except Exception as e: return JSONResponse(status_code=500, content={"error": str(e)}) finally: conn.close() @app.get("/api/history") async def get_history( date: Optional[str] = Query(None), username: Optional[str] = Query(None) ): """获取历史报告列表,支持按日期和用户名筛选""" conn = get_connection() try: with conn.cursor() as cursor: # 如果指定了 username,只返回该用户的任务 if username: if date: from datetime import datetime start_dt = datetime.strptime(date, "%Y-%m-%d") start_ts = int(start_dt.timestamp() * 1000) end_ts = start_ts + 86400000 cursor.execute( """ SELECT * FROM requests WHERE username = %s AND status = 'completed' AND completed_at >= %s AND completed_at < %s ORDER BY completed_at DESC """, (username, start_ts, end_ts), ) else: cursor.execute( """ SELECT * FROM requests WHERE username = %s AND status IN ('completed', 'error', 'pending', 'processing') ORDER BY created_at DESC LIMIT 100 """, (username,), ) elif date: from datetime import datetime start_dt = datetime.strptime(date, "%Y-%m-%d") start_ts = int(start_dt.timestamp() * 1000) end_ts = start_ts + 86400000 # 首页调用(无username):只返回公开报告 (is_public = 1 或 is_public IS NULL) cursor.execute( """ SELECT r1.* FROM requests r1 INNER JOIN ( SELECT stock_code, MAX(completed_at) as max_completed_at FROM requests WHERE status = 'completed' AND completed_at >= %s AND completed_at < %s AND (is_public = 1 OR is_public IS NULL) GROUP BY stock_code ) r2 ON r1.stock_code = r2.stock_code AND r1.completed_at = r2.max_completed_at WHERE r1.status = 'completed' AND (r1.is_public = 1 OR r1.is_public IS NULL) ORDER BY r1.completed_at DESC """, (start_ts, end_ts), ) else: # 首页调用(无username,无date):只返回公开报告 (is_public = 1 或 is_public IS NULL) cursor.execute(""" SELECT r1.* FROM requests r1 INNER JOIN ( SELECT stock_code, MAX(completed_at) as max_completed_at FROM requests WHERE status = 'completed' AND (is_public = 1 OR is_public IS NULL) GROUP BY stock_code ) r2 ON r1.stock_code = r2.stock_code AND r1.completed_at = r2.max_completed_at WHERE r1.status = 'completed' AND (r1.is_public = 1 OR r1.is_public IS NULL) ORDER BY r1.completed_at DESC """) tasks = cursor.fetchall() return {"tasks": tasks} except Exception as e: return JSONResponse(status_code=500, content={"error": str(e)}) finally: conn.close() @app.get("/api/check-report") async def check_report(code: str = Query(...)): """检查报告是否存在""" conn = get_connection() try: with conn.cursor() as cursor: cursor.execute( "SELECT stock_code, stock_name, created_at FROM reports WHERE stock_code = %s ORDER BY created_at DESC LIMIT 1", (code,), ) result = cursor.fetchone() if result: return { "exists": True, "stock_code": result["stock_code"], "stock_name": result["stock_name"], "created_at": result["created_at"], } return {"exists": False} except Exception as e: return JSONResponse(status_code=500, content={"error": str(e)}) finally: conn.close() @app.post("/api/reorder") async def reorder_task(req: ReorderRequest): """任务置顶""" conn = get_connection() try: with conn.cursor() as cursor: cursor.execute( "SELECT MIN(created_at) as min_time FROM requests WHERE status = 'pending'" ) result = cursor.fetchone() min_time = ( result["min_time"] if result and result["min_time"] else int(time.time() * 1000) ) new_time = min_time - 1 cursor.execute( "UPDATE requests SET created_at = %s WHERE id = %s AND status = 'pending'", (new_time, req.id), ) conn.commit() return {"success": True, "new_time": new_time} except Exception as e: return JSONResponse(status_code=500, content={"error": str(e)}) finally: conn.close() @app.post("/api/remove-duplicates") async def remove_duplicates(req: RemoveDuplicatesRequest): """批量删除重复任务""" if not req.ids: return JSONResponse( status_code=400, content={"error": "Missing or invalid task IDs"} ) conn = get_connection() try: with conn.cursor() as cursor: placeholders = ", ".join(["%s"] * len(req.ids)) cursor.execute( f"DELETE FROM requests WHERE id IN ({placeholders}) AND status = 'pending'", tuple(req.ids), ) deleted = cursor.rowcount conn.commit() return {"success": True, "deleted": deleted} except Exception as e: return JSONResponse(status_code=500, content={"error": str(e)}) finally: conn.close() @app.get("/api/queue") async def get_queue(): """获取完整任务队列""" conn = get_connection() try: with conn.cursor() as cursor: cursor.execute( "SELECT * FROM requests WHERE status IN ('pending', 'processing') ORDER BY created_at ASC" ) tasks = cursor.fetchall() return {"tasks": tasks} except Exception as e: return JSONResponse(status_code=500, content={"error": str(e)}) finally: conn.close() # ============ 板块轮动报告 API ============ @app.get("/api/sector-rotation-reports") async def get_sector_rotation_reports(limit: int = Query(default=10)): """获取最近的板块轮动报告列表""" conn = get_connection() try: with conn.cursor() as cursor: cursor.execute( """ SELECT trade_date, updated_at FROM sector_rotation_reports ORDER BY trade_date DESC LIMIT %s """, (limit,), ) reports = cursor.fetchall() # 格式化返回数据 result = [] for report in reports: result.append( { "trade_date": report["trade_date"], "update_time": time.strftime( "%Y-%m-%d %H:%M:%S", time.localtime(report["updated_at"] / 1000), ) if report["updated_at"] else None, "updated_at": report["updated_at"], } ) return {"reports": result} except Exception as e: return JSONResponse(status_code=500, content={"error": str(e)}) finally: conn.close() @app.get("/api/sector-rotation-reports/{trade_date}") async def get_sector_rotation_report(trade_date: str): """获取指定交易日的板块轮动报告""" conn = get_connection() try: with conn.cursor() as cursor: cursor.execute( """ SELECT trade_date, html_content, updated_at FROM sector_rotation_reports WHERE trade_date = %s """, (trade_date,), ) report = cursor.fetchone() if not report: return JSONResponse( status_code=404, content={"error": f"未找到 {trade_date} 的板块轮动报告"}, ) return { "trade_date": report["trade_date"], "html_content": report["html_content"], "updated_at": report["updated_at"], } except Exception as e: return JSONResponse(status_code=500, content={"error": str(e)}) finally: conn.close() @app.get("/api/sector-rotation-reports/{trade_date}/html") async def get_sector_rotation_report_html(trade_date: str): """获取指定交易日的板块轮动报告(直接返回HTML页面)""" conn = get_connection() try: with conn.cursor() as cursor: cursor.execute( """ SELECT html_content FROM sector_rotation_reports WHERE trade_date = %s """, (trade_date,), ) report = cursor.fetchone() if not report or not report["html_content"]: return HTMLResponse( content=f"

未找到 {trade_date} 的板块轮动报告

", status_code=404, ) return HTMLResponse(content=report["html_content"]) except Exception as e: return HTMLResponse( content=f"

加载失败

{str(e)}

", status_code=500, ) finally: conn.close() # ============ 龙头股对比分析报告 API ============ @app.get("/api/longtou-compare-reports") async def get_longtou_compare_reports(limit: int = Query(default=10)): """获取最近的龙头股对比分析报告列表""" conn = get_connection() try: with conn.cursor() as cursor: cursor.execute( """ SELECT trade_date, updated_at FROM longtou_compare_reports ORDER BY trade_date DESC LIMIT %s """, (limit,), ) reports = cursor.fetchall() result = [] for report in reports: result.append( { "trade_date": report["trade_date"], "update_time": time.strftime( "%Y-%m-%d %H:%M:%S", time.localtime(report["updated_at"] / 1000), ) if report["updated_at"] else None, "updated_at": report["updated_at"], } ) return {"reports": result} except Exception as e: return JSONResponse(status_code=500, content={"error": str(e)}) finally: conn.close() @app.get("/api/longtou-compare-reports/{trade_date}") async def get_longtou_compare_report(trade_date: str): """获取指定交易日的龙头股对比分析报告""" conn = get_connection() try: with conn.cursor() as cursor: cursor.execute( """ SELECT trade_date, html_content, updated_at FROM longtou_compare_reports WHERE trade_date = %s """, (trade_date,), ) report = cursor.fetchone() if not report: return JSONResponse( status_code=404, content={"error": f"未找到 {trade_date} 的龙头股对比分析报告"}, ) return { "trade_date": report["trade_date"], "html_content": report["html_content"], "updated_at": report["updated_at"], } except Exception as e: return JSONResponse(status_code=500, content={"error": str(e)}) finally: conn.close() @app.get("/api/longtou-compare-reports/{trade_date}/html") async def get_longtou_compare_report_html(trade_date: str): """获取指定交易日的龙头股对比分析报告(直接返回HTML页面)""" conn = get_connection() try: with conn.cursor() as cursor: cursor.execute( """ SELECT html_content FROM longtou_compare_reports WHERE trade_date = %s """, (trade_date,), ) report = cursor.fetchone() if not report or not report["html_content"]: return HTMLResponse( content=f"

未找到 {trade_date} 的龙头股对比分析报告

", status_code=404, ) return HTMLResponse(content=report["html_content"]) except Exception as e: return HTMLResponse( content=f"

加载失败

{str(e)}

", status_code=500, ) finally: conn.close() # Hugging Face Spaces 使用 7860 端口 if __name__ == "__main__": import uvicorn port = int(os.getenv("PORT", 7860)) print(f"🚀 Starting MySQL API Server on port {port}...") uvicorn.run(app, host="0.0.0.0", port=port)