Spaces:
Running
Running
Upload app.py
Browse files
app.py
CHANGED
|
@@ -2,28 +2,28 @@
|
|
| 2 |
# 此文件专为 Hugging Face Spaces 部署优化
|
| 3 |
|
| 4 |
import os
|
| 5 |
-
import uuid
|
| 6 |
import time
|
| 7 |
-
|
| 8 |
from contextlib import asynccontextmanager
|
|
|
|
| 9 |
|
|
|
|
|
|
|
| 10 |
from fastapi import FastAPI, HTTPException, Query
|
| 11 |
from fastapi.middleware.cors import CORSMiddleware
|
| 12 |
from fastapi.responses import HTMLResponse, JSONResponse
|
| 13 |
from pydantic import BaseModel
|
| 14 |
-
import pymysql
|
| 15 |
from pymysql.cursors import DictCursor
|
| 16 |
-
from dbutils.pooled_db import PooledDB
|
| 17 |
|
| 18 |
# 数据库配置 - 从环境变量读取
|
| 19 |
DB_CONFIG = {
|
| 20 |
-
|
| 21 |
-
|
| 22 |
-
|
| 23 |
-
|
| 24 |
-
|
| 25 |
-
|
| 26 |
-
|
| 27 |
}
|
| 28 |
|
| 29 |
# 连接池
|
|
@@ -33,22 +33,24 @@ pool: Optional[PooledDB] = None
|
|
| 33 |
def init_pool():
|
| 34 |
"""初始化数据库连接池"""
|
| 35 |
global pool
|
| 36 |
-
if not DB_CONFIG[
|
| 37 |
print("⚠️ 警告: DB_PASSWORD 环境变量未设置!")
|
| 38 |
return
|
| 39 |
-
|
| 40 |
pool = PooledDB(
|
| 41 |
creator=pymysql,
|
| 42 |
-
maxconnections=5,
|
| 43 |
mincached=1,
|
| 44 |
maxcached=3,
|
| 45 |
blocking=True,
|
| 46 |
maxusage=None,
|
| 47 |
setsession=[],
|
| 48 |
ping=1,
|
| 49 |
-
**DB_CONFIG
|
|
|
|
|
|
|
|
|
|
| 50 |
)
|
| 51 |
-
print(f"✅ MySQL 连接池已初始化: {DB_CONFIG['host']}:{DB_CONFIG['port']}/{DB_CONFIG['database']}")
|
| 52 |
|
| 53 |
|
| 54 |
def get_connection():
|
|
@@ -72,7 +74,7 @@ app = FastAPI(
|
|
| 72 |
title="Stock Analysis MySQL API",
|
| 73 |
description="MySQL API Server for Stock Analysis Agent (Hugging Face Spaces)",
|
| 74 |
version="1.0.0",
|
| 75 |
-
lifespan=lifespan
|
| 76 |
)
|
| 77 |
|
| 78 |
# CORS 中间件 - 允许 Cloudflare Worker 访问
|
|
@@ -87,6 +89,7 @@ app.add_middleware(
|
|
| 87 |
|
| 88 |
# ============ Pydantic 模型 ============
|
| 89 |
|
|
|
|
| 90 |
class SubmitRequest(BaseModel):
|
| 91 |
stock_code: str
|
| 92 |
stock_name: str
|
|
@@ -116,6 +119,7 @@ class RemoveDuplicatesRequest(BaseModel):
|
|
| 116 |
|
| 117 |
# ============ API 端点 ============
|
| 118 |
|
|
|
|
| 119 |
@app.get("/")
|
| 120 |
async def root():
|
| 121 |
"""健康检查"""
|
|
@@ -123,7 +127,7 @@ async def root():
|
|
| 123 |
"status": "ok",
|
| 124 |
"service": "MySQL API Server",
|
| 125 |
"platform": "Hugging Face Spaces",
|
| 126 |
-
"db_configured": bool(DB_CONFIG[
|
| 127 |
}
|
| 128 |
|
| 129 |
|
|
@@ -133,24 +137,28 @@ async def submit_task(req: SubmitRequest):
|
|
| 133 |
conn = get_connection()
|
| 134 |
# 获取 username(优先使用 username,其次 submitted_by)
|
| 135 |
username = req.username or req.submitted_by or None
|
| 136 |
-
|
| 137 |
try:
|
| 138 |
with conn.cursor() as cursor:
|
| 139 |
cursor.execute(
|
| 140 |
"SELECT id, submit_count FROM requests WHERE stock_code = %s AND stock_name = %s AND status IN ('pending', 'processing') LIMIT 1",
|
| 141 |
-
(req.stock_code, req.stock_name)
|
| 142 |
)
|
| 143 |
existing = cursor.fetchone()
|
| 144 |
-
|
| 145 |
if existing:
|
| 146 |
-
new_count = (existing.get(
|
| 147 |
cursor.execute(
|
| 148 |
"UPDATE requests SET submit_count = %s WHERE id = %s",
|
| 149 |
-
(new_count, existing[
|
| 150 |
)
|
| 151 |
conn.commit()
|
| 152 |
-
return {
|
| 153 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
| 154 |
request_id = str(uuid.uuid4())
|
| 155 |
now = int(time.time() * 1000)
|
| 156 |
# 处理 is_public 字段:确保值为 0 或 1
|
|
@@ -161,13 +169,24 @@ async def submit_task(req: SubmitRequest):
|
|
| 161 |
is_public = 0
|
| 162 |
else:
|
| 163 |
is_public = 1 # 默认公开
|
| 164 |
-
|
| 165 |
# 处理 is_vip 字段
|
| 166 |
is_vip = 1 if req.is_vip == 1 or req.is_vip == True else 0
|
| 167 |
-
|
| 168 |
cursor.execute(
|
| 169 |
"INSERT INTO requests (id, stock_code, stock_name, market, status, created_at, submit_count, username, is_public, is_vip) VALUES (%s, %s, %s, %s, %s, %s, %s, %s, %s, %s)",
|
| 170 |
-
(
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 171 |
)
|
| 172 |
conn.commit()
|
| 173 |
return {"success": True, "request_id": request_id, "username": username}
|
|
@@ -218,20 +237,20 @@ async def complete_task(req: CompleteRequest):
|
|
| 218 |
try:
|
| 219 |
with conn.cursor() as cursor:
|
| 220 |
now = int(time.time() * 1000)
|
| 221 |
-
|
| 222 |
cursor.execute(
|
| 223 |
"UPDATE requests SET status = %s, completed_at = %s WHERE id = %s",
|
| 224 |
-
(req.status, now, req.id)
|
| 225 |
)
|
| 226 |
-
|
| 227 |
-
if req.status ==
|
| 228 |
# 生成报告 ID
|
| 229 |
report_id = str(uuid.uuid4())
|
| 230 |
cursor.execute(
|
| 231 |
"INSERT INTO reports (id, stock_code, stock_name, html_content, created_at) VALUES (%s, %s, %s, %s, %s) ON DUPLICATE KEY UPDATE stock_name = VALUES(stock_name), html_content = VALUES(html_content), created_at = VALUES(created_at)",
|
| 232 |
-
(report_id, req.stock_code, req.stock_name, req.html_content, now)
|
| 233 |
)
|
| 234 |
-
|
| 235 |
conn.commit()
|
| 236 |
return {"success": True}
|
| 237 |
except Exception as e:
|
|
@@ -246,14 +265,18 @@ async def get_report(code: str = Query(...)):
|
|
| 246 |
conn = get_connection()
|
| 247 |
try:
|
| 248 |
with conn.cursor() as cursor:
|
| 249 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 250 |
result = cursor.fetchone()
|
| 251 |
-
if not result or not result.get(
|
| 252 |
return HTMLResponse(
|
| 253 |
-
content=f"<h1>Report not found for {code}</h1>",
|
| 254 |
-
status_code=404
|
| 255 |
)
|
| 256 |
-
return HTMLResponse(content=result[
|
| 257 |
except Exception as e:
|
| 258 |
return JSONResponse(status_code=500, content={"error": str(e)})
|
| 259 |
finally:
|
|
@@ -261,7 +284,9 @@ async def get_report(code: str = Query(...)):
|
|
| 261 |
|
| 262 |
|
| 263 |
@app.get("/api/history")
|
| 264 |
-
async def get_history(
|
|
|
|
|
|
|
| 265 |
"""获取历史报告列表,支持按日期和用户名筛选"""
|
| 266 |
conn = get_connection()
|
| 267 |
try:
|
|
@@ -270,29 +295,38 @@ async def get_history(date: Optional[str] = Query(None), username: Optional[str]
|
|
| 270 |
if username:
|
| 271 |
if date:
|
| 272 |
from datetime import datetime
|
| 273 |
-
|
|
|
|
| 274 |
start_ts = int(start_dt.timestamp() * 1000)
|
| 275 |
end_ts = start_ts + 86400000
|
| 276 |
-
cursor.execute(
|
|
|
|
| 277 |
SELECT * FROM requests
|
| 278 |
WHERE username = %s AND status = 'completed' AND completed_at >= %s AND completed_at < %s
|
| 279 |
ORDER BY completed_at DESC
|
| 280 |
-
""",
|
|
|
|
|
|
|
| 281 |
else:
|
| 282 |
-
cursor.execute(
|
|
|
|
| 283 |
SELECT * FROM requests
|
| 284 |
WHERE username = %s AND status IN ('completed', 'error', 'pending', 'processing')
|
| 285 |
ORDER BY created_at DESC
|
| 286 |
LIMIT 100
|
| 287 |
-
""",
|
|
|
|
|
|
|
| 288 |
elif date:
|
| 289 |
from datetime import datetime
|
| 290 |
-
|
|
|
|
| 291 |
start_ts = int(start_dt.timestamp() * 1000)
|
| 292 |
end_ts = start_ts + 86400000
|
| 293 |
-
|
| 294 |
# 首页调用(无username):只返回公开报告 (is_public = 1 或 is_public IS NULL)
|
| 295 |
-
cursor.execute(
|
|
|
|
| 296 |
SELECT r1.* FROM requests r1
|
| 297 |
INNER JOIN (
|
| 298 |
SELECT stock_code, MAX(completed_at) as max_completed_at
|
|
@@ -303,7 +337,9 @@ async def get_history(date: Optional[str] = Query(None), username: Optional[str]
|
|
| 303 |
) r2 ON r1.stock_code = r2.stock_code AND r1.completed_at = r2.max_completed_at
|
| 304 |
WHERE r1.status = 'completed' AND (r1.is_public = 1 OR r1.is_public IS NULL)
|
| 305 |
ORDER BY r1.completed_at DESC
|
| 306 |
-
""",
|
|
|
|
|
|
|
| 307 |
else:
|
| 308 |
# 首页调用(无username,无date):只返回公开报告 (is_public = 1 或 is_public IS NULL)
|
| 309 |
cursor.execute("""
|
|
@@ -318,7 +354,7 @@ async def get_history(date: Optional[str] = Query(None), username: Optional[str]
|
|
| 318 |
WHERE r1.status = 'completed' AND (r1.is_public = 1 OR r1.is_public IS NULL)
|
| 319 |
ORDER BY r1.completed_at DESC
|
| 320 |
""")
|
| 321 |
-
|
| 322 |
tasks = cursor.fetchall()
|
| 323 |
return {"tasks": tasks}
|
| 324 |
except Exception as e:
|
|
@@ -334,16 +370,16 @@ async def check_report(code: str = Query(...)):
|
|
| 334 |
try:
|
| 335 |
with conn.cursor() as cursor:
|
| 336 |
cursor.execute(
|
| 337 |
-
"SELECT stock_code, stock_name, created_at FROM reports WHERE stock_code = %s",
|
| 338 |
-
(code,)
|
| 339 |
)
|
| 340 |
result = cursor.fetchone()
|
| 341 |
if result:
|
| 342 |
return {
|
| 343 |
"exists": True,
|
| 344 |
-
"stock_code": result[
|
| 345 |
-
"stock_name": result[
|
| 346 |
-
"created_at": result[
|
| 347 |
}
|
| 348 |
return {"exists": False}
|
| 349 |
except Exception as e:
|
|
@@ -362,12 +398,16 @@ async def reorder_task(req: ReorderRequest):
|
|
| 362 |
"SELECT MIN(created_at) as min_time FROM requests WHERE status = 'pending'"
|
| 363 |
)
|
| 364 |
result = cursor.fetchone()
|
| 365 |
-
min_time =
|
|
|
|
|
|
|
|
|
|
|
|
|
| 366 |
new_time = min_time - 1
|
| 367 |
-
|
| 368 |
cursor.execute(
|
| 369 |
"UPDATE requests SET created_at = %s WHERE id = %s AND status = 'pending'",
|
| 370 |
-
(new_time, req.id)
|
| 371 |
)
|
| 372 |
conn.commit()
|
| 373 |
return {"success": True, "new_time": new_time}
|
|
@@ -381,15 +421,17 @@ async def reorder_task(req: ReorderRequest):
|
|
| 381 |
async def remove_duplicates(req: RemoveDuplicatesRequest):
|
| 382 |
"""批量删除重复任务"""
|
| 383 |
if not req.ids:
|
| 384 |
-
return JSONResponse(
|
| 385 |
-
|
|
|
|
|
|
|
| 386 |
conn = get_connection()
|
| 387 |
try:
|
| 388 |
with conn.cursor() as cursor:
|
| 389 |
-
placeholders =
|
| 390 |
cursor.execute(
|
| 391 |
f"DELETE FROM requests WHERE id IN ({placeholders}) AND status = 'pending'",
|
| 392 |
-
tuple(req.ids)
|
| 393 |
)
|
| 394 |
deleted = cursor.rowcount
|
| 395 |
conn.commit()
|
|
@@ -419,29 +461,40 @@ async def get_queue():
|
|
| 419 |
|
| 420 |
# ============ 板块轮动报告 API ============
|
| 421 |
|
|
|
|
| 422 |
@app.get("/api/sector-rotation-reports")
|
| 423 |
async def get_sector_rotation_reports(limit: int = Query(default=10)):
|
| 424 |
"""获取最近的板块轮动报告列表"""
|
| 425 |
conn = get_connection()
|
| 426 |
try:
|
| 427 |
with conn.cursor() as cursor:
|
| 428 |
-
cursor.execute(
|
|
|
|
| 429 |
SELECT trade_date, updated_at
|
| 430 |
FROM sector_rotation_reports
|
| 431 |
ORDER BY trade_date DESC
|
| 432 |
LIMIT %s
|
| 433 |
-
""",
|
|
|
|
|
|
|
| 434 |
reports = cursor.fetchall()
|
| 435 |
-
|
| 436 |
# 格式化返回数据
|
| 437 |
result = []
|
| 438 |
for report in reports:
|
| 439 |
-
result.append(
|
| 440 |
-
|
| 441 |
-
|
| 442 |
-
|
| 443 |
-
|
| 444 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 445 |
return {"reports": result}
|
| 446 |
except Exception as e:
|
| 447 |
return JSONResponse(status_code=500, content={"error": str(e)})
|
|
@@ -455,23 +508,26 @@ async def get_sector_rotation_report(trade_date: str):
|
|
| 455 |
conn = get_connection()
|
| 456 |
try:
|
| 457 |
with conn.cursor() as cursor:
|
| 458 |
-
cursor.execute(
|
|
|
|
| 459 |
SELECT trade_date, html_content, updated_at
|
| 460 |
FROM sector_rotation_reports
|
| 461 |
WHERE trade_date = %s
|
| 462 |
-
""",
|
|
|
|
|
|
|
| 463 |
report = cursor.fetchone()
|
| 464 |
-
|
| 465 |
if not report:
|
| 466 |
return JSONResponse(
|
| 467 |
status_code=404,
|
| 468 |
-
content={"error": f"未找到 {trade_date} 的板块轮动报告"}
|
| 469 |
)
|
| 470 |
-
|
| 471 |
return {
|
| 472 |
-
"trade_date": report[
|
| 473 |
-
"html_content": report[
|
| 474 |
-
"updated_at": report[
|
| 475 |
}
|
| 476 |
except Exception as e:
|
| 477 |
return JSONResponse(status_code=500, content={"error": str(e)})
|
|
@@ -479,128 +535,149 @@ async def get_sector_rotation_report(trade_date: str):
|
|
| 479 |
conn.close()
|
| 480 |
|
| 481 |
|
| 482 |
-
@app.get("/api/sector-rotation-reports/{trade_date}/html")
|
| 483 |
-
async def get_sector_rotation_report_html(trade_date: str):
|
| 484 |
-
"""获取指定交易日的板块轮动报告(直接返回HTML页面)"""
|
| 485 |
-
conn = get_connection()
|
| 486 |
-
try:
|
| 487 |
with conn.cursor() as cursor:
|
| 488 |
-
cursor.execute(
|
|
|
|
| 489 |
SELECT html_content
|
| 490 |
FROM sector_rotation_reports
|
| 491 |
WHERE trade_date = %s
|
| 492 |
-
""",
|
|
|
|
|
|
|
| 493 |
report = cursor.fetchone()
|
| 494 |
-
|
| 495 |
-
if not report or not report[
|
| 496 |
return HTMLResponse(
|
| 497 |
content=f"<html><body><h1>未找到 {trade_date} 的板块轮动报告</h1></body></html>",
|
| 498 |
-
status_code=404
|
| 499 |
)
|
| 500 |
-
|
| 501 |
-
return HTMLResponse(content=report[
|
| 502 |
except Exception as e:
|
| 503 |
return HTMLResponse(
|
| 504 |
content=f"<html><body><h1>加载失败</h1><p>{str(e)}</p></body></html>",
|
| 505 |
-
status_code=500
|
| 506 |
)
|
| 507 |
-
finally:
|
| 508 |
-
conn.close()
|
| 509 |
-
|
| 510 |
-
|
| 511 |
-
# ============ 龙头股对比分析报告 API ============
|
| 512 |
-
|
| 513 |
-
|
| 514 |
-
|
| 515 |
-
|
| 516 |
-
|
| 517 |
-
|
| 518 |
-
|
| 519 |
-
|
| 520 |
-
|
| 521 |
-
|
| 522 |
-
|
| 523 |
-
|
| 524 |
-
|
| 525 |
-
|
| 526 |
-
|
| 527 |
-
|
| 528 |
-
|
| 529 |
-
|
| 530 |
-
|
| 531 |
-
|
| 532 |
-
|
| 533 |
-
|
| 534 |
-
|
| 535 |
-
|
| 536 |
-
|
| 537 |
-
|
| 538 |
-
|
| 539 |
-
|
| 540 |
-
|
| 541 |
-
|
| 542 |
-
|
| 543 |
-
|
| 544 |
-
|
| 545 |
-
|
| 546 |
-
|
| 547 |
-
|
| 548 |
-
|
| 549 |
-
|
| 550 |
-
|
| 551 |
-
|
| 552 |
-
|
| 553 |
-
|
| 554 |
-
|
| 555 |
-
|
| 556 |
-
|
| 557 |
-
|
| 558 |
-
|
| 559 |
-
|
| 560 |
-
|
| 561 |
-
|
| 562 |
-
|
| 563 |
-
|
| 564 |
-
|
| 565 |
-
|
| 566 |
-
|
| 567 |
-
|
| 568 |
-
|
| 569 |
-
|
| 570 |
-
|
| 571 |
-
|
| 572 |
-
|
| 573 |
-
|
| 574 |
-
|
| 575 |
-
|
| 576 |
-
|
| 577 |
-
|
| 578 |
-
|
| 579 |
-
|
| 580 |
-
|
| 581 |
-
|
| 582 |
-
|
| 583 |
-
|
| 584 |
-
|
| 585 |
-
|
| 586 |
-
|
| 587 |
-
|
| 588 |
-
|
| 589 |
-
|
| 590 |
-
|
| 591 |
-
|
| 592 |
-
|
| 593 |
-
|
| 594 |
-
|
| 595 |
-
|
| 596 |
-
|
| 597 |
-
|
| 598 |
-
|
| 599 |
-
|
| 600 |
-
|
| 601 |
-
|
| 602 |
-
if
|
| 603 |
-
|
| 604 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 605 |
print(f"🚀 Starting MySQL API Server on port {port}...")
|
| 606 |
uvicorn.run(app, host="0.0.0.0", port=port)
|
|
|
|
| 2 |
# 此文件专为 Hugging Face Spaces 部署优化
|
| 3 |
|
| 4 |
import os
|
|
|
|
| 5 |
import time
|
| 6 |
+
import uuid
|
| 7 |
from contextlib import asynccontextmanager
|
| 8 |
+
from typing import List, Optional
|
| 9 |
|
| 10 |
+
import pymysql
|
| 11 |
+
from dbutils.pooled_db import PooledDB
|
| 12 |
from fastapi import FastAPI, HTTPException, Query
|
| 13 |
from fastapi.middleware.cors import CORSMiddleware
|
| 14 |
from fastapi.responses import HTMLResponse, JSONResponse
|
| 15 |
from pydantic import BaseModel
|
|
|
|
| 16 |
from pymysql.cursors import DictCursor
|
|
|
|
| 17 |
|
| 18 |
# 数据库配置 - 从环境变量读取
|
| 19 |
DB_CONFIG = {
|
| 20 |
+
"host": os.getenv("DB_HOST", "114.116.200.230"),
|
| 21 |
+
"port": int(os.getenv("DB_PORT", 3306)),
|
| 22 |
+
"user": os.getenv("DB_USER", "aistock_admin"),
|
| 23 |
+
"password": os.getenv("DB_PASSWORD", ""), # 必须通过环境变量设置
|
| 24 |
+
"database": os.getenv("DB_NAME", "aistock"),
|
| 25 |
+
"charset": "utf8mb4",
|
| 26 |
+
"cursorclass": DictCursor,
|
| 27 |
}
|
| 28 |
|
| 29 |
# 连接池
|
|
|
|
| 33 |
def init_pool():
|
| 34 |
"""初始化数据库连接池"""
|
| 35 |
global pool
|
| 36 |
+
if not DB_CONFIG["password"]:
|
| 37 |
print("⚠️ 警告: DB_PASSWORD 环境变量未设置!")
|
| 38 |
return
|
| 39 |
+
|
| 40 |
pool = PooledDB(
|
| 41 |
creator=pymysql,
|
| 42 |
+
maxconnections=5, # Hugging Face 免费版资源有限
|
| 43 |
mincached=1,
|
| 44 |
maxcached=3,
|
| 45 |
blocking=True,
|
| 46 |
maxusage=None,
|
| 47 |
setsession=[],
|
| 48 |
ping=1,
|
| 49 |
+
**DB_CONFIG,
|
| 50 |
+
)
|
| 51 |
+
print(
|
| 52 |
+
f"✅ MySQL 连接池已初始化: {DB_CONFIG['host']}:{DB_CONFIG['port']}/{DB_CONFIG['database']}"
|
| 53 |
)
|
|
|
|
| 54 |
|
| 55 |
|
| 56 |
def get_connection():
|
|
|
|
| 74 |
title="Stock Analysis MySQL API",
|
| 75 |
description="MySQL API Server for Stock Analysis Agent (Hugging Face Spaces)",
|
| 76 |
version="1.0.0",
|
| 77 |
+
lifespan=lifespan,
|
| 78 |
)
|
| 79 |
|
| 80 |
# CORS 中间件 - 允许 Cloudflare Worker 访问
|
|
|
|
| 89 |
|
| 90 |
# ============ Pydantic 模型 ============
|
| 91 |
|
| 92 |
+
|
| 93 |
class SubmitRequest(BaseModel):
|
| 94 |
stock_code: str
|
| 95 |
stock_name: str
|
|
|
|
| 119 |
|
| 120 |
# ============ API 端点 ============
|
| 121 |
|
| 122 |
+
|
| 123 |
@app.get("/")
|
| 124 |
async def root():
|
| 125 |
"""健康检查"""
|
|
|
|
| 127 |
"status": "ok",
|
| 128 |
"service": "MySQL API Server",
|
| 129 |
"platform": "Hugging Face Spaces",
|
| 130 |
+
"db_configured": bool(DB_CONFIG["password"]),
|
| 131 |
}
|
| 132 |
|
| 133 |
|
|
|
|
| 137 |
conn = get_connection()
|
| 138 |
# 获取 username(优先使用 username,其次 submitted_by)
|
| 139 |
username = req.username or req.submitted_by or None
|
| 140 |
+
|
| 141 |
try:
|
| 142 |
with conn.cursor() as cursor:
|
| 143 |
cursor.execute(
|
| 144 |
"SELECT id, submit_count FROM requests WHERE stock_code = %s AND stock_name = %s AND status IN ('pending', 'processing') LIMIT 1",
|
| 145 |
+
(req.stock_code, req.stock_name),
|
| 146 |
)
|
| 147 |
existing = cursor.fetchone()
|
| 148 |
+
|
| 149 |
if existing:
|
| 150 |
+
new_count = (existing.get("submit_count") or 1) + 1
|
| 151 |
cursor.execute(
|
| 152 |
"UPDATE requests SET submit_count = %s WHERE id = %s",
|
| 153 |
+
(new_count, existing["id"]),
|
| 154 |
)
|
| 155 |
conn.commit()
|
| 156 |
+
return {
|
| 157 |
+
"success": True,
|
| 158 |
+
"request_id": existing["id"],
|
| 159 |
+
"duplicate": True,
|
| 160 |
+
}
|
| 161 |
+
|
| 162 |
request_id = str(uuid.uuid4())
|
| 163 |
now = int(time.time() * 1000)
|
| 164 |
# 处理 is_public 字段:确保值为 0 或 1
|
|
|
|
| 169 |
is_public = 0
|
| 170 |
else:
|
| 171 |
is_public = 1 # 默认公开
|
| 172 |
+
|
| 173 |
# 处理 is_vip 字段
|
| 174 |
is_vip = 1 if req.is_vip == 1 or req.is_vip == True else 0
|
| 175 |
+
|
| 176 |
cursor.execute(
|
| 177 |
"INSERT INTO requests (id, stock_code, stock_name, market, status, created_at, submit_count, username, is_public, is_vip) VALUES (%s, %s, %s, %s, %s, %s, %s, %s, %s, %s)",
|
| 178 |
+
(
|
| 179 |
+
request_id,
|
| 180 |
+
req.stock_code,
|
| 181 |
+
req.stock_name,
|
| 182 |
+
req.market,
|
| 183 |
+
"pending",
|
| 184 |
+
now,
|
| 185 |
+
1,
|
| 186 |
+
username,
|
| 187 |
+
is_public,
|
| 188 |
+
is_vip,
|
| 189 |
+
),
|
| 190 |
)
|
| 191 |
conn.commit()
|
| 192 |
return {"success": True, "request_id": request_id, "username": username}
|
|
|
|
| 237 |
try:
|
| 238 |
with conn.cursor() as cursor:
|
| 239 |
now = int(time.time() * 1000)
|
| 240 |
+
|
| 241 |
cursor.execute(
|
| 242 |
"UPDATE requests SET status = %s, completed_at = %s WHERE id = %s",
|
| 243 |
+
(req.status, now, req.id),
|
| 244 |
)
|
| 245 |
+
|
| 246 |
+
if req.status == "completed" and req.html_content and req.stock_code:
|
| 247 |
# 生成报告 ID
|
| 248 |
report_id = str(uuid.uuid4())
|
| 249 |
cursor.execute(
|
| 250 |
"INSERT INTO reports (id, stock_code, stock_name, html_content, created_at) VALUES (%s, %s, %s, %s, %s) ON DUPLICATE KEY UPDATE stock_name = VALUES(stock_name), html_content = VALUES(html_content), created_at = VALUES(created_at)",
|
| 251 |
+
(report_id, req.stock_code, req.stock_name, req.html_content, now),
|
| 252 |
)
|
| 253 |
+
|
| 254 |
conn.commit()
|
| 255 |
return {"success": True}
|
| 256 |
except Exception as e:
|
|
|
|
| 265 |
conn = get_connection()
|
| 266 |
try:
|
| 267 |
with conn.cursor() as cursor:
|
| 268 |
+
# reports 表可能存在同一 stock_code 的多条历史记录(按 created_at 递增)
|
| 269 |
+
# 这里必须取最新的一条,否则会出现“列表时间是最新、打开报告却是旧内容”的错配
|
| 270 |
+
cursor.execute(
|
| 271 |
+
"SELECT html_content FROM reports WHERE stock_code = %s ORDER BY created_at DESC LIMIT 1",
|
| 272 |
+
(code,),
|
| 273 |
+
)
|
| 274 |
result = cursor.fetchone()
|
| 275 |
+
if not result or not result.get("html_content"):
|
| 276 |
return HTMLResponse(
|
| 277 |
+
content=f"<h1>Report not found for {code}</h1>", status_code=404
|
|
|
|
| 278 |
)
|
| 279 |
+
return HTMLResponse(content=result["html_content"])
|
| 280 |
except Exception as e:
|
| 281 |
return JSONResponse(status_code=500, content={"error": str(e)})
|
| 282 |
finally:
|
|
|
|
| 284 |
|
| 285 |
|
| 286 |
@app.get("/api/history")
|
| 287 |
+
async def get_history(
|
| 288 |
+
date: Optional[str] = Query(None), username: Optional[str] = Query(None)
|
| 289 |
+
):
|
| 290 |
"""获取历史报告列表,支持按日期和用户名筛选"""
|
| 291 |
conn = get_connection()
|
| 292 |
try:
|
|
|
|
| 295 |
if username:
|
| 296 |
if date:
|
| 297 |
from datetime import datetime
|
| 298 |
+
|
| 299 |
+
start_dt = datetime.strptime(date, "%Y-%m-%d")
|
| 300 |
start_ts = int(start_dt.timestamp() * 1000)
|
| 301 |
end_ts = start_ts + 86400000
|
| 302 |
+
cursor.execute(
|
| 303 |
+
"""
|
| 304 |
SELECT * FROM requests
|
| 305 |
WHERE username = %s AND status = 'completed' AND completed_at >= %s AND completed_at < %s
|
| 306 |
ORDER BY completed_at DESC
|
| 307 |
+
""",
|
| 308 |
+
(username, start_ts, end_ts),
|
| 309 |
+
)
|
| 310 |
else:
|
| 311 |
+
cursor.execute(
|
| 312 |
+
"""
|
| 313 |
SELECT * FROM requests
|
| 314 |
WHERE username = %s AND status IN ('completed', 'error', 'pending', 'processing')
|
| 315 |
ORDER BY created_at DESC
|
| 316 |
LIMIT 100
|
| 317 |
+
""",
|
| 318 |
+
(username,),
|
| 319 |
+
)
|
| 320 |
elif date:
|
| 321 |
from datetime import datetime
|
| 322 |
+
|
| 323 |
+
start_dt = datetime.strptime(date, "%Y-%m-%d")
|
| 324 |
start_ts = int(start_dt.timestamp() * 1000)
|
| 325 |
end_ts = start_ts + 86400000
|
| 326 |
+
|
| 327 |
# 首页调用(无username):只返回公开报告 (is_public = 1 或 is_public IS NULL)
|
| 328 |
+
cursor.execute(
|
| 329 |
+
"""
|
| 330 |
SELECT r1.* FROM requests r1
|
| 331 |
INNER JOIN (
|
| 332 |
SELECT stock_code, MAX(completed_at) as max_completed_at
|
|
|
|
| 337 |
) r2 ON r1.stock_code = r2.stock_code AND r1.completed_at = r2.max_completed_at
|
| 338 |
WHERE r1.status = 'completed' AND (r1.is_public = 1 OR r1.is_public IS NULL)
|
| 339 |
ORDER BY r1.completed_at DESC
|
| 340 |
+
""",
|
| 341 |
+
(start_ts, end_ts),
|
| 342 |
+
)
|
| 343 |
else:
|
| 344 |
# 首页调用(无username,无date):只返回公开报告 (is_public = 1 或 is_public IS NULL)
|
| 345 |
cursor.execute("""
|
|
|
|
| 354 |
WHERE r1.status = 'completed' AND (r1.is_public = 1 OR r1.is_public IS NULL)
|
| 355 |
ORDER BY r1.completed_at DESC
|
| 356 |
""")
|
| 357 |
+
|
| 358 |
tasks = cursor.fetchall()
|
| 359 |
return {"tasks": tasks}
|
| 360 |
except Exception as e:
|
|
|
|
| 370 |
try:
|
| 371 |
with conn.cursor() as cursor:
|
| 372 |
cursor.execute(
|
| 373 |
+
"SELECT stock_code, stock_name, created_at FROM reports WHERE stock_code = %s ORDER BY created_at DESC LIMIT 1",
|
| 374 |
+
(code,),
|
| 375 |
)
|
| 376 |
result = cursor.fetchone()
|
| 377 |
if result:
|
| 378 |
return {
|
| 379 |
"exists": True,
|
| 380 |
+
"stock_code": result["stock_code"],
|
| 381 |
+
"stock_name": result["stock_name"],
|
| 382 |
+
"created_at": result["created_at"],
|
| 383 |
}
|
| 384 |
return {"exists": False}
|
| 385 |
except Exception as e:
|
|
|
|
| 398 |
"SELECT MIN(created_at) as min_time FROM requests WHERE status = 'pending'"
|
| 399 |
)
|
| 400 |
result = cursor.fetchone()
|
| 401 |
+
min_time = (
|
| 402 |
+
result["min_time"]
|
| 403 |
+
if result and result["min_time"]
|
| 404 |
+
else int(time.time() * 1000)
|
| 405 |
+
)
|
| 406 |
new_time = min_time - 1
|
| 407 |
+
|
| 408 |
cursor.execute(
|
| 409 |
"UPDATE requests SET created_at = %s WHERE id = %s AND status = 'pending'",
|
| 410 |
+
(new_time, req.id),
|
| 411 |
)
|
| 412 |
conn.commit()
|
| 413 |
return {"success": True, "new_time": new_time}
|
|
|
|
| 421 |
async def remove_duplicates(req: RemoveDuplicatesRequest):
|
| 422 |
"""批量删除重复任务"""
|
| 423 |
if not req.ids:
|
| 424 |
+
return JSONResponse(
|
| 425 |
+
status_code=400, content={"error": "Missing or invalid task IDs"}
|
| 426 |
+
)
|
| 427 |
+
|
| 428 |
conn = get_connection()
|
| 429 |
try:
|
| 430 |
with conn.cursor() as cursor:
|
| 431 |
+
placeholders = ", ".join(["%s"] * len(req.ids))
|
| 432 |
cursor.execute(
|
| 433 |
f"DELETE FROM requests WHERE id IN ({placeholders}) AND status = 'pending'",
|
| 434 |
+
tuple(req.ids),
|
| 435 |
)
|
| 436 |
deleted = cursor.rowcount
|
| 437 |
conn.commit()
|
|
|
|
| 461 |
|
| 462 |
# ============ 板块轮动报告 API ============
|
| 463 |
|
| 464 |
+
|
| 465 |
@app.get("/api/sector-rotation-reports")
|
| 466 |
async def get_sector_rotation_reports(limit: int = Query(default=10)):
|
| 467 |
"""获取最近的板块轮动报告列表"""
|
| 468 |
conn = get_connection()
|
| 469 |
try:
|
| 470 |
with conn.cursor() as cursor:
|
| 471 |
+
cursor.execute(
|
| 472 |
+
"""
|
| 473 |
SELECT trade_date, updated_at
|
| 474 |
FROM sector_rotation_reports
|
| 475 |
ORDER BY trade_date DESC
|
| 476 |
LIMIT %s
|
| 477 |
+
""",
|
| 478 |
+
(limit,),
|
| 479 |
+
)
|
| 480 |
reports = cursor.fetchall()
|
| 481 |
+
|
| 482 |
# 格式化返回数据
|
| 483 |
result = []
|
| 484 |
for report in reports:
|
| 485 |
+
result.append(
|
| 486 |
+
{
|
| 487 |
+
"trade_date": report["trade_date"],
|
| 488 |
+
"update_time": time.strftime(
|
| 489 |
+
"%Y-%m-%d %H:%M:%S",
|
| 490 |
+
time.localtime(report["updated_at"] / 1000),
|
| 491 |
+
)
|
| 492 |
+
if report["updated_at"]
|
| 493 |
+
else None,
|
| 494 |
+
"updated_at": report["updated_at"],
|
| 495 |
+
}
|
| 496 |
+
)
|
| 497 |
+
|
| 498 |
return {"reports": result}
|
| 499 |
except Exception as e:
|
| 500 |
return JSONResponse(status_code=500, content={"error": str(e)})
|
|
|
|
| 508 |
conn = get_connection()
|
| 509 |
try:
|
| 510 |
with conn.cursor() as cursor:
|
| 511 |
+
cursor.execute(
|
| 512 |
+
"""
|
| 513 |
SELECT trade_date, html_content, updated_at
|
| 514 |
FROM sector_rotation_reports
|
| 515 |
WHERE trade_date = %s
|
| 516 |
+
""",
|
| 517 |
+
(trade_date,),
|
| 518 |
+
)
|
| 519 |
report = cursor.fetchone()
|
| 520 |
+
|
| 521 |
if not report:
|
| 522 |
return JSONResponse(
|
| 523 |
status_code=404,
|
| 524 |
+
content={"error": f"未找到 {trade_date} 的板块轮动报告"},
|
| 525 |
)
|
| 526 |
+
|
| 527 |
return {
|
| 528 |
+
"trade_date": report["trade_date"],
|
| 529 |
+
"html_content": report["html_content"],
|
| 530 |
+
"updated_at": report["updated_at"],
|
| 531 |
}
|
| 532 |
except Exception as e:
|
| 533 |
return JSONResponse(status_code=500, content={"error": str(e)})
|
|
|
|
| 535 |
conn.close()
|
| 536 |
|
| 537 |
|
| 538 |
+
@app.get("/api/sector-rotation-reports/{trade_date}/html")
|
| 539 |
+
async def get_sector_rotation_report_html(trade_date: str):
|
| 540 |
+
"""获取指定交易日的板块轮动报告(直接返回HTML页面)"""
|
| 541 |
+
conn = get_connection()
|
| 542 |
+
try:
|
| 543 |
with conn.cursor() as cursor:
|
| 544 |
+
cursor.execute(
|
| 545 |
+
"""
|
| 546 |
SELECT html_content
|
| 547 |
FROM sector_rotation_reports
|
| 548 |
WHERE trade_date = %s
|
| 549 |
+
""",
|
| 550 |
+
(trade_date,),
|
| 551 |
+
)
|
| 552 |
report = cursor.fetchone()
|
| 553 |
+
|
| 554 |
+
if not report or not report["html_content"]:
|
| 555 |
return HTMLResponse(
|
| 556 |
content=f"<html><body><h1>未找到 {trade_date} 的板块轮动报告</h1></body></html>",
|
| 557 |
+
status_code=404,
|
| 558 |
)
|
| 559 |
+
|
| 560 |
+
return HTMLResponse(content=report["html_content"])
|
| 561 |
except Exception as e:
|
| 562 |
return HTMLResponse(
|
| 563 |
content=f"<html><body><h1>加载失败</h1><p>{str(e)}</p></body></html>",
|
| 564 |
+
status_code=500,
|
| 565 |
)
|
| 566 |
+
finally:
|
| 567 |
+
conn.close()
|
| 568 |
+
|
| 569 |
+
|
| 570 |
+
# ============ 龙头股对比分析报告 API ============
|
| 571 |
+
|
| 572 |
+
|
| 573 |
+
@app.get("/api/longtou-compare-reports")
|
| 574 |
+
async def get_longtou_compare_reports(limit: int = Query(default=10)):
|
| 575 |
+
"""获取最近的龙头股对比分析报告列表"""
|
| 576 |
+
conn = get_connection()
|
| 577 |
+
try:
|
| 578 |
+
with conn.cursor() as cursor:
|
| 579 |
+
cursor.execute(
|
| 580 |
+
"""
|
| 581 |
+
SELECT trade_date, updated_at
|
| 582 |
+
FROM longtou_compare_reports
|
| 583 |
+
ORDER BY trade_date DESC
|
| 584 |
+
LIMIT %s
|
| 585 |
+
""",
|
| 586 |
+
(limit,),
|
| 587 |
+
)
|
| 588 |
+
reports = cursor.fetchall()
|
| 589 |
+
|
| 590 |
+
result = []
|
| 591 |
+
for report in reports:
|
| 592 |
+
result.append(
|
| 593 |
+
{
|
| 594 |
+
"trade_date": report["trade_date"],
|
| 595 |
+
"update_time": time.strftime(
|
| 596 |
+
"%Y-%m-%d %H:%M:%S",
|
| 597 |
+
time.localtime(report["updated_at"] / 1000),
|
| 598 |
+
)
|
| 599 |
+
if report["updated_at"]
|
| 600 |
+
else None,
|
| 601 |
+
"updated_at": report["updated_at"],
|
| 602 |
+
}
|
| 603 |
+
)
|
| 604 |
+
|
| 605 |
+
return {"reports": result}
|
| 606 |
+
except Exception as e:
|
| 607 |
+
return JSONResponse(status_code=500, content={"error": str(e)})
|
| 608 |
+
finally:
|
| 609 |
+
conn.close()
|
| 610 |
+
|
| 611 |
+
|
| 612 |
+
@app.get("/api/longtou-compare-reports/{trade_date}")
|
| 613 |
+
async def get_longtou_compare_report(trade_date: str):
|
| 614 |
+
"""获取指定交易日的龙头股对比分析报告"""
|
| 615 |
+
conn = get_connection()
|
| 616 |
+
try:
|
| 617 |
+
with conn.cursor() as cursor:
|
| 618 |
+
cursor.execute(
|
| 619 |
+
"""
|
| 620 |
+
SELECT trade_date, html_content, updated_at
|
| 621 |
+
FROM longtou_compare_reports
|
| 622 |
+
WHERE trade_date = %s
|
| 623 |
+
""",
|
| 624 |
+
(trade_date,),
|
| 625 |
+
)
|
| 626 |
+
report = cursor.fetchone()
|
| 627 |
+
|
| 628 |
+
if not report:
|
| 629 |
+
return JSONResponse(
|
| 630 |
+
status_code=404,
|
| 631 |
+
content={"error": f"未找到 {trade_date} 的龙头股对比分析报告"},
|
| 632 |
+
)
|
| 633 |
+
|
| 634 |
+
return {
|
| 635 |
+
"trade_date": report["trade_date"],
|
| 636 |
+
"html_content": report["html_content"],
|
| 637 |
+
"updated_at": report["updated_at"],
|
| 638 |
+
}
|
| 639 |
+
except Exception as e:
|
| 640 |
+
return JSONResponse(status_code=500, content={"error": str(e)})
|
| 641 |
+
finally:
|
| 642 |
+
conn.close()
|
| 643 |
+
|
| 644 |
+
|
| 645 |
+
@app.get("/api/longtou-compare-reports/{trade_date}/html")
|
| 646 |
+
async def get_longtou_compare_report_html(trade_date: str):
|
| 647 |
+
"""获取指定交易日的龙头股对比分析报告(直接返回HTML页面)"""
|
| 648 |
+
conn = get_connection()
|
| 649 |
+
try:
|
| 650 |
+
with conn.cursor() as cursor:
|
| 651 |
+
cursor.execute(
|
| 652 |
+
"""
|
| 653 |
+
SELECT html_content
|
| 654 |
+
FROM longtou_compare_reports
|
| 655 |
+
WHERE trade_date = %s
|
| 656 |
+
""",
|
| 657 |
+
(trade_date,),
|
| 658 |
+
)
|
| 659 |
+
report = cursor.fetchone()
|
| 660 |
+
|
| 661 |
+
if not report or not report["html_content"]:
|
| 662 |
+
return HTMLResponse(
|
| 663 |
+
content=f"<html><body><h1>未找到 {trade_date} 的龙头股对比分析报告</h1></body></html>",
|
| 664 |
+
status_code=404,
|
| 665 |
+
)
|
| 666 |
+
|
| 667 |
+
return HTMLResponse(content=report["html_content"])
|
| 668 |
+
except Exception as e:
|
| 669 |
+
return HTMLResponse(
|
| 670 |
+
content=f"<html><body><h1>加载失败</h1><p>{str(e)}</p></body></html>",
|
| 671 |
+
status_code=500,
|
| 672 |
+
)
|
| 673 |
+
finally:
|
| 674 |
+
conn.close()
|
| 675 |
+
|
| 676 |
+
|
| 677 |
+
# Hugging Face Spaces 使用 7860 端口
|
| 678 |
+
if __name__ == "__main__":
|
| 679 |
+
import uvicorn
|
| 680 |
+
|
| 681 |
+
port = int(os.getenv("PORT", 7860))
|
| 682 |
print(f"🚀 Starting MySQL API Server on port {port}...")
|
| 683 |
uvicorn.run(app, host="0.0.0.0", port=port)
|