kronos-api / app.py
fengwm
fix: report direction probability as signal confidence
4a3dd5b
"""
Kronos Stock Predictor โ€” RESTful API
=====================================
POST /api/v1/predict โ†’ { "task_id": "uuid" }
GET /api/v1/predict/{id} โ†’ { "status": "pending|done|failed", "result": {...} }
POST /api/v1/predict/batch โ†’ { "batch_id": "uuid", "task_ids": [...] }
GET /api/v1/predict/batch/{id} โ†’ { "batch_id", "status", "total", "done", "failed", "tasks": [...] }
GET /api/v1/cache โ†’ cache contents & expiry info
GET /health โ†’ { "status": "ok" }
"""
import asyncio
import logging
import uuid
from concurrent.futures import ThreadPoolExecutor
from contextlib import asynccontextmanager
from datetime import datetime, time, timedelta, timezone
from time import perf_counter
from typing import Literal, List
import pandas as pd
from fastapi import FastAPI, HTTPException
from fastapi.middleware.cors import CORSMiddleware
from pydantic import BaseModel, Field
import data_fetcher
import predictor as pred_module
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)
# โ”€โ”€ Timezone / market-close helpers โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€
_CST = timezone(timedelta(hours=8))
_MARKET_CLOSE = time(15, 0) # A-share close: 15:00 CST
def _next_cache_expiry() -> datetime:
"""
Return the UTC datetime of the NEXT A-share market close (15:00 CST on a
weekday), which is when new candle data becomes available and the cache
should be invalidated.
Chinese public holidays are intentionally ignored: on those days market
data does not advance, so a cache hit is harmless.
"""
now_cst = datetime.now(_CST)
today_close = now_cst.replace(hour=15, minute=0, second=0, microsecond=0)
if now_cst.weekday() < 5 and now_cst < today_close:
# Before today's close on a weekday โ†’ expire at today 15:00 CST
expiry_cst = today_close
else:
# After close, or on a weekend โ†’ find next weekday's 15:00 CST
candidate = now_cst + timedelta(days=1)
while candidate.weekday() >= 5: # skip Sat(5) and Sun(6)
candidate += timedelta(days=1)
expiry_cst = candidate.replace(hour=15, minute=0, second=0, microsecond=0)
return expiry_cst.astimezone(timezone.utc)
# โ”€โ”€ Result cache โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€
# key : (symbol, lookback, pred_len, sample_count, mode, include_volume)
# value : {"result": dict, "expires_at": datetime(UTC), "cached_at": datetime(UTC)}
_cache: dict[tuple, dict] = {}
def _cache_key(req: "PredictRequest") -> tuple:
return (req.symbol, req.lookback, req.pred_len,
req.sample_count, req.mode, req.include_volume)
def _get_cached(req: "PredictRequest") -> dict | None:
entry = _cache.get(_cache_key(req))
if entry and datetime.now(timezone.utc) < entry["expires_at"]:
return entry
return None
def _set_cache(req: "PredictRequest", result: dict) -> None:
now_utc = datetime.now(timezone.utc)
_cache[_cache_key(req)] = {
"result": result,
"expires_at": _next_cache_expiry(),
"cached_at": now_utc,
}
logger.info(
"Cached %s, expires at %s CST",
req.symbol,
_cache[_cache_key(req)]["expires_at"].astimezone(_CST).strftime("%Y-%m-%d %H:%M"),
)
# โ”€โ”€ Task / Batch store โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€
_tasks: dict[str, dict] = {}
_batches: dict[str, dict] = {}
_executor = ThreadPoolExecutor(max_workers=2)
# โ”€โ”€ Startup: eagerly load the model so the first request isn't slow โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€
@asynccontextmanager
async def lifespan(app: FastAPI):
loop = asyncio.get_event_loop()
logger.info("Pre-loading Kronos predictor โ€ฆ")
await loop.run_in_executor(_executor, pred_module.get_predictor)
logger.info("Kronos predictor ready.")
yield
app = FastAPI(
title="Kronos Stock Predictor API",
version="1.0.0",
description=(
"Monte-Carlo probabilistic stock forecasting powered by the "
"Kronos foundation model (Tsinghua University)."
),
lifespan=lifespan,
)
app.add_middleware(
CORSMiddleware,
allow_origins=["*"],
allow_methods=["*"],
allow_headers=["*"],
)
# โ”€โ”€ Request / Response schemas โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€
class PredictRequest(BaseModel):
symbol: str = Field(
...,
examples=["603777", "600900.SH"],
description="A ่‚กไปฃ็ ๏ผ›ๆ”ฏๆŒ 6 ไฝไปฃ็ ๆˆ–ๅธฆๅธ‚ๅœบๅŽ็ผ€๏ผˆๅฆ‚ 600900.SH๏ผ‰",
)
lookback: int = Field(
default=512,
ge=20,
le=512,
description="ๅ›ž็œ‹ๅކๅฒ K ็บฟๆ นๆ•ฐ๏ผˆๆœ€ๅคš 512๏ผŒไธ่ถณๆ—ถ่‡ชๅŠจๆˆชๆ–ญ๏ผ‰",
)
pred_len: int = Field(
default=5,
ge=1,
le=60,
description="้ข„ๆต‹ๆœชๆฅไบคๆ˜“ๆ—ฅๆ•ฐ๏ผˆๅปบ่ฎฎ โ‰ค 30๏ผŒ่ถ…่ฟ‡ๆ—ถ่ฟ”ๅ›ž confidence_warning๏ผ‰",
)
sample_count: int = Field(
default=30,
ge=1,
le=100,
description="MC ่’™็‰นๅกๆด›้‡‡ๆ ทๆฌกๆ•ฐ",
)
mode: Literal["simple", "advanced"] = Field(
default="simple",
description="simple: ไป…่ฟ”ๅ›žๅ‡ๅ€ผ + ไบคๆ˜“ๅŒบ้—ด๏ผ›advanced: ่ฟฝๅŠ  OHLC ๅ‡ๅ€ผๅŠๆ”ถ็›˜ CI",
)
include_volume: bool = Field(
default=False,
description="mode=advanced ๆ—ถๆ˜ฏๅฆ้ขๅค–่ฟ”ๅ›žๆˆไบค้‡้ข„ๆต‹๏ผˆ้ป˜่ฎคๅ…ณ้—ญ๏ผ‰",
)
# โ”€โ”€ Response builder โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€
def _build_response(req: PredictRequest, base_date: str, pred_mean, ci,
trading_low, trading_high, direction_prob, last_close,
y_timestamp) -> dict:
bullish_prob = float(direction_prob)
direction_signal = "bullish" if bullish_prob >= 0.5 else "bearish"
signal_prob = bullish_prob if direction_signal == "bullish" else (1 - bullish_prob)
bands = []
for i in range(req.pred_len):
band: dict = {
"date": str(y_timestamp.iloc[i].date()),
"step": i + 1,
"mean_close": round(float(pred_mean["close"].iloc[i]), 4),
"trading_low": round(float(trading_low[i]), 4),
"trading_high": round(float(trading_high[i]), 4),
"uncertainty": round(
float((trading_high[i] - trading_low[i]) / last_close), 4
),
}
if req.mode == "advanced":
band.update({
"mean_open": round(float(pred_mean["open"].iloc[i]), 4),
"mean_high": round(float(pred_mean["high"].iloc[i]), 4),
"mean_low": round(float(pred_mean["low"].iloc[i]), 4),
"close_ci_low": round(float(ci["close"]["low"][i]), 4),
"close_ci_high": round(float(ci["close"]["high"][i]), 4),
})
bands.append(band)
result: dict = {
"symbol": req.symbol,
"base_date": base_date,
"pred_len": req.pred_len,
"confidence": 95,
"confidence_warning": req.pred_len > 30,
"direction": {
"signal": direction_signal,
"probability": round(signal_prob, 4),
},
"summary": {
"mean_close": round(float(pred_mean["close"].iloc[-1]), 4),
"range_low": round(float(trading_low.min()), 4),
"range_high": round(float(trading_high.max()), 4),
"range_width": round(float(trading_high.max() - trading_low.min()), 4),
},
"bands": bands,
}
if req.mode == "advanced" and req.include_volume:
result["volume"] = [
{
"date": str(y_timestamp.iloc[i].date()),
"mean_volume": round(float(pred_mean["volume"].iloc[i])),
"volume_ci_low": round(float(ci["volume"]["low"][i])),
"volume_ci_high": round(float(ci["volume"]["high"][i])),
}
for i in range(req.pred_len)
]
return result
# โ”€โ”€ Background task โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€
def _run_prediction(task_id: str, req: PredictRequest) -> None:
t_total_start = perf_counter()
try:
# โ”€โ”€ Cache check โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€
cache_entry = _get_cached(req)
if cache_entry is not None:
total_ms = (perf_counter() - t_total_start) * 1000
logger.info(
"Cache hit for %s (expires %s CST, total=%.1fms)",
req.symbol,
cache_entry["expires_at"].astimezone(_CST).strftime("%Y-%m-%d %H:%M"),
total_ms,
)
_tasks[task_id] = {
"status": "done",
"result": {**cache_entry["result"], "cached": True,
"cache_expires_at": cache_entry["expires_at"].astimezone(_CST).strftime("%Y-%m-%d %H:%M:%S %Z")},
"error": None,
}
return
# โ”€โ”€ Full inference โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€
t_fetch_start = perf_counter()
x_df, x_timestamp, last_trade_date = data_fetcher.fetch_stock_data(
req.symbol, req.lookback
)
fetch_ms = (perf_counter() - t_fetch_start) * 1000
t_calendar_start = perf_counter()
y_timestamp = data_fetcher.get_future_trading_dates(last_trade_date, req.pred_len)
calendar_ms = (perf_counter() - t_calendar_start) * 1000
t_infer_start = perf_counter()
pred_mean, ci, trading_low, trading_high, direction_prob, last_close = (
pred_module.run_mc_prediction(
x_df, x_timestamp, y_timestamp, req.pred_len, req.sample_count
)
)
infer_ms = (perf_counter() - t_infer_start) * 1000
t_build_start = perf_counter()
base_date = str(pd.to_datetime(last_trade_date, format="%Y%m%d").date())
result = _build_response(
req, base_date, pred_mean, ci,
trading_low, trading_high, direction_prob, last_close, y_timestamp,
)
build_ms = (perf_counter() - t_build_start) * 1000
# โ”€โ”€ Store in cache โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€
t_cache_start = perf_counter()
_set_cache(req, result)
cache_entry = _cache[_cache_key(req)]
cache_ms = (perf_counter() - t_cache_start) * 1000
_tasks[task_id] = {
"status": "done",
"result": {**result, "cached": False,
"cache_expires_at": cache_entry["expires_at"].astimezone(_CST).strftime("%Y-%m-%d %H:%M:%S %Z")},
"error": None,
}
total_ms = (perf_counter() - t_total_start) * 1000
logger.info(
"Task %s timing symbol=%s fetch=%.1fms calendar=%.1fms infer=%.1fms build=%.1fms cache=%.1fms total=%.1fms",
task_id,
req.symbol,
fetch_ms,
calendar_ms,
infer_ms,
build_ms,
cache_ms,
total_ms,
)
except Exception as exc:
total_ms = (perf_counter() - t_total_start) * 1000
logger.exception("Task %s failed after %.1fms", task_id, total_ms)
_tasks[task_id] = {"status": "failed", "result": None, "error": str(exc)}
# โ”€โ”€ Routes โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€
@app.post(
"/api/v1/predict",
summary="ๆไบค้ข„ๆต‹ไปปๅŠก",
response_description="ไปปๅŠก ID๏ผŒ็”จไบŽ่ฝฎ่ฏข็ป“ๆžœ",
)
async def submit_predict(req: PredictRequest):
"""
ๆไบคไธ€ไธช่’™็‰นๅกๆด›้ข„ๆต‹ไปปๅŠก๏ผŒ็ซ‹ๅณ่ฟ”ๅ›ž `task_id`ใ€‚
้€š่ฟ‡ `GET /api/v1/predict/{task_id}` ่ฝฎ่ฏข็ป“ๆžœใ€‚
"""
task_id = str(uuid.uuid4())
_tasks[task_id] = {"status": "pending", "result": None, "error": None}
_executor.submit(_run_prediction, task_id, req)
return {"task_id": task_id}
@app.get(
"/api/v1/predict/{task_id}",
summary="ๆŸฅ่ฏข้ข„ๆต‹็ป“ๆžœ",
)
async def get_predict_result(task_id: str):
"""
่ฝฎ่ฏข้ข„ๆต‹ไปปๅŠก็Šถๆ€ใ€‚
- `status: "pending"` โ€” ๆญฃๅœจ่ฎก็ฎ—
- `status: "done"` โ€” ๅฎŒๆˆ๏ผŒ`result` ๅญ—ๆฎตๅŒ…ๅซ้ข„ๆต‹ๆ•ฐๆฎ
- `status: "failed"` โ€” ๅคฑ่ดฅ๏ผŒ`error` ๅญ—ๆฎตๅŒ…ๅซ้”™่ฏฏไฟกๆฏ
"""
task = _tasks.get(task_id)
if task is None:
raise HTTPException(status_code=404, detail=f"Task {task_id!r} not found")
return task
@app.get("/api/v1/cache", summary="ๆŸฅ็œ‹็ผ“ๅญ˜็Šถๆ€")
async def get_cache(symbol: str | None = None):
"""
ๅˆ—ๅ‡บๆœ‰ๆ•ˆ็š„็ผ“ๅญ˜ๆก็›ฎๅŠๅ…ถ่ฟ‡ๆœŸๆ—ถ้—ดใ€‚
- ไธไผ ๅ‚ๆ•ฐ๏ผš่ฟ”ๅ›žๅ…จ้ƒจ
- `?symbol=000063.SZ`๏ผšๅช่ฟ”ๅ›ž่ฏฅ่‚ก็ฅจ็š„ๆ‰€ๆœ‰ๅ‚ๆ•ฐ็ป„ๅˆ
"""
now_utc = datetime.now(timezone.utc)
entries = []
for key, entry in _cache.items():
if symbol and key[0] != symbol:
continue
remaining = (entry["expires_at"] - now_utc).total_seconds()
if remaining > 0:
entries.append({
"symbol": key[0],
"lookback": key[1],
"pred_len": key[2],
"sample_count": key[3],
"mode": key[4],
"include_volume": key[5],
"cached_at": entry["cached_at"].astimezone(_CST).strftime("%Y-%m-%d %H:%M:%S %Z"),
"expires_at": entry["expires_at"].astimezone(_CST).strftime("%Y-%m-%d %H:%M:%S %Z"),
"ttl_seconds": int(remaining),
"result": entry["result"],
})
return {"count": len(entries), "entries": entries}
@app.get("/health", summary="ๅฅๅบทๆฃ€ๆŸฅ")
async def health():
return {"status": "ok"}
# โ”€โ”€ Batch schemas โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€
class BatchPredictRequest(BaseModel):
requests: List[PredictRequest] = Field(
...,
min_length=1,
max_length=20,
description="้ข„ๆต‹่ฏทๆฑ‚ๅˆ—่กจ๏ผˆๆœ€ๅคš 20 ไธช๏ผ‰",
)
# โ”€โ”€ Batch helper โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€
def _batch_status(batch_id: str) -> dict:
batch = _batches[batch_id]
task_ids = batch["task_ids"]
tasks = [{"task_id": tid, **_tasks[tid]} for tid in task_ids]
n_done = sum(1 for t in tasks if t["status"] == "done")
n_failed = sum(1 for t in tasks if t["status"] == "failed")
n_total = len(task_ids)
if n_done + n_failed == n_total:
overall = "done" if n_failed == 0 else ("failed" if n_done == 0 else "partial")
else:
overall = "pending"
return {
"batch_id": batch_id,
"status": overall,
"total": n_total,
"done": n_done,
"failed": n_failed,
"tasks": tasks,
}
# โ”€โ”€ Batch routes โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€
@app.post(
"/api/v1/predict/batch",
summary="ๆ‰น้‡ๆไบค้ข„ๆต‹ไปปๅŠก",
response_description="batch_id ๅŠๆฏไธชๅญไปปๅŠก็š„ task_id ๅˆ—่กจ",
)
async def submit_batch(req: BatchPredictRequest):
"""
ไธ€ๆฌกๆไบคๅคšๆ”ฏ่‚ก็ฅจ๏ผˆๆˆ–ๅคš็ป„ๅ‚ๆ•ฐ๏ผ‰็š„้ข„ๆต‹ไปปๅŠก๏ผŒ็ซ‹ๅณ่ฟ”ๅ›ž `batch_id` ๅ’Œ `task_ids`ใ€‚
ๆ‰€ๆœ‰ๅญไปปๅŠกๅนถๅ‘่ฟ›ๅ…ฅๅŒไธ€ executor ้˜Ÿๅˆ—๏ผŒ้€š่ฟ‡
`GET /api/v1/predict/batch/{batch_id}` ็ปŸไธ€ๆŸฅ่ฏข่ฟ›ๅบฆๅŠ็ป“ๆžœใ€‚
"""
batch_id = str(uuid.uuid4())
task_ids = []
for r in req.requests:
task_id = str(uuid.uuid4())
_tasks[task_id] = {"status": "pending", "result": None, "error": None}
_executor.submit(_run_prediction, task_id, r)
task_ids.append(task_id)
_batches[batch_id] = {"task_ids": task_ids}
return {"batch_id": batch_id, "task_ids": task_ids}
@app.get(
"/api/v1/predict/batch/{batch_id}",
summary="ๆŸฅ่ฏขๆ‰น้‡ไปปๅŠก่ฟ›ๅบฆๅŠ็ป“ๆžœ",
)
async def get_batch_result(batch_id: str):
"""
่ฝฎ่ฏขๆ‰น้‡ไปปๅŠกๆ•ดไฝ“็Šถๆ€๏ผš
- `status: "pending"` โ€” ไปๆœ‰ๅญไปปๅŠกๅœจ่ฎก็ฎ—
- `status: "done"` โ€” ๅ…จ้ƒจๆˆๅŠŸ
- `status: "partial"` โ€” ้ƒจๅˆ†ๆˆๅŠŸใ€้ƒจๅˆ†ๅคฑ่ดฅ
- `status: "failed"` โ€” ๅ…จ้ƒจๅคฑ่ดฅ
`tasks` ๆ•ฐ็ป„ๅŒ…ๅซๆฏไธชๅญไปปๅŠก็š„ๅฎŒๆ•ด็Šถๆ€ไธŽ็ป“ๆžœใ€‚
"""
if batch_id not in _batches:
raise HTTPException(status_code=404, detail=f"Batch {batch_id!r} not found")
return _batch_status(batch_id)