Spaces:
Sleeping
Sleeping
| """ | |
| Kronos Forecast API β HuggingFace Spaces Backend | |
| Fine-tuned on 17 years of NQ + ES data. v4.1. | |
| Added caching to avoid yfinance rate limits. | |
| """ | |
| import os | |
| import sys | |
| import logging | |
| import time | |
| from datetime import datetime | |
| from typing import Optional | |
| import numpy as np | |
| import pandas as pd | |
| import torch | |
| import yfinance as yf | |
| sys.path.insert(0, "/app/kronos") | |
| from model import Kronos, KronosTokenizer, KronosPredictor | |
| from fastapi import FastAPI, HTTPException, Query | |
| from fastapi.middleware.cors import CORSMiddleware | |
| from pydantic import BaseModel | |
| from safetensors.torch import load_file | |
| from huggingface_hub import hf_hub_download | |
| logging.basicConfig(level=logging.INFO) | |
| logger = logging.getLogger("kronos-api") | |
| app = FastAPI(title="Kronos Forecast API", version="4.1.0") | |
| app.add_middleware( | |
| CORSMiddleware, | |
| allow_origins=["*"], | |
| allow_methods=["GET"], | |
| allow_headers=["*"], | |
| ) | |
| predictor: Optional[KronosPredictor] = None | |
| # βββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| # CACHE β store responses for 5 minutes to avoid rate limits | |
| # βββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| CACHE = {} | |
| CACHE_TTL = 300 # 5 minutes | |
| TICKER_MAP = {"NQ": "NQ=F", "ES": "ES=F"} | |
| TIMEFRAME_MAP = { | |
| "5m": {"interval": "5m", "period": "5d", "forecast_bars": 12}, | |
| "15m": {"interval": "15m", "period": "30d", "forecast_bars": 8}, | |
| "1h": {"interval": "1h", "period": "60d", "forecast_bars": 8}, | |
| "4h": {"interval": "1h", "period": "60d", "forecast_bars": 8, "resample": "4h"}, | |
| } | |
| async def load_model(): | |
| global predictor | |
| hf_token = os.environ.get("HF_TOKEN") | |
| logger.info("Loading fine-tuned Kronos NQ+ES v2 model β¦") | |
| tokenizer = KronosTokenizer.from_pretrained("NeoQuasar/Kronos-Tokenizer-base") | |
| model = Kronos.from_pretrained("NeoQuasar/Kronos-mini") | |
| try: | |
| tok_path = hf_hub_download( | |
| repo_id="Jenak5/Kronos-NQ-ES-Tokenizer-v2", | |
| filename="model.safetensors", | |
| token=hf_token, | |
| ) | |
| tokenizer.load_state_dict(load_file(tok_path), strict=False) | |
| logger.info("Fine-tuned tokenizer weights loaded.") | |
| except Exception as e: | |
| logger.warning(f"Tokenizer weights not loaded: {e}") | |
| try: | |
| pred_path = hf_hub_download( | |
| repo_id="Jenak5/Kronos-NQ-ES-v2", | |
| filename="model.safetensors", | |
| token=hf_token, | |
| ) | |
| model.load_state_dict(load_file(pred_path), strict=False) | |
| logger.info("Fine-tuned predictor weights loaded.") | |
| except Exception as e: | |
| logger.warning(f"Predictor weights not loaded: {e}") | |
| predictor = KronosPredictor(model, tokenizer, device="cpu", max_context=512) | |
| logger.info("Kronos NQ+ES v4.1 ready.") | |
| class CandleOut(BaseModel): | |
| timestamp: str | |
| open: float | |
| high: float | |
| low: float | |
| close: float | |
| class ForecastResponse(BaseModel): | |
| instrument: str | |
| timeframe: str | |
| generated_at: str | |
| historical: list[CandleOut] | |
| forecast_mean: list[CandleOut] | |
| forecast_upper: list[CandleOut] | |
| forecast_lower: list[CandleOut] | |
| direction: str | |
| confidence: float | |
| volatility_ratio: float | |
| trading_context: str | |
| cached: bool = False | |
| def fetch_candles(ticker: str, interval: str, period: str, resample: str = None) -> pd.DataFrame: | |
| raw = yf.download(ticker, period=period, interval=interval, progress=False) | |
| if raw.empty: | |
| raise HTTPException(status_code=502, detail=f"No data for {ticker}") | |
| df = raw.reset_index() | |
| if hasattr(df.columns, 'levels'): | |
| df.columns = [c[0] if isinstance(c, tuple) else c for c in df.columns] | |
| rename = {c: c.lower() for c in df.columns} | |
| df.rename(columns=rename, inplace=True) | |
| if "datetime" in df.columns: | |
| df.rename(columns={"datetime": "timestamp"}, inplace=True) | |
| elif "date" in df.columns: | |
| df.rename(columns={"date": "timestamp"}, inplace=True) | |
| df["timestamp"] = pd.to_datetime(df["timestamp"]) | |
| df = df[["timestamp", "open", "high", "low", "close"]].dropna() | |
| if resample: | |
| df = df.set_index("timestamp") | |
| df = df.resample(resample).agg( | |
| {"open": "first", "high": "max", "low": "min", "close": "last"} | |
| ).dropna().reset_index() | |
| return df | |
| def run_forecast(df: pd.DataFrame, forecast_bars: int, n_samples: int = 10): | |
| lookback = min(len(df), 400) | |
| x_df = df.tail(lookback).reset_index(drop=True) | |
| freq = pd.infer_freq(x_df["timestamp"]) | |
| if freq is None: | |
| delta = x_df["timestamp"].iloc[-1] - x_df["timestamp"].iloc[-2] | |
| future_ts = pd.Series([x_df["timestamp"].iloc[-1] + delta * (i + 1) for i in range(forecast_bars)]) | |
| else: | |
| future_ts = pd.Series(pd.date_range( | |
| start=x_df["timestamp"].iloc[-1], | |
| periods=forecast_bars + 1, | |
| freq=freq, | |
| )[1:]) | |
| samples = [] | |
| for _ in range(n_samples): | |
| pred_df = predictor.predict( | |
| df=x_df[["open", "high", "low", "close"]], | |
| x_timestamp=x_df["timestamp"], | |
| y_timestamp=future_ts, | |
| pred_len=forecast_bars, | |
| T=0.3, | |
| top_p=0.5, | |
| sample_count=1, | |
| ) | |
| samples.append(pred_df[["open", "high", "low", "close"]].values) | |
| samples = np.array(samples) | |
| mean = samples.mean(axis=0) | |
| upper = np.percentile(samples, 90, axis=0) | |
| lower = np.percentile(samples, 10, axis=0) | |
| return mean, upper, lower, future_ts | |
| def calc_direction(mean_candles: np.ndarray, last_close: float): | |
| final_close = mean_candles[-1, 3] | |
| pct_change = (final_close - last_close) / last_close * 100 | |
| if pct_change > 0.10: | |
| return "BULLISH", min(abs(pct_change) * 30, 95) | |
| elif pct_change < -0.10: | |
| return "BEARISH", min(abs(pct_change) * 30, 95) | |
| else: | |
| return "NEUTRAL", max(50 - abs(pct_change) * 150, 10) | |
| def calc_vol_ratio(mean_candles: np.ndarray, hist_df: pd.DataFrame): | |
| pred_ranges = mean_candles[:, 1] - mean_candles[:, 2] | |
| hist_ranges = (hist_df["high"] - hist_df["low"]).tail(len(mean_candles)).values | |
| if hist_ranges.mean() == 0: | |
| return 1.0 | |
| return float(pred_ranges.mean() / hist_ranges.mean()) | |
| def get_trading_context(vol_ratio: float, timeframe: str, direction: str, confidence: float) -> str: | |
| """Generate trading context based on backtested rules.""" | |
| if timeframe == "4h": | |
| if vol_ratio < 0.6: | |
| return "COMPRESSED VOL β Magic Hour reversion 94.6% reliable (5.4% fail rate). Full size on fade-to-midpoint setups. Expect reversion within 1 bar." | |
| elif vol_ratio < 0.8: | |
| return "LOW VOL β Magic Hour reversion 93.6% reliable (6.4% fail rate). Strong conditions for fade-to-midpoint trades." | |
| elif vol_ratio < 1.2: | |
| return "NORMAL VOL β Magic Hour reversion 90.3% reliable. Standard conditions, trade normally." | |
| else: | |
| return "ELEVATED VOL β Magic Hour reversion drops to 84.5% (15.5% fail rate). Reduce size on reversion trades or wait for compression." | |
| if timeframe == "1h": | |
| if vol_ratio < 1.2: | |
| return f"LOW VOL β Sniper window ELITE zones at 65.9% WR. Full size if ELITE zone (>=18pt NQ / >=5pt ES) aligns with {direction} bias." | |
| elif vol_ratio < 1.5: | |
| return f"NORMAL VOL β Sniper window at 59.8% WR. Trade ELITE zones at full size, GOOD zones at half size." | |
| else: | |
| return f"ELEVATED VOL β Sniper window accuracy drops. Half size only, require ELITE zone confirmation." | |
| if timeframe == "15m": | |
| return f"15m structure β Use to time entries within your active window. {direction} bias with {confidence:.0f}% confidence." | |
| if timeframe == "5m": | |
| return f"5m tactical β Use for precise entry timing. Look for candle confirmation at your level before entering." | |
| return "" | |
| def candles_to_list(arr: np.ndarray, timestamps) -> list[CandleOut]: | |
| out = [] | |
| for i, ts in enumerate(timestamps): | |
| out.append(CandleOut( | |
| timestamp=str(ts), | |
| open=round(float(arr[i, 0]), 2), | |
| high=round(float(arr[i, 1]), 2), | |
| low=round(float(arr[i, 2]), 2), | |
| close=round(float(arr[i, 3]), 2), | |
| )) | |
| return out | |
| async def get_forecast( | |
| instrument: str = Query("NQ", pattern="^(NQ|ES)$"), | |
| timeframe: str = Query("1h", pattern="^(5m|15m|1h|4h)$"), | |
| ): | |
| if predictor is None: | |
| raise HTTPException(status_code=503, detail="Model still loading") | |
| # Check cache first | |
| cache_key = f"{instrument}_{timeframe}" | |
| now = time.time() | |
| if cache_key in CACHE: | |
| cached_time, cached_response = CACHE[cache_key] | |
| if now - cached_time < CACHE_TTL: | |
| logger.info(f"Serving cached response for {cache_key} (age: {int(now - cached_time)}s)") | |
| cached_response.cached = True | |
| return cached_response | |
| ticker = TICKER_MAP[instrument] | |
| tf_cfg = TIMEFRAME_MAP[timeframe] | |
| resample = tf_cfg.get("resample") | |
| df = fetch_candles(ticker, tf_cfg["interval"], tf_cfg["period"], resample) | |
| logger.info(f"Fetched {len(df)} candles for {instrument} @ {timeframe}") | |
| mean, upper, lower, future_ts = run_forecast(df, tf_cfg["forecast_bars"]) | |
| last_close = float(df["close"].iloc[-1]) | |
| direction, confidence = calc_direction(mean, last_close) | |
| vol_ratio = calc_vol_ratio(mean, df) | |
| context = get_trading_context(vol_ratio, timeframe, direction, confidence) | |
| hist_tail = df.tail(50) | |
| historical = [ | |
| CandleOut( | |
| timestamp=str(row.timestamp), | |
| open=round(float(row.open), 2), | |
| high=round(float(row.high), 2), | |
| low=round(float(row.low), 2), | |
| close=round(float(row.close), 2), | |
| ) | |
| for row in hist_tail.itertuples() | |
| ] | |
| response = ForecastResponse( | |
| instrument=instrument, | |
| timeframe=timeframe, | |
| generated_at=datetime.utcnow().isoformat() + "Z", | |
| historical=historical, | |
| forecast_mean=candles_to_list(mean, future_ts), | |
| forecast_upper=candles_to_list(upper, future_ts), | |
| forecast_lower=candles_to_list(lower, future_ts), | |
| direction=direction, | |
| confidence=round(confidence, 1), | |
| volatility_ratio=round(vol_ratio, 2), | |
| trading_context=context, | |
| cached=False, | |
| ) | |
| # Store in cache | |
| CACHE[cache_key] = (now, response) | |
| return response | |
| async def health(): | |
| return { | |
| "status": "ok", | |
| "model_loaded": predictor is not None, | |
| "cache_keys": list(CACHE.keys()), | |
| "cache_ages": {k: int(time.time() - v[0]) for k, v in CACHE.items()}, | |
| } |