kronos-forecast / app.py
Jenak5's picture
Update app.py
42bd6a2 verified
"""
Kronos Forecast API β€” HuggingFace Spaces Backend
Fine-tuned on 17 years of NQ + ES data. v4.1.
Added caching to avoid yfinance rate limits.
"""
import os
import sys
import logging
import time
from datetime import datetime
from typing import Optional
import numpy as np
import pandas as pd
import torch
import yfinance as yf
sys.path.insert(0, "/app/kronos")
from model import Kronos, KronosTokenizer, KronosPredictor
from fastapi import FastAPI, HTTPException, Query
from fastapi.middleware.cors import CORSMiddleware
from pydantic import BaseModel
from safetensors.torch import load_file
from huggingface_hub import hf_hub_download
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger("kronos-api")
app = FastAPI(title="Kronos Forecast API", version="4.1.0")
app.add_middleware(
CORSMiddleware,
allow_origins=["*"],
allow_methods=["GET"],
allow_headers=["*"],
)
predictor: Optional[KronosPredictor] = None
# ═══════════════════════════════════════════════════════
# CACHE β€” store responses for 5 minutes to avoid rate limits
# ═══════════════════════════════════════════════════════
CACHE = {}
CACHE_TTL = 300 # 5 minutes
TICKER_MAP = {"NQ": "NQ=F", "ES": "ES=F"}
TIMEFRAME_MAP = {
"5m": {"interval": "5m", "period": "5d", "forecast_bars": 12},
"15m": {"interval": "15m", "period": "30d", "forecast_bars": 8},
"1h": {"interval": "1h", "period": "60d", "forecast_bars": 8},
"4h": {"interval": "1h", "period": "60d", "forecast_bars": 8, "resample": "4h"},
}
@app.on_event("startup")
async def load_model():
global predictor
hf_token = os.environ.get("HF_TOKEN")
logger.info("Loading fine-tuned Kronos NQ+ES v2 model …")
tokenizer = KronosTokenizer.from_pretrained("NeoQuasar/Kronos-Tokenizer-base")
model = Kronos.from_pretrained("NeoQuasar/Kronos-mini")
try:
tok_path = hf_hub_download(
repo_id="Jenak5/Kronos-NQ-ES-Tokenizer-v2",
filename="model.safetensors",
token=hf_token,
)
tokenizer.load_state_dict(load_file(tok_path), strict=False)
logger.info("Fine-tuned tokenizer weights loaded.")
except Exception as e:
logger.warning(f"Tokenizer weights not loaded: {e}")
try:
pred_path = hf_hub_download(
repo_id="Jenak5/Kronos-NQ-ES-v2",
filename="model.safetensors",
token=hf_token,
)
model.load_state_dict(load_file(pred_path), strict=False)
logger.info("Fine-tuned predictor weights loaded.")
except Exception as e:
logger.warning(f"Predictor weights not loaded: {e}")
predictor = KronosPredictor(model, tokenizer, device="cpu", max_context=512)
logger.info("Kronos NQ+ES v4.1 ready.")
class CandleOut(BaseModel):
timestamp: str
open: float
high: float
low: float
close: float
class ForecastResponse(BaseModel):
instrument: str
timeframe: str
generated_at: str
historical: list[CandleOut]
forecast_mean: list[CandleOut]
forecast_upper: list[CandleOut]
forecast_lower: list[CandleOut]
direction: str
confidence: float
volatility_ratio: float
trading_context: str
cached: bool = False
def fetch_candles(ticker: str, interval: str, period: str, resample: str = None) -> pd.DataFrame:
raw = yf.download(ticker, period=period, interval=interval, progress=False)
if raw.empty:
raise HTTPException(status_code=502, detail=f"No data for {ticker}")
df = raw.reset_index()
if hasattr(df.columns, 'levels'):
df.columns = [c[0] if isinstance(c, tuple) else c for c in df.columns]
rename = {c: c.lower() for c in df.columns}
df.rename(columns=rename, inplace=True)
if "datetime" in df.columns:
df.rename(columns={"datetime": "timestamp"}, inplace=True)
elif "date" in df.columns:
df.rename(columns={"date": "timestamp"}, inplace=True)
df["timestamp"] = pd.to_datetime(df["timestamp"])
df = df[["timestamp", "open", "high", "low", "close"]].dropna()
if resample:
df = df.set_index("timestamp")
df = df.resample(resample).agg(
{"open": "first", "high": "max", "low": "min", "close": "last"}
).dropna().reset_index()
return df
def run_forecast(df: pd.DataFrame, forecast_bars: int, n_samples: int = 10):
lookback = min(len(df), 400)
x_df = df.tail(lookback).reset_index(drop=True)
freq = pd.infer_freq(x_df["timestamp"])
if freq is None:
delta = x_df["timestamp"].iloc[-1] - x_df["timestamp"].iloc[-2]
future_ts = pd.Series([x_df["timestamp"].iloc[-1] + delta * (i + 1) for i in range(forecast_bars)])
else:
future_ts = pd.Series(pd.date_range(
start=x_df["timestamp"].iloc[-1],
periods=forecast_bars + 1,
freq=freq,
)[1:])
samples = []
for _ in range(n_samples):
pred_df = predictor.predict(
df=x_df[["open", "high", "low", "close"]],
x_timestamp=x_df["timestamp"],
y_timestamp=future_ts,
pred_len=forecast_bars,
T=0.3,
top_p=0.5,
sample_count=1,
)
samples.append(pred_df[["open", "high", "low", "close"]].values)
samples = np.array(samples)
mean = samples.mean(axis=0)
upper = np.percentile(samples, 90, axis=0)
lower = np.percentile(samples, 10, axis=0)
return mean, upper, lower, future_ts
def calc_direction(mean_candles: np.ndarray, last_close: float):
final_close = mean_candles[-1, 3]
pct_change = (final_close - last_close) / last_close * 100
if pct_change > 0.10:
return "BULLISH", min(abs(pct_change) * 30, 95)
elif pct_change < -0.10:
return "BEARISH", min(abs(pct_change) * 30, 95)
else:
return "NEUTRAL", max(50 - abs(pct_change) * 150, 10)
def calc_vol_ratio(mean_candles: np.ndarray, hist_df: pd.DataFrame):
pred_ranges = mean_candles[:, 1] - mean_candles[:, 2]
hist_ranges = (hist_df["high"] - hist_df["low"]).tail(len(mean_candles)).values
if hist_ranges.mean() == 0:
return 1.0
return float(pred_ranges.mean() / hist_ranges.mean())
def get_trading_context(vol_ratio: float, timeframe: str, direction: str, confidence: float) -> str:
"""Generate trading context based on backtested rules."""
if timeframe == "4h":
if vol_ratio < 0.6:
return "COMPRESSED VOL β€” Magic Hour reversion 94.6% reliable (5.4% fail rate). Full size on fade-to-midpoint setups. Expect reversion within 1 bar."
elif vol_ratio < 0.8:
return "LOW VOL β€” Magic Hour reversion 93.6% reliable (6.4% fail rate). Strong conditions for fade-to-midpoint trades."
elif vol_ratio < 1.2:
return "NORMAL VOL β€” Magic Hour reversion 90.3% reliable. Standard conditions, trade normally."
else:
return "ELEVATED VOL β€” Magic Hour reversion drops to 84.5% (15.5% fail rate). Reduce size on reversion trades or wait for compression."
if timeframe == "1h":
if vol_ratio < 1.2:
return f"LOW VOL β€” Sniper window ELITE zones at 65.9% WR. Full size if ELITE zone (>=18pt NQ / >=5pt ES) aligns with {direction} bias."
elif vol_ratio < 1.5:
return f"NORMAL VOL β€” Sniper window at 59.8% WR. Trade ELITE zones at full size, GOOD zones at half size."
else:
return f"ELEVATED VOL β€” Sniper window accuracy drops. Half size only, require ELITE zone confirmation."
if timeframe == "15m":
return f"15m structure β€” Use to time entries within your active window. {direction} bias with {confidence:.0f}% confidence."
if timeframe == "5m":
return f"5m tactical β€” Use for precise entry timing. Look for candle confirmation at your level before entering."
return ""
def candles_to_list(arr: np.ndarray, timestamps) -> list[CandleOut]:
out = []
for i, ts in enumerate(timestamps):
out.append(CandleOut(
timestamp=str(ts),
open=round(float(arr[i, 0]), 2),
high=round(float(arr[i, 1]), 2),
low=round(float(arr[i, 2]), 2),
close=round(float(arr[i, 3]), 2),
))
return out
@app.get("/forecast", response_model=ForecastResponse)
async def get_forecast(
instrument: str = Query("NQ", pattern="^(NQ|ES)$"),
timeframe: str = Query("1h", pattern="^(5m|15m|1h|4h)$"),
):
if predictor is None:
raise HTTPException(status_code=503, detail="Model still loading")
# Check cache first
cache_key = f"{instrument}_{timeframe}"
now = time.time()
if cache_key in CACHE:
cached_time, cached_response = CACHE[cache_key]
if now - cached_time < CACHE_TTL:
logger.info(f"Serving cached response for {cache_key} (age: {int(now - cached_time)}s)")
cached_response.cached = True
return cached_response
ticker = TICKER_MAP[instrument]
tf_cfg = TIMEFRAME_MAP[timeframe]
resample = tf_cfg.get("resample")
df = fetch_candles(ticker, tf_cfg["interval"], tf_cfg["period"], resample)
logger.info(f"Fetched {len(df)} candles for {instrument} @ {timeframe}")
mean, upper, lower, future_ts = run_forecast(df, tf_cfg["forecast_bars"])
last_close = float(df["close"].iloc[-1])
direction, confidence = calc_direction(mean, last_close)
vol_ratio = calc_vol_ratio(mean, df)
context = get_trading_context(vol_ratio, timeframe, direction, confidence)
hist_tail = df.tail(50)
historical = [
CandleOut(
timestamp=str(row.timestamp),
open=round(float(row.open), 2),
high=round(float(row.high), 2),
low=round(float(row.low), 2),
close=round(float(row.close), 2),
)
for row in hist_tail.itertuples()
]
response = ForecastResponse(
instrument=instrument,
timeframe=timeframe,
generated_at=datetime.utcnow().isoformat() + "Z",
historical=historical,
forecast_mean=candles_to_list(mean, future_ts),
forecast_upper=candles_to_list(upper, future_ts),
forecast_lower=candles_to_list(lower, future_ts),
direction=direction,
confidence=round(confidence, 1),
volatility_ratio=round(vol_ratio, 2),
trading_context=context,
cached=False,
)
# Store in cache
CACHE[cache_key] = (now, response)
return response
@app.get("/health")
async def health():
return {
"status": "ok",
"model_loaded": predictor is not None,
"cache_keys": list(CACHE.keys()),
"cache_ages": {k: int(time.time() - v[0]) for k, v in CACHE.items()},
}