daeeeee's picture
feat(ui): Forecast tab pre-fills T/top_p from cached tuning
f414087
Raw
History Blame Contribute Delete
31.6 kB
import os, json, time, sqlite3, threading, requests
from concurrent.futures import ThreadPoolExecutor
from datetime import datetime, timezone
import numpy as np
import pandas as pd
import torch
import gradio as gr
import plotly.graph_objects as go
from plotly.subplots import make_subplots
from model import Kronos, KronosTokenizer, KronosPredictor
import autotune
# ----- Load model once at startup -----
print("Loading Kronos-small on CPU...")
tokenizer = KronosTokenizer.from_pretrained("NeoQuasar/Kronos-Tokenizer-base")
model = Kronos.from_pretrained("NeoQuasar/Kronos-small")
device = "cuda:0" if torch.cuda.is_available() else "cpu"
predictor = KronosPredictor(model, tokenizer, device=device, max_context=512)
print(f"Model loaded on {device}")
FMP_KEY = os.getenv("FMP_API_KEY")
PREDICT_LOCK = threading.Lock()
DB_PATH = os.path.join(os.path.dirname(os.path.abspath(__file__)), "forecasts.db")
DB_LOCK = threading.Lock()
# ----- SQLite cache -----
def init_db():
with sqlite3.connect(DB_PATH) as c:
c.execute("""CREATE TABLE IF NOT EXISTS forecasts (
id INTEGER PRIMARY KEY AUTOINCREMENT,
ts TEXT NOT NULL,
symbol TEXT NOT NULL,
interval TEXT NOT NULL,
lookback INTEGER NOT NULL,
pred_len INTEGER NOT NULL,
sample_count INTEGER NOT NULL,
temperature REAL NOT NULL,
top_p REAL NOT NULL,
last_close REAL NOT NULL,
forecast_close REAL NOT NULL,
expected_return_pct REAL NOT NULL,
pred_json TEXT NOT NULL
)""")
autotune.init_tuning_table(DB_PATH)
init_db()
def save_forecast(symbol, interval, lookback, pred_len, sample_count, temperature, top_p,
last_close, forecast_close, expected_return_pct, pred_df):
payload = pred_df.reset_index().rename(columns={"index": "timestamps"})
payload["timestamps"] = pd.to_datetime(payload["timestamps"]).astype(str)
pred_json = payload.to_json(orient="records")
ts = datetime.now(timezone.utc).isoformat(timespec="seconds")
with DB_LOCK, sqlite3.connect(DB_PATH) as c:
c.execute("""INSERT INTO forecasts
(ts, symbol, interval, lookback, pred_len, sample_count, temperature, top_p,
last_close, forecast_close, expected_return_pct, pred_json)
VALUES (?,?,?,?,?,?,?,?,?,?,?,?)""",
(ts, symbol, interval, int(lookback), int(pred_len), int(sample_count),
float(temperature), float(top_p), float(last_close), float(forecast_close),
float(expected_return_pct), pred_json))
def list_forecasts(limit=200):
with sqlite3.connect(DB_PATH) as c:
rows = c.execute("""SELECT id, ts, symbol, interval, lookback, pred_len, sample_count,
ROUND(last_close, 4), ROUND(forecast_close, 4),
ROUND(expected_return_pct, 3)
FROM forecasts ORDER BY id DESC LIMIT ?""", (limit,)).fetchall()
return pd.DataFrame(rows, columns=[
"id", "ts", "symbol", "interval", "lookback", "pred_len", "sample_count",
"last_close", "forecast_close", "expected_return_pct"
])
def load_forecast(forecast_id):
with sqlite3.connect(DB_PATH) as c:
row = c.execute("SELECT * FROM forecasts WHERE id = ?", (int(forecast_id),)).fetchone()
return row
# ----- Data fetch -----
def fetch_fmp(symbol: str, interval: str, n_bars: int) -> pd.DataFrame:
url = (f"https://financialmodelingprep.com/api/v3/historical-chart/"
f"{interval}/{symbol}?apikey={FMP_KEY}")
r = requests.get(url, timeout=30); r.raise_for_status()
df = pd.DataFrame(r.json()).rename(columns={"date": "timestamps"})
if df.empty:
raise gr.Error(f"No data from FMP for {symbol} at {interval}")
df["timestamps"] = pd.to_datetime(df["timestamps"])
df = df.sort_values("timestamps").reset_index(drop=True)
df["amount"] = df["close"] * df["volume"]
df = df[["timestamps","open","high","low","close","volume","amount"]]
return df.tail(n_bars).reset_index(drop=True)
def fetch_fmp_safe(symbol, interval, n_bars):
try:
return symbol, fetch_fmp(symbol, interval, n_bars), None
except Exception as e:
return symbol, None, str(e)
# ----- Forecast helpers -----
def _percentiles_from_samples(sample_dfs):
samples = np.stack([d.values for d in sample_dfs], axis=0) # (S, T, F)
cols = sample_dfs[0].columns.tolist()
p10 = np.percentile(samples, 10, axis=0)
p50 = np.percentile(samples, 50, axis=0)
p90 = np.percentile(samples, 90, axis=0)
mean_v = samples.mean(axis=0)
return cols, p10, p50, p90, mean_v, samples
def _build_chart(df, pred_index, p10_close, p50_close, p90_close, vol_mean,
title, sample_count_label=None):
title_full = title if sample_count_label is None else f"{title} (MC, n={sample_count_label})"
fig = make_subplots(rows=2, cols=1, shared_xaxes=True,
row_heights=[0.75, 0.25], vertical_spacing=0.03,
subplot_titles=(title_full, "Volume"))
fig.add_trace(go.Candlestick(x=df["timestamps"], open=df["open"], high=df["high"],
low=df["low"], close=df["close"], name="History"), row=1, col=1)
fig.add_trace(go.Scatter(x=pred_index, y=p10_close, mode="lines",
line=dict(width=0, color="rgba(0,206,209,0)"),
name="P10", showlegend=False, hoverinfo="skip"), row=1, col=1)
fig.add_trace(go.Scatter(x=pred_index, y=p90_close, mode="lines",
line=dict(width=0, color="rgba(0,206,209,0)"),
fill="tonexty", fillcolor="rgba(0,206,209,0.22)",
name="P10–P90 close"), row=1, col=1)
fig.add_trace(go.Scatter(x=pred_index, y=p50_close, mode="lines",
line=dict(color="#00CED1", width=2, dash="dash"),
name="P50 close"), row=1, col=1)
fig.add_trace(go.Bar(x=df["timestamps"], y=df["volume"], name="Vol", marker_color="#888"), row=2, col=1)
fig.add_trace(go.Scatter(x=pred_index, y=vol_mean, mode="lines",
line=dict(color="#FFD700", width=2),
name="Vol mean (fcst)"), row=2, col=1)
fig.update_layout(height=700, template="plotly_dark", xaxis_rangeslider_visible=False,
showlegend=True, margin=dict(l=20, r=20, t=50, b=20))
return fig
def run_forecast(symbol, interval, lookback, pred_len, sample_count, temperature, top_p,
persist=True):
lookback, pred_len, sample_count = int(lookback), int(pred_len), int(sample_count)
if lookback + pred_len > 512:
raise gr.Error(f"lookback + pred_len must be ≤ 512 (got {lookback+pred_len})")
symbol = symbol.upper()
df = fetch_fmp(symbol, interval, lookback + 10)
df = df.tail(lookback).reset_index(drop=True)
step = df["timestamps"].diff().median()
y_timestamp = pd.Series([df["timestamps"].iloc[-1] + step*(i+1) for i in range(pred_len)])
x_df = df[["open","high","low","close","volume","amount"]]
x_timestamp = df["timestamps"]
with PREDICT_LOCK:
sample_dfs = predictor.predict(
df=x_df, x_timestamp=x_timestamp, y_timestamp=y_timestamp,
pred_len=pred_len, T=temperature, top_p=top_p,
sample_count=sample_count, verbose=False, return_samples=True,
)
cols, p10, p50, p90, mean_v, _ = _percentiles_from_samples(sample_dfs)
close_idx = cols.index("close")
vol_idx = cols.index("volume")
pred_df = pd.DataFrame(p50, columns=cols, index=y_timestamp)
pred_df.index = y_timestamp
fig = _build_chart(df, pred_df.index, p10[:, close_idx], p50[:, close_idx], p90[:, close_idx],
mean_v[:, vol_idx],
title=f"{symbol} {interval} — Kronos Forecast",
sample_count_label=sample_count)
last_close = float(df["close"].iloc[-1])
forecast_close = float(pred_df["close"].iloc[-1])
expected_return_pct = (forecast_close / last_close - 1.0) * 100.0
summary = pd.DataFrame({
"Metric": ["Last close", "Forecast close", "Expected change %", "Forecast high", "Forecast low"],
"Value": [
f"{last_close:.2f}",
f"{forecast_close:.2f}",
f"{expected_return_pct:+.2f}%",
f"{pred_df['high'].max():.2f}",
f"{pred_df['low'].min():.2f}",
]
})
pred_out = pred_df.reset_index().rename(columns={"index": "timestamps"})
if persist:
try:
save_forecast(symbol, interval, lookback, pred_len, sample_count, temperature, top_p,
last_close, forecast_close, expected_return_pct, pred_df)
except Exception as e:
print(f"[cache] save failed: {e}")
return fig, summary, pred_out
# ----- Watchlist -----
_SPARK_BARS = "▁▂▃▄▅▆▇█"
def _sparkline_text(prices, target_len=24):
arr = np.asarray(prices, dtype=float)
if arr.size < 2 or not np.all(np.isfinite(arr)):
return ""
if arr.size > target_len:
idx = np.linspace(0, arr.size - 1, target_len).astype(int)
arr = arr[idx]
pmin, pmax = float(arr.min()), float(arr.max())
rng = max(pmax - pmin, 1e-9)
bins = np.clip(((arr - pmin) / rng * (len(_SPARK_BARS) - 1)).astype(int), 0, len(_SPARK_BARS) - 1)
arrow = "▲" if arr[-1] >= arr[0] else "▼"
return arrow + " " + "".join(_SPARK_BARS[i] for i in bins)
def run_watchlist(symbols_csv, interval, lookback, pred_len):
symbols = [s.strip().upper() for s in (symbols_csv or "").split(",") if s.strip()]
if not symbols:
raise gr.Error("Provide at least one symbol")
lookback, pred_len = int(lookback), int(pred_len)
if lookback + pred_len > 512:
raise gr.Error(f"lookback + pred_len must be ≤ 512 (got {lookback+pred_len})")
with ThreadPoolExecutor(max_workers=min(8, len(symbols))) as ex:
fetched = list(ex.map(lambda s: fetch_fmp_safe(s, interval, lookback + 10), symbols))
df_list, x_ts_list, y_ts_list, valid = [], [], [], []
errors = []
for sym, df, err in fetched:
if err is not None or df is None or len(df) < lookback:
errors.append(f"{sym}: {err or 'insufficient data'}")
continue
df = df.tail(lookback).reset_index(drop=True)
step = df["timestamps"].diff().median()
y_ts = pd.Series([df["timestamps"].iloc[-1] + step*(i+1) for i in range(pred_len)])
df_list.append(df[["open","high","low","close","volume","amount"]])
x_ts_list.append(df["timestamps"])
y_ts_list.append(y_ts)
valid.append((sym, df))
if not df_list:
raise gr.Error("No valid symbols fetched. " + "; ".join(errors))
with PREDICT_LOCK:
per_symbol_samples = predictor.predict_batch(
df_list=df_list, x_timestamp_list=x_ts_list, y_timestamp_list=y_ts_list,
pred_len=pred_len, T=1.0, top_p=0.9,
sample_count=30, verbose=False, return_samples=True,
)
rows = []
for (sym, hist_df), sample_dfs in zip(valid, per_symbol_samples):
cols, p10, p50, p90, _, _ = _percentiles_from_samples(sample_dfs)
ci = cols.index("close")
last_close = float(hist_df["close"].iloc[-1])
p50_close = p50[:, ci]
forecast_close = float(p50_close[-1])
expected_return_pct = (forecast_close / last_close - 1.0) * 100.0
log_rets = np.diff(np.log(np.maximum(p50_close, 1e-9)))
forecast_vol = float(np.std(log_rets) * 100.0)
spread = (p90[:, ci] - p10[:, ci]) / np.maximum(np.abs(p50_close), 1e-9)
confidence = float(np.clip(1.0 - float(np.mean(spread)), 0.0, 1.0))
rows.append([
sym,
round(last_close, 4),
round(forecast_close, 4),
round(expected_return_pct, 3),
round(forecast_vol, 3),
round(confidence, 3),
_sparkline_text(p50_close),
])
rows.sort(key=lambda r: r[3], reverse=True)
out = pd.DataFrame(rows, columns=[
"symbol", "last_close", "forecast_close", "expected_return_pct",
"forecast_volatility", "kronos_confidence_score", "sparkline"
])
note = "" if not errors else f"Skipped: {'; '.join(errors)}"
return out, note
# ----- History tab -----
def reopen_from_history(history_df: pd.DataFrame, evt: gr.SelectData):
if history_df is None or len(history_df) == 0:
raise gr.Error("History empty")
row_idx = evt.index[0] if isinstance(evt.index, (list, tuple)) else evt.index
fid = int(history_df.iloc[row_idx]["id"])
rec = load_forecast(fid)
if rec is None:
raise gr.Error(f"Forecast #{fid} not found")
(_id, ts, symbol, interval, lookback, pred_len, sample_count, T, top_p,
last_close, forecast_close, expected_return_pct, pred_json) = rec
pred_records = json.loads(pred_json)
pred_df = pd.DataFrame(pred_records)
pred_df["timestamps"] = pd.to_datetime(pred_df["timestamps"])
pred_df = pred_df.set_index("timestamps")
df = fetch_fmp(symbol, interval, int(lookback) + 10).tail(int(lookback)).reset_index(drop=True)
fig = make_subplots(rows=2, cols=1, shared_xaxes=True,
row_heights=[0.75, 0.25], vertical_spacing=0.03,
subplot_titles=(f"#{fid} {symbol} {interval} — stored {ts}", "Volume"))
fig.add_trace(go.Candlestick(x=df["timestamps"], open=df["open"], high=df["high"],
low=df["low"], close=df["close"], name="History (live)"), row=1, col=1)
fig.add_trace(go.Scatter(x=pred_df.index, y=pred_df["close"], mode="lines",
line=dict(color="#00CED1", width=2, dash="dash"),
name="P50 close (stored)"), row=1, col=1)
fig.add_trace(go.Bar(x=df["timestamps"], y=df["volume"], name="Vol", marker_color="#888"), row=2, col=1)
fig.add_trace(go.Scatter(x=pred_df.index, y=pred_df["volume"], mode="lines",
line=dict(color="#FFD700", width=2),
name="Vol mean (stored)"), row=2, col=1)
fig.update_layout(height=700, template="plotly_dark", xaxis_rangeslider_visible=False,
showlegend=True, margin=dict(l=20, r=20, t=50, b=20))
summary = pd.DataFrame({
"Field": ["id", "stored_at", "symbol", "interval", "lookback", "pred_len",
"sample_count", "T", "top_p", "last_close", "forecast_close",
"expected_return_pct"],
"Value": [fid, ts, symbol, interval, int(lookback), int(pred_len), int(sample_count),
T, top_p, last_close, forecast_close, f"{expected_return_pct:+.3f}%"],
})
return fig, summary
# ----- Live BTC tab -----
LIVE_REFRESH_SEC = 60
LIVE_DEFAULTS = dict(symbol="BTCUSD", interval="5min", lookback=392, pred_len=120,
sample_count=5, temperature=1.0, top_p=0.9)
def live_refresh():
fig, summary, _ = run_forecast(
LIVE_DEFAULTS["symbol"], LIVE_DEFAULTS["interval"], LIVE_DEFAULTS["lookback"],
LIVE_DEFAULTS["pred_len"], LIVE_DEFAULTS["sample_count"],
LIVE_DEFAULTS["temperature"], LIVE_DEFAULTS["top_p"], persist=False,
)
return fig, summary, time.time()
def live_countdown(last_ts):
if not last_ts:
return f"⏳ Waiting for first refresh… (auto every {LIVE_REFRESH_SEC}s)"
rem = max(0, LIVE_REFRESH_SEC - int(time.time() - float(last_ts)))
return f"⏳ Next refresh in **{rem}s** (auto every {LIVE_REFRESH_SEC}s)"
# ----- Backtest -----
def run_backtest(symbol, interval, start_date, end_date,
lookback, pred_len, stride, sample_count, max_anchors):
def _predict(**kwargs):
with PREDICT_LOCK:
return predictor.predict(**kwargs)
try:
core = autotune.backtest_core(
_predict, fetch_fmp,
symbol=symbol, interval=interval,
start_date=start_date, end_date=end_date,
lookback=lookback, pred_len=pred_len, stride=stride,
T=1.0, top_p=0.9,
sample_count=sample_count, max_anchors=max_anchors,
)
except ValueError as e:
raise gr.Error(str(e))
df = core["per_anchor"]
fig = make_subplots(
rows=3, cols=1, shared_xaxes=True,
subplot_titles=("RMSE per anchor",
"Cumulative directional hit rate",
f"Cumulative P&L (long-if-up-else-short, "
f"{autotune.BACKTEST_COST_BP:g}bp/trade)"),
vertical_spacing=0.07,
)
fig.add_trace(go.Scatter(x=df["anchor_ts"], y=df["rmse"], mode="lines+markers",
line=dict(color="#FF6B6B"), name="RMSE"), row=1, col=1)
fig.add_trace(go.Scatter(x=df["anchor_ts"], y=df["hit_rate_running"] * 100.0,
mode="lines+markers", line=dict(color="#FFD700"),
name="Hit %"), row=2, col=1)
fig.add_trace(go.Scatter(x=df["anchor_ts"], y=df["cum_pnl"] * 100.0,
mode="lines+markers", line=dict(color="#00CED1"),
name="Cum P&L %", fill="tozeroy",
fillcolor="rgba(0,206,209,0.15)"), row=3, col=1)
fig.update_yaxes(title_text="USD", row=1, col=1)
fig.update_yaxes(title_text="%", row=2, col=1)
fig.update_yaxes(title_text="%", row=3, col=1)
fig.update_layout(height=720, template="plotly_dark", showlegend=False,
margin=dict(l=20, r=20, t=50, b=20))
summary = pd.DataFrame({
"Metric": ["Symbol", "Interval", "Anchors", "Lookback", "Pred len", "Stride",
"Sample count", "Mean RMSE", "Final hit rate %",
"Total return %", "Max drawdown %"],
"Value": [
symbol.upper(), interval, core["anchors"],
int(lookback), int(pred_len), int(stride), int(sample_count),
f"{core['mean_rmse']:.4f}",
f"{core['hit_rate'] * 100.0:.2f}",
f"{core['total_return_pct']:+.2f}",
f"{core['max_dd_pct']:.2f}",
],
})
out_table = df.copy()
out_table["anchor_ts"] = out_table["anchor_ts"].astype(str)
out_table["last_close"] = out_table["last_close"].round(4)
out_table["forecast_close"] = out_table["forecast_close"].round(4)
out_table["realized_close"] = out_table["realized_close"].round(4)
out_table["rmse"] = out_table["rmse"].round(4)
out_table["trade_pnl"] = (out_table["trade_pnl"] * 100.0).round(3)
out_table["cum_pnl"] = (out_table["cum_pnl"] * 100.0).round(3)
out_table["hit_rate_running"] = (out_table["hit_rate_running"] * 100.0).round(2)
return fig, summary, out_table
def run_autotune_ui(symbol, interval, start_date, end_date,
lookback, pred_len, stride, sample_count, max_anchors):
def _predict(**kwargs):
with PREDICT_LOCK:
return predictor.predict(**kwargs)
try:
return autotune.run_autotune(
predict_fn=_predict, fetch_fn=fetch_fmp, db_path=DB_PATH,
symbol=symbol, interval=interval,
start_date=start_date, end_date=end_date,
lookback=lookback, pred_len=pred_len, stride=stride,
sample_count=sample_count, max_anchors=max_anchors,
)
except (ValueError, RuntimeError) as e:
raise gr.Error(str(e))
# ----- UI -----
with gr.Blocks(title="Kronos Forecast Dashboard", theme=gr.themes.Soft()) as demo:
gr.Markdown("# 📈 Kronos Financial Forecast Dashboard\nZero-shot OHLCV forecasting powered by [Kronos](https://github.com/shiyu-coder/Kronos) + FMP market data.")
with gr.Tabs():
# ---- Forecast tab ----
with gr.Tab("Forecast"):
with gr.Row():
with gr.Column(scale=1):
symbol = gr.Textbox("SPY", label="Symbol (e.g. SPY, AAPL, BTCUSD, EURUSD)")
interval = gr.Dropdown(["1min","5min","15min","30min","1hour","4hour"], value="5min", label="Interval")
lookback = gr.Slider(100, 400, 392, step=1, label="Lookback bars")
pred_len = gr.Slider(12, 120, 120, step=1, label="Forecast bars")
sample_count = gr.Slider(1, 30, 5, step=1, label="Monte Carlo samples (cpu-basic: 30 ≈ 4 min)")
temperature = gr.Slider(0.1, 2.0, 1.0, step=0.1, label="Temperature (T)")
top_p = gr.Slider(0.1, 1.0, 0.9, step=0.05, label="Top-p")
use_tuned = gr.Checkbox(value=True,
label="🎯 Use tuned defaults (if cached)")
run_btn = gr.Button("🚀 Forecast", variant="primary")
with gr.Column(scale=3):
chart_out = gr.Plot(label="Forecast Chart")
summary_out = gr.Dataframe(label="Summary", interactive=False)
pred_out = gr.Dataframe(label="Forecasted Bars", interactive=False)
run_btn.click(run_forecast,
inputs=[symbol, interval, lookback, pred_len, sample_count, temperature, top_p],
outputs=[chart_out, summary_out, pred_out],
concurrency_id="predictor", concurrency_limit=1)
def _apply_tuning_to_sliders(symbol_val, interval_val, use_val):
t_val, p_val = autotune.apply_tuning(DB_PATH, symbol_val, interval_val, use_val)
t_update = gr.update(value=t_val) if t_val is not None else gr.update()
p_update = gr.update(value=p_val) if p_val is not None else gr.update()
return t_update, p_update
symbol.change(_apply_tuning_to_sliders,
inputs=[symbol, interval, use_tuned],
outputs=[temperature, top_p])
interval.change(_apply_tuning_to_sliders,
inputs=[symbol, interval, use_tuned],
outputs=[temperature, top_p])
use_tuned.change(_apply_tuning_to_sliders,
inputs=[symbol, interval, use_tuned],
outputs=[temperature, top_p])
# ---- Watchlist tab ----
with gr.Tab("Watchlist"):
with gr.Row():
with gr.Column(scale=1):
wl_symbols = gr.Textbox("SPY, QQQ, AAPL, MSFT, NVDA, BTCUSD",
label="Symbols (comma-separated)")
wl_interval = gr.Dropdown(["1min","5min","15min","30min","1hour","4hour"],
value="5min", label="Interval")
wl_lookback = gr.Slider(100, 400, 200, step=1, label="Lookback bars")
wl_predlen = gr.Slider(12, 120, 60, step=1, label="Forecast bars")
wl_btn = gr.Button("📊 Run Watchlist", variant="primary")
wl_note = gr.Markdown("")
with gr.Column(scale=3):
wl_table = gr.Dataframe(
label="Watchlist (sorted by expected return ↓)",
headers=["symbol","last_close","forecast_close","expected_return_pct",
"forecast_volatility","kronos_confidence_score","sparkline"],
datatype=["str","number","number","number","number","number","str"],
interactive=False,
)
wl_btn.click(run_watchlist,
inputs=[wl_symbols, wl_interval, wl_lookback, wl_predlen],
outputs=[wl_table, wl_note],
concurrency_id="predictor", concurrency_limit=1)
# ---- History tab ----
with gr.Tab("History"):
with gr.Row():
with gr.Column(scale=1):
hist_refresh = gr.Button("🔄 Refresh", variant="secondary")
hist_summary = gr.Dataframe(label="Selected forecast", interactive=False)
with gr.Column(scale=3):
hist_table = gr.Dataframe(
label="Recent forecasts (click row to re-open chart)",
headers=["id","ts","symbol","interval","lookback","pred_len",
"sample_count","last_close","forecast_close","expected_return_pct"],
datatype=["number","str","str","str","number","number","number",
"number","number","number"],
interactive=False,
)
hist_chart = gr.Plot(label="Stored forecast vs current history")
hist_refresh.click(list_forecasts, outputs=hist_table)
demo.load(list_forecasts, outputs=hist_table)
hist_table.select(reopen_from_history, inputs=hist_table,
outputs=[hist_chart, hist_summary],
concurrency_id="predictor", concurrency_limit=1)
# ---- Live tab ----
with gr.Tab("Live BTC/USDT (5m)"):
gr.Markdown(f"Mimics the public Kronos demo — refresh BTCUSD {LIVE_DEFAULTS['interval']} every {LIVE_REFRESH_SEC}s.")
with gr.Row():
live_btn = gr.Button("⟳ Refresh now", variant="primary")
live_auto = gr.Checkbox(value=False, label=f"🟢 Auto-refresh every {LIVE_REFRESH_SEC}s")
live_status = gr.Markdown(value=live_countdown(0))
live_chart = gr.Plot(label="Live forecast")
live_summary = gr.Dataframe(label="Summary", interactive=False)
live_last_ts = gr.State(value=0.0)
live_timer = gr.Timer(value=LIVE_REFRESH_SEC, active=False)
countdown_timer = gr.Timer(value=1, active=False)
live_btn.click(live_refresh,
outputs=[live_chart, live_summary, live_last_ts],
concurrency_id="predictor", concurrency_limit=1)
def _toggle_auto(enabled):
return gr.Timer(active=bool(enabled)), gr.Timer(active=bool(enabled))
live_auto.change(_toggle_auto, inputs=live_auto, outputs=[live_timer, countdown_timer])
live_timer.tick(live_refresh,
outputs=[live_chart, live_summary, live_last_ts])
countdown_timer.tick(live_countdown, inputs=live_last_ts, outputs=live_status)
# ---- Backtest tab ----
with gr.Tab("Backtest"):
gr.Markdown(
"Walk-forward Kronos through a date range. At each anchor, forecast "
"`pred_len` bars ahead and compare to realized close. Stride=`pred_len` "
"gives non-overlapping windows; stride=1 is bar-by-bar (slow on cpu-basic). "
"`Max anchors` caps total work; if more anchors fit the window, they are "
"evenly sub-sampled."
)
with gr.Row():
with gr.Column(scale=1):
bt_symbol = gr.Textbox("SPY", label="Symbol")
bt_interval = gr.Dropdown(["1min","5min","15min","30min","1hour","4hour"],
value="5min", label="Interval")
bt_start = gr.Textbox("", label="Start (YYYY-MM-DD HH:MM, blank=earliest)")
bt_end = gr.Textbox("", label="End (YYYY-MM-DD HH:MM, blank=latest)")
bt_lookback = gr.Slider(100, 400, 200, step=1, label="Lookback bars")
bt_predlen = gr.Slider(5, 120, 30, step=1, label="Forecast horizon (pred_len)")
bt_stride = gr.Slider(1, 120, 30, step=1,
label="Stride (1 = bar-by-bar, slow)")
bt_samples = gr.Slider(1, 10, 1, step=1, label="MC samples per step")
bt_max = gr.Slider(5, 100, 20, step=1,
label="Max anchors (caps total work)")
bt_btn = gr.Button("🧪 Run backtest", variant="primary")
with gr.Column(scale=3):
bt_chart = gr.Plot(label="Backtest metrics over time")
bt_summary = gr.Dataframe(label="Summary", interactive=False)
bt_table = gr.Dataframe(label="Per-anchor results", interactive=False)
bt_btn.click(
run_backtest,
inputs=[bt_symbol, bt_interval, bt_start, bt_end,
bt_lookback, bt_predlen, bt_stride, bt_samples, bt_max],
outputs=[bt_chart, bt_summary, bt_table],
concurrency_id="predictor", concurrency_limit=1,
)
# ---- Auto-tune tab ----
with gr.Tab("Auto-tune"):
gr.Markdown(
"Search Kronos sampling knobs (T, top_p) on a 3×3 grid, score each "
"cell with a walk-forward backtest, and persist the best "
"(T, top_p) per (symbol, interval). The Forecast tab can then "
"pre-fill those defaults via the 🎯 checkbox."
)
with gr.Row():
with gr.Column(scale=1):
at_symbol = gr.Textbox("SPY", label="Symbol")
at_interval = gr.Dropdown(["1min","5min","15min","30min","1hour","4hour"],
value="5min", label="Interval")
at_start = gr.Textbox("", label="Start (YYYY-MM-DD HH:MM, blank=earliest)")
at_end = gr.Textbox("", label="End (YYYY-MM-DD HH:MM, blank=latest)")
at_lookback = gr.Slider(100, 400, 200, step=1, label="Lookback bars")
at_predlen = gr.Slider(5, 120, 30, step=1, label="Forecast horizon (pred_len)")
at_stride = gr.Slider(1, 120, 30, step=1, label="Stride")
at_samples = gr.Slider(1, 5, 1, step=1, label="MC samples per step")
at_max = gr.Slider(3, 30, 5, step=1,
label="Max anchors per cell (caps work; 9 cells × this)")
at_btn = gr.Button("🎯 Run auto-tune", variant="primary")
with gr.Column(scale=3):
at_chart = gr.Plot(label="Heatmap: P&L % over (T, top_p)")
at_summary = gr.Dataframe(label="Best parameters", interactive=False)
at_table = gr.Dataframe(label="Per-cell results", interactive=False)
at_btn.click(
run_autotune_ui,
inputs=[at_symbol, at_interval, at_start, at_end,
at_lookback, at_predlen, at_stride, at_samples, at_max],
outputs=[at_chart, at_summary, at_table],
concurrency_id="predictor", concurrency_limit=1,
)
demo.queue(default_concurrency_limit=10, max_size=32)
if __name__ == "__main__":
on_hf = os.getenv("SPACE_ID") is not None
demo.launch(
server_name="0.0.0.0" if on_hf else "127.0.0.1",
server_port=7860,
share=False,
inbrowser=not on_hf,
)