aistock-api / app.py
fromozu's picture
Upload app.py
d67ba32 verified
# Hugging Face Spaces - MySQL API Server
# 此文件专为 Hugging Face Spaces 部署优化
import os
import time
import uuid
from contextlib import asynccontextmanager
from typing import List, Optional
import pymysql
from dbutils.pooled_db import PooledDB
from fastapi import FastAPI, HTTPException, Query
from fastapi.middleware.cors import CORSMiddleware
from fastapi.responses import HTMLResponse, JSONResponse
from pydantic import BaseModel
from pymysql.cursors import DictCursor
# 数据库配置 - 从环境变量读取
DB_CONFIG = {
"host": os.getenv("DB_HOST", "114.116.200.230"),
"port": int(os.getenv("DB_PORT", 3306)),
"user": os.getenv("DB_USER", "aistock_admin"),
"password": os.getenv("DB_PASSWORD", ""), # 必须通过环境变量设置
"database": os.getenv("DB_NAME", "aistock"),
"charset": "utf8mb4",
"cursorclass": DictCursor,
}
# 连接池
pool: Optional[PooledDB] = None
def init_pool():
"""初始化数据库连接池"""
global pool
if not DB_CONFIG["password"]:
print("⚠️ 警告: DB_PASSWORD 环境变量未设置!")
return
pool = PooledDB(
creator=pymysql,
maxconnections=5, # Hugging Face 免费版资源有限
mincached=1,
maxcached=3,
blocking=True,
maxusage=None,
setsession=[],
ping=1,
**DB_CONFIG,
)
print(
f"✅ MySQL 连接池已初始化: {DB_CONFIG['host']}:{DB_CONFIG['port']}/{DB_CONFIG['database']}"
)
def get_connection():
"""从连接池获取连接"""
if not pool:
raise HTTPException(status_code=500, detail="数据库未配置")
return pool.connection()
@asynccontextmanager
async def lifespan(app: FastAPI):
"""应用生命周期管理"""
init_pool()
yield
if pool:
pool.close()
# FastAPI 应用
app = FastAPI(
title="Stock Analysis MySQL API",
description="MySQL API Server for Stock Analysis Agent (Hugging Face Spaces)",
version="1.0.0",
lifespan=lifespan,
)
# CORS 中间件 - 允许 Cloudflare Worker 访问
app.add_middleware(
CORSMiddleware,
allow_origins=["*"],
allow_credentials=True,
allow_methods=["*"],
allow_headers=["*"],
)
# ============ Pydantic 模型 ============
class SubmitRequest(BaseModel):
stock_code: str
stock_name: str
market: Optional[str] = ""
username: Optional[str] = None # 提交任务的用户
submitted_by: Optional[str] = None # 同上,兼容 Cloudflare Worker
submitted_at: Optional[str] = None # 提交时间
is_public: Optional[int] = 1 # 报告是否公开: 1=公开(默认), 0=私有
is_vip: Optional[int] = 0 # VIP用户提交的任务: 1=VIP, 0=普通
class CompleteRequest(BaseModel):
id: str
stock_code: Optional[str] = None
stock_name: Optional[str] = None
html_content: Optional[str] = None
status: Optional[str] = "completed"
class ReorderRequest(BaseModel):
id: str
class RemoveDuplicatesRequest(BaseModel):
ids: List[str]
# ============ API 端点 ============
@app.get("/")
async def root():
"""健康检查"""
return {
"status": "ok",
"service": "MySQL API Server",
"platform": "Hugging Face Spaces",
"db_configured": bool(DB_CONFIG["password"]),
}
@app.post("/api/submit")
async def submit_task(req: SubmitRequest):
"""提交新任务"""
conn = get_connection()
# 获取 username(优先使用 username,其次 submitted_by)
username = req.username or req.submitted_by or None
try:
with conn.cursor() as cursor:
cursor.execute(
"SELECT id, submit_count FROM requests WHERE stock_code = %s AND stock_name = %s AND status IN ('pending', 'processing') LIMIT 1",
(req.stock_code, req.stock_name),
)
existing = cursor.fetchone()
if existing:
new_count = (existing.get("submit_count") or 1) + 1
cursor.execute(
"UPDATE requests SET submit_count = %s WHERE id = %s",
(new_count, existing["id"]),
)
conn.commit()
return {
"success": True,
"request_id": existing["id"],
"duplicate": True,
}
request_id = str(uuid.uuid4())
now = int(time.time() * 1000)
# 处理 is_public 字段:确保值为 0 或 1
is_public = 1 if req.is_public is None or req.is_public else 0
if req.is_public == 1 or req.is_public == True:
is_public = 1
elif req.is_public == 0 or req.is_public == False:
is_public = 0
else:
is_public = 1 # 默认公开
# 处理 is_vip 字段
is_vip = 1 if req.is_vip == 1 or req.is_vip == True else 0
cursor.execute(
"INSERT INTO requests (id, stock_code, stock_name, market, status, created_at, submit_count, username, is_public, is_vip) VALUES (%s, %s, %s, %s, %s, %s, %s, %s, %s, %s)",
(
request_id,
req.stock_code,
req.stock_name,
req.market,
"pending",
now,
1,
username,
is_public,
is_vip,
),
)
conn.commit()
return {"success": True, "request_id": request_id, "username": username}
except Exception as e:
return JSONResponse(status_code=500, content={"error": str(e)})
finally:
conn.close()
@app.get("/api/pending")
async def get_pending_tasks():
"""获取待处理任务"""
conn = get_connection()
try:
with conn.cursor() as cursor:
cursor.execute(
"SELECT * FROM requests WHERE status = 'pending' ORDER BY created_at ASC"
)
tasks = cursor.fetchall()
return {"tasks": tasks}
except Exception as e:
return JSONResponse(status_code=500, content={"error": str(e)})
finally:
conn.close()
@app.get("/api/status")
async def get_task_status(id: str = Query(...)):
"""获取任务状态"""
conn = get_connection()
try:
with conn.cursor() as cursor:
cursor.execute("SELECT * FROM requests WHERE id = %s", (id,))
result = cursor.fetchone()
if not result:
return JSONResponse(status_code=404, content={"error": "Not found"})
return result
except Exception as e:
return JSONResponse(status_code=500, content={"error": str(e)})
finally:
conn.close()
@app.post("/api/complete")
async def complete_task(req: CompleteRequest):
"""完成任务并保存报告"""
conn = get_connection()
try:
with conn.cursor() as cursor:
now = int(time.time() * 1000)
cursor.execute(
"UPDATE requests SET status = %s, completed_at = %s WHERE id = %s",
(req.status, now, req.id),
)
if req.status == "completed" and req.html_content and req.stock_code:
# 使用 stock_code 作为稳定 id,确保每支股票只保留最新一份报告
report_id = str(req.stock_code)
# 历史遗留:reports 表可能存在同一 stock_code 的多条记录,先清理再写入
cursor.execute(
"DELETE FROM reports WHERE stock_code = %s AND id <> %s",
(req.stock_code, report_id),
)
cursor.execute(
"INSERT INTO reports (id, stock_code, stock_name, html_content, created_at) VALUES (%s, %s, %s, %s, %s) ON DUPLICATE KEY UPDATE stock_name = VALUES(stock_name), html_content = VALUES(html_content), created_at = VALUES(created_at)",
(report_id, req.stock_code, req.stock_name, req.html_content, now),
)
conn.commit()
return {"success": True}
except Exception as e:
return JSONResponse(status_code=500, content={"error": str(e)})
finally:
conn.close()
@app.get("/api/report")
async def get_report(code: str = Query(...)):
"""获取报告 HTML"""
conn = get_connection()
try:
with conn.cursor() as cursor:
# reports 表可能存在同一 stock_code 的多条历史记录(按 created_at 递增)
# 这里必须取最新的一条,否则会出现“列表时间是最新、打开报告却是旧内容”的错配
cursor.execute(
"SELECT html_content FROM reports WHERE stock_code = %s ORDER BY created_at DESC LIMIT 1",
(code,),
)
result = cursor.fetchone()
if not result or not result.get("html_content"):
return HTMLResponse(
content=f"<h1>Report not found for {code}</h1>", status_code=404
)
return HTMLResponse(content=result["html_content"])
except Exception as e:
return JSONResponse(status_code=500, content={"error": str(e)})
finally:
conn.close()
@app.get("/api/history")
async def get_history(
date: Optional[str] = Query(None), username: Optional[str] = Query(None)
):
"""获取历史报告列表,支持按日期和用户名筛选"""
conn = get_connection()
try:
with conn.cursor() as cursor:
# 如果指定了 username,只返回该用户的任务
if username:
if date:
from datetime import datetime
start_dt = datetime.strptime(date, "%Y-%m-%d")
start_ts = int(start_dt.timestamp() * 1000)
end_ts = start_ts + 86400000
cursor.execute(
"""
SELECT * FROM requests
WHERE username = %s AND status = 'completed' AND completed_at >= %s AND completed_at < %s
ORDER BY completed_at DESC
""",
(username, start_ts, end_ts),
)
else:
cursor.execute(
"""
SELECT * FROM requests
WHERE username = %s AND status IN ('completed', 'error', 'pending', 'processing')
ORDER BY created_at DESC
LIMIT 100
""",
(username,),
)
elif date:
from datetime import datetime
start_dt = datetime.strptime(date, "%Y-%m-%d")
start_ts = int(start_dt.timestamp() * 1000)
end_ts = start_ts + 86400000
# 首页调用(无username):只返回公开报告 (is_public = 1 或 is_public IS NULL)
cursor.execute(
"""
SELECT r1.* FROM requests r1
INNER JOIN (
SELECT stock_code, MAX(completed_at) as max_completed_at
FROM requests
WHERE status = 'completed' AND completed_at >= %s AND completed_at < %s
AND (is_public = 1 OR is_public IS NULL)
GROUP BY stock_code
) r2 ON r1.stock_code = r2.stock_code AND r1.completed_at = r2.max_completed_at
WHERE r1.status = 'completed' AND (r1.is_public = 1 OR r1.is_public IS NULL)
ORDER BY r1.completed_at DESC
""",
(start_ts, end_ts),
)
else:
# 首页调用(无username,无date):只返回公开报告 (is_public = 1 或 is_public IS NULL)
cursor.execute("""
SELECT r1.* FROM requests r1
INNER JOIN (
SELECT stock_code, MAX(completed_at) as max_completed_at
FROM requests
WHERE status = 'completed'
AND (is_public = 1 OR is_public IS NULL)
GROUP BY stock_code
) r2 ON r1.stock_code = r2.stock_code AND r1.completed_at = r2.max_completed_at
WHERE r1.status = 'completed' AND (r1.is_public = 1 OR r1.is_public IS NULL)
ORDER BY r1.completed_at DESC
""")
tasks = cursor.fetchall()
return {"tasks": tasks}
except Exception as e:
return JSONResponse(status_code=500, content={"error": str(e)})
finally:
conn.close()
@app.get("/api/check-report")
async def check_report(code: str = Query(...)):
"""检查报告是否存在"""
conn = get_connection()
try:
with conn.cursor() as cursor:
cursor.execute(
"SELECT stock_code, stock_name, created_at FROM reports WHERE stock_code = %s ORDER BY created_at DESC LIMIT 1",
(code,),
)
result = cursor.fetchone()
if result:
return {
"exists": True,
"stock_code": result["stock_code"],
"stock_name": result["stock_name"],
"created_at": result["created_at"],
}
return {"exists": False}
except Exception as e:
return JSONResponse(status_code=500, content={"error": str(e)})
finally:
conn.close()
@app.post("/api/reorder")
async def reorder_task(req: ReorderRequest):
"""任务置顶"""
conn = get_connection()
try:
with conn.cursor() as cursor:
cursor.execute(
"SELECT MIN(created_at) as min_time FROM requests WHERE status = 'pending'"
)
result = cursor.fetchone()
min_time = (
result["min_time"]
if result and result["min_time"]
else int(time.time() * 1000)
)
new_time = min_time - 1
cursor.execute(
"UPDATE requests SET created_at = %s WHERE id = %s AND status = 'pending'",
(new_time, req.id),
)
conn.commit()
return {"success": True, "new_time": new_time}
except Exception as e:
return JSONResponse(status_code=500, content={"error": str(e)})
finally:
conn.close()
@app.post("/api/remove-duplicates")
async def remove_duplicates(req: RemoveDuplicatesRequest):
"""批量删除重复任务"""
if not req.ids:
return JSONResponse(
status_code=400, content={"error": "Missing or invalid task IDs"}
)
conn = get_connection()
try:
with conn.cursor() as cursor:
placeholders = ", ".join(["%s"] * len(req.ids))
cursor.execute(
f"DELETE FROM requests WHERE id IN ({placeholders}) AND status = 'pending'",
tuple(req.ids),
)
deleted = cursor.rowcount
conn.commit()
return {"success": True, "deleted": deleted}
except Exception as e:
return JSONResponse(status_code=500, content={"error": str(e)})
finally:
conn.close()
@app.get("/api/queue")
async def get_queue():
"""获取完整任务队列"""
conn = get_connection()
try:
with conn.cursor() as cursor:
cursor.execute(
"SELECT * FROM requests WHERE status IN ('pending', 'processing') ORDER BY created_at ASC"
)
tasks = cursor.fetchall()
return {"tasks": tasks}
except Exception as e:
return JSONResponse(status_code=500, content={"error": str(e)})
finally:
conn.close()
# ============ 板块轮动报告 API ============
@app.get("/api/sector-rotation-reports")
async def get_sector_rotation_reports(limit: int = Query(default=10)):
"""获取最近的板块轮动报告列表"""
conn = get_connection()
try:
with conn.cursor() as cursor:
cursor.execute(
"""
SELECT trade_date, updated_at
FROM sector_rotation_reports
ORDER BY trade_date DESC
LIMIT %s
""",
(limit,),
)
reports = cursor.fetchall()
# 格式化返回数据
result = []
for report in reports:
result.append(
{
"trade_date": report["trade_date"],
"update_time": time.strftime(
"%Y-%m-%d %H:%M:%S",
time.localtime(report["updated_at"] / 1000),
)
if report["updated_at"]
else None,
"updated_at": report["updated_at"],
}
)
return {"reports": result}
except Exception as e:
return JSONResponse(status_code=500, content={"error": str(e)})
finally:
conn.close()
@app.get("/api/sector-rotation-reports/{trade_date}")
async def get_sector_rotation_report(trade_date: str):
"""获取指定交易日的板块轮动报告"""
conn = get_connection()
try:
with conn.cursor() as cursor:
cursor.execute(
"""
SELECT trade_date, html_content, updated_at
FROM sector_rotation_reports
WHERE trade_date = %s
""",
(trade_date,),
)
report = cursor.fetchone()
if not report:
return JSONResponse(
status_code=404,
content={"error": f"未找到 {trade_date} 的板块轮动报告"},
)
return {
"trade_date": report["trade_date"],
"html_content": report["html_content"],
"updated_at": report["updated_at"],
}
except Exception as e:
return JSONResponse(status_code=500, content={"error": str(e)})
finally:
conn.close()
@app.get("/api/sector-rotation-reports/{trade_date}/html")
async def get_sector_rotation_report_html(trade_date: str):
"""获取指定交易日的板块轮动报告(直接返回HTML页面)"""
conn = get_connection()
try:
with conn.cursor() as cursor:
cursor.execute(
"""
SELECT html_content
FROM sector_rotation_reports
WHERE trade_date = %s
""",
(trade_date,),
)
report = cursor.fetchone()
if not report or not report["html_content"]:
return HTMLResponse(
content=f"<html><body><h1>未找到 {trade_date} 的板块轮动报告</h1></body></html>",
status_code=404,
)
return HTMLResponse(content=report["html_content"])
except Exception as e:
return HTMLResponse(
content=f"<html><body><h1>加载失败</h1><p>{str(e)}</p></body></html>",
status_code=500,
)
finally:
conn.close()
# ============ 龙头股对比分析报告 API ============
@app.get("/api/longtou-compare-reports")
async def get_longtou_compare_reports(limit: int = Query(default=10)):
"""获取最近的龙头股对比分析报告列表"""
conn = get_connection()
try:
with conn.cursor() as cursor:
cursor.execute(
"""
SELECT trade_date, updated_at
FROM longtou_compare_reports
ORDER BY trade_date DESC
LIMIT %s
""",
(limit,),
)
reports = cursor.fetchall()
result = []
for report in reports:
result.append(
{
"trade_date": report["trade_date"],
"update_time": time.strftime(
"%Y-%m-%d %H:%M:%S",
time.localtime(report["updated_at"] / 1000),
)
if report["updated_at"]
else None,
"updated_at": report["updated_at"],
}
)
return {"reports": result}
except Exception as e:
return JSONResponse(status_code=500, content={"error": str(e)})
finally:
conn.close()
@app.get("/api/longtou-compare-reports/{trade_date}")
async def get_longtou_compare_report(trade_date: str):
"""获取指定交易日的龙头股对比分析报告"""
conn = get_connection()
try:
with conn.cursor() as cursor:
cursor.execute(
"""
SELECT trade_date, html_content, updated_at
FROM longtou_compare_reports
WHERE trade_date = %s
""",
(trade_date,),
)
report = cursor.fetchone()
if not report:
return JSONResponse(
status_code=404,
content={"error": f"未找到 {trade_date} 的龙头股对比分析报告"},
)
return {
"trade_date": report["trade_date"],
"html_content": report["html_content"],
"updated_at": report["updated_at"],
}
except Exception as e:
return JSONResponse(status_code=500, content={"error": str(e)})
finally:
conn.close()
@app.get("/api/longtou-compare-reports/{trade_date}/html")
async def get_longtou_compare_report_html(trade_date: str):
"""获取指定交易日的龙头股对比分析报告(直接返回HTML页面)"""
conn = get_connection()
try:
with conn.cursor() as cursor:
cursor.execute(
"""
SELECT html_content
FROM longtou_compare_reports
WHERE trade_date = %s
""",
(trade_date,),
)
report = cursor.fetchone()
if not report or not report["html_content"]:
return HTMLResponse(
content=f"<html><body><h1>未找到 {trade_date} 的龙头股对比分析报告</h1></body></html>",
status_code=404,
)
return HTMLResponse(content=report["html_content"])
except Exception as e:
return HTMLResponse(
content=f"<html><body><h1>加载失败</h1><p>{str(e)}</p></body></html>",
status_code=500,
)
finally:
conn.close()
# Hugging Face Spaces 使用 7860 端口
if __name__ == "__main__":
import uvicorn
port = int(os.getenv("PORT", 7860))
print(f"🚀 Starting MySQL API Server on port {port}...")
uvicorn.run(app, host="0.0.0.0", port=port)