Spaces:
Running
Running
| """ | |
| Kronos Stock Predictor โ RESTful API | |
| ===================================== | |
| POST /api/v1/predict โ { "task_id": "uuid" } | |
| GET /api/v1/predict/{id} โ { "status": "pending|done|failed", "result": {...} } | |
| POST /api/v1/predict/batch โ { "batch_id": "uuid", "task_ids": [...] } | |
| GET /api/v1/predict/batch/{id} โ { "batch_id", "status", "total", "done", "failed", "tasks": [...] } | |
| GET /api/v1/cache โ cache contents & expiry info | |
| GET /health โ { "status": "ok" } | |
| """ | |
| import asyncio | |
| import logging | |
| import uuid | |
| from concurrent.futures import ThreadPoolExecutor | |
| from contextlib import asynccontextmanager | |
| from datetime import datetime, time, timedelta, timezone | |
| from time import perf_counter | |
| from typing import Literal, List | |
| import pandas as pd | |
| from fastapi import FastAPI, HTTPException | |
| from fastapi.middleware.cors import CORSMiddleware | |
| from pydantic import BaseModel, Field | |
| import data_fetcher | |
| import predictor as pred_module | |
| logging.basicConfig(level=logging.INFO) | |
| logger = logging.getLogger(__name__) | |
| # โโ Timezone / market-close helpers โโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโ | |
| _CST = timezone(timedelta(hours=8)) | |
| _MARKET_CLOSE = time(15, 0) # A-share close: 15:00 CST | |
| def _next_cache_expiry() -> datetime: | |
| """ | |
| Return the UTC datetime of the NEXT A-share market close (15:00 CST on a | |
| weekday), which is when new candle data becomes available and the cache | |
| should be invalidated. | |
| Chinese public holidays are intentionally ignored: on those days market | |
| data does not advance, so a cache hit is harmless. | |
| """ | |
| now_cst = datetime.now(_CST) | |
| today_close = now_cst.replace(hour=15, minute=0, second=0, microsecond=0) | |
| if now_cst.weekday() < 5 and now_cst < today_close: | |
| # Before today's close on a weekday โ expire at today 15:00 CST | |
| expiry_cst = today_close | |
| else: | |
| # After close, or on a weekend โ find next weekday's 15:00 CST | |
| candidate = now_cst + timedelta(days=1) | |
| while candidate.weekday() >= 5: # skip Sat(5) and Sun(6) | |
| candidate += timedelta(days=1) | |
| expiry_cst = candidate.replace(hour=15, minute=0, second=0, microsecond=0) | |
| return expiry_cst.astimezone(timezone.utc) | |
| # โโ Result cache โโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโ | |
| # key : (symbol, lookback, pred_len, sample_count, mode, include_volume) | |
| # value : {"result": dict, "expires_at": datetime(UTC), "cached_at": datetime(UTC)} | |
| _cache: dict[tuple, dict] = {} | |
| def _cache_key(req: "PredictRequest") -> tuple: | |
| return (req.symbol, req.lookback, req.pred_len, | |
| req.sample_count, req.mode, req.include_volume) | |
| def _get_cached(req: "PredictRequest") -> dict | None: | |
| entry = _cache.get(_cache_key(req)) | |
| if entry and datetime.now(timezone.utc) < entry["expires_at"]: | |
| return entry | |
| return None | |
| def _set_cache(req: "PredictRequest", result: dict) -> None: | |
| now_utc = datetime.now(timezone.utc) | |
| _cache[_cache_key(req)] = { | |
| "result": result, | |
| "expires_at": _next_cache_expiry(), | |
| "cached_at": now_utc, | |
| } | |
| logger.info( | |
| "Cached %s, expires at %s CST", | |
| req.symbol, | |
| _cache[_cache_key(req)]["expires_at"].astimezone(_CST).strftime("%Y-%m-%d %H:%M"), | |
| ) | |
| # โโ Task / Batch store โโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโ | |
| _tasks: dict[str, dict] = {} | |
| _batches: dict[str, dict] = {} | |
| _executor = ThreadPoolExecutor(max_workers=2) | |
| # โโ Startup: eagerly load the model so the first request isn't slow โโโโโโโโโโโ | |
| async def lifespan(app: FastAPI): | |
| loop = asyncio.get_event_loop() | |
| logger.info("Pre-loading Kronos predictor โฆ") | |
| await loop.run_in_executor(_executor, pred_module.get_predictor) | |
| logger.info("Kronos predictor ready.") | |
| yield | |
| app = FastAPI( | |
| title="Kronos Stock Predictor API", | |
| version="1.0.0", | |
| description=( | |
| "Monte-Carlo probabilistic stock forecasting powered by the " | |
| "Kronos foundation model (Tsinghua University)." | |
| ), | |
| lifespan=lifespan, | |
| ) | |
| app.add_middleware( | |
| CORSMiddleware, | |
| allow_origins=["*"], | |
| allow_methods=["*"], | |
| allow_headers=["*"], | |
| ) | |
| # โโ Request / Response schemas โโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโ | |
| class PredictRequest(BaseModel): | |
| symbol: str = Field( | |
| ..., | |
| examples=["603777", "600900.SH"], | |
| description="A ่กไปฃ็ ๏ผๆฏๆ 6 ไฝไปฃ็ ๆๅธฆๅธๅบๅ็ผ๏ผๅฆ 600900.SH๏ผ", | |
| ) | |
| lookback: int = Field( | |
| default=512, | |
| ge=20, | |
| le=512, | |
| description="ๅ็ๅๅฒ K ็บฟๆ นๆฐ๏ผๆๅค 512๏ผไธ่ถณๆถ่ชๅจๆชๆญ๏ผ", | |
| ) | |
| pred_len: int = Field( | |
| default=5, | |
| ge=1, | |
| le=60, | |
| description="้ขๆตๆชๆฅไบคๆๆฅๆฐ๏ผๅปบ่ฎฎ โค 30๏ผ่ถ ่ฟๆถ่ฟๅ confidence_warning๏ผ", | |
| ) | |
| sample_count: int = Field( | |
| default=30, | |
| ge=1, | |
| le=100, | |
| description="MC ่็นๅกๆด้ๆ ทๆฌกๆฐ", | |
| ) | |
| mode: Literal["simple", "advanced"] = Field( | |
| default="simple", | |
| description="simple: ไป ่ฟๅๅๅผ + ไบคๆๅบ้ด๏ผadvanced: ่ฟฝๅ OHLC ๅๅผๅๆถ็ CI", | |
| ) | |
| include_volume: bool = Field( | |
| default=False, | |
| description="mode=advanced ๆถๆฏๅฆ้ขๅค่ฟๅๆไบค้้ขๆต๏ผ้ป่ฎคๅ ณ้ญ๏ผ", | |
| ) | |
| # โโ Response builder โโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโ | |
| def _build_response(req: PredictRequest, base_date: str, pred_mean, ci, | |
| trading_low, trading_high, direction_prob, last_close, | |
| y_timestamp) -> dict: | |
| bullish_prob = float(direction_prob) | |
| direction_signal = "bullish" if bullish_prob >= 0.5 else "bearish" | |
| signal_prob = bullish_prob if direction_signal == "bullish" else (1 - bullish_prob) | |
| bands = [] | |
| for i in range(req.pred_len): | |
| band: dict = { | |
| "date": str(y_timestamp.iloc[i].date()), | |
| "step": i + 1, | |
| "mean_close": round(float(pred_mean["close"].iloc[i]), 4), | |
| "trading_low": round(float(trading_low[i]), 4), | |
| "trading_high": round(float(trading_high[i]), 4), | |
| "uncertainty": round( | |
| float((trading_high[i] - trading_low[i]) / last_close), 4 | |
| ), | |
| } | |
| if req.mode == "advanced": | |
| band.update({ | |
| "mean_open": round(float(pred_mean["open"].iloc[i]), 4), | |
| "mean_high": round(float(pred_mean["high"].iloc[i]), 4), | |
| "mean_low": round(float(pred_mean["low"].iloc[i]), 4), | |
| "close_ci_low": round(float(ci["close"]["low"][i]), 4), | |
| "close_ci_high": round(float(ci["close"]["high"][i]), 4), | |
| }) | |
| bands.append(band) | |
| result: dict = { | |
| "symbol": req.symbol, | |
| "base_date": base_date, | |
| "pred_len": req.pred_len, | |
| "confidence": 95, | |
| "confidence_warning": req.pred_len > 30, | |
| "direction": { | |
| "signal": direction_signal, | |
| "probability": round(signal_prob, 4), | |
| }, | |
| "summary": { | |
| "mean_close": round(float(pred_mean["close"].iloc[-1]), 4), | |
| "range_low": round(float(trading_low.min()), 4), | |
| "range_high": round(float(trading_high.max()), 4), | |
| "range_width": round(float(trading_high.max() - trading_low.min()), 4), | |
| }, | |
| "bands": bands, | |
| } | |
| if req.mode == "advanced" and req.include_volume: | |
| result["volume"] = [ | |
| { | |
| "date": str(y_timestamp.iloc[i].date()), | |
| "mean_volume": round(float(pred_mean["volume"].iloc[i])), | |
| "volume_ci_low": round(float(ci["volume"]["low"][i])), | |
| "volume_ci_high": round(float(ci["volume"]["high"][i])), | |
| } | |
| for i in range(req.pred_len) | |
| ] | |
| return result | |
| # โโ Background task โโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโ | |
| def _run_prediction(task_id: str, req: PredictRequest) -> None: | |
| t_total_start = perf_counter() | |
| try: | |
| # โโ Cache check โโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโ | |
| cache_entry = _get_cached(req) | |
| if cache_entry is not None: | |
| total_ms = (perf_counter() - t_total_start) * 1000 | |
| logger.info( | |
| "Cache hit for %s (expires %s CST, total=%.1fms)", | |
| req.symbol, | |
| cache_entry["expires_at"].astimezone(_CST).strftime("%Y-%m-%d %H:%M"), | |
| total_ms, | |
| ) | |
| _tasks[task_id] = { | |
| "status": "done", | |
| "result": {**cache_entry["result"], "cached": True, | |
| "cache_expires_at": cache_entry["expires_at"].astimezone(_CST).strftime("%Y-%m-%d %H:%M:%S %Z")}, | |
| "error": None, | |
| } | |
| return | |
| # โโ Full inference โโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโ | |
| t_fetch_start = perf_counter() | |
| x_df, x_timestamp, last_trade_date = data_fetcher.fetch_stock_data( | |
| req.symbol, req.lookback | |
| ) | |
| fetch_ms = (perf_counter() - t_fetch_start) * 1000 | |
| t_calendar_start = perf_counter() | |
| y_timestamp = data_fetcher.get_future_trading_dates(last_trade_date, req.pred_len) | |
| calendar_ms = (perf_counter() - t_calendar_start) * 1000 | |
| t_infer_start = perf_counter() | |
| pred_mean, ci, trading_low, trading_high, direction_prob, last_close = ( | |
| pred_module.run_mc_prediction( | |
| x_df, x_timestamp, y_timestamp, req.pred_len, req.sample_count | |
| ) | |
| ) | |
| infer_ms = (perf_counter() - t_infer_start) * 1000 | |
| t_build_start = perf_counter() | |
| base_date = str(pd.to_datetime(last_trade_date, format="%Y%m%d").date()) | |
| result = _build_response( | |
| req, base_date, pred_mean, ci, | |
| trading_low, trading_high, direction_prob, last_close, y_timestamp, | |
| ) | |
| build_ms = (perf_counter() - t_build_start) * 1000 | |
| # โโ Store in cache โโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโ | |
| t_cache_start = perf_counter() | |
| _set_cache(req, result) | |
| cache_entry = _cache[_cache_key(req)] | |
| cache_ms = (perf_counter() - t_cache_start) * 1000 | |
| _tasks[task_id] = { | |
| "status": "done", | |
| "result": {**result, "cached": False, | |
| "cache_expires_at": cache_entry["expires_at"].astimezone(_CST).strftime("%Y-%m-%d %H:%M:%S %Z")}, | |
| "error": None, | |
| } | |
| total_ms = (perf_counter() - t_total_start) * 1000 | |
| logger.info( | |
| "Task %s timing symbol=%s fetch=%.1fms calendar=%.1fms infer=%.1fms build=%.1fms cache=%.1fms total=%.1fms", | |
| task_id, | |
| req.symbol, | |
| fetch_ms, | |
| calendar_ms, | |
| infer_ms, | |
| build_ms, | |
| cache_ms, | |
| total_ms, | |
| ) | |
| except Exception as exc: | |
| total_ms = (perf_counter() - t_total_start) * 1000 | |
| logger.exception("Task %s failed after %.1fms", task_id, total_ms) | |
| _tasks[task_id] = {"status": "failed", "result": None, "error": str(exc)} | |
| # โโ Routes โโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโ | |
| async def submit_predict(req: PredictRequest): | |
| """ | |
| ๆไบคไธไธช่็นๅกๆด้ขๆตไปปๅก๏ผ็ซๅณ่ฟๅ `task_id`ใ | |
| ้่ฟ `GET /api/v1/predict/{task_id}` ่ฝฎ่ฏข็ปๆใ | |
| """ | |
| task_id = str(uuid.uuid4()) | |
| _tasks[task_id] = {"status": "pending", "result": None, "error": None} | |
| _executor.submit(_run_prediction, task_id, req) | |
| return {"task_id": task_id} | |
| async def get_predict_result(task_id: str): | |
| """ | |
| ่ฝฎ่ฏข้ขๆตไปปๅก็ถๆใ | |
| - `status: "pending"` โ ๆญฃๅจ่ฎก็ฎ | |
| - `status: "done"` โ ๅฎๆ๏ผ`result` ๅญๆฎตๅ ๅซ้ขๆตๆฐๆฎ | |
| - `status: "failed"` โ ๅคฑ่ดฅ๏ผ`error` ๅญๆฎตๅ ๅซ้่ฏฏไฟกๆฏ | |
| """ | |
| task = _tasks.get(task_id) | |
| if task is None: | |
| raise HTTPException(status_code=404, detail=f"Task {task_id!r} not found") | |
| return task | |
| async def get_cache(symbol: str | None = None): | |
| """ | |
| ๅๅบๆๆ็็ผๅญๆก็ฎๅๅ ถ่ฟๆๆถ้ดใ | |
| - ไธไผ ๅๆฐ๏ผ่ฟๅๅ จ้จ | |
| - `?symbol=000063.SZ`๏ผๅช่ฟๅ่ฏฅ่ก็ฅจ็ๆๆๅๆฐ็ปๅ | |
| """ | |
| now_utc = datetime.now(timezone.utc) | |
| entries = [] | |
| for key, entry in _cache.items(): | |
| if symbol and key[0] != symbol: | |
| continue | |
| remaining = (entry["expires_at"] - now_utc).total_seconds() | |
| if remaining > 0: | |
| entries.append({ | |
| "symbol": key[0], | |
| "lookback": key[1], | |
| "pred_len": key[2], | |
| "sample_count": key[3], | |
| "mode": key[4], | |
| "include_volume": key[5], | |
| "cached_at": entry["cached_at"].astimezone(_CST).strftime("%Y-%m-%d %H:%M:%S %Z"), | |
| "expires_at": entry["expires_at"].astimezone(_CST).strftime("%Y-%m-%d %H:%M:%S %Z"), | |
| "ttl_seconds": int(remaining), | |
| "result": entry["result"], | |
| }) | |
| return {"count": len(entries), "entries": entries} | |
| async def health(): | |
| return {"status": "ok"} | |
| # โโ Batch schemas โโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโ | |
| class BatchPredictRequest(BaseModel): | |
| requests: List[PredictRequest] = Field( | |
| ..., | |
| min_length=1, | |
| max_length=20, | |
| description="้ขๆต่ฏทๆฑๅ่กจ๏ผๆๅค 20 ไธช๏ผ", | |
| ) | |
| # โโ Batch helper โโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโ | |
| def _batch_status(batch_id: str) -> dict: | |
| batch = _batches[batch_id] | |
| task_ids = batch["task_ids"] | |
| tasks = [{"task_id": tid, **_tasks[tid]} for tid in task_ids] | |
| n_done = sum(1 for t in tasks if t["status"] == "done") | |
| n_failed = sum(1 for t in tasks if t["status"] == "failed") | |
| n_total = len(task_ids) | |
| if n_done + n_failed == n_total: | |
| overall = "done" if n_failed == 0 else ("failed" if n_done == 0 else "partial") | |
| else: | |
| overall = "pending" | |
| return { | |
| "batch_id": batch_id, | |
| "status": overall, | |
| "total": n_total, | |
| "done": n_done, | |
| "failed": n_failed, | |
| "tasks": tasks, | |
| } | |
| # โโ Batch routes โโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโ | |
| async def submit_batch(req: BatchPredictRequest): | |
| """ | |
| ไธๆฌกๆไบคๅคๆฏ่ก็ฅจ๏ผๆๅค็ปๅๆฐ๏ผ็้ขๆตไปปๅก๏ผ็ซๅณ่ฟๅ `batch_id` ๅ `task_ids`ใ | |
| ๆๆๅญไปปๅกๅนถๅ่ฟๅ ฅๅไธ executor ้ๅ๏ผ้่ฟ | |
| `GET /api/v1/predict/batch/{batch_id}` ็ปไธๆฅ่ฏข่ฟๅบฆๅ็ปๆใ | |
| """ | |
| batch_id = str(uuid.uuid4()) | |
| task_ids = [] | |
| for r in req.requests: | |
| task_id = str(uuid.uuid4()) | |
| _tasks[task_id] = {"status": "pending", "result": None, "error": None} | |
| _executor.submit(_run_prediction, task_id, r) | |
| task_ids.append(task_id) | |
| _batches[batch_id] = {"task_ids": task_ids} | |
| return {"batch_id": batch_id, "task_ids": task_ids} | |
| async def get_batch_result(batch_id: str): | |
| """ | |
| ่ฝฎ่ฏขๆน้ไปปๅกๆดไฝ็ถๆ๏ผ | |
| - `status: "pending"` โ ไปๆๅญไปปๅกๅจ่ฎก็ฎ | |
| - `status: "done"` โ ๅ จ้จๆๅ | |
| - `status: "partial"` โ ้จๅๆๅใ้จๅๅคฑ่ดฅ | |
| - `status: "failed"` โ ๅ จ้จๅคฑ่ดฅ | |
| `tasks` ๆฐ็ปๅ ๅซๆฏไธชๅญไปปๅก็ๅฎๆด็ถๆไธ็ปๆใ | |
| """ | |
| if batch_id not in _batches: | |
| raise HTTPException(status_code=404, detail=f"Batch {batch_id!r} not found") | |
| return _batch_status(batch_id) | |