""" Kronos Forecast API — HuggingFace Spaces Backend Fine-tuned on 17 years of NQ + ES data. v4.1. Added caching to avoid yfinance rate limits. """ import os import sys import logging import time from datetime import datetime from typing import Optional import numpy as np import pandas as pd import torch import yfinance as yf sys.path.insert(0, "/app/kronos") from model import Kronos, KronosTokenizer, KronosPredictor from fastapi import FastAPI, HTTPException, Query from fastapi.middleware.cors import CORSMiddleware from pydantic import BaseModel from safetensors.torch import load_file from huggingface_hub import hf_hub_download logging.basicConfig(level=logging.INFO) logger = logging.getLogger("kronos-api") app = FastAPI(title="Kronos Forecast API", version="4.1.0") app.add_middleware( CORSMiddleware, allow_origins=["*"], allow_methods=["GET"], allow_headers=["*"], ) predictor: Optional[KronosPredictor] = None # ═══════════════════════════════════════════════════════ # CACHE — store responses for 5 minutes to avoid rate limits # ═══════════════════════════════════════════════════════ CACHE = {} CACHE_TTL = 300 # 5 minutes TICKER_MAP = {"NQ": "NQ=F", "ES": "ES=F"} TIMEFRAME_MAP = { "5m": {"interval": "5m", "period": "5d", "forecast_bars": 12}, "15m": {"interval": "15m", "period": "30d", "forecast_bars": 8}, "1h": {"interval": "1h", "period": "60d", "forecast_bars": 8}, "4h": {"interval": "1h", "period": "60d", "forecast_bars": 8, "resample": "4h"}, } @app.on_event("startup") async def load_model(): global predictor hf_token = os.environ.get("HF_TOKEN") logger.info("Loading fine-tuned Kronos NQ+ES v2 model …") tokenizer = KronosTokenizer.from_pretrained("NeoQuasar/Kronos-Tokenizer-base") model = Kronos.from_pretrained("NeoQuasar/Kronos-mini") try: tok_path = hf_hub_download( repo_id="Jenak5/Kronos-NQ-ES-Tokenizer-v2", filename="model.safetensors", token=hf_token, ) tokenizer.load_state_dict(load_file(tok_path), strict=False) logger.info("Fine-tuned tokenizer weights loaded.") except Exception as e: logger.warning(f"Tokenizer weights not loaded: {e}") try: pred_path = hf_hub_download( repo_id="Jenak5/Kronos-NQ-ES-v2", filename="model.safetensors", token=hf_token, ) model.load_state_dict(load_file(pred_path), strict=False) logger.info("Fine-tuned predictor weights loaded.") except Exception as e: logger.warning(f"Predictor weights not loaded: {e}") predictor = KronosPredictor(model, tokenizer, device="cpu", max_context=512) logger.info("Kronos NQ+ES v4.1 ready.") class CandleOut(BaseModel): timestamp: str open: float high: float low: float close: float class ForecastResponse(BaseModel): instrument: str timeframe: str generated_at: str historical: list[CandleOut] forecast_mean: list[CandleOut] forecast_upper: list[CandleOut] forecast_lower: list[CandleOut] direction: str confidence: float volatility_ratio: float trading_context: str cached: bool = False def fetch_candles(ticker: str, interval: str, period: str, resample: str = None) -> pd.DataFrame: raw = yf.download(ticker, period=period, interval=interval, progress=False) if raw.empty: raise HTTPException(status_code=502, detail=f"No data for {ticker}") df = raw.reset_index() if hasattr(df.columns, 'levels'): df.columns = [c[0] if isinstance(c, tuple) else c for c in df.columns] rename = {c: c.lower() for c in df.columns} df.rename(columns=rename, inplace=True) if "datetime" in df.columns: df.rename(columns={"datetime": "timestamp"}, inplace=True) elif "date" in df.columns: df.rename(columns={"date": "timestamp"}, inplace=True) df["timestamp"] = pd.to_datetime(df["timestamp"]) df = df[["timestamp", "open", "high", "low", "close"]].dropna() if resample: df = df.set_index("timestamp") df = df.resample(resample).agg( {"open": "first", "high": "max", "low": "min", "close": "last"} ).dropna().reset_index() return df def run_forecast(df: pd.DataFrame, forecast_bars: int, n_samples: int = 10): lookback = min(len(df), 400) x_df = df.tail(lookback).reset_index(drop=True) freq = pd.infer_freq(x_df["timestamp"]) if freq is None: delta = x_df["timestamp"].iloc[-1] - x_df["timestamp"].iloc[-2] future_ts = pd.Series([x_df["timestamp"].iloc[-1] + delta * (i + 1) for i in range(forecast_bars)]) else: future_ts = pd.Series(pd.date_range( start=x_df["timestamp"].iloc[-1], periods=forecast_bars + 1, freq=freq, )[1:]) samples = [] for _ in range(n_samples): pred_df = predictor.predict( df=x_df[["open", "high", "low", "close"]], x_timestamp=x_df["timestamp"], y_timestamp=future_ts, pred_len=forecast_bars, T=0.3, top_p=0.5, sample_count=1, ) samples.append(pred_df[["open", "high", "low", "close"]].values) samples = np.array(samples) mean = samples.mean(axis=0) upper = np.percentile(samples, 90, axis=0) lower = np.percentile(samples, 10, axis=0) return mean, upper, lower, future_ts def calc_direction(mean_candles: np.ndarray, last_close: float): final_close = mean_candles[-1, 3] pct_change = (final_close - last_close) / last_close * 100 if pct_change > 0.10: return "BULLISH", min(abs(pct_change) * 30, 95) elif pct_change < -0.10: return "BEARISH", min(abs(pct_change) * 30, 95) else: return "NEUTRAL", max(50 - abs(pct_change) * 150, 10) def calc_vol_ratio(mean_candles: np.ndarray, hist_df: pd.DataFrame): pred_ranges = mean_candles[:, 1] - mean_candles[:, 2] hist_ranges = (hist_df["high"] - hist_df["low"]).tail(len(mean_candles)).values if hist_ranges.mean() == 0: return 1.0 return float(pred_ranges.mean() / hist_ranges.mean()) def get_trading_context(vol_ratio: float, timeframe: str, direction: str, confidence: float) -> str: """Generate trading context based on backtested rules.""" if timeframe == "4h": if vol_ratio < 0.6: return "COMPRESSED VOL — Magic Hour reversion 94.6% reliable (5.4% fail rate). Full size on fade-to-midpoint setups. Expect reversion within 1 bar." elif vol_ratio < 0.8: return "LOW VOL — Magic Hour reversion 93.6% reliable (6.4% fail rate). Strong conditions for fade-to-midpoint trades." elif vol_ratio < 1.2: return "NORMAL VOL — Magic Hour reversion 90.3% reliable. Standard conditions, trade normally." else: return "ELEVATED VOL — Magic Hour reversion drops to 84.5% (15.5% fail rate). Reduce size on reversion trades or wait for compression." if timeframe == "1h": if vol_ratio < 1.2: return f"LOW VOL — Sniper window ELITE zones at 65.9% WR. Full size if ELITE zone (>=18pt NQ / >=5pt ES) aligns with {direction} bias." elif vol_ratio < 1.5: return f"NORMAL VOL — Sniper window at 59.8% WR. Trade ELITE zones at full size, GOOD zones at half size." else: return f"ELEVATED VOL — Sniper window accuracy drops. Half size only, require ELITE zone confirmation." if timeframe == "15m": return f"15m structure — Use to time entries within your active window. {direction} bias with {confidence:.0f}% confidence." if timeframe == "5m": return f"5m tactical — Use for precise entry timing. Look for candle confirmation at your level before entering." return "" def candles_to_list(arr: np.ndarray, timestamps) -> list[CandleOut]: out = [] for i, ts in enumerate(timestamps): out.append(CandleOut( timestamp=str(ts), open=round(float(arr[i, 0]), 2), high=round(float(arr[i, 1]), 2), low=round(float(arr[i, 2]), 2), close=round(float(arr[i, 3]), 2), )) return out @app.get("/forecast", response_model=ForecastResponse) async def get_forecast( instrument: str = Query("NQ", pattern="^(NQ|ES)$"), timeframe: str = Query("1h", pattern="^(5m|15m|1h|4h)$"), ): if predictor is None: raise HTTPException(status_code=503, detail="Model still loading") # Check cache first cache_key = f"{instrument}_{timeframe}" now = time.time() if cache_key in CACHE: cached_time, cached_response = CACHE[cache_key] if now - cached_time < CACHE_TTL: logger.info(f"Serving cached response for {cache_key} (age: {int(now - cached_time)}s)") cached_response.cached = True return cached_response ticker = TICKER_MAP[instrument] tf_cfg = TIMEFRAME_MAP[timeframe] resample = tf_cfg.get("resample") df = fetch_candles(ticker, tf_cfg["interval"], tf_cfg["period"], resample) logger.info(f"Fetched {len(df)} candles for {instrument} @ {timeframe}") mean, upper, lower, future_ts = run_forecast(df, tf_cfg["forecast_bars"]) last_close = float(df["close"].iloc[-1]) direction, confidence = calc_direction(mean, last_close) vol_ratio = calc_vol_ratio(mean, df) context = get_trading_context(vol_ratio, timeframe, direction, confidence) hist_tail = df.tail(50) historical = [ CandleOut( timestamp=str(row.timestamp), open=round(float(row.open), 2), high=round(float(row.high), 2), low=round(float(row.low), 2), close=round(float(row.close), 2), ) for row in hist_tail.itertuples() ] response = ForecastResponse( instrument=instrument, timeframe=timeframe, generated_at=datetime.utcnow().isoformat() + "Z", historical=historical, forecast_mean=candles_to_list(mean, future_ts), forecast_upper=candles_to_list(upper, future_ts), forecast_lower=candles_to_list(lower, future_ts), direction=direction, confidence=round(confidence, 1), volatility_ratio=round(vol_ratio, 2), trading_context=context, cached=False, ) # Store in cache CACHE[cache_key] = (now, response) return response @app.get("/health") async def health(): return { "status": "ok", "model_loaded": predictor is not None, "cache_keys": list(CACHE.keys()), "cache_ages": {k: int(time.time() - v[0]) for k, v in CACHE.items()}, }