Spaces:
Running
Running
| import os, json, time, sqlite3, threading, requests | |
| from concurrent.futures import ThreadPoolExecutor | |
| from datetime import datetime, timezone | |
| import numpy as np | |
| import pandas as pd | |
| import torch | |
| import gradio as gr | |
| import plotly.graph_objects as go | |
| from plotly.subplots import make_subplots | |
| from model import Kronos, KronosTokenizer, KronosPredictor | |
| import autotune | |
| # ----- Load model once at startup ----- | |
| print("Loading Kronos-small on CPU...") | |
| tokenizer = KronosTokenizer.from_pretrained("NeoQuasar/Kronos-Tokenizer-base") | |
| model = Kronos.from_pretrained("NeoQuasar/Kronos-small") | |
| device = "cuda:0" if torch.cuda.is_available() else "cpu" | |
| predictor = KronosPredictor(model, tokenizer, device=device, max_context=512) | |
| print(f"Model loaded on {device}") | |
| FMP_KEY = os.getenv("FMP_API_KEY") | |
| PREDICT_LOCK = threading.Lock() | |
| DB_PATH = os.path.join(os.path.dirname(os.path.abspath(__file__)), "forecasts.db") | |
| DB_LOCK = threading.Lock() | |
| # ----- SQLite cache ----- | |
| def init_db(): | |
| with sqlite3.connect(DB_PATH) as c: | |
| c.execute("""CREATE TABLE IF NOT EXISTS forecasts ( | |
| id INTEGER PRIMARY KEY AUTOINCREMENT, | |
| ts TEXT NOT NULL, | |
| symbol TEXT NOT NULL, | |
| interval TEXT NOT NULL, | |
| lookback INTEGER NOT NULL, | |
| pred_len INTEGER NOT NULL, | |
| sample_count INTEGER NOT NULL, | |
| temperature REAL NOT NULL, | |
| top_p REAL NOT NULL, | |
| last_close REAL NOT NULL, | |
| forecast_close REAL NOT NULL, | |
| expected_return_pct REAL NOT NULL, | |
| pred_json TEXT NOT NULL | |
| )""") | |
| autotune.init_tuning_table(DB_PATH) | |
| init_db() | |
| def save_forecast(symbol, interval, lookback, pred_len, sample_count, temperature, top_p, | |
| last_close, forecast_close, expected_return_pct, pred_df): | |
| payload = pred_df.reset_index().rename(columns={"index": "timestamps"}) | |
| payload["timestamps"] = pd.to_datetime(payload["timestamps"]).astype(str) | |
| pred_json = payload.to_json(orient="records") | |
| ts = datetime.now(timezone.utc).isoformat(timespec="seconds") | |
| with DB_LOCK, sqlite3.connect(DB_PATH) as c: | |
| c.execute("""INSERT INTO forecasts | |
| (ts, symbol, interval, lookback, pred_len, sample_count, temperature, top_p, | |
| last_close, forecast_close, expected_return_pct, pred_json) | |
| VALUES (?,?,?,?,?,?,?,?,?,?,?,?)""", | |
| (ts, symbol, interval, int(lookback), int(pred_len), int(sample_count), | |
| float(temperature), float(top_p), float(last_close), float(forecast_close), | |
| float(expected_return_pct), pred_json)) | |
| def list_forecasts(limit=200): | |
| with sqlite3.connect(DB_PATH) as c: | |
| rows = c.execute("""SELECT id, ts, symbol, interval, lookback, pred_len, sample_count, | |
| ROUND(last_close, 4), ROUND(forecast_close, 4), | |
| ROUND(expected_return_pct, 3) | |
| FROM forecasts ORDER BY id DESC LIMIT ?""", (limit,)).fetchall() | |
| return pd.DataFrame(rows, columns=[ | |
| "id", "ts", "symbol", "interval", "lookback", "pred_len", "sample_count", | |
| "last_close", "forecast_close", "expected_return_pct" | |
| ]) | |
| def load_forecast(forecast_id): | |
| with sqlite3.connect(DB_PATH) as c: | |
| row = c.execute("SELECT * FROM forecasts WHERE id = ?", (int(forecast_id),)).fetchone() | |
| return row | |
| # ----- Data fetch ----- | |
| def fetch_fmp(symbol: str, interval: str, n_bars: int) -> pd.DataFrame: | |
| url = (f"https://financialmodelingprep.com/api/v3/historical-chart/" | |
| f"{interval}/{symbol}?apikey={FMP_KEY}") | |
| r = requests.get(url, timeout=30); r.raise_for_status() | |
| df = pd.DataFrame(r.json()).rename(columns={"date": "timestamps"}) | |
| if df.empty: | |
| raise gr.Error(f"No data from FMP for {symbol} at {interval}") | |
| df["timestamps"] = pd.to_datetime(df["timestamps"]) | |
| df = df.sort_values("timestamps").reset_index(drop=True) | |
| df["amount"] = df["close"] * df["volume"] | |
| df = df[["timestamps","open","high","low","close","volume","amount"]] | |
| return df.tail(n_bars).reset_index(drop=True) | |
| def fetch_fmp_safe(symbol, interval, n_bars): | |
| try: | |
| return symbol, fetch_fmp(symbol, interval, n_bars), None | |
| except Exception as e: | |
| return symbol, None, str(e) | |
| # ----- Forecast helpers ----- | |
| def _percentiles_from_samples(sample_dfs): | |
| samples = np.stack([d.values for d in sample_dfs], axis=0) # (S, T, F) | |
| cols = sample_dfs[0].columns.tolist() | |
| p10 = np.percentile(samples, 10, axis=0) | |
| p50 = np.percentile(samples, 50, axis=0) | |
| p90 = np.percentile(samples, 90, axis=0) | |
| mean_v = samples.mean(axis=0) | |
| return cols, p10, p50, p90, mean_v, samples | |
| def _build_chart(df, pred_index, p10_close, p50_close, p90_close, vol_mean, | |
| title, sample_count_label=None): | |
| title_full = title if sample_count_label is None else f"{title} (MC, n={sample_count_label})" | |
| fig = make_subplots(rows=2, cols=1, shared_xaxes=True, | |
| row_heights=[0.75, 0.25], vertical_spacing=0.03, | |
| subplot_titles=(title_full, "Volume")) | |
| fig.add_trace(go.Candlestick(x=df["timestamps"], open=df["open"], high=df["high"], | |
| low=df["low"], close=df["close"], name="History"), row=1, col=1) | |
| fig.add_trace(go.Scatter(x=pred_index, y=p10_close, mode="lines", | |
| line=dict(width=0, color="rgba(0,206,209,0)"), | |
| name="P10", showlegend=False, hoverinfo="skip"), row=1, col=1) | |
| fig.add_trace(go.Scatter(x=pred_index, y=p90_close, mode="lines", | |
| line=dict(width=0, color="rgba(0,206,209,0)"), | |
| fill="tonexty", fillcolor="rgba(0,206,209,0.22)", | |
| name="P10–P90 close"), row=1, col=1) | |
| fig.add_trace(go.Scatter(x=pred_index, y=p50_close, mode="lines", | |
| line=dict(color="#00CED1", width=2, dash="dash"), | |
| name="P50 close"), row=1, col=1) | |
| fig.add_trace(go.Bar(x=df["timestamps"], y=df["volume"], name="Vol", marker_color="#888"), row=2, col=1) | |
| fig.add_trace(go.Scatter(x=pred_index, y=vol_mean, mode="lines", | |
| line=dict(color="#FFD700", width=2), | |
| name="Vol mean (fcst)"), row=2, col=1) | |
| fig.update_layout(height=700, template="plotly_dark", xaxis_rangeslider_visible=False, | |
| showlegend=True, margin=dict(l=20, r=20, t=50, b=20)) | |
| return fig | |
| def run_forecast(symbol, interval, lookback, pred_len, sample_count, temperature, top_p, | |
| persist=True): | |
| lookback, pred_len, sample_count = int(lookback), int(pred_len), int(sample_count) | |
| if lookback + pred_len > 512: | |
| raise gr.Error(f"lookback + pred_len must be ≤ 512 (got {lookback+pred_len})") | |
| symbol = symbol.upper() | |
| df = fetch_fmp(symbol, interval, lookback + 10) | |
| df = df.tail(lookback).reset_index(drop=True) | |
| step = df["timestamps"].diff().median() | |
| y_timestamp = pd.Series([df["timestamps"].iloc[-1] + step*(i+1) for i in range(pred_len)]) | |
| x_df = df[["open","high","low","close","volume","amount"]] | |
| x_timestamp = df["timestamps"] | |
| with PREDICT_LOCK: | |
| sample_dfs = predictor.predict( | |
| df=x_df, x_timestamp=x_timestamp, y_timestamp=y_timestamp, | |
| pred_len=pred_len, T=temperature, top_p=top_p, | |
| sample_count=sample_count, verbose=False, return_samples=True, | |
| ) | |
| cols, p10, p50, p90, mean_v, _ = _percentiles_from_samples(sample_dfs) | |
| close_idx = cols.index("close") | |
| vol_idx = cols.index("volume") | |
| pred_df = pd.DataFrame(p50, columns=cols, index=y_timestamp) | |
| pred_df.index = y_timestamp | |
| fig = _build_chart(df, pred_df.index, p10[:, close_idx], p50[:, close_idx], p90[:, close_idx], | |
| mean_v[:, vol_idx], | |
| title=f"{symbol} {interval} — Kronos Forecast", | |
| sample_count_label=sample_count) | |
| last_close = float(df["close"].iloc[-1]) | |
| forecast_close = float(pred_df["close"].iloc[-1]) | |
| expected_return_pct = (forecast_close / last_close - 1.0) * 100.0 | |
| summary = pd.DataFrame({ | |
| "Metric": ["Last close", "Forecast close", "Expected change %", "Forecast high", "Forecast low"], | |
| "Value": [ | |
| f"{last_close:.2f}", | |
| f"{forecast_close:.2f}", | |
| f"{expected_return_pct:+.2f}%", | |
| f"{pred_df['high'].max():.2f}", | |
| f"{pred_df['low'].min():.2f}", | |
| ] | |
| }) | |
| pred_out = pred_df.reset_index().rename(columns={"index": "timestamps"}) | |
| if persist: | |
| try: | |
| save_forecast(symbol, interval, lookback, pred_len, sample_count, temperature, top_p, | |
| last_close, forecast_close, expected_return_pct, pred_df) | |
| except Exception as e: | |
| print(f"[cache] save failed: {e}") | |
| return fig, summary, pred_out | |
| # ----- Watchlist ----- | |
| _SPARK_BARS = "▁▂▃▄▅▆▇█" | |
| def _sparkline_text(prices, target_len=24): | |
| arr = np.asarray(prices, dtype=float) | |
| if arr.size < 2 or not np.all(np.isfinite(arr)): | |
| return "" | |
| if arr.size > target_len: | |
| idx = np.linspace(0, arr.size - 1, target_len).astype(int) | |
| arr = arr[idx] | |
| pmin, pmax = float(arr.min()), float(arr.max()) | |
| rng = max(pmax - pmin, 1e-9) | |
| bins = np.clip(((arr - pmin) / rng * (len(_SPARK_BARS) - 1)).astype(int), 0, len(_SPARK_BARS) - 1) | |
| arrow = "▲" if arr[-1] >= arr[0] else "▼" | |
| return arrow + " " + "".join(_SPARK_BARS[i] for i in bins) | |
| def run_watchlist(symbols_csv, interval, lookback, pred_len): | |
| symbols = [s.strip().upper() for s in (symbols_csv or "").split(",") if s.strip()] | |
| if not symbols: | |
| raise gr.Error("Provide at least one symbol") | |
| lookback, pred_len = int(lookback), int(pred_len) | |
| if lookback + pred_len > 512: | |
| raise gr.Error(f"lookback + pred_len must be ≤ 512 (got {lookback+pred_len})") | |
| with ThreadPoolExecutor(max_workers=min(8, len(symbols))) as ex: | |
| fetched = list(ex.map(lambda s: fetch_fmp_safe(s, interval, lookback + 10), symbols)) | |
| df_list, x_ts_list, y_ts_list, valid = [], [], [], [] | |
| errors = [] | |
| for sym, df, err in fetched: | |
| if err is not None or df is None or len(df) < lookback: | |
| errors.append(f"{sym}: {err or 'insufficient data'}") | |
| continue | |
| df = df.tail(lookback).reset_index(drop=True) | |
| step = df["timestamps"].diff().median() | |
| y_ts = pd.Series([df["timestamps"].iloc[-1] + step*(i+1) for i in range(pred_len)]) | |
| df_list.append(df[["open","high","low","close","volume","amount"]]) | |
| x_ts_list.append(df["timestamps"]) | |
| y_ts_list.append(y_ts) | |
| valid.append((sym, df)) | |
| if not df_list: | |
| raise gr.Error("No valid symbols fetched. " + "; ".join(errors)) | |
| with PREDICT_LOCK: | |
| per_symbol_samples = predictor.predict_batch( | |
| df_list=df_list, x_timestamp_list=x_ts_list, y_timestamp_list=y_ts_list, | |
| pred_len=pred_len, T=1.0, top_p=0.9, | |
| sample_count=30, verbose=False, return_samples=True, | |
| ) | |
| rows = [] | |
| for (sym, hist_df), sample_dfs in zip(valid, per_symbol_samples): | |
| cols, p10, p50, p90, _, _ = _percentiles_from_samples(sample_dfs) | |
| ci = cols.index("close") | |
| last_close = float(hist_df["close"].iloc[-1]) | |
| p50_close = p50[:, ci] | |
| forecast_close = float(p50_close[-1]) | |
| expected_return_pct = (forecast_close / last_close - 1.0) * 100.0 | |
| log_rets = np.diff(np.log(np.maximum(p50_close, 1e-9))) | |
| forecast_vol = float(np.std(log_rets) * 100.0) | |
| spread = (p90[:, ci] - p10[:, ci]) / np.maximum(np.abs(p50_close), 1e-9) | |
| confidence = float(np.clip(1.0 - float(np.mean(spread)), 0.0, 1.0)) | |
| rows.append([ | |
| sym, | |
| round(last_close, 4), | |
| round(forecast_close, 4), | |
| round(expected_return_pct, 3), | |
| round(forecast_vol, 3), | |
| round(confidence, 3), | |
| _sparkline_text(p50_close), | |
| ]) | |
| rows.sort(key=lambda r: r[3], reverse=True) | |
| out = pd.DataFrame(rows, columns=[ | |
| "symbol", "last_close", "forecast_close", "expected_return_pct", | |
| "forecast_volatility", "kronos_confidence_score", "sparkline" | |
| ]) | |
| note = "" if not errors else f"Skipped: {'; '.join(errors)}" | |
| return out, note | |
| # ----- History tab ----- | |
| def reopen_from_history(history_df: pd.DataFrame, evt: gr.SelectData): | |
| if history_df is None or len(history_df) == 0: | |
| raise gr.Error("History empty") | |
| row_idx = evt.index[0] if isinstance(evt.index, (list, tuple)) else evt.index | |
| fid = int(history_df.iloc[row_idx]["id"]) | |
| rec = load_forecast(fid) | |
| if rec is None: | |
| raise gr.Error(f"Forecast #{fid} not found") | |
| (_id, ts, symbol, interval, lookback, pred_len, sample_count, T, top_p, | |
| last_close, forecast_close, expected_return_pct, pred_json) = rec | |
| pred_records = json.loads(pred_json) | |
| pred_df = pd.DataFrame(pred_records) | |
| pred_df["timestamps"] = pd.to_datetime(pred_df["timestamps"]) | |
| pred_df = pred_df.set_index("timestamps") | |
| df = fetch_fmp(symbol, interval, int(lookback) + 10).tail(int(lookback)).reset_index(drop=True) | |
| fig = make_subplots(rows=2, cols=1, shared_xaxes=True, | |
| row_heights=[0.75, 0.25], vertical_spacing=0.03, | |
| subplot_titles=(f"#{fid} {symbol} {interval} — stored {ts}", "Volume")) | |
| fig.add_trace(go.Candlestick(x=df["timestamps"], open=df["open"], high=df["high"], | |
| low=df["low"], close=df["close"], name="History (live)"), row=1, col=1) | |
| fig.add_trace(go.Scatter(x=pred_df.index, y=pred_df["close"], mode="lines", | |
| line=dict(color="#00CED1", width=2, dash="dash"), | |
| name="P50 close (stored)"), row=1, col=1) | |
| fig.add_trace(go.Bar(x=df["timestamps"], y=df["volume"], name="Vol", marker_color="#888"), row=2, col=1) | |
| fig.add_trace(go.Scatter(x=pred_df.index, y=pred_df["volume"], mode="lines", | |
| line=dict(color="#FFD700", width=2), | |
| name="Vol mean (stored)"), row=2, col=1) | |
| fig.update_layout(height=700, template="plotly_dark", xaxis_rangeslider_visible=False, | |
| showlegend=True, margin=dict(l=20, r=20, t=50, b=20)) | |
| summary = pd.DataFrame({ | |
| "Field": ["id", "stored_at", "symbol", "interval", "lookback", "pred_len", | |
| "sample_count", "T", "top_p", "last_close", "forecast_close", | |
| "expected_return_pct"], | |
| "Value": [fid, ts, symbol, interval, int(lookback), int(pred_len), int(sample_count), | |
| T, top_p, last_close, forecast_close, f"{expected_return_pct:+.3f}%"], | |
| }) | |
| return fig, summary | |
| # ----- Live BTC tab ----- | |
| LIVE_REFRESH_SEC = 60 | |
| LIVE_DEFAULTS = dict(symbol="BTCUSD", interval="5min", lookback=392, pred_len=120, | |
| sample_count=5, temperature=1.0, top_p=0.9) | |
| def live_refresh(): | |
| fig, summary, _ = run_forecast( | |
| LIVE_DEFAULTS["symbol"], LIVE_DEFAULTS["interval"], LIVE_DEFAULTS["lookback"], | |
| LIVE_DEFAULTS["pred_len"], LIVE_DEFAULTS["sample_count"], | |
| LIVE_DEFAULTS["temperature"], LIVE_DEFAULTS["top_p"], persist=False, | |
| ) | |
| return fig, summary, time.time() | |
| def live_countdown(last_ts): | |
| if not last_ts: | |
| return f"⏳ Waiting for first refresh… (auto every {LIVE_REFRESH_SEC}s)" | |
| rem = max(0, LIVE_REFRESH_SEC - int(time.time() - float(last_ts))) | |
| return f"⏳ Next refresh in **{rem}s** (auto every {LIVE_REFRESH_SEC}s)" | |
| # ----- Backtest ----- | |
| def run_backtest(symbol, interval, start_date, end_date, | |
| lookback, pred_len, stride, sample_count, max_anchors): | |
| def _predict(**kwargs): | |
| with PREDICT_LOCK: | |
| return predictor.predict(**kwargs) | |
| try: | |
| core = autotune.backtest_core( | |
| _predict, fetch_fmp, | |
| symbol=symbol, interval=interval, | |
| start_date=start_date, end_date=end_date, | |
| lookback=lookback, pred_len=pred_len, stride=stride, | |
| T=1.0, top_p=0.9, | |
| sample_count=sample_count, max_anchors=max_anchors, | |
| ) | |
| except ValueError as e: | |
| raise gr.Error(str(e)) | |
| df = core["per_anchor"] | |
| fig = make_subplots( | |
| rows=3, cols=1, shared_xaxes=True, | |
| subplot_titles=("RMSE per anchor", | |
| "Cumulative directional hit rate", | |
| f"Cumulative P&L (long-if-up-else-short, " | |
| f"{autotune.BACKTEST_COST_BP:g}bp/trade)"), | |
| vertical_spacing=0.07, | |
| ) | |
| fig.add_trace(go.Scatter(x=df["anchor_ts"], y=df["rmse"], mode="lines+markers", | |
| line=dict(color="#FF6B6B"), name="RMSE"), row=1, col=1) | |
| fig.add_trace(go.Scatter(x=df["anchor_ts"], y=df["hit_rate_running"] * 100.0, | |
| mode="lines+markers", line=dict(color="#FFD700"), | |
| name="Hit %"), row=2, col=1) | |
| fig.add_trace(go.Scatter(x=df["anchor_ts"], y=df["cum_pnl"] * 100.0, | |
| mode="lines+markers", line=dict(color="#00CED1"), | |
| name="Cum P&L %", fill="tozeroy", | |
| fillcolor="rgba(0,206,209,0.15)"), row=3, col=1) | |
| fig.update_yaxes(title_text="USD", row=1, col=1) | |
| fig.update_yaxes(title_text="%", row=2, col=1) | |
| fig.update_yaxes(title_text="%", row=3, col=1) | |
| fig.update_layout(height=720, template="plotly_dark", showlegend=False, | |
| margin=dict(l=20, r=20, t=50, b=20)) | |
| summary = pd.DataFrame({ | |
| "Metric": ["Symbol", "Interval", "Anchors", "Lookback", "Pred len", "Stride", | |
| "Sample count", "Mean RMSE", "Final hit rate %", | |
| "Total return %", "Max drawdown %"], | |
| "Value": [ | |
| symbol.upper(), interval, core["anchors"], | |
| int(lookback), int(pred_len), int(stride), int(sample_count), | |
| f"{core['mean_rmse']:.4f}", | |
| f"{core['hit_rate'] * 100.0:.2f}", | |
| f"{core['total_return_pct']:+.2f}", | |
| f"{core['max_dd_pct']:.2f}", | |
| ], | |
| }) | |
| out_table = df.copy() | |
| out_table["anchor_ts"] = out_table["anchor_ts"].astype(str) | |
| out_table["last_close"] = out_table["last_close"].round(4) | |
| out_table["forecast_close"] = out_table["forecast_close"].round(4) | |
| out_table["realized_close"] = out_table["realized_close"].round(4) | |
| out_table["rmse"] = out_table["rmse"].round(4) | |
| out_table["trade_pnl"] = (out_table["trade_pnl"] * 100.0).round(3) | |
| out_table["cum_pnl"] = (out_table["cum_pnl"] * 100.0).round(3) | |
| out_table["hit_rate_running"] = (out_table["hit_rate_running"] * 100.0).round(2) | |
| return fig, summary, out_table | |
| def run_autotune_ui(symbol, interval, start_date, end_date, | |
| lookback, pred_len, stride, sample_count, max_anchors): | |
| def _predict(**kwargs): | |
| with PREDICT_LOCK: | |
| return predictor.predict(**kwargs) | |
| try: | |
| return autotune.run_autotune( | |
| predict_fn=_predict, fetch_fn=fetch_fmp, db_path=DB_PATH, | |
| symbol=symbol, interval=interval, | |
| start_date=start_date, end_date=end_date, | |
| lookback=lookback, pred_len=pred_len, stride=stride, | |
| sample_count=sample_count, max_anchors=max_anchors, | |
| ) | |
| except (ValueError, RuntimeError) as e: | |
| raise gr.Error(str(e)) | |
| # ----- UI ----- | |
| with gr.Blocks(title="Kronos Forecast Dashboard", theme=gr.themes.Soft()) as demo: | |
| gr.Markdown("# 📈 Kronos Financial Forecast Dashboard\nZero-shot OHLCV forecasting powered by [Kronos](https://github.com/shiyu-coder/Kronos) + FMP market data.") | |
| with gr.Tabs(): | |
| # ---- Forecast tab ---- | |
| with gr.Tab("Forecast"): | |
| with gr.Row(): | |
| with gr.Column(scale=1): | |
| symbol = gr.Textbox("SPY", label="Symbol (e.g. SPY, AAPL, BTCUSD, EURUSD)") | |
| interval = gr.Dropdown(["1min","5min","15min","30min","1hour","4hour"], value="5min", label="Interval") | |
| lookback = gr.Slider(100, 400, 392, step=1, label="Lookback bars") | |
| pred_len = gr.Slider(12, 120, 120, step=1, label="Forecast bars") | |
| sample_count = gr.Slider(1, 30, 5, step=1, label="Monte Carlo samples (cpu-basic: 30 ≈ 4 min)") | |
| temperature = gr.Slider(0.1, 2.0, 1.0, step=0.1, label="Temperature (T)") | |
| top_p = gr.Slider(0.1, 1.0, 0.9, step=0.05, label="Top-p") | |
| use_tuned = gr.Checkbox(value=True, | |
| label="🎯 Use tuned defaults (if cached)") | |
| run_btn = gr.Button("🚀 Forecast", variant="primary") | |
| with gr.Column(scale=3): | |
| chart_out = gr.Plot(label="Forecast Chart") | |
| summary_out = gr.Dataframe(label="Summary", interactive=False) | |
| pred_out = gr.Dataframe(label="Forecasted Bars", interactive=False) | |
| run_btn.click(run_forecast, | |
| inputs=[symbol, interval, lookback, pred_len, sample_count, temperature, top_p], | |
| outputs=[chart_out, summary_out, pred_out], | |
| concurrency_id="predictor", concurrency_limit=1) | |
| def _apply_tuning_to_sliders(symbol_val, interval_val, use_val): | |
| t_val, p_val = autotune.apply_tuning(DB_PATH, symbol_val, interval_val, use_val) | |
| t_update = gr.update(value=t_val) if t_val is not None else gr.update() | |
| p_update = gr.update(value=p_val) if p_val is not None else gr.update() | |
| return t_update, p_update | |
| symbol.change(_apply_tuning_to_sliders, | |
| inputs=[symbol, interval, use_tuned], | |
| outputs=[temperature, top_p]) | |
| interval.change(_apply_tuning_to_sliders, | |
| inputs=[symbol, interval, use_tuned], | |
| outputs=[temperature, top_p]) | |
| use_tuned.change(_apply_tuning_to_sliders, | |
| inputs=[symbol, interval, use_tuned], | |
| outputs=[temperature, top_p]) | |
| # ---- Watchlist tab ---- | |
| with gr.Tab("Watchlist"): | |
| with gr.Row(): | |
| with gr.Column(scale=1): | |
| wl_symbols = gr.Textbox("SPY, QQQ, AAPL, MSFT, NVDA, BTCUSD", | |
| label="Symbols (comma-separated)") | |
| wl_interval = gr.Dropdown(["1min","5min","15min","30min","1hour","4hour"], | |
| value="5min", label="Interval") | |
| wl_lookback = gr.Slider(100, 400, 200, step=1, label="Lookback bars") | |
| wl_predlen = gr.Slider(12, 120, 60, step=1, label="Forecast bars") | |
| wl_btn = gr.Button("📊 Run Watchlist", variant="primary") | |
| wl_note = gr.Markdown("") | |
| with gr.Column(scale=3): | |
| wl_table = gr.Dataframe( | |
| label="Watchlist (sorted by expected return ↓)", | |
| headers=["symbol","last_close","forecast_close","expected_return_pct", | |
| "forecast_volatility","kronos_confidence_score","sparkline"], | |
| datatype=["str","number","number","number","number","number","str"], | |
| interactive=False, | |
| ) | |
| wl_btn.click(run_watchlist, | |
| inputs=[wl_symbols, wl_interval, wl_lookback, wl_predlen], | |
| outputs=[wl_table, wl_note], | |
| concurrency_id="predictor", concurrency_limit=1) | |
| # ---- History tab ---- | |
| with gr.Tab("History"): | |
| with gr.Row(): | |
| with gr.Column(scale=1): | |
| hist_refresh = gr.Button("🔄 Refresh", variant="secondary") | |
| hist_summary = gr.Dataframe(label="Selected forecast", interactive=False) | |
| with gr.Column(scale=3): | |
| hist_table = gr.Dataframe( | |
| label="Recent forecasts (click row to re-open chart)", | |
| headers=["id","ts","symbol","interval","lookback","pred_len", | |
| "sample_count","last_close","forecast_close","expected_return_pct"], | |
| datatype=["number","str","str","str","number","number","number", | |
| "number","number","number"], | |
| interactive=False, | |
| ) | |
| hist_chart = gr.Plot(label="Stored forecast vs current history") | |
| hist_refresh.click(list_forecasts, outputs=hist_table) | |
| demo.load(list_forecasts, outputs=hist_table) | |
| hist_table.select(reopen_from_history, inputs=hist_table, | |
| outputs=[hist_chart, hist_summary], | |
| concurrency_id="predictor", concurrency_limit=1) | |
| # ---- Live tab ---- | |
| with gr.Tab("Live BTC/USDT (5m)"): | |
| gr.Markdown(f"Mimics the public Kronos demo — refresh BTCUSD {LIVE_DEFAULTS['interval']} every {LIVE_REFRESH_SEC}s.") | |
| with gr.Row(): | |
| live_btn = gr.Button("⟳ Refresh now", variant="primary") | |
| live_auto = gr.Checkbox(value=False, label=f"🟢 Auto-refresh every {LIVE_REFRESH_SEC}s") | |
| live_status = gr.Markdown(value=live_countdown(0)) | |
| live_chart = gr.Plot(label="Live forecast") | |
| live_summary = gr.Dataframe(label="Summary", interactive=False) | |
| live_last_ts = gr.State(value=0.0) | |
| live_timer = gr.Timer(value=LIVE_REFRESH_SEC, active=False) | |
| countdown_timer = gr.Timer(value=1, active=False) | |
| live_btn.click(live_refresh, | |
| outputs=[live_chart, live_summary, live_last_ts], | |
| concurrency_id="predictor", concurrency_limit=1) | |
| def _toggle_auto(enabled): | |
| return gr.Timer(active=bool(enabled)), gr.Timer(active=bool(enabled)) | |
| live_auto.change(_toggle_auto, inputs=live_auto, outputs=[live_timer, countdown_timer]) | |
| live_timer.tick(live_refresh, | |
| outputs=[live_chart, live_summary, live_last_ts]) | |
| countdown_timer.tick(live_countdown, inputs=live_last_ts, outputs=live_status) | |
| # ---- Backtest tab ---- | |
| with gr.Tab("Backtest"): | |
| gr.Markdown( | |
| "Walk-forward Kronos through a date range. At each anchor, forecast " | |
| "`pred_len` bars ahead and compare to realized close. Stride=`pred_len` " | |
| "gives non-overlapping windows; stride=1 is bar-by-bar (slow on cpu-basic). " | |
| "`Max anchors` caps total work; if more anchors fit the window, they are " | |
| "evenly sub-sampled." | |
| ) | |
| with gr.Row(): | |
| with gr.Column(scale=1): | |
| bt_symbol = gr.Textbox("SPY", label="Symbol") | |
| bt_interval = gr.Dropdown(["1min","5min","15min","30min","1hour","4hour"], | |
| value="5min", label="Interval") | |
| bt_start = gr.Textbox("", label="Start (YYYY-MM-DD HH:MM, blank=earliest)") | |
| bt_end = gr.Textbox("", label="End (YYYY-MM-DD HH:MM, blank=latest)") | |
| bt_lookback = gr.Slider(100, 400, 200, step=1, label="Lookback bars") | |
| bt_predlen = gr.Slider(5, 120, 30, step=1, label="Forecast horizon (pred_len)") | |
| bt_stride = gr.Slider(1, 120, 30, step=1, | |
| label="Stride (1 = bar-by-bar, slow)") | |
| bt_samples = gr.Slider(1, 10, 1, step=1, label="MC samples per step") | |
| bt_max = gr.Slider(5, 100, 20, step=1, | |
| label="Max anchors (caps total work)") | |
| bt_btn = gr.Button("🧪 Run backtest", variant="primary") | |
| with gr.Column(scale=3): | |
| bt_chart = gr.Plot(label="Backtest metrics over time") | |
| bt_summary = gr.Dataframe(label="Summary", interactive=False) | |
| bt_table = gr.Dataframe(label="Per-anchor results", interactive=False) | |
| bt_btn.click( | |
| run_backtest, | |
| inputs=[bt_symbol, bt_interval, bt_start, bt_end, | |
| bt_lookback, bt_predlen, bt_stride, bt_samples, bt_max], | |
| outputs=[bt_chart, bt_summary, bt_table], | |
| concurrency_id="predictor", concurrency_limit=1, | |
| ) | |
| # ---- Auto-tune tab ---- | |
| with gr.Tab("Auto-tune"): | |
| gr.Markdown( | |
| "Search Kronos sampling knobs (T, top_p) on a 3×3 grid, score each " | |
| "cell with a walk-forward backtest, and persist the best " | |
| "(T, top_p) per (symbol, interval). The Forecast tab can then " | |
| "pre-fill those defaults via the 🎯 checkbox." | |
| ) | |
| with gr.Row(): | |
| with gr.Column(scale=1): | |
| at_symbol = gr.Textbox("SPY", label="Symbol") | |
| at_interval = gr.Dropdown(["1min","5min","15min","30min","1hour","4hour"], | |
| value="5min", label="Interval") | |
| at_start = gr.Textbox("", label="Start (YYYY-MM-DD HH:MM, blank=earliest)") | |
| at_end = gr.Textbox("", label="End (YYYY-MM-DD HH:MM, blank=latest)") | |
| at_lookback = gr.Slider(100, 400, 200, step=1, label="Lookback bars") | |
| at_predlen = gr.Slider(5, 120, 30, step=1, label="Forecast horizon (pred_len)") | |
| at_stride = gr.Slider(1, 120, 30, step=1, label="Stride") | |
| at_samples = gr.Slider(1, 5, 1, step=1, label="MC samples per step") | |
| at_max = gr.Slider(3, 30, 5, step=1, | |
| label="Max anchors per cell (caps work; 9 cells × this)") | |
| at_btn = gr.Button("🎯 Run auto-tune", variant="primary") | |
| with gr.Column(scale=3): | |
| at_chart = gr.Plot(label="Heatmap: P&L % over (T, top_p)") | |
| at_summary = gr.Dataframe(label="Best parameters", interactive=False) | |
| at_table = gr.Dataframe(label="Per-cell results", interactive=False) | |
| at_btn.click( | |
| run_autotune_ui, | |
| inputs=[at_symbol, at_interval, at_start, at_end, | |
| at_lookback, at_predlen, at_stride, at_samples, at_max], | |
| outputs=[at_chart, at_summary, at_table], | |
| concurrency_id="predictor", concurrency_limit=1, | |
| ) | |
| demo.queue(default_concurrency_limit=10, max_size=32) | |
| if __name__ == "__main__": | |
| on_hf = os.getenv("SPACE_ID") is not None | |
| demo.launch( | |
| server_name="0.0.0.0" if on_hf else "127.0.0.1", | |
| server_port=7860, | |
| share=False, | |
| inbrowser=not on_hf, | |
| ) | |