Spaces:
Sleeping
Sleeping
File size: 10,970 Bytes
05c437e 42bd6a2 05c437e 290300c 05c437e 42bd6a2 186d38d 05c437e 5ed89b9 05c437e 290300c 05c437e 186d38d 5ed89b9 05c437e 42bd6a2 05c437e 290300c 05c437e 42bd6a2 186d38d 05c437e bafaf0f 05c437e 186d38d 5ed89b9 186d38d 5ed89b9 186d38d 5ed89b9 186d38d 5ed89b9 186d38d 5ed89b9 186d38d 05c437e 186d38d 7fd312d 42bd6a2 7fd312d bafaf0f 42bd6a2 7fd312d bafaf0f 7fd312d bafaf0f 7fd312d bafaf0f 7fd312d bafaf0f 7fd312d 42bd6a2 7fd312d bafaf0f 7fd312d bafaf0f 7fd312d 42bd6a2 7fd312d bafaf0f 42bd6a2 7fd312d 42bd6a2 7fd312d 42bd6a2 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 270 271 272 273 274 275 276 277 278 279 280 281 282 283 284 285 286 287 288 289 290 291 292 293 294 295 296 297 298 299 300 301 302 303 304 305 306 307 308 309 310 | """
Kronos Forecast API β HuggingFace Spaces Backend
Fine-tuned on 17 years of NQ + ES data. v4.1.
Added caching to avoid yfinance rate limits.
"""
import os
import sys
import logging
import time
from datetime import datetime
from typing import Optional
import numpy as np
import pandas as pd
import torch
import yfinance as yf
sys.path.insert(0, "/app/kronos")
from model import Kronos, KronosTokenizer, KronosPredictor
from fastapi import FastAPI, HTTPException, Query
from fastapi.middleware.cors import CORSMiddleware
from pydantic import BaseModel
from safetensors.torch import load_file
from huggingface_hub import hf_hub_download
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger("kronos-api")
app = FastAPI(title="Kronos Forecast API", version="4.1.0")
app.add_middleware(
CORSMiddleware,
allow_origins=["*"],
allow_methods=["GET"],
allow_headers=["*"],
)
predictor: Optional[KronosPredictor] = None
# βββββββββββββββββββββββββββββββββββββββββββββββββββββββ
# CACHE β store responses for 5 minutes to avoid rate limits
# βββββββββββββββββββββββββββββββββββββββββββββββββββββββ
CACHE = {}
CACHE_TTL = 300 # 5 minutes
TICKER_MAP = {"NQ": "NQ=F", "ES": "ES=F"}
TIMEFRAME_MAP = {
"5m": {"interval": "5m", "period": "5d", "forecast_bars": 12},
"15m": {"interval": "15m", "period": "30d", "forecast_bars": 8},
"1h": {"interval": "1h", "period": "60d", "forecast_bars": 8},
"4h": {"interval": "1h", "period": "60d", "forecast_bars": 8, "resample": "4h"},
}
@app.on_event("startup")
async def load_model():
global predictor
hf_token = os.environ.get("HF_TOKEN")
logger.info("Loading fine-tuned Kronos NQ+ES v2 model β¦")
tokenizer = KronosTokenizer.from_pretrained("NeoQuasar/Kronos-Tokenizer-base")
model = Kronos.from_pretrained("NeoQuasar/Kronos-mini")
try:
tok_path = hf_hub_download(
repo_id="Jenak5/Kronos-NQ-ES-Tokenizer-v2",
filename="model.safetensors",
token=hf_token,
)
tokenizer.load_state_dict(load_file(tok_path), strict=False)
logger.info("Fine-tuned tokenizer weights loaded.")
except Exception as e:
logger.warning(f"Tokenizer weights not loaded: {e}")
try:
pred_path = hf_hub_download(
repo_id="Jenak5/Kronos-NQ-ES-v2",
filename="model.safetensors",
token=hf_token,
)
model.load_state_dict(load_file(pred_path), strict=False)
logger.info("Fine-tuned predictor weights loaded.")
except Exception as e:
logger.warning(f"Predictor weights not loaded: {e}")
predictor = KronosPredictor(model, tokenizer, device="cpu", max_context=512)
logger.info("Kronos NQ+ES v4.1 ready.")
class CandleOut(BaseModel):
timestamp: str
open: float
high: float
low: float
close: float
class ForecastResponse(BaseModel):
instrument: str
timeframe: str
generated_at: str
historical: list[CandleOut]
forecast_mean: list[CandleOut]
forecast_upper: list[CandleOut]
forecast_lower: list[CandleOut]
direction: str
confidence: float
volatility_ratio: float
trading_context: str
cached: bool = False
def fetch_candles(ticker: str, interval: str, period: str, resample: str = None) -> pd.DataFrame:
raw = yf.download(ticker, period=period, interval=interval, progress=False)
if raw.empty:
raise HTTPException(status_code=502, detail=f"No data for {ticker}")
df = raw.reset_index()
if hasattr(df.columns, 'levels'):
df.columns = [c[0] if isinstance(c, tuple) else c for c in df.columns]
rename = {c: c.lower() for c in df.columns}
df.rename(columns=rename, inplace=True)
if "datetime" in df.columns:
df.rename(columns={"datetime": "timestamp"}, inplace=True)
elif "date" in df.columns:
df.rename(columns={"date": "timestamp"}, inplace=True)
df["timestamp"] = pd.to_datetime(df["timestamp"])
df = df[["timestamp", "open", "high", "low", "close"]].dropna()
if resample:
df = df.set_index("timestamp")
df = df.resample(resample).agg(
{"open": "first", "high": "max", "low": "min", "close": "last"}
).dropna().reset_index()
return df
def run_forecast(df: pd.DataFrame, forecast_bars: int, n_samples: int = 10):
lookback = min(len(df), 400)
x_df = df.tail(lookback).reset_index(drop=True)
freq = pd.infer_freq(x_df["timestamp"])
if freq is None:
delta = x_df["timestamp"].iloc[-1] - x_df["timestamp"].iloc[-2]
future_ts = pd.Series([x_df["timestamp"].iloc[-1] + delta * (i + 1) for i in range(forecast_bars)])
else:
future_ts = pd.Series(pd.date_range(
start=x_df["timestamp"].iloc[-1],
periods=forecast_bars + 1,
freq=freq,
)[1:])
samples = []
for _ in range(n_samples):
pred_df = predictor.predict(
df=x_df[["open", "high", "low", "close"]],
x_timestamp=x_df["timestamp"],
y_timestamp=future_ts,
pred_len=forecast_bars,
T=0.3,
top_p=0.5,
sample_count=1,
)
samples.append(pred_df[["open", "high", "low", "close"]].values)
samples = np.array(samples)
mean = samples.mean(axis=0)
upper = np.percentile(samples, 90, axis=0)
lower = np.percentile(samples, 10, axis=0)
return mean, upper, lower, future_ts
def calc_direction(mean_candles: np.ndarray, last_close: float):
final_close = mean_candles[-1, 3]
pct_change = (final_close - last_close) / last_close * 100
if pct_change > 0.10:
return "BULLISH", min(abs(pct_change) * 30, 95)
elif pct_change < -0.10:
return "BEARISH", min(abs(pct_change) * 30, 95)
else:
return "NEUTRAL", max(50 - abs(pct_change) * 150, 10)
def calc_vol_ratio(mean_candles: np.ndarray, hist_df: pd.DataFrame):
pred_ranges = mean_candles[:, 1] - mean_candles[:, 2]
hist_ranges = (hist_df["high"] - hist_df["low"]).tail(len(mean_candles)).values
if hist_ranges.mean() == 0:
return 1.0
return float(pred_ranges.mean() / hist_ranges.mean())
def get_trading_context(vol_ratio: float, timeframe: str, direction: str, confidence: float) -> str:
"""Generate trading context based on backtested rules."""
if timeframe == "4h":
if vol_ratio < 0.6:
return "COMPRESSED VOL β Magic Hour reversion 94.6% reliable (5.4% fail rate). Full size on fade-to-midpoint setups. Expect reversion within 1 bar."
elif vol_ratio < 0.8:
return "LOW VOL β Magic Hour reversion 93.6% reliable (6.4% fail rate). Strong conditions for fade-to-midpoint trades."
elif vol_ratio < 1.2:
return "NORMAL VOL β Magic Hour reversion 90.3% reliable. Standard conditions, trade normally."
else:
return "ELEVATED VOL β Magic Hour reversion drops to 84.5% (15.5% fail rate). Reduce size on reversion trades or wait for compression."
if timeframe == "1h":
if vol_ratio < 1.2:
return f"LOW VOL β Sniper window ELITE zones at 65.9% WR. Full size if ELITE zone (>=18pt NQ / >=5pt ES) aligns with {direction} bias."
elif vol_ratio < 1.5:
return f"NORMAL VOL β Sniper window at 59.8% WR. Trade ELITE zones at full size, GOOD zones at half size."
else:
return f"ELEVATED VOL β Sniper window accuracy drops. Half size only, require ELITE zone confirmation."
if timeframe == "15m":
return f"15m structure β Use to time entries within your active window. {direction} bias with {confidence:.0f}% confidence."
if timeframe == "5m":
return f"5m tactical β Use for precise entry timing. Look for candle confirmation at your level before entering."
return ""
def candles_to_list(arr: np.ndarray, timestamps) -> list[CandleOut]:
out = []
for i, ts in enumerate(timestamps):
out.append(CandleOut(
timestamp=str(ts),
open=round(float(arr[i, 0]), 2),
high=round(float(arr[i, 1]), 2),
low=round(float(arr[i, 2]), 2),
close=round(float(arr[i, 3]), 2),
))
return out
@app.get("/forecast", response_model=ForecastResponse)
async def get_forecast(
instrument: str = Query("NQ", pattern="^(NQ|ES)$"),
timeframe: str = Query("1h", pattern="^(5m|15m|1h|4h)$"),
):
if predictor is None:
raise HTTPException(status_code=503, detail="Model still loading")
# Check cache first
cache_key = f"{instrument}_{timeframe}"
now = time.time()
if cache_key in CACHE:
cached_time, cached_response = CACHE[cache_key]
if now - cached_time < CACHE_TTL:
logger.info(f"Serving cached response for {cache_key} (age: {int(now - cached_time)}s)")
cached_response.cached = True
return cached_response
ticker = TICKER_MAP[instrument]
tf_cfg = TIMEFRAME_MAP[timeframe]
resample = tf_cfg.get("resample")
df = fetch_candles(ticker, tf_cfg["interval"], tf_cfg["period"], resample)
logger.info(f"Fetched {len(df)} candles for {instrument} @ {timeframe}")
mean, upper, lower, future_ts = run_forecast(df, tf_cfg["forecast_bars"])
last_close = float(df["close"].iloc[-1])
direction, confidence = calc_direction(mean, last_close)
vol_ratio = calc_vol_ratio(mean, df)
context = get_trading_context(vol_ratio, timeframe, direction, confidence)
hist_tail = df.tail(50)
historical = [
CandleOut(
timestamp=str(row.timestamp),
open=round(float(row.open), 2),
high=round(float(row.high), 2),
low=round(float(row.low), 2),
close=round(float(row.close), 2),
)
for row in hist_tail.itertuples()
]
response = ForecastResponse(
instrument=instrument,
timeframe=timeframe,
generated_at=datetime.utcnow().isoformat() + "Z",
historical=historical,
forecast_mean=candles_to_list(mean, future_ts),
forecast_upper=candles_to_list(upper, future_ts),
forecast_lower=candles_to_list(lower, future_ts),
direction=direction,
confidence=round(confidence, 1),
volatility_ratio=round(vol_ratio, 2),
trading_context=context,
cached=False,
)
# Store in cache
CACHE[cache_key] = (now, response)
return response
@app.get("/health")
async def health():
return {
"status": "ok",
"model_loaded": predictor is not None,
"cache_keys": list(CACHE.keys()),
"cache_ages": {k: int(time.time() - v[0]) for k, v in CACHE.items()},
} |