""" Kronos Stock Predictor — RESTful API ===================================== POST /api/v1/predict → { "task_id": "uuid" } GET /api/v1/predict/{id} → { "status": "pending|done|failed", "result": {...} } POST /api/v1/predict/batch → { "batch_id": "uuid", "task_ids": [...] } GET /api/v1/predict/batch/{id} → { "batch_id", "status", "total", "done", "failed", "tasks": [...] } GET /api/v1/cache → cache contents & expiry info GET /health → { "status": "ok" } """ import asyncio import logging import uuid from concurrent.futures import ThreadPoolExecutor from contextlib import asynccontextmanager from datetime import datetime, time, timedelta, timezone from time import perf_counter from typing import Literal, List import pandas as pd from fastapi import FastAPI, HTTPException from fastapi.middleware.cors import CORSMiddleware from pydantic import BaseModel, Field import data_fetcher import predictor as pred_module logging.basicConfig(level=logging.INFO) logger = logging.getLogger(__name__) # ── Timezone / market-close helpers ────────────────────────────────────────── _CST = timezone(timedelta(hours=8)) _MARKET_CLOSE = time(15, 0) # A-share close: 15:00 CST def _next_cache_expiry() -> datetime: """ Return the UTC datetime of the NEXT A-share market close (15:00 CST on a weekday), which is when new candle data becomes available and the cache should be invalidated. Chinese public holidays are intentionally ignored: on those days market data does not advance, so a cache hit is harmless. """ now_cst = datetime.now(_CST) today_close = now_cst.replace(hour=15, minute=0, second=0, microsecond=0) if now_cst.weekday() < 5 and now_cst < today_close: # Before today's close on a weekday → expire at today 15:00 CST expiry_cst = today_close else: # After close, or on a weekend → find next weekday's 15:00 CST candidate = now_cst + timedelta(days=1) while candidate.weekday() >= 5: # skip Sat(5) and Sun(6) candidate += timedelta(days=1) expiry_cst = candidate.replace(hour=15, minute=0, second=0, microsecond=0) return expiry_cst.astimezone(timezone.utc) # ── Result cache ────────────────────────────────────────────────────────────── # key : (symbol, lookback, pred_len, sample_count, mode, include_volume) # value : {"result": dict, "expires_at": datetime(UTC), "cached_at": datetime(UTC)} _cache: dict[tuple, dict] = {} def _cache_key(req: "PredictRequest") -> tuple: return (req.symbol, req.lookback, req.pred_len, req.sample_count, req.mode, req.include_volume) def _get_cached(req: "PredictRequest") -> dict | None: entry = _cache.get(_cache_key(req)) if entry and datetime.now(timezone.utc) < entry["expires_at"]: return entry return None def _set_cache(req: "PredictRequest", result: dict) -> None: now_utc = datetime.now(timezone.utc) _cache[_cache_key(req)] = { "result": result, "expires_at": _next_cache_expiry(), "cached_at": now_utc, } logger.info( "Cached %s, expires at %s CST", req.symbol, _cache[_cache_key(req)]["expires_at"].astimezone(_CST).strftime("%Y-%m-%d %H:%M"), ) # ── Task / Batch store ──────────────────────────────────────────────────────── _tasks: dict[str, dict] = {} _batches: dict[str, dict] = {} _executor = ThreadPoolExecutor(max_workers=2) # ── Startup: eagerly load the model so the first request isn't slow ─────────── @asynccontextmanager async def lifespan(app: FastAPI): loop = asyncio.get_event_loop() logger.info("Pre-loading Kronos predictor …") await loop.run_in_executor(_executor, pred_module.get_predictor) logger.info("Kronos predictor ready.") yield app = FastAPI( title="Kronos Stock Predictor API", version="1.0.0", description=( "Monte-Carlo probabilistic stock forecasting powered by the " "Kronos foundation model (Tsinghua University)." ), lifespan=lifespan, ) app.add_middleware( CORSMiddleware, allow_origins=["*"], allow_methods=["*"], allow_headers=["*"], ) # ── Request / Response schemas ──────────────────────────────────────────────── class PredictRequest(BaseModel): symbol: str = Field( ..., examples=["603777", "600900.SH"], description="A 股代码;支持 6 位代码或带市场后缀(如 600900.SH)", ) lookback: int = Field( default=512, ge=20, le=512, description="回看历史 K 线根数(最多 512,不足时自动截断)", ) pred_len: int = Field( default=5, ge=1, le=60, description="预测未来交易日数(建议 ≤ 30,超过时返回 confidence_warning)", ) sample_count: int = Field( default=30, ge=1, le=100, description="MC 蒙特卡洛采样次数", ) mode: Literal["simple", "advanced"] = Field( default="simple", description="simple: 仅返回均值 + 交易区间;advanced: 追加 OHLC 均值及收盘 CI", ) include_volume: bool = Field( default=False, description="mode=advanced 时是否额外返回成交量预测(默认关闭)", ) # ── Response builder ────────────────────────────────────────────────────────── def _build_response(req: PredictRequest, base_date: str, pred_mean, ci, trading_low, trading_high, direction_prob, last_close, y_timestamp) -> dict: bullish_prob = float(direction_prob) direction_signal = "bullish" if bullish_prob >= 0.5 else "bearish" signal_prob = bullish_prob if direction_signal == "bullish" else (1 - bullish_prob) bands = [] for i in range(req.pred_len): band: dict = { "date": str(y_timestamp.iloc[i].date()), "step": i + 1, "mean_close": round(float(pred_mean["close"].iloc[i]), 4), "trading_low": round(float(trading_low[i]), 4), "trading_high": round(float(trading_high[i]), 4), "uncertainty": round( float((trading_high[i] - trading_low[i]) / last_close), 4 ), } if req.mode == "advanced": band.update({ "mean_open": round(float(pred_mean["open"].iloc[i]), 4), "mean_high": round(float(pred_mean["high"].iloc[i]), 4), "mean_low": round(float(pred_mean["low"].iloc[i]), 4), "close_ci_low": round(float(ci["close"]["low"][i]), 4), "close_ci_high": round(float(ci["close"]["high"][i]), 4), }) bands.append(band) result: dict = { "symbol": req.symbol, "base_date": base_date, "pred_len": req.pred_len, "confidence": 95, "confidence_warning": req.pred_len > 30, "direction": { "signal": direction_signal, "probability": round(signal_prob, 4), }, "summary": { "mean_close": round(float(pred_mean["close"].iloc[-1]), 4), "range_low": round(float(trading_low.min()), 4), "range_high": round(float(trading_high.max()), 4), "range_width": round(float(trading_high.max() - trading_low.min()), 4), }, "bands": bands, } if req.mode == "advanced" and req.include_volume: result["volume"] = [ { "date": str(y_timestamp.iloc[i].date()), "mean_volume": round(float(pred_mean["volume"].iloc[i])), "volume_ci_low": round(float(ci["volume"]["low"][i])), "volume_ci_high": round(float(ci["volume"]["high"][i])), } for i in range(req.pred_len) ] return result # ── Background task ─────────────────────────────────────────────────────────── def _run_prediction(task_id: str, req: PredictRequest) -> None: t_total_start = perf_counter() try: # ── Cache check ─────────────────────────────────────────────────────── cache_entry = _get_cached(req) if cache_entry is not None: total_ms = (perf_counter() - t_total_start) * 1000 logger.info( "Cache hit for %s (expires %s CST, total=%.1fms)", req.symbol, cache_entry["expires_at"].astimezone(_CST).strftime("%Y-%m-%d %H:%M"), total_ms, ) _tasks[task_id] = { "status": "done", "result": {**cache_entry["result"], "cached": True, "cache_expires_at": cache_entry["expires_at"].astimezone(_CST).strftime("%Y-%m-%d %H:%M:%S %Z")}, "error": None, } return # ── Full inference ──────────────────────────────────────────────────── t_fetch_start = perf_counter() x_df, x_timestamp, last_trade_date = data_fetcher.fetch_stock_data( req.symbol, req.lookback ) fetch_ms = (perf_counter() - t_fetch_start) * 1000 t_calendar_start = perf_counter() y_timestamp = data_fetcher.get_future_trading_dates(last_trade_date, req.pred_len) calendar_ms = (perf_counter() - t_calendar_start) * 1000 t_infer_start = perf_counter() pred_mean, ci, trading_low, trading_high, direction_prob, last_close = ( pred_module.run_mc_prediction( x_df, x_timestamp, y_timestamp, req.pred_len, req.sample_count ) ) infer_ms = (perf_counter() - t_infer_start) * 1000 t_build_start = perf_counter() base_date = str(pd.to_datetime(last_trade_date, format="%Y%m%d").date()) result = _build_response( req, base_date, pred_mean, ci, trading_low, trading_high, direction_prob, last_close, y_timestamp, ) build_ms = (perf_counter() - t_build_start) * 1000 # ── Store in cache ──────────────────────────────────────────────────── t_cache_start = perf_counter() _set_cache(req, result) cache_entry = _cache[_cache_key(req)] cache_ms = (perf_counter() - t_cache_start) * 1000 _tasks[task_id] = { "status": "done", "result": {**result, "cached": False, "cache_expires_at": cache_entry["expires_at"].astimezone(_CST).strftime("%Y-%m-%d %H:%M:%S %Z")}, "error": None, } total_ms = (perf_counter() - t_total_start) * 1000 logger.info( "Task %s timing symbol=%s fetch=%.1fms calendar=%.1fms infer=%.1fms build=%.1fms cache=%.1fms total=%.1fms", task_id, req.symbol, fetch_ms, calendar_ms, infer_ms, build_ms, cache_ms, total_ms, ) except Exception as exc: total_ms = (perf_counter() - t_total_start) * 1000 logger.exception("Task %s failed after %.1fms", task_id, total_ms) _tasks[task_id] = {"status": "failed", "result": None, "error": str(exc)} # ── Routes ──────────────────────────────────────────────────────────────────── @app.post( "/api/v1/predict", summary="提交预测任务", response_description="任务 ID,用于轮询结果", ) async def submit_predict(req: PredictRequest): """ 提交一个蒙特卡洛预测任务,立即返回 `task_id`。 通过 `GET /api/v1/predict/{task_id}` 轮询结果。 """ task_id = str(uuid.uuid4()) _tasks[task_id] = {"status": "pending", "result": None, "error": None} _executor.submit(_run_prediction, task_id, req) return {"task_id": task_id} @app.get( "/api/v1/predict/{task_id}", summary="查询预测结果", ) async def get_predict_result(task_id: str): """ 轮询预测任务状态。 - `status: "pending"` — 正在计算 - `status: "done"` — 完成,`result` 字段包含预测数据 - `status: "failed"` — 失败,`error` 字段包含错误信息 """ task = _tasks.get(task_id) if task is None: raise HTTPException(status_code=404, detail=f"Task {task_id!r} not found") return task @app.get("/api/v1/cache", summary="查看缓存状态") async def get_cache(symbol: str | None = None): """ 列出有效的缓存条目及其过期时间。 - 不传参数:返回全部 - `?symbol=000063.SZ`:只返回该股票的所有参数组合 """ now_utc = datetime.now(timezone.utc) entries = [] for key, entry in _cache.items(): if symbol and key[0] != symbol: continue remaining = (entry["expires_at"] - now_utc).total_seconds() if remaining > 0: entries.append({ "symbol": key[0], "lookback": key[1], "pred_len": key[2], "sample_count": key[3], "mode": key[4], "include_volume": key[5], "cached_at": entry["cached_at"].astimezone(_CST).strftime("%Y-%m-%d %H:%M:%S %Z"), "expires_at": entry["expires_at"].astimezone(_CST).strftime("%Y-%m-%d %H:%M:%S %Z"), "ttl_seconds": int(remaining), "result": entry["result"], }) return {"count": len(entries), "entries": entries} @app.get("/health", summary="健康检查") async def health(): return {"status": "ok"} # ── Batch schemas ───────────────────────────────────────────────────────────── class BatchPredictRequest(BaseModel): requests: List[PredictRequest] = Field( ..., min_length=1, max_length=20, description="预测请求列表(最多 20 个)", ) # ── Batch helper ────────────────────────────────────────────────────────────── def _batch_status(batch_id: str) -> dict: batch = _batches[batch_id] task_ids = batch["task_ids"] tasks = [{"task_id": tid, **_tasks[tid]} for tid in task_ids] n_done = sum(1 for t in tasks if t["status"] == "done") n_failed = sum(1 for t in tasks if t["status"] == "failed") n_total = len(task_ids) if n_done + n_failed == n_total: overall = "done" if n_failed == 0 else ("failed" if n_done == 0 else "partial") else: overall = "pending" return { "batch_id": batch_id, "status": overall, "total": n_total, "done": n_done, "failed": n_failed, "tasks": tasks, } # ── Batch routes ────────────────────────────────────────────────────────────── @app.post( "/api/v1/predict/batch", summary="批量提交预测任务", response_description="batch_id 及每个子任务的 task_id 列表", ) async def submit_batch(req: BatchPredictRequest): """ 一次提交多支股票(或多组参数)的预测任务,立即返回 `batch_id` 和 `task_ids`。 所有子任务并发进入同一 executor 队列,通过 `GET /api/v1/predict/batch/{batch_id}` 统一查询进度及结果。 """ batch_id = str(uuid.uuid4()) task_ids = [] for r in req.requests: task_id = str(uuid.uuid4()) _tasks[task_id] = {"status": "pending", "result": None, "error": None} _executor.submit(_run_prediction, task_id, r) task_ids.append(task_id) _batches[batch_id] = {"task_ids": task_ids} return {"batch_id": batch_id, "task_ids": task_ids} @app.get( "/api/v1/predict/batch/{batch_id}", summary="查询批量任务进度及结果", ) async def get_batch_result(batch_id: str): """ 轮询批量任务整体状态: - `status: "pending"` — 仍有子任务在计算 - `status: "done"` — 全部成功 - `status: "partial"` — 部分成功、部分失败 - `status: "failed"` — 全部失败 `tasks` 数组包含每个子任务的完整状态与结果。 """ if batch_id not in _batches: raise HTTPException(status_code=404, detail=f"Batch {batch_id!r} not found") return _batch_status(batch_id)