Spaces:
Running
Running
| # Hugging Face Spaces - MySQL API Server | |
| # 此文件专为 Hugging Face Spaces 部署优化 | |
| import os | |
| import time | |
| import uuid | |
| from contextlib import asynccontextmanager | |
| from typing import List, Optional | |
| import pymysql | |
| from dbutils.pooled_db import PooledDB | |
| from fastapi import FastAPI, HTTPException, Query | |
| from fastapi.middleware.cors import CORSMiddleware | |
| from fastapi.responses import HTMLResponse, JSONResponse | |
| from pydantic import BaseModel | |
| from pymysql.cursors import DictCursor | |
| # 数据库配置 - 从环境变量读取 | |
| DB_CONFIG = { | |
| "host": os.getenv("DB_HOST", "114.116.200.230"), | |
| "port": int(os.getenv("DB_PORT", 3306)), | |
| "user": os.getenv("DB_USER", "aistock_admin"), | |
| "password": os.getenv("DB_PASSWORD", ""), # 必须通过环境变量设置 | |
| "database": os.getenv("DB_NAME", "aistock"), | |
| "charset": "utf8mb4", | |
| "cursorclass": DictCursor, | |
| } | |
| # 连接池 | |
| pool: Optional[PooledDB] = None | |
| def init_pool(): | |
| """初始化数据库连接池""" | |
| global pool | |
| if not DB_CONFIG["password"]: | |
| print("⚠️ 警告: DB_PASSWORD 环境变量未设置!") | |
| return | |
| pool = PooledDB( | |
| creator=pymysql, | |
| maxconnections=5, # Hugging Face 免费版资源有限 | |
| mincached=1, | |
| maxcached=3, | |
| blocking=True, | |
| maxusage=None, | |
| setsession=[], | |
| ping=1, | |
| **DB_CONFIG, | |
| ) | |
| print( | |
| f"✅ MySQL 连接池已初始化: {DB_CONFIG['host']}:{DB_CONFIG['port']}/{DB_CONFIG['database']}" | |
| ) | |
| def get_connection(): | |
| """从连接池获取连接""" | |
| if not pool: | |
| raise HTTPException(status_code=500, detail="数据库未配置") | |
| return pool.connection() | |
| async def lifespan(app: FastAPI): | |
| """应用生命周期管理""" | |
| init_pool() | |
| yield | |
| if pool: | |
| pool.close() | |
| # FastAPI 应用 | |
| app = FastAPI( | |
| title="Stock Analysis MySQL API", | |
| description="MySQL API Server for Stock Analysis Agent (Hugging Face Spaces)", | |
| version="1.0.0", | |
| lifespan=lifespan, | |
| ) | |
| # CORS 中间件 - 允许 Cloudflare Worker 访问 | |
| app.add_middleware( | |
| CORSMiddleware, | |
| allow_origins=["*"], | |
| allow_credentials=True, | |
| allow_methods=["*"], | |
| allow_headers=["*"], | |
| ) | |
| # ============ Pydantic 模型 ============ | |
| class SubmitRequest(BaseModel): | |
| stock_code: str | |
| stock_name: str | |
| market: Optional[str] = "" | |
| username: Optional[str] = None # 提交任务的用户 | |
| submitted_by: Optional[str] = None # 同上,兼容 Cloudflare Worker | |
| submitted_at: Optional[str] = None # 提交时间 | |
| is_public: Optional[int] = 1 # 报告是否公开: 1=公开(默认), 0=私有 | |
| is_vip: Optional[int] = 0 # VIP用户提交的任务: 1=VIP, 0=普通 | |
| class CompleteRequest(BaseModel): | |
| id: str | |
| stock_code: Optional[str] = None | |
| stock_name: Optional[str] = None | |
| html_content: Optional[str] = None | |
| status: Optional[str] = "completed" | |
| class ReorderRequest(BaseModel): | |
| id: str | |
| class RemoveDuplicatesRequest(BaseModel): | |
| ids: List[str] | |
| # ============ API 端点 ============ | |
| async def root(): | |
| """健康检查""" | |
| return { | |
| "status": "ok", | |
| "service": "MySQL API Server", | |
| "platform": "Hugging Face Spaces", | |
| "db_configured": bool(DB_CONFIG["password"]), | |
| } | |
| async def submit_task(req: SubmitRequest): | |
| """提交新任务""" | |
| conn = get_connection() | |
| # 获取 username(优先使用 username,其次 submitted_by) | |
| username = req.username or req.submitted_by or None | |
| try: | |
| with conn.cursor() as cursor: | |
| cursor.execute( | |
| "SELECT id, submit_count FROM requests WHERE stock_code = %s AND stock_name = %s AND status IN ('pending', 'processing') LIMIT 1", | |
| (req.stock_code, req.stock_name), | |
| ) | |
| existing = cursor.fetchone() | |
| if existing: | |
| new_count = (existing.get("submit_count") or 1) + 1 | |
| cursor.execute( | |
| "UPDATE requests SET submit_count = %s WHERE id = %s", | |
| (new_count, existing["id"]), | |
| ) | |
| conn.commit() | |
| return { | |
| "success": True, | |
| "request_id": existing["id"], | |
| "duplicate": True, | |
| } | |
| request_id = str(uuid.uuid4()) | |
| now = int(time.time() * 1000) | |
| # 处理 is_public 字段:确保值为 0 或 1 | |
| is_public = 1 if req.is_public is None or req.is_public else 0 | |
| if req.is_public == 1 or req.is_public == True: | |
| is_public = 1 | |
| elif req.is_public == 0 or req.is_public == False: | |
| is_public = 0 | |
| else: | |
| is_public = 1 # 默认公开 | |
| # 处理 is_vip 字段 | |
| is_vip = 1 if req.is_vip == 1 or req.is_vip == True else 0 | |
| cursor.execute( | |
| "INSERT INTO requests (id, stock_code, stock_name, market, status, created_at, submit_count, username, is_public, is_vip) VALUES (%s, %s, %s, %s, %s, %s, %s, %s, %s, %s)", | |
| ( | |
| request_id, | |
| req.stock_code, | |
| req.stock_name, | |
| req.market, | |
| "pending", | |
| now, | |
| 1, | |
| username, | |
| is_public, | |
| is_vip, | |
| ), | |
| ) | |
| conn.commit() | |
| return {"success": True, "request_id": request_id, "username": username} | |
| except Exception as e: | |
| return JSONResponse(status_code=500, content={"error": str(e)}) | |
| finally: | |
| conn.close() | |
| async def get_pending_tasks(): | |
| """获取待处理任务""" | |
| conn = get_connection() | |
| try: | |
| with conn.cursor() as cursor: | |
| cursor.execute( | |
| "SELECT * FROM requests WHERE status = 'pending' ORDER BY created_at ASC" | |
| ) | |
| tasks = cursor.fetchall() | |
| return {"tasks": tasks} | |
| except Exception as e: | |
| return JSONResponse(status_code=500, content={"error": str(e)}) | |
| finally: | |
| conn.close() | |
| async def get_task_status(id: str = Query(...)): | |
| """获取任务状态""" | |
| conn = get_connection() | |
| try: | |
| with conn.cursor() as cursor: | |
| cursor.execute("SELECT * FROM requests WHERE id = %s", (id,)) | |
| result = cursor.fetchone() | |
| if not result: | |
| return JSONResponse(status_code=404, content={"error": "Not found"}) | |
| return result | |
| except Exception as e: | |
| return JSONResponse(status_code=500, content={"error": str(e)}) | |
| finally: | |
| conn.close() | |
| async def complete_task(req: CompleteRequest): | |
| """完成任务并保存报告""" | |
| conn = get_connection() | |
| try: | |
| with conn.cursor() as cursor: | |
| now = int(time.time() * 1000) | |
| cursor.execute( | |
| "UPDATE requests SET status = %s, completed_at = %s WHERE id = %s", | |
| (req.status, now, req.id), | |
| ) | |
| if req.status == "completed" and req.html_content and req.stock_code: | |
| # 使用 stock_code 作为稳定 id,确保每支股票只保留最新一份报告 | |
| report_id = str(req.stock_code) | |
| # 历史遗留:reports 表可能存在同一 stock_code 的多条记录,先清理再写入 | |
| cursor.execute( | |
| "DELETE FROM reports WHERE stock_code = %s AND id <> %s", | |
| (req.stock_code, report_id), | |
| ) | |
| cursor.execute( | |
| "INSERT INTO reports (id, stock_code, stock_name, html_content, created_at) VALUES (%s, %s, %s, %s, %s) ON DUPLICATE KEY UPDATE stock_name = VALUES(stock_name), html_content = VALUES(html_content), created_at = VALUES(created_at)", | |
| (report_id, req.stock_code, req.stock_name, req.html_content, now), | |
| ) | |
| conn.commit() | |
| return {"success": True} | |
| except Exception as e: | |
| return JSONResponse(status_code=500, content={"error": str(e)}) | |
| finally: | |
| conn.close() | |
| async def get_report(code: str = Query(...)): | |
| """获取报告 HTML""" | |
| conn = get_connection() | |
| try: | |
| with conn.cursor() as cursor: | |
| # reports 表可能存在同一 stock_code 的多条历史记录(按 created_at 递增) | |
| # 这里必须取最新的一条,否则会出现“列表时间是最新、打开报告却是旧内容”的错配 | |
| cursor.execute( | |
| "SELECT html_content FROM reports WHERE stock_code = %s ORDER BY created_at DESC LIMIT 1", | |
| (code,), | |
| ) | |
| result = cursor.fetchone() | |
| if not result or not result.get("html_content"): | |
| return HTMLResponse( | |
| content=f"<h1>Report not found for {code}</h1>", status_code=404 | |
| ) | |
| return HTMLResponse(content=result["html_content"]) | |
| except Exception as e: | |
| return JSONResponse(status_code=500, content={"error": str(e)}) | |
| finally: | |
| conn.close() | |
| async def get_history( | |
| date: Optional[str] = Query(None), username: Optional[str] = Query(None) | |
| ): | |
| """获取历史报告列表,支持按日期和用户名筛选""" | |
| conn = get_connection() | |
| try: | |
| with conn.cursor() as cursor: | |
| # 如果指定了 username,只返回该用户的任务 | |
| if username: | |
| if date: | |
| from datetime import datetime | |
| start_dt = datetime.strptime(date, "%Y-%m-%d") | |
| start_ts = int(start_dt.timestamp() * 1000) | |
| end_ts = start_ts + 86400000 | |
| cursor.execute( | |
| """ | |
| SELECT * FROM requests | |
| WHERE username = %s AND status = 'completed' AND completed_at >= %s AND completed_at < %s | |
| ORDER BY completed_at DESC | |
| """, | |
| (username, start_ts, end_ts), | |
| ) | |
| else: | |
| cursor.execute( | |
| """ | |
| SELECT * FROM requests | |
| WHERE username = %s AND status IN ('completed', 'error', 'pending', 'processing') | |
| ORDER BY created_at DESC | |
| LIMIT 100 | |
| """, | |
| (username,), | |
| ) | |
| elif date: | |
| from datetime import datetime | |
| start_dt = datetime.strptime(date, "%Y-%m-%d") | |
| start_ts = int(start_dt.timestamp() * 1000) | |
| end_ts = start_ts + 86400000 | |
| # 首页调用(无username):只返回公开报告 (is_public = 1 或 is_public IS NULL) | |
| cursor.execute( | |
| """ | |
| SELECT r1.* FROM requests r1 | |
| INNER JOIN ( | |
| SELECT stock_code, MAX(completed_at) as max_completed_at | |
| FROM requests | |
| WHERE status = 'completed' AND completed_at >= %s AND completed_at < %s | |
| AND (is_public = 1 OR is_public IS NULL) | |
| GROUP BY stock_code | |
| ) r2 ON r1.stock_code = r2.stock_code AND r1.completed_at = r2.max_completed_at | |
| WHERE r1.status = 'completed' AND (r1.is_public = 1 OR r1.is_public IS NULL) | |
| ORDER BY r1.completed_at DESC | |
| """, | |
| (start_ts, end_ts), | |
| ) | |
| else: | |
| # 首页调用(无username,无date):只返回公开报告 (is_public = 1 或 is_public IS NULL) | |
| cursor.execute(""" | |
| SELECT r1.* FROM requests r1 | |
| INNER JOIN ( | |
| SELECT stock_code, MAX(completed_at) as max_completed_at | |
| FROM requests | |
| WHERE status = 'completed' | |
| AND (is_public = 1 OR is_public IS NULL) | |
| GROUP BY stock_code | |
| ) r2 ON r1.stock_code = r2.stock_code AND r1.completed_at = r2.max_completed_at | |
| WHERE r1.status = 'completed' AND (r1.is_public = 1 OR r1.is_public IS NULL) | |
| ORDER BY r1.completed_at DESC | |
| """) | |
| tasks = cursor.fetchall() | |
| return {"tasks": tasks} | |
| except Exception as e: | |
| return JSONResponse(status_code=500, content={"error": str(e)}) | |
| finally: | |
| conn.close() | |
| async def check_report(code: str = Query(...)): | |
| """检查报告是否存在""" | |
| conn = get_connection() | |
| try: | |
| with conn.cursor() as cursor: | |
| cursor.execute( | |
| "SELECT stock_code, stock_name, created_at FROM reports WHERE stock_code = %s ORDER BY created_at DESC LIMIT 1", | |
| (code,), | |
| ) | |
| result = cursor.fetchone() | |
| if result: | |
| return { | |
| "exists": True, | |
| "stock_code": result["stock_code"], | |
| "stock_name": result["stock_name"], | |
| "created_at": result["created_at"], | |
| } | |
| return {"exists": False} | |
| except Exception as e: | |
| return JSONResponse(status_code=500, content={"error": str(e)}) | |
| finally: | |
| conn.close() | |
| async def reorder_task(req: ReorderRequest): | |
| """任务置顶""" | |
| conn = get_connection() | |
| try: | |
| with conn.cursor() as cursor: | |
| cursor.execute( | |
| "SELECT MIN(created_at) as min_time FROM requests WHERE status = 'pending'" | |
| ) | |
| result = cursor.fetchone() | |
| min_time = ( | |
| result["min_time"] | |
| if result and result["min_time"] | |
| else int(time.time() * 1000) | |
| ) | |
| new_time = min_time - 1 | |
| cursor.execute( | |
| "UPDATE requests SET created_at = %s WHERE id = %s AND status = 'pending'", | |
| (new_time, req.id), | |
| ) | |
| conn.commit() | |
| return {"success": True, "new_time": new_time} | |
| except Exception as e: | |
| return JSONResponse(status_code=500, content={"error": str(e)}) | |
| finally: | |
| conn.close() | |
| async def remove_duplicates(req: RemoveDuplicatesRequest): | |
| """批量删除重复任务""" | |
| if not req.ids: | |
| return JSONResponse( | |
| status_code=400, content={"error": "Missing or invalid task IDs"} | |
| ) | |
| conn = get_connection() | |
| try: | |
| with conn.cursor() as cursor: | |
| placeholders = ", ".join(["%s"] * len(req.ids)) | |
| cursor.execute( | |
| f"DELETE FROM requests WHERE id IN ({placeholders}) AND status = 'pending'", | |
| tuple(req.ids), | |
| ) | |
| deleted = cursor.rowcount | |
| conn.commit() | |
| return {"success": True, "deleted": deleted} | |
| except Exception as e: | |
| return JSONResponse(status_code=500, content={"error": str(e)}) | |
| finally: | |
| conn.close() | |
| async def get_queue(): | |
| """获取完整任务队列""" | |
| conn = get_connection() | |
| try: | |
| with conn.cursor() as cursor: | |
| cursor.execute( | |
| "SELECT * FROM requests WHERE status IN ('pending', 'processing') ORDER BY created_at ASC" | |
| ) | |
| tasks = cursor.fetchall() | |
| return {"tasks": tasks} | |
| except Exception as e: | |
| return JSONResponse(status_code=500, content={"error": str(e)}) | |
| finally: | |
| conn.close() | |
| # ============ 板块轮动报告 API ============ | |
| async def get_sector_rotation_reports(limit: int = Query(default=10)): | |
| """获取最近的板块轮动报告列表""" | |
| conn = get_connection() | |
| try: | |
| with conn.cursor() as cursor: | |
| cursor.execute( | |
| """ | |
| SELECT trade_date, updated_at | |
| FROM sector_rotation_reports | |
| ORDER BY trade_date DESC | |
| LIMIT %s | |
| """, | |
| (limit,), | |
| ) | |
| reports = cursor.fetchall() | |
| # 格式化返回数据 | |
| result = [] | |
| for report in reports: | |
| result.append( | |
| { | |
| "trade_date": report["trade_date"], | |
| "update_time": time.strftime( | |
| "%Y-%m-%d %H:%M:%S", | |
| time.localtime(report["updated_at"] / 1000), | |
| ) | |
| if report["updated_at"] | |
| else None, | |
| "updated_at": report["updated_at"], | |
| } | |
| ) | |
| return {"reports": result} | |
| except Exception as e: | |
| return JSONResponse(status_code=500, content={"error": str(e)}) | |
| finally: | |
| conn.close() | |
| async def get_sector_rotation_report(trade_date: str): | |
| """获取指定交易日的板块轮动报告""" | |
| conn = get_connection() | |
| try: | |
| with conn.cursor() as cursor: | |
| cursor.execute( | |
| """ | |
| SELECT trade_date, html_content, updated_at | |
| FROM sector_rotation_reports | |
| WHERE trade_date = %s | |
| """, | |
| (trade_date,), | |
| ) | |
| report = cursor.fetchone() | |
| if not report: | |
| return JSONResponse( | |
| status_code=404, | |
| content={"error": f"未找到 {trade_date} 的板块轮动报告"}, | |
| ) | |
| return { | |
| "trade_date": report["trade_date"], | |
| "html_content": report["html_content"], | |
| "updated_at": report["updated_at"], | |
| } | |
| except Exception as e: | |
| return JSONResponse(status_code=500, content={"error": str(e)}) | |
| finally: | |
| conn.close() | |
| async def get_sector_rotation_report_html(trade_date: str): | |
| """获取指定交易日的板块轮动报告(直接返回HTML页面)""" | |
| conn = get_connection() | |
| try: | |
| with conn.cursor() as cursor: | |
| cursor.execute( | |
| """ | |
| SELECT html_content | |
| FROM sector_rotation_reports | |
| WHERE trade_date = %s | |
| """, | |
| (trade_date,), | |
| ) | |
| report = cursor.fetchone() | |
| if not report or not report["html_content"]: | |
| return HTMLResponse( | |
| content=f"<html><body><h1>未找到 {trade_date} 的板块轮动报告</h1></body></html>", | |
| status_code=404, | |
| ) | |
| return HTMLResponse(content=report["html_content"]) | |
| except Exception as e: | |
| return HTMLResponse( | |
| content=f"<html><body><h1>加载失败</h1><p>{str(e)}</p></body></html>", | |
| status_code=500, | |
| ) | |
| finally: | |
| conn.close() | |
| # ============ 龙头股对比分析报告 API ============ | |
| async def get_longtou_compare_reports(limit: int = Query(default=10)): | |
| """获取最近的龙头股对比分析报告列表""" | |
| conn = get_connection() | |
| try: | |
| with conn.cursor() as cursor: | |
| cursor.execute( | |
| """ | |
| SELECT trade_date, updated_at | |
| FROM longtou_compare_reports | |
| ORDER BY trade_date DESC | |
| LIMIT %s | |
| """, | |
| (limit,), | |
| ) | |
| reports = cursor.fetchall() | |
| result = [] | |
| for report in reports: | |
| result.append( | |
| { | |
| "trade_date": report["trade_date"], | |
| "update_time": time.strftime( | |
| "%Y-%m-%d %H:%M:%S", | |
| time.localtime(report["updated_at"] / 1000), | |
| ) | |
| if report["updated_at"] | |
| else None, | |
| "updated_at": report["updated_at"], | |
| } | |
| ) | |
| return {"reports": result} | |
| except Exception as e: | |
| return JSONResponse(status_code=500, content={"error": str(e)}) | |
| finally: | |
| conn.close() | |
| async def get_longtou_compare_report(trade_date: str): | |
| """获取指定交易日的龙头股对比分析报告""" | |
| conn = get_connection() | |
| try: | |
| with conn.cursor() as cursor: | |
| cursor.execute( | |
| """ | |
| SELECT trade_date, html_content, updated_at | |
| FROM longtou_compare_reports | |
| WHERE trade_date = %s | |
| """, | |
| (trade_date,), | |
| ) | |
| report = cursor.fetchone() | |
| if not report: | |
| return JSONResponse( | |
| status_code=404, | |
| content={"error": f"未找到 {trade_date} 的龙头股对比分析报告"}, | |
| ) | |
| return { | |
| "trade_date": report["trade_date"], | |
| "html_content": report["html_content"], | |
| "updated_at": report["updated_at"], | |
| } | |
| except Exception as e: | |
| return JSONResponse(status_code=500, content={"error": str(e)}) | |
| finally: | |
| conn.close() | |
| async def get_longtou_compare_report_html(trade_date: str): | |
| """获取指定交易日的龙头股对比分析报告(直接返回HTML页面)""" | |
| conn = get_connection() | |
| try: | |
| with conn.cursor() as cursor: | |
| cursor.execute( | |
| """ | |
| SELECT html_content | |
| FROM longtou_compare_reports | |
| WHERE trade_date = %s | |
| """, | |
| (trade_date,), | |
| ) | |
| report = cursor.fetchone() | |
| if not report or not report["html_content"]: | |
| return HTMLResponse( | |
| content=f"<html><body><h1>未找到 {trade_date} 的龙头股对比分析报告</h1></body></html>", | |
| status_code=404, | |
| ) | |
| return HTMLResponse(content=report["html_content"]) | |
| except Exception as e: | |
| return HTMLResponse( | |
| content=f"<html><body><h1>加载失败</h1><p>{str(e)}</p></body></html>", | |
| status_code=500, | |
| ) | |
| finally: | |
| conn.close() | |
| # Hugging Face Spaces 使用 7860 端口 | |
| if __name__ == "__main__": | |
| import uvicorn | |
| port = int(os.getenv("PORT", 7860)) | |
| print(f"🚀 Starting MySQL API Server on port {port}...") | |
| uvicorn.run(app, host="0.0.0.0", port=port) | |