fromozu commited on
Commit
2ea95d7
·
verified ·
1 Parent(s): 0624461

Upload app.py

Browse files
Files changed (1) hide show
  1. app.py +267 -190
app.py CHANGED
@@ -2,28 +2,28 @@
2
  # 此文件专为 Hugging Face Spaces 部署优化
3
 
4
  import os
5
- import uuid
6
  import time
7
- from typing import Optional, List
8
  from contextlib import asynccontextmanager
 
9
 
 
 
10
  from fastapi import FastAPI, HTTPException, Query
11
  from fastapi.middleware.cors import CORSMiddleware
12
  from fastapi.responses import HTMLResponse, JSONResponse
13
  from pydantic import BaseModel
14
- import pymysql
15
  from pymysql.cursors import DictCursor
16
- from dbutils.pooled_db import PooledDB
17
 
18
  # 数据库配置 - 从环境变量读取
19
  DB_CONFIG = {
20
- 'host': os.getenv('DB_HOST', '114.116.200.230'),
21
- 'port': int(os.getenv('DB_PORT', 3306)),
22
- 'user': os.getenv('DB_USER', 'aistock_admin'),
23
- 'password': os.getenv('DB_PASSWORD', ''), # 必须通过环境变量设置
24
- 'database': os.getenv('DB_NAME', 'aistock'),
25
- 'charset': 'utf8mb4',
26
- 'cursorclass': DictCursor
27
  }
28
 
29
  # 连接池
@@ -33,22 +33,24 @@ pool: Optional[PooledDB] = None
33
  def init_pool():
34
  """初始化数据库连接池"""
35
  global pool
36
- if not DB_CONFIG['password']:
37
  print("⚠️ 警告: DB_PASSWORD 环境变量未设置!")
38
  return
39
-
40
  pool = PooledDB(
41
  creator=pymysql,
42
- maxconnections=5, # Hugging Face 免费版资源有限
43
  mincached=1,
44
  maxcached=3,
45
  blocking=True,
46
  maxusage=None,
47
  setsession=[],
48
  ping=1,
49
- **DB_CONFIG
 
 
 
50
  )
51
- print(f"✅ MySQL 连接池已初始化: {DB_CONFIG['host']}:{DB_CONFIG['port']}/{DB_CONFIG['database']}")
52
 
53
 
54
  def get_connection():
@@ -72,7 +74,7 @@ app = FastAPI(
72
  title="Stock Analysis MySQL API",
73
  description="MySQL API Server for Stock Analysis Agent (Hugging Face Spaces)",
74
  version="1.0.0",
75
- lifespan=lifespan
76
  )
77
 
78
  # CORS 中间件 - 允许 Cloudflare Worker 访问
@@ -87,6 +89,7 @@ app.add_middleware(
87
 
88
  # ============ Pydantic 模型 ============
89
 
 
90
  class SubmitRequest(BaseModel):
91
  stock_code: str
92
  stock_name: str
@@ -116,6 +119,7 @@ class RemoveDuplicatesRequest(BaseModel):
116
 
117
  # ============ API 端点 ============
118
 
 
119
  @app.get("/")
120
  async def root():
121
  """健康检查"""
@@ -123,7 +127,7 @@ async def root():
123
  "status": "ok",
124
  "service": "MySQL API Server",
125
  "platform": "Hugging Face Spaces",
126
- "db_configured": bool(DB_CONFIG['password'])
127
  }
128
 
129
 
@@ -133,24 +137,28 @@ async def submit_task(req: SubmitRequest):
133
  conn = get_connection()
134
  # 获取 username(优先使用 username,其次 submitted_by)
135
  username = req.username or req.submitted_by or None
136
-
137
  try:
138
  with conn.cursor() as cursor:
139
  cursor.execute(
140
  "SELECT id, submit_count FROM requests WHERE stock_code = %s AND stock_name = %s AND status IN ('pending', 'processing') LIMIT 1",
141
- (req.stock_code, req.stock_name)
142
  )
143
  existing = cursor.fetchone()
144
-
145
  if existing:
146
- new_count = (existing.get('submit_count') or 1) + 1
147
  cursor.execute(
148
  "UPDATE requests SET submit_count = %s WHERE id = %s",
149
- (new_count, existing['id'])
150
  )
151
  conn.commit()
152
- return {"success": True, "request_id": existing['id'], "duplicate": True}
153
-
 
 
 
 
154
  request_id = str(uuid.uuid4())
155
  now = int(time.time() * 1000)
156
  # 处理 is_public 字段:确保值为 0 或 1
@@ -161,13 +169,24 @@ async def submit_task(req: SubmitRequest):
161
  is_public = 0
162
  else:
163
  is_public = 1 # 默认公开
164
-
165
  # 处理 is_vip 字段
166
  is_vip = 1 if req.is_vip == 1 or req.is_vip == True else 0
167
-
168
  cursor.execute(
169
  "INSERT INTO requests (id, stock_code, stock_name, market, status, created_at, submit_count, username, is_public, is_vip) VALUES (%s, %s, %s, %s, %s, %s, %s, %s, %s, %s)",
170
- (request_id, req.stock_code, req.stock_name, req.market, 'pending', now, 1, username, is_public, is_vip)
 
 
 
 
 
 
 
 
 
 
 
171
  )
172
  conn.commit()
173
  return {"success": True, "request_id": request_id, "username": username}
@@ -218,20 +237,20 @@ async def complete_task(req: CompleteRequest):
218
  try:
219
  with conn.cursor() as cursor:
220
  now = int(time.time() * 1000)
221
-
222
  cursor.execute(
223
  "UPDATE requests SET status = %s, completed_at = %s WHERE id = %s",
224
- (req.status, now, req.id)
225
  )
226
-
227
- if req.status == 'completed' and req.html_content and req.stock_code:
228
  # 生成报告 ID
229
  report_id = str(uuid.uuid4())
230
  cursor.execute(
231
  "INSERT INTO reports (id, stock_code, stock_name, html_content, created_at) VALUES (%s, %s, %s, %s, %s) ON DUPLICATE KEY UPDATE stock_name = VALUES(stock_name), html_content = VALUES(html_content), created_at = VALUES(created_at)",
232
- (report_id, req.stock_code, req.stock_name, req.html_content, now)
233
  )
234
-
235
  conn.commit()
236
  return {"success": True}
237
  except Exception as e:
@@ -246,14 +265,18 @@ async def get_report(code: str = Query(...)):
246
  conn = get_connection()
247
  try:
248
  with conn.cursor() as cursor:
249
- cursor.execute("SELECT html_content FROM reports WHERE stock_code = %s", (code,))
 
 
 
 
 
250
  result = cursor.fetchone()
251
- if not result or not result.get('html_content'):
252
  return HTMLResponse(
253
- content=f"<h1>Report not found for {code}</h1>",
254
- status_code=404
255
  )
256
- return HTMLResponse(content=result['html_content'])
257
  except Exception as e:
258
  return JSONResponse(status_code=500, content={"error": str(e)})
259
  finally:
@@ -261,7 +284,9 @@ async def get_report(code: str = Query(...)):
261
 
262
 
263
  @app.get("/api/history")
264
- async def get_history(date: Optional[str] = Query(None), username: Optional[str] = Query(None)):
 
 
265
  """获取历史报告列表,支持按日期和用户名筛选"""
266
  conn = get_connection()
267
  try:
@@ -270,29 +295,38 @@ async def get_history(date: Optional[str] = Query(None), username: Optional[str]
270
  if username:
271
  if date:
272
  from datetime import datetime
273
- start_dt = datetime.strptime(date, '%Y-%m-%d')
 
274
  start_ts = int(start_dt.timestamp() * 1000)
275
  end_ts = start_ts + 86400000
276
- cursor.execute("""
 
277
  SELECT * FROM requests
278
  WHERE username = %s AND status = 'completed' AND completed_at >= %s AND completed_at < %s
279
  ORDER BY completed_at DESC
280
- """, (username, start_ts, end_ts))
 
 
281
  else:
282
- cursor.execute("""
 
283
  SELECT * FROM requests
284
  WHERE username = %s AND status IN ('completed', 'error', 'pending', 'processing')
285
  ORDER BY created_at DESC
286
  LIMIT 100
287
- """, (username,))
 
 
288
  elif date:
289
  from datetime import datetime
290
- start_dt = datetime.strptime(date, '%Y-%m-%d')
 
291
  start_ts = int(start_dt.timestamp() * 1000)
292
  end_ts = start_ts + 86400000
293
-
294
  # 首页调用(无username):只返回公开报告 (is_public = 1 或 is_public IS NULL)
295
- cursor.execute("""
 
296
  SELECT r1.* FROM requests r1
297
  INNER JOIN (
298
  SELECT stock_code, MAX(completed_at) as max_completed_at
@@ -303,7 +337,9 @@ async def get_history(date: Optional[str] = Query(None), username: Optional[str]
303
  ) r2 ON r1.stock_code = r2.stock_code AND r1.completed_at = r2.max_completed_at
304
  WHERE r1.status = 'completed' AND (r1.is_public = 1 OR r1.is_public IS NULL)
305
  ORDER BY r1.completed_at DESC
306
- """, (start_ts, end_ts))
 
 
307
  else:
308
  # 首页调用(无username,无date):只返回公开报告 (is_public = 1 或 is_public IS NULL)
309
  cursor.execute("""
@@ -318,7 +354,7 @@ async def get_history(date: Optional[str] = Query(None), username: Optional[str]
318
  WHERE r1.status = 'completed' AND (r1.is_public = 1 OR r1.is_public IS NULL)
319
  ORDER BY r1.completed_at DESC
320
  """)
321
-
322
  tasks = cursor.fetchall()
323
  return {"tasks": tasks}
324
  except Exception as e:
@@ -334,16 +370,16 @@ async def check_report(code: str = Query(...)):
334
  try:
335
  with conn.cursor() as cursor:
336
  cursor.execute(
337
- "SELECT stock_code, stock_name, created_at FROM reports WHERE stock_code = %s",
338
- (code,)
339
  )
340
  result = cursor.fetchone()
341
  if result:
342
  return {
343
  "exists": True,
344
- "stock_code": result['stock_code'],
345
- "stock_name": result['stock_name'],
346
- "created_at": result['created_at']
347
  }
348
  return {"exists": False}
349
  except Exception as e:
@@ -362,12 +398,16 @@ async def reorder_task(req: ReorderRequest):
362
  "SELECT MIN(created_at) as min_time FROM requests WHERE status = 'pending'"
363
  )
364
  result = cursor.fetchone()
365
- min_time = result['min_time'] if result and result['min_time'] else int(time.time() * 1000)
 
 
 
 
366
  new_time = min_time - 1
367
-
368
  cursor.execute(
369
  "UPDATE requests SET created_at = %s WHERE id = %s AND status = 'pending'",
370
- (new_time, req.id)
371
  )
372
  conn.commit()
373
  return {"success": True, "new_time": new_time}
@@ -381,15 +421,17 @@ async def reorder_task(req: ReorderRequest):
381
  async def remove_duplicates(req: RemoveDuplicatesRequest):
382
  """批量删除重复任务"""
383
  if not req.ids:
384
- return JSONResponse(status_code=400, content={"error": "Missing or invalid task IDs"})
385
-
 
 
386
  conn = get_connection()
387
  try:
388
  with conn.cursor() as cursor:
389
- placeholders = ', '.join(['%s'] * len(req.ids))
390
  cursor.execute(
391
  f"DELETE FROM requests WHERE id IN ({placeholders}) AND status = 'pending'",
392
- tuple(req.ids)
393
  )
394
  deleted = cursor.rowcount
395
  conn.commit()
@@ -419,29 +461,40 @@ async def get_queue():
419
 
420
  # ============ 板块轮动报告 API ============
421
 
 
422
  @app.get("/api/sector-rotation-reports")
423
  async def get_sector_rotation_reports(limit: int = Query(default=10)):
424
  """获取最近的板块轮动报告列表"""
425
  conn = get_connection()
426
  try:
427
  with conn.cursor() as cursor:
428
- cursor.execute("""
 
429
  SELECT trade_date, updated_at
430
  FROM sector_rotation_reports
431
  ORDER BY trade_date DESC
432
  LIMIT %s
433
- """, (limit,))
 
 
434
  reports = cursor.fetchall()
435
-
436
  # 格式化返回数据
437
  result = []
438
  for report in reports:
439
- result.append({
440
- "trade_date": report['trade_date'],
441
- "update_time": time.strftime('%Y-%m-%d %H:%M:%S', time.localtime(report['updated_at'] / 1000)) if report['updated_at'] else None,
442
- "updated_at": report['updated_at']
443
- })
444
-
 
 
 
 
 
 
 
445
  return {"reports": result}
446
  except Exception as e:
447
  return JSONResponse(status_code=500, content={"error": str(e)})
@@ -455,23 +508,26 @@ async def get_sector_rotation_report(trade_date: str):
455
  conn = get_connection()
456
  try:
457
  with conn.cursor() as cursor:
458
- cursor.execute("""
 
459
  SELECT trade_date, html_content, updated_at
460
  FROM sector_rotation_reports
461
  WHERE trade_date = %s
462
- """, (trade_date,))
 
 
463
  report = cursor.fetchone()
464
-
465
  if not report:
466
  return JSONResponse(
467
  status_code=404,
468
- content={"error": f"未找到 {trade_date} 的板块轮动报告"}
469
  )
470
-
471
  return {
472
- "trade_date": report['trade_date'],
473
- "html_content": report['html_content'],
474
- "updated_at": report['updated_at']
475
  }
476
  except Exception as e:
477
  return JSONResponse(status_code=500, content={"error": str(e)})
@@ -479,128 +535,149 @@ async def get_sector_rotation_report(trade_date: str):
479
  conn.close()
480
 
481
 
482
- @app.get("/api/sector-rotation-reports/{trade_date}/html")
483
- async def get_sector_rotation_report_html(trade_date: str):
484
- """获取指定交易日的板块轮动报告(直接返回HTML页面)"""
485
- conn = get_connection()
486
- try:
487
  with conn.cursor() as cursor:
488
- cursor.execute("""
 
489
  SELECT html_content
490
  FROM sector_rotation_reports
491
  WHERE trade_date = %s
492
- """, (trade_date,))
 
 
493
  report = cursor.fetchone()
494
-
495
- if not report or not report['html_content']:
496
  return HTMLResponse(
497
  content=f"<html><body><h1>未找到 {trade_date} 的板块轮动报告</h1></body></html>",
498
- status_code=404
499
  )
500
-
501
- return HTMLResponse(content=report['html_content'])
502
  except Exception as e:
503
  return HTMLResponse(
504
  content=f"<html><body><h1>加载失败</h1><p>{str(e)}</p></body></html>",
505
- status_code=500
506
  )
507
- finally:
508
- conn.close()
509
-
510
-
511
- # ============ 龙头股对比分析报告 API ============
512
-
513
- @app.get("/api/longtou-compare-reports")
514
- async def get_longtou_compare_reports(limit: int = Query(default=10)):
515
- """获取最近的龙头股对比分析报告列表"""
516
- conn = get_connection()
517
- try:
518
- with conn.cursor() as cursor:
519
- cursor.execute("""
520
- SELECT trade_date, updated_at
521
- FROM longtou_compare_reports
522
- ORDER BY trade_date DESC
523
- LIMIT %s
524
- """, (limit,))
525
- reports = cursor.fetchall()
526
-
527
- result = []
528
- for report in reports:
529
- result.append({
530
- "trade_date": report['trade_date'],
531
- "update_time": time.strftime('%Y-%m-%d %H:%M:%S', time.localtime(report['updated_at'] / 1000)) if report['updated_at'] else None,
532
- "updated_at": report['updated_at']
533
- })
534
-
535
- return {"reports": result}
536
- except Exception as e:
537
- return JSONResponse(status_code=500, content={"error": str(e)})
538
- finally:
539
- conn.close()
540
-
541
-
542
- @app.get("/api/longtou-compare-reports/{trade_date}")
543
- async def get_longtou_compare_report(trade_date: str):
544
- """获取指定交易日的龙头股对比分析报告"""
545
- conn = get_connection()
546
- try:
547
- with conn.cursor() as cursor:
548
- cursor.execute("""
549
- SELECT trade_date, html_content, updated_at
550
- FROM longtou_compare_reports
551
- WHERE trade_date = %s
552
- """, (trade_date,))
553
- report = cursor.fetchone()
554
-
555
- if not report:
556
- return JSONResponse(
557
- status_code=404,
558
- content={"error": f"未找到 {trade_date} 的龙头股对比分析报告"}
559
- )
560
-
561
- return {
562
- "trade_date": report['trade_date'],
563
- "html_content": report['html_content'],
564
- "updated_at": report['updated_at']
565
- }
566
- except Exception as e:
567
- return JSONResponse(status_code=500, content={"error": str(e)})
568
- finally:
569
- conn.close()
570
-
571
-
572
- @app.get("/api/longtou-compare-reports/{trade_date}/html")
573
- async def get_longtou_compare_report_html(trade_date: str):
574
- """获取指定交易日的龙头股对比分析报告(直接返回HTML页面)"""
575
- conn = get_connection()
576
- try:
577
- with conn.cursor() as cursor:
578
- cursor.execute("""
579
- SELECT html_content
580
- FROM longtou_compare_reports
581
- WHERE trade_date = %s
582
- """, (trade_date,))
583
- report = cursor.fetchone()
584
-
585
- if not report or not report['html_content']:
586
- return HTMLResponse(
587
- content=f"<html><body><h1>未找到 {trade_date} 的龙头股对比分析报告</h1></body></html>",
588
- status_code=404
589
- )
590
-
591
- return HTMLResponse(content=report['html_content'])
592
- except Exception as e:
593
- return HTMLResponse(
594
- content=f"<html><body><h1>加载失败</h1><p>{str(e)}</p></body></html>",
595
- status_code=500
596
- )
597
- finally:
598
- conn.close()
599
-
600
-
601
- # Hugging Face Spaces 使用 7860 端口
602
- if __name__ == "__main__":
603
- import uvicorn
604
- port = int(os.getenv('PORT', 7860))
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
605
  print(f"🚀 Starting MySQL API Server on port {port}...")
606
  uvicorn.run(app, host="0.0.0.0", port=port)
 
2
  # 此文件专为 Hugging Face Spaces 部署优化
3
 
4
  import os
 
5
  import time
6
+ import uuid
7
  from contextlib import asynccontextmanager
8
+ from typing import List, Optional
9
 
10
+ import pymysql
11
+ from dbutils.pooled_db import PooledDB
12
  from fastapi import FastAPI, HTTPException, Query
13
  from fastapi.middleware.cors import CORSMiddleware
14
  from fastapi.responses import HTMLResponse, JSONResponse
15
  from pydantic import BaseModel
 
16
  from pymysql.cursors import DictCursor
 
17
 
18
  # 数据库配置 - 从环境变量读取
19
  DB_CONFIG = {
20
+ "host": os.getenv("DB_HOST", "114.116.200.230"),
21
+ "port": int(os.getenv("DB_PORT", 3306)),
22
+ "user": os.getenv("DB_USER", "aistock_admin"),
23
+ "password": os.getenv("DB_PASSWORD", ""), # 必须通过环境变量设置
24
+ "database": os.getenv("DB_NAME", "aistock"),
25
+ "charset": "utf8mb4",
26
+ "cursorclass": DictCursor,
27
  }
28
 
29
  # 连接池
 
33
  def init_pool():
34
  """初始化数据库连接池"""
35
  global pool
36
+ if not DB_CONFIG["password"]:
37
  print("⚠️ 警告: DB_PASSWORD 环境变量未设置!")
38
  return
39
+
40
  pool = PooledDB(
41
  creator=pymysql,
42
+ maxconnections=5, # Hugging Face 免费版资源有限
43
  mincached=1,
44
  maxcached=3,
45
  blocking=True,
46
  maxusage=None,
47
  setsession=[],
48
  ping=1,
49
+ **DB_CONFIG,
50
+ )
51
+ print(
52
+ f"✅ MySQL 连接池已初始化: {DB_CONFIG['host']}:{DB_CONFIG['port']}/{DB_CONFIG['database']}"
53
  )
 
54
 
55
 
56
  def get_connection():
 
74
  title="Stock Analysis MySQL API",
75
  description="MySQL API Server for Stock Analysis Agent (Hugging Face Spaces)",
76
  version="1.0.0",
77
+ lifespan=lifespan,
78
  )
79
 
80
  # CORS 中间件 - 允许 Cloudflare Worker 访问
 
89
 
90
  # ============ Pydantic 模型 ============
91
 
92
+
93
  class SubmitRequest(BaseModel):
94
  stock_code: str
95
  stock_name: str
 
119
 
120
  # ============ API 端点 ============
121
 
122
+
123
  @app.get("/")
124
  async def root():
125
  """健康检查"""
 
127
  "status": "ok",
128
  "service": "MySQL API Server",
129
  "platform": "Hugging Face Spaces",
130
+ "db_configured": bool(DB_CONFIG["password"]),
131
  }
132
 
133
 
 
137
  conn = get_connection()
138
  # 获取 username(优先使用 username,其次 submitted_by)
139
  username = req.username or req.submitted_by or None
140
+
141
  try:
142
  with conn.cursor() as cursor:
143
  cursor.execute(
144
  "SELECT id, submit_count FROM requests WHERE stock_code = %s AND stock_name = %s AND status IN ('pending', 'processing') LIMIT 1",
145
+ (req.stock_code, req.stock_name),
146
  )
147
  existing = cursor.fetchone()
148
+
149
  if existing:
150
+ new_count = (existing.get("submit_count") or 1) + 1
151
  cursor.execute(
152
  "UPDATE requests SET submit_count = %s WHERE id = %s",
153
+ (new_count, existing["id"]),
154
  )
155
  conn.commit()
156
+ return {
157
+ "success": True,
158
+ "request_id": existing["id"],
159
+ "duplicate": True,
160
+ }
161
+
162
  request_id = str(uuid.uuid4())
163
  now = int(time.time() * 1000)
164
  # 处理 is_public 字段:确保值为 0 或 1
 
169
  is_public = 0
170
  else:
171
  is_public = 1 # 默认公开
172
+
173
  # 处理 is_vip 字段
174
  is_vip = 1 if req.is_vip == 1 or req.is_vip == True else 0
175
+
176
  cursor.execute(
177
  "INSERT INTO requests (id, stock_code, stock_name, market, status, created_at, submit_count, username, is_public, is_vip) VALUES (%s, %s, %s, %s, %s, %s, %s, %s, %s, %s)",
178
+ (
179
+ request_id,
180
+ req.stock_code,
181
+ req.stock_name,
182
+ req.market,
183
+ "pending",
184
+ now,
185
+ 1,
186
+ username,
187
+ is_public,
188
+ is_vip,
189
+ ),
190
  )
191
  conn.commit()
192
  return {"success": True, "request_id": request_id, "username": username}
 
237
  try:
238
  with conn.cursor() as cursor:
239
  now = int(time.time() * 1000)
240
+
241
  cursor.execute(
242
  "UPDATE requests SET status = %s, completed_at = %s WHERE id = %s",
243
+ (req.status, now, req.id),
244
  )
245
+
246
+ if req.status == "completed" and req.html_content and req.stock_code:
247
  # 生成报告 ID
248
  report_id = str(uuid.uuid4())
249
  cursor.execute(
250
  "INSERT INTO reports (id, stock_code, stock_name, html_content, created_at) VALUES (%s, %s, %s, %s, %s) ON DUPLICATE KEY UPDATE stock_name = VALUES(stock_name), html_content = VALUES(html_content), created_at = VALUES(created_at)",
251
+ (report_id, req.stock_code, req.stock_name, req.html_content, now),
252
  )
253
+
254
  conn.commit()
255
  return {"success": True}
256
  except Exception as e:
 
265
  conn = get_connection()
266
  try:
267
  with conn.cursor() as cursor:
268
+ # reports 表可能存在同一 stock_code 的多条历史记录(按 created_at 递增)
269
+ # 这里必须取最新的一条,否则会出现“列表时间是最新、打开报告却是旧内容”的错配
270
+ cursor.execute(
271
+ "SELECT html_content FROM reports WHERE stock_code = %s ORDER BY created_at DESC LIMIT 1",
272
+ (code,),
273
+ )
274
  result = cursor.fetchone()
275
+ if not result or not result.get("html_content"):
276
  return HTMLResponse(
277
+ content=f"<h1>Report not found for {code}</h1>", status_code=404
 
278
  )
279
+ return HTMLResponse(content=result["html_content"])
280
  except Exception as e:
281
  return JSONResponse(status_code=500, content={"error": str(e)})
282
  finally:
 
284
 
285
 
286
  @app.get("/api/history")
287
+ async def get_history(
288
+ date: Optional[str] = Query(None), username: Optional[str] = Query(None)
289
+ ):
290
  """获取历史报告列表,支持按日期和用户名筛选"""
291
  conn = get_connection()
292
  try:
 
295
  if username:
296
  if date:
297
  from datetime import datetime
298
+
299
+ start_dt = datetime.strptime(date, "%Y-%m-%d")
300
  start_ts = int(start_dt.timestamp() * 1000)
301
  end_ts = start_ts + 86400000
302
+ cursor.execute(
303
+ """
304
  SELECT * FROM requests
305
  WHERE username = %s AND status = 'completed' AND completed_at >= %s AND completed_at < %s
306
  ORDER BY completed_at DESC
307
+ """,
308
+ (username, start_ts, end_ts),
309
+ )
310
  else:
311
+ cursor.execute(
312
+ """
313
  SELECT * FROM requests
314
  WHERE username = %s AND status IN ('completed', 'error', 'pending', 'processing')
315
  ORDER BY created_at DESC
316
  LIMIT 100
317
+ """,
318
+ (username,),
319
+ )
320
  elif date:
321
  from datetime import datetime
322
+
323
+ start_dt = datetime.strptime(date, "%Y-%m-%d")
324
  start_ts = int(start_dt.timestamp() * 1000)
325
  end_ts = start_ts + 86400000
326
+
327
  # 首页调用(无username):只返回公开报告 (is_public = 1 或 is_public IS NULL)
328
+ cursor.execute(
329
+ """
330
  SELECT r1.* FROM requests r1
331
  INNER JOIN (
332
  SELECT stock_code, MAX(completed_at) as max_completed_at
 
337
  ) r2 ON r1.stock_code = r2.stock_code AND r1.completed_at = r2.max_completed_at
338
  WHERE r1.status = 'completed' AND (r1.is_public = 1 OR r1.is_public IS NULL)
339
  ORDER BY r1.completed_at DESC
340
+ """,
341
+ (start_ts, end_ts),
342
+ )
343
  else:
344
  # 首页调用(无username,无date):只返回公开报告 (is_public = 1 或 is_public IS NULL)
345
  cursor.execute("""
 
354
  WHERE r1.status = 'completed' AND (r1.is_public = 1 OR r1.is_public IS NULL)
355
  ORDER BY r1.completed_at DESC
356
  """)
357
+
358
  tasks = cursor.fetchall()
359
  return {"tasks": tasks}
360
  except Exception as e:
 
370
  try:
371
  with conn.cursor() as cursor:
372
  cursor.execute(
373
+ "SELECT stock_code, stock_name, created_at FROM reports WHERE stock_code = %s ORDER BY created_at DESC LIMIT 1",
374
+ (code,),
375
  )
376
  result = cursor.fetchone()
377
  if result:
378
  return {
379
  "exists": True,
380
+ "stock_code": result["stock_code"],
381
+ "stock_name": result["stock_name"],
382
+ "created_at": result["created_at"],
383
  }
384
  return {"exists": False}
385
  except Exception as e:
 
398
  "SELECT MIN(created_at) as min_time FROM requests WHERE status = 'pending'"
399
  )
400
  result = cursor.fetchone()
401
+ min_time = (
402
+ result["min_time"]
403
+ if result and result["min_time"]
404
+ else int(time.time() * 1000)
405
+ )
406
  new_time = min_time - 1
407
+
408
  cursor.execute(
409
  "UPDATE requests SET created_at = %s WHERE id = %s AND status = 'pending'",
410
+ (new_time, req.id),
411
  )
412
  conn.commit()
413
  return {"success": True, "new_time": new_time}
 
421
  async def remove_duplicates(req: RemoveDuplicatesRequest):
422
  """批量删除重复任务"""
423
  if not req.ids:
424
+ return JSONResponse(
425
+ status_code=400, content={"error": "Missing or invalid task IDs"}
426
+ )
427
+
428
  conn = get_connection()
429
  try:
430
  with conn.cursor() as cursor:
431
+ placeholders = ", ".join(["%s"] * len(req.ids))
432
  cursor.execute(
433
  f"DELETE FROM requests WHERE id IN ({placeholders}) AND status = 'pending'",
434
+ tuple(req.ids),
435
  )
436
  deleted = cursor.rowcount
437
  conn.commit()
 
461
 
462
  # ============ 板块轮动报告 API ============
463
 
464
+
465
  @app.get("/api/sector-rotation-reports")
466
  async def get_sector_rotation_reports(limit: int = Query(default=10)):
467
  """获取最近的板块轮动报告列表"""
468
  conn = get_connection()
469
  try:
470
  with conn.cursor() as cursor:
471
+ cursor.execute(
472
+ """
473
  SELECT trade_date, updated_at
474
  FROM sector_rotation_reports
475
  ORDER BY trade_date DESC
476
  LIMIT %s
477
+ """,
478
+ (limit,),
479
+ )
480
  reports = cursor.fetchall()
481
+
482
  # 格式化返回数据
483
  result = []
484
  for report in reports:
485
+ result.append(
486
+ {
487
+ "trade_date": report["trade_date"],
488
+ "update_time": time.strftime(
489
+ "%Y-%m-%d %H:%M:%S",
490
+ time.localtime(report["updated_at"] / 1000),
491
+ )
492
+ if report["updated_at"]
493
+ else None,
494
+ "updated_at": report["updated_at"],
495
+ }
496
+ )
497
+
498
  return {"reports": result}
499
  except Exception as e:
500
  return JSONResponse(status_code=500, content={"error": str(e)})
 
508
  conn = get_connection()
509
  try:
510
  with conn.cursor() as cursor:
511
+ cursor.execute(
512
+ """
513
  SELECT trade_date, html_content, updated_at
514
  FROM sector_rotation_reports
515
  WHERE trade_date = %s
516
+ """,
517
+ (trade_date,),
518
+ )
519
  report = cursor.fetchone()
520
+
521
  if not report:
522
  return JSONResponse(
523
  status_code=404,
524
+ content={"error": f"未找到 {trade_date} 的板块轮动报告"},
525
  )
526
+
527
  return {
528
+ "trade_date": report["trade_date"],
529
+ "html_content": report["html_content"],
530
+ "updated_at": report["updated_at"],
531
  }
532
  except Exception as e:
533
  return JSONResponse(status_code=500, content={"error": str(e)})
 
535
  conn.close()
536
 
537
 
538
+ @app.get("/api/sector-rotation-reports/{trade_date}/html")
539
+ async def get_sector_rotation_report_html(trade_date: str):
540
+ """获取指定交易日的板块轮动报告(直接返回HTML页面)"""
541
+ conn = get_connection()
542
+ try:
543
  with conn.cursor() as cursor:
544
+ cursor.execute(
545
+ """
546
  SELECT html_content
547
  FROM sector_rotation_reports
548
  WHERE trade_date = %s
549
+ """,
550
+ (trade_date,),
551
+ )
552
  report = cursor.fetchone()
553
+
554
+ if not report or not report["html_content"]:
555
  return HTMLResponse(
556
  content=f"<html><body><h1>未找到 {trade_date} 的板块轮动报告</h1></body></html>",
557
+ status_code=404,
558
  )
559
+
560
+ return HTMLResponse(content=report["html_content"])
561
  except Exception as e:
562
  return HTMLResponse(
563
  content=f"<html><body><h1>加载失败</h1><p>{str(e)}</p></body></html>",
564
+ status_code=500,
565
  )
566
+ finally:
567
+ conn.close()
568
+
569
+
570
+ # ============ 龙头股对比分析报告 API ============
571
+
572
+
573
+ @app.get("/api/longtou-compare-reports")
574
+ async def get_longtou_compare_reports(limit: int = Query(default=10)):
575
+ """获取最近的龙头股对比分析报告列表"""
576
+ conn = get_connection()
577
+ try:
578
+ with conn.cursor() as cursor:
579
+ cursor.execute(
580
+ """
581
+ SELECT trade_date, updated_at
582
+ FROM longtou_compare_reports
583
+ ORDER BY trade_date DESC
584
+ LIMIT %s
585
+ """,
586
+ (limit,),
587
+ )
588
+ reports = cursor.fetchall()
589
+
590
+ result = []
591
+ for report in reports:
592
+ result.append(
593
+ {
594
+ "trade_date": report["trade_date"],
595
+ "update_time": time.strftime(
596
+ "%Y-%m-%d %H:%M:%S",
597
+ time.localtime(report["updated_at"] / 1000),
598
+ )
599
+ if report["updated_at"]
600
+ else None,
601
+ "updated_at": report["updated_at"],
602
+ }
603
+ )
604
+
605
+ return {"reports": result}
606
+ except Exception as e:
607
+ return JSONResponse(status_code=500, content={"error": str(e)})
608
+ finally:
609
+ conn.close()
610
+
611
+
612
+ @app.get("/api/longtou-compare-reports/{trade_date}")
613
+ async def get_longtou_compare_report(trade_date: str):
614
+ """获取指定交易日的龙头股对比分析报告"""
615
+ conn = get_connection()
616
+ try:
617
+ with conn.cursor() as cursor:
618
+ cursor.execute(
619
+ """
620
+ SELECT trade_date, html_content, updated_at
621
+ FROM longtou_compare_reports
622
+ WHERE trade_date = %s
623
+ """,
624
+ (trade_date,),
625
+ )
626
+ report = cursor.fetchone()
627
+
628
+ if not report:
629
+ return JSONResponse(
630
+ status_code=404,
631
+ content={"error": f"未找到 {trade_date} 的龙头股对比分析报告"},
632
+ )
633
+
634
+ return {
635
+ "trade_date": report["trade_date"],
636
+ "html_content": report["html_content"],
637
+ "updated_at": report["updated_at"],
638
+ }
639
+ except Exception as e:
640
+ return JSONResponse(status_code=500, content={"error": str(e)})
641
+ finally:
642
+ conn.close()
643
+
644
+
645
+ @app.get("/api/longtou-compare-reports/{trade_date}/html")
646
+ async def get_longtou_compare_report_html(trade_date: str):
647
+ """获取指定交易日的龙头股对比分析报告(直接返回HTML页面)"""
648
+ conn = get_connection()
649
+ try:
650
+ with conn.cursor() as cursor:
651
+ cursor.execute(
652
+ """
653
+ SELECT html_content
654
+ FROM longtou_compare_reports
655
+ WHERE trade_date = %s
656
+ """,
657
+ (trade_date,),
658
+ )
659
+ report = cursor.fetchone()
660
+
661
+ if not report or not report["html_content"]:
662
+ return HTMLResponse(
663
+ content=f"<html><body><h1>未找到 {trade_date} 的龙头股对比分析报告</h1></body></html>",
664
+ status_code=404,
665
+ )
666
+
667
+ return HTMLResponse(content=report["html_content"])
668
+ except Exception as e:
669
+ return HTMLResponse(
670
+ content=f"<html><body><h1>加载失败</h1><p>{str(e)}</p></body></html>",
671
+ status_code=500,
672
+ )
673
+ finally:
674
+ conn.close()
675
+
676
+
677
+ # Hugging Face Spaces 使用 7860 端口
678
+ if __name__ == "__main__":
679
+ import uvicorn
680
+
681
+ port = int(os.getenv("PORT", 7860))
682
  print(f"🚀 Starting MySQL API Server on port {port}...")
683
  uvicorn.run(app, host="0.0.0.0", port=port)