# Hugging Face Spaces - MySQL API Server # 此文件专为 Hugging Face Spaces 部署优化 import os import time import uuid from contextlib import asynccontextmanager from typing import List, Optional import pymysql from dbutils.pooled_db import PooledDB from fastapi import FastAPI, HTTPException, Query from fastapi.middleware.cors import CORSMiddleware from fastapi.responses import HTMLResponse, JSONResponse from pydantic import BaseModel from pymysql.cursors import DictCursor # 数据库配置 - 从环境变量读取 DB_CONFIG = { "host": os.getenv("DB_HOST", "114.116.200.230"), "port": int(os.getenv("DB_PORT", 3306)), "user": os.getenv("DB_USER", "aistock_admin"), "password": os.getenv("DB_PASSWORD", ""), # 必须通过环境变量设置 "database": os.getenv("DB_NAME", "aistock"), "charset": "utf8mb4", "cursorclass": DictCursor, } # 连接池 pool: Optional[PooledDB] = None def init_pool(): """初始化数据库连接池""" global pool if not DB_CONFIG["password"]: print("⚠️ 警告: DB_PASSWORD 环境变量未设置!") return pool = PooledDB( creator=pymysql, maxconnections=5, # Hugging Face 免费版资源有限 mincached=1, maxcached=3, blocking=True, maxusage=None, setsession=[], ping=1, **DB_CONFIG, ) print( f"✅ MySQL 连接池已初始化: {DB_CONFIG['host']}:{DB_CONFIG['port']}/{DB_CONFIG['database']}" ) def get_connection(): """从连接池获取连接""" if not pool: raise HTTPException(status_code=500, detail="数据库未配置") return pool.connection() @asynccontextmanager async def lifespan(app: FastAPI): """应用生命周期管理""" init_pool() yield if pool: pool.close() # FastAPI 应用 app = FastAPI( title="Stock Analysis MySQL API", description="MySQL API Server for Stock Analysis Agent (Hugging Face Spaces)", version="1.0.0", lifespan=lifespan, ) # CORS 中间件 - 允许 Cloudflare Worker 访问 app.add_middleware( CORSMiddleware, allow_origins=["*"], allow_credentials=True, allow_methods=["*"], allow_headers=["*"], ) # ============ Pydantic 模型 ============ class SubmitRequest(BaseModel): stock_code: str stock_name: str market: Optional[str] = "" username: Optional[str] = None # 提交任务的用户 submitted_by: Optional[str] = None # 同上,兼容 Cloudflare Worker submitted_at: Optional[str] = None # 提交时间 is_public: Optional[int] = 1 # 报告是否公开: 1=公开(默认), 0=私有 is_vip: Optional[int] = 0 # VIP用户提交的任务: 1=VIP, 0=普通 class CompleteRequest(BaseModel): id: str stock_code: Optional[str] = None stock_name: Optional[str] = None html_content: Optional[str] = None status: Optional[str] = "completed" class ReorderRequest(BaseModel): id: str class RemoveDuplicatesRequest(BaseModel): ids: List[str] # ============ API 端点 ============ @app.get("/") async def root(): """健康检查""" return { "status": "ok", "service": "MySQL API Server", "platform": "Hugging Face Spaces", "db_configured": bool(DB_CONFIG["password"]), } @app.post("/api/submit") async def submit_task(req: SubmitRequest): """提交新任务""" conn = get_connection() # 获取 username(优先使用 username,其次 submitted_by) username = req.username or req.submitted_by or None try: with conn.cursor() as cursor: cursor.execute( "SELECT id, submit_count FROM requests WHERE stock_code = %s AND stock_name = %s AND status IN ('pending', 'processing') LIMIT 1", (req.stock_code, req.stock_name), ) existing = cursor.fetchone() if existing: new_count = (existing.get("submit_count") or 1) + 1 cursor.execute( "UPDATE requests SET submit_count = %s WHERE id = %s", (new_count, existing["id"]), ) conn.commit() return { "success": True, "request_id": existing["id"], "duplicate": True, } request_id = str(uuid.uuid4()) now = int(time.time() * 1000) # 处理 is_public 字段:确保值为 0 或 1 is_public = 1 if req.is_public is None or req.is_public else 0 if req.is_public == 1 or req.is_public == True: is_public = 1 elif req.is_public == 0 or req.is_public == False: is_public = 0 else: is_public = 1 # 默认公开 # 处理 is_vip 字段 is_vip = 1 if req.is_vip == 1 or req.is_vip == True else 0 cursor.execute( "INSERT INTO requests (id, stock_code, stock_name, market, status, created_at, submit_count, username, is_public, is_vip) VALUES (%s, %s, %s, %s, %s, %s, %s, %s, %s, %s)", ( request_id, req.stock_code, req.stock_name, req.market, "pending", now, 1, username, is_public, is_vip, ), ) conn.commit() return {"success": True, "request_id": request_id, "username": username} except Exception as e: return JSONResponse(status_code=500, content={"error": str(e)}) finally: conn.close() @app.get("/api/pending") async def get_pending_tasks(): """获取待处理任务""" conn = get_connection() try: with conn.cursor() as cursor: cursor.execute( "SELECT * FROM requests WHERE status = 'pending' ORDER BY created_at ASC" ) tasks = cursor.fetchall() return {"tasks": tasks} except Exception as e: return JSONResponse(status_code=500, content={"error": str(e)}) finally: conn.close() @app.get("/api/status") async def get_task_status(id: str = Query(...)): """获取任务状态""" conn = get_connection() try: with conn.cursor() as cursor: cursor.execute("SELECT * FROM requests WHERE id = %s", (id,)) result = cursor.fetchone() if not result: return JSONResponse(status_code=404, content={"error": "Not found"}) return result except Exception as e: return JSONResponse(status_code=500, content={"error": str(e)}) finally: conn.close() @app.post("/api/complete") async def complete_task(req: CompleteRequest): """完成任务并保存报告""" conn = get_connection() try: with conn.cursor() as cursor: now = int(time.time() * 1000) cursor.execute( "UPDATE requests SET status = %s, completed_at = %s WHERE id = %s", (req.status, now, req.id), ) if req.status == "completed" and req.html_content and req.stock_code: # 使用 stock_code 作为稳定 id,确保每支股票只保留最新一份报告 report_id = str(req.stock_code) # 历史遗留:reports 表可能存在同一 stock_code 的多条记录,先清理再写入 cursor.execute( "DELETE FROM reports WHERE stock_code = %s AND id <> %s", (req.stock_code, report_id), ) cursor.execute( "INSERT INTO reports (id, stock_code, stock_name, html_content, created_at) VALUES (%s, %s, %s, %s, %s) ON DUPLICATE KEY UPDATE stock_name = VALUES(stock_name), html_content = VALUES(html_content), created_at = VALUES(created_at)", (report_id, req.stock_code, req.stock_name, req.html_content, now), ) conn.commit() return {"success": True} except Exception as e: return JSONResponse(status_code=500, content={"error": str(e)}) finally: conn.close() @app.get("/api/report") async def get_report(code: str = Query(...)): """获取报告 HTML""" conn = get_connection() try: with conn.cursor() as cursor: # reports 表可能存在同一 stock_code 的多条历史记录(按 created_at 递增) # 这里必须取最新的一条,否则会出现“列表时间是最新、打开报告却是旧内容”的错配 cursor.execute( "SELECT html_content FROM reports WHERE stock_code = %s ORDER BY created_at DESC LIMIT 1", (code,), ) result = cursor.fetchone() if not result or not result.get("html_content"): return HTMLResponse( content=f"
{str(e)}
", status_code=500, ) finally: conn.close() # ============ 龙头股对比分析报告 API ============ @app.get("/api/longtou-compare-reports") async def get_longtou_compare_reports(limit: int = Query(default=10)): """获取最近的龙头股对比分析报告列表""" conn = get_connection() try: with conn.cursor() as cursor: cursor.execute( """ SELECT trade_date, updated_at FROM longtou_compare_reports ORDER BY trade_date DESC LIMIT %s """, (limit,), ) reports = cursor.fetchall() result = [] for report in reports: result.append( { "trade_date": report["trade_date"], "update_time": time.strftime( "%Y-%m-%d %H:%M:%S", time.localtime(report["updated_at"] / 1000), ) if report["updated_at"] else None, "updated_at": report["updated_at"], } ) return {"reports": result} except Exception as e: return JSONResponse(status_code=500, content={"error": str(e)}) finally: conn.close() @app.get("/api/longtou-compare-reports/{trade_date}") async def get_longtou_compare_report(trade_date: str): """获取指定交易日的龙头股对比分析报告""" conn = get_connection() try: with conn.cursor() as cursor: cursor.execute( """ SELECT trade_date, html_content, updated_at FROM longtou_compare_reports WHERE trade_date = %s """, (trade_date,), ) report = cursor.fetchone() if not report: return JSONResponse( status_code=404, content={"error": f"未找到 {trade_date} 的龙头股对比分析报告"}, ) return { "trade_date": report["trade_date"], "html_content": report["html_content"], "updated_at": report["updated_at"], } except Exception as e: return JSONResponse(status_code=500, content={"error": str(e)}) finally: conn.close() @app.get("/api/longtou-compare-reports/{trade_date}/html") async def get_longtou_compare_report_html(trade_date: str): """获取指定交易日的龙头股对比分析报告(直接返回HTML页面)""" conn = get_connection() try: with conn.cursor() as cursor: cursor.execute( """ SELECT html_content FROM longtou_compare_reports WHERE trade_date = %s """, (trade_date,), ) report = cursor.fetchone() if not report or not report["html_content"]: return HTMLResponse( content=f"{str(e)}
", status_code=500, ) finally: conn.close() # Hugging Face Spaces 使用 7860 端口 if __name__ == "__main__": import uvicorn port = int(os.getenv("PORT", 7860)) print(f"🚀 Starting MySQL API Server on port {port}...") uvicorn.run(app, host="0.0.0.0", port=port)