"""
app.py — Sniper Model Evaluator Space
Two modes: Backtester and Model Evaluator.
"""
import sys
import json
import time
import logging
from datetime import datetime, date
from pathlib import Path
import numpy as np
import pandas as pd
import gradio as gr
sys.path.insert(0, str(Path(__file__).parent))
# Must run before any joblib.load — injects SniperModel into __main__
# so pickles trained with __main__.SniperModel deserialize correctly.
from src.sniper_model import patch_main as _patch_main
_patch_main()
from src.registry import (
discover_repos, load_bundle, get_all_bundle_labels,
find_bundle_by_label, ModelRepo, ArtifactBundle,
)
from src.data_loader import (
download_ticker_batch, filter_ticker_data, extract_market_series,
)
from src.backtester import run_backtest, BacktestConfig, PRESETS
from src.evaluator import run_evaluation, DIMENSION_WEIGHTS
from src.charts import (
nav_chart, monthly_returns_heatmap, exit_reasons_chart,
trade_return_distribution, radar_chart, reliability_diagram,
regime_heatmap, feature_psi_chart, multi_model_comparison,
)
logging.basicConfig(level=logging.INFO, format="%(asctime)s [%(levelname)s] %(message)s")
logger = logging.getLogger("SniperApp")
# ---------------------------------------------------------------------------
# Global state
# ---------------------------------------------------------------------------
_repos: list[ModelRepo] = []
_loaded_bundles: dict[str, ArtifactBundle] = {} # label -> loaded bundle
_eval_results: list = [] # for multi-model comparison
_eval_labels: list[str] = []
TICKER_UNIVERSE_PRESETS = {
"All (~800 tickers)": None, # None = use bundle's full universe
"S&P 500 subset (200)": [ # representative 200-ticker subset
"AAPL","MSFT","AMZN","GOOGL","META","NVDA","TSLA","BRK-B","UNH","JNJ",
"XOM","JPM","V","PG","MA","CVX","HD","LLY","ABBV","MRK","PEP","KO",
"BAC","PFE","COST","TMO","AVGO","CSCO","ABT","ACN","WMT","CRM","NEE",
"TXN","DHR","VZ","AMGN","NKE","MDT","UPS","LIN","PM","RTX","QCOM",
"HON","T","SBUX","INTU","AMAT","ISRG","GS","MS","SCHW","AXP","BLK",
"C","USB","TGT","GILD","CVS","ZTS","SYK","ELV","ADP","NOW","SPGI",
"CL","MO","DUK","SO","D","EXC","AEP","SRE","PCG","WEC","ES","XEL",
"LMT","NOC","GD","BA","RTX","HII","TDG","HWM","HEI","LDOS",
"CAT","DE","EMR","ETN","ITW","PH","ROK","DOV","CARR","OTIS",
"F","GM","APTV","BWA","LEA","MGA","AZO","ORLY","GPC","AAP",
"AMZN","EBAY","ETSY","ROST","TJX","LOW","HD","WST","SHW","RPM",
"ECL","PPG","EMN","LYB","APD","CF","MOS","NUE","FCX","FMC",
"BIIB","REGN","VRTX","ILMN","IDXX","MRNA","HOLX","BIO","ABMD",
"MCO","ICE","CME","CBOE","NDAQ","MSCI","FDS","MKTX","BR","AMP",
"TROW","BEN","IVZ","FHI","AMG","SEIC","APAM","VRTS","HLNE","PIPR",
"PLD","EQIX","AMT","CCI","SBAC","ARE","AVB","EQR","MAA","UDR",
"WM","RSG","CWST","CLH","GFL","WCN","SRCL","ECOL","ARIS","US",
"FDX","UPS","XPO","ODFL","SAIA","WERN","JBHT","ARCB","CHRW","EXPD",
],
"Custom (paste tickers)": "custom",
}
GRADE_COLORS = {
"A+": "#00d4aa", "A": "#00d4aa", "A-": "#34d399",
"B+": "#60a5fa", "B": "#60a5fa", "B-": "#93c5fd",
"C+": "#f5a623", "C": "#f5a623", "C-": "#fbbf24",
"D+": "#ff4d6a", "D": "#ff4d6a", "F": "#dc2626",
}
# ---------------------------------------------------------------------------
# Shared helpers
# ---------------------------------------------------------------------------
def _refresh_registry(progress=gr.Progress()):
global _repos
messages = []
def _cb(msg):
messages.append(msg)
progress(0, desc=msg)
progress(0, desc="Scanning HuggingFace for Sniper repositories...")
_repos = discover_repos(progress_cb=_cb)
labels = get_all_bundle_labels(_repos)
if not labels:
return (
gr.Dropdown(choices=[], value=None, label="Model / Run"),
gr.Dropdown(choices=[], value=None, label="Model / Run (Evaluator)"),
"⚠️ No models found. Check that Arkm20/sniper-model-* repos are public.",
)
status = f"✅ Found {len(_repos)} repo(s), {len(labels)} run(s) total.\n" + "\n".join(f" · {l}" for l in labels)
return (
gr.Dropdown(choices=labels, value=labels[0], label="Model / Run"),
gr.Dropdown(choices=labels, value=labels[0], label="Model / Run (Evaluator)"),
status,
)
def _get_or_load_bundle(label: str, progress_cb=None) -> ArtifactBundle | None:
global _loaded_bundles
if label in _loaded_bundles:
return _loaded_bundles[label]
bundle = find_bundle_by_label(_repos, label)
if bundle is None:
return None
bundle = load_bundle(bundle, progress_cb=progress_cb)
_loaded_bundles[label] = bundle
return bundle
def _universe_for_preset(preset_name: str, custom_text: str, bundle) -> list[str]:
if preset_name == "Custom (paste tickers)":
tickers = [t.strip().upper() for t in custom_text.replace(",", " ").split() if t.strip()]
return tickers if tickers else (bundle.metadata.get("configuration", {}).get("TICKER_UNIVERSE", []) or [])
if preset_name == "S&P 500 subset (200)":
return TICKER_UNIVERSE_PRESETS["S&P 500 subset (200)"]
# "All" - use training universe from bundle metadata if available, else hardcoded
return bundle.metadata.get("configuration", {}).get("TICKER_UNIVERSE", []) or []
# ---------------------------------------------------------------------------
# Backtester runner
# ---------------------------------------------------------------------------
def run_backtester(
model_label, bt_start, bt_end, initial_cash,
conviction_threshold, max_positions,
sl_multiplier, pt_multiplier, horizon_days, cooldown_days,
sizing_mode, risk_fraction, transaction_pct,
account_mode, withdrawal_fraction,
use_regime_routing, benchmark,
use_confluence, min_confluence_score,
ticker_preset, custom_tickers,
progress=gr.Progress(),
):
def _cb(msg, frac=None):
progress(frac or 0, desc=msg)
try:
# ---- Load model ----
_cb("Resolving model artifacts from HuggingFace...", 0.01)
bundle = _get_or_load_bundle(model_label, progress_cb=_cb)
if bundle is None or bundle.main_model is None:
return _bt_empty("❌ Could not load model. Check model label and connectivity.")
# ---- Determine universe ----
tickers = _universe_for_preset(ticker_preset, custom_tickers, bundle)
if not tickers:
return _bt_empty("❌ No tickers found for selected universe.")
# Also add benchmark ticker for download
bm_sym = benchmark if benchmark != "None" else None
extra = [bm_sym] if bm_sym else []
# ---- Download data ----
_cb("Downloading market data — batch 1/N...", 0.05)
ticker_data = download_ticker_batch(
tickers + extra,
start=str(bt_start), end=str(bt_end),
progress_cb=_cb,
)
ticker_data = filter_ticker_data(ticker_data, progress_cb=_cb)
# ---- Build config ----
config = BacktestConfig(
start_date=str(bt_start),
end_date=str(bt_end),
initial_cash=float(initial_cash),
conviction_threshold=float(conviction_threshold),
use_regime_routing=bool(use_regime_routing),
max_positions=int(max_positions),
pt_multiplier=float(pt_multiplier),
sl_multiplier=float(sl_multiplier),
atr_period=14,
horizon_days=int(horizon_days),
cooldown_days=int(cooldown_days),
sizing_mode="volatility_adjusted" if sizing_mode == "Volatility-adjusted" else "equal_weight",
risk_fraction=float(risk_fraction) / 100.0,
transaction_pct=float(transaction_pct) / 100.0,
account_mode="compound" if account_mode == "Compound" else "realistic",
withdrawal_fraction=float(withdrawal_fraction) / 100.0,
use_confluence=bool(use_confluence),
min_confluence_score=int(min_confluence_score),
benchmark=benchmark,
)
# ---- Run backtest ----
result = run_backtest(ticker_data, bundle, config, progress_cb=_cb)
# ---- Build outputs ----
_cb("Rendering results...", 0.97)
m = result.metrics
metrics_html = _build_metrics_html(m)
fig_nav = nav_chart(result.nav_df, result.benchmark_df, config.initial_cash)
fig_monthly = monthly_returns_heatmap(result.nav_df, config.initial_cash)
fig_exit = exit_reasons_chart(result.trades_df)
fig_dist = trade_return_distribution(result.trades_df)
trades_display = result.trades_df if not result.trades_df.empty else pd.DataFrame(
columns=["Ticker","Entry Date","Exit Date","Exit Reason","Entry Price","Exit Price","Return %","Profit $","Days Held"]
)
# CSV export string
csv_str = result.trades_df.to_csv(index=False) if not result.trades_df.empty else ""
json_str = json.dumps(result.metrics, indent=2, default=str)
_cb("Done.", 1.0)
return (
metrics_html,
fig_nav, fig_monthly, fig_exit, fig_dist,
trades_display,
csv_str, json_str,
f"✅ Backtest complete — {m.get('Total Trades', 0)} trades over {result.n_tickers_processed} tickers.",
)
except Exception as e:
logger.exception("Backtest error")
return _bt_empty(f"❌ Error: {e}")
def _bt_empty(msg):
empty_fig = _empty_fig(msg)
empty_df = pd.DataFrame()
return "
%s
" % msg, \
empty_fig, empty_fig, empty_fig, empty_fig, empty_df, "", "{}", msg
def _empty_fig(msg="No data"):
import plotly.graph_objects as go
fig = go.Figure()
fig.update_layout(
paper_bgcolor="#0d0f14", plot_bgcolor="#14171f",
font=dict(color="#7a7f94"),
annotations=[dict(text=msg, xref="paper", yref="paper",
x=0.5, y=0.5, showarrow=False,
font=dict(size=13, color="#7a7f94"))],
height=300,
)
return fig
def _build_metrics_html(m: dict) -> str:
if not m:
return "No metrics available.
"
def _card(label, value, color="#e4e6ef", sub=""):
return f"""
{label}
{value}
{f'
{sub}
' if sub else ''}
"""
total_ret = m.get("Total Return %", 0)
ann_ret = m.get("Annualized Return %", 0)
sharpe = m.get("Sharpe Ratio", 0)
max_dd = m.get("Max Drawdown %", 0)
win_rate = m.get("Win Rate %", 0)
n_trades = m.get("Total Trades", 0)
pf = m.get("Profit Factor", 0)
calmar = m.get("Calmar Ratio", 0)
ret_color = "#00d4aa" if total_ret >= 0 else "#ff4d6a"
dd_color = "#ff4d6a" if max_dd < -10 else "#f5a623" if max_dd < -5 else "#e4e6ef"
sharpe_color = "#00d4aa" if sharpe > 1 else "#f5a623" if sharpe > 0.5 else "#ff4d6a"
cards = "".join([
_card("Total Return", f"{total_ret:+.1f}%", ret_color),
_card("Ann. Return", f"{ann_ret:+.1f}%", ret_color),
_card("Sharpe", f"{sharpe:.2f}", sharpe_color),
_card("Max Drawdown", f"{max_dd:.1f}%", dd_color),
_card("Win Rate", f"{win_rate:.1f}%"),
_card("Trades", f"{n_trades:,}"),
_card("Profit Factor", f"{pf:.2f}"),
_card("Calmar", f"{calmar:.2f}"),
])
return f"""
{cards}
"""
def apply_preset(preset_name):
cfg = PRESETS.get(preset_name, PRESETS["Balanced"])
sizing_label = "Volatility-adjusted" if cfg.sizing_mode == "volatility_adjusted" else "Equal-weight"
account_label = "Compound" if cfg.account_mode == "compound" else "Realistic"
return (
cfg.conviction_threshold,
cfg.max_positions,
cfg.sl_multiplier,
cfg.pt_multiplier,
cfg.horizon_days,
cfg.cooldown_days,
sizing_label,
cfg.risk_fraction * 100,
account_label,
)
# ---------------------------------------------------------------------------
# Evaluator runner
# ---------------------------------------------------------------------------
def run_evaluator(
model_labels_selected,
eval_start, eval_end,
ticker_preset, custom_tickers,
pt_multiplier, sl_multiplier, horizon,
w_discrimination, w_feature_health, w_signal_stability,
w_calibration, w_regime_robustness, w_asymmetry,
progress=gr.Progress(),
):
global _eval_results, _eval_labels
def _cb(msg, frac=None):
progress(frac or 0, desc=msg)
if not model_labels_selected:
return _eval_empty("❌ Select at least one model.")
if isinstance(model_labels_selected, str):
model_labels_selected = [model_labels_selected]
weights = {
"discrimination": w_discrimination / 100,
"feature_health": w_feature_health / 100,
"signal_stability": w_signal_stability / 100,
"calibration": w_calibration / 100,
"regime_robustness": w_regime_robustness / 100,
"asymmetry": w_asymmetry / 100,
}
all_results = []
all_labels = []
for model_label in model_labels_selected:
try:
_cb(f"Loading model: {model_label}...", 0.02)
bundle = _get_or_load_bundle(model_label, progress_cb=_cb)
if bundle is None or bundle.main_model is None:
logger.warning(f"Could not load {model_label}, skipping.")
continue
tickers = _universe_for_preset(ticker_preset, custom_tickers, bundle)
if not tickers:
continue
_cb(f"Downloading data for {model_label}...", 0.05)
ticker_data = download_ticker_batch(
tickers, start=str(eval_start), end=str(eval_end), progress_cb=_cb
)
ticker_data = filter_ticker_data(ticker_data, progress_cb=_cb)
_cb(f"Evaluating {model_label}...", 0.38)
result = run_evaluation(
ticker_data=ticker_data,
bundle=bundle,
pt_multiplier=float(pt_multiplier),
sl_multiplier=float(sl_multiplier),
horizon=int(horizon),
dimension_weights=weights,
progress_cb=_cb,
)
all_results.append(result)
all_labels.append(model_label.split("·")[0].strip())
except Exception as e:
logger.exception(f"Evaluation error for {model_label}")
continue
if not all_results:
return _eval_empty("❌ No results produced. Check model and data.")
_eval_results = all_results
_eval_labels = all_labels
primary = all_results[0]
primary_label = all_labels[0]
# Build grade HTML
grade_html = _build_grade_html(primary, primary_label)
# Build dim scores table
dim_table = _build_dim_table(primary)
# Charts
fig_radar = radar_chart(primary.dimensions)
fig_reliability = reliability_diagram(primary.reliability_bins)
fig_regime = regime_heatmap(primary.regime_scores)
fig_psi = feature_psi_chart(primary.feature_psi)
fig_compare = multi_model_comparison(all_results, all_labels) if len(all_results) > 1 else _empty_fig("Add more models to enable comparison")
# Flags
all_flags = []
for dim in primary.dimensions:
for flag in dim.flags:
all_flags.append(f"⚠️ [{dim.name.replace('_',' ').title()}] {flag}")
flags_html = _build_flags_html(all_flags)
# PSI table
psi_display = primary.feature_psi if not primary.feature_psi.empty else pd.DataFrame(
columns=["Feature", "NaN Rate", "Inf Rate", "PSI", "Status"]
)
status = f"✅ Evaluated {len(all_results)} model(s). Primary score: {primary.overall_score:.1f} ({primary.grade})"
_cb("Done.", 1.0)
return (
grade_html,
dim_table,
fig_radar,
fig_reliability,
fig_regime,
fig_psi,
fig_compare,
psi_display,
flags_html,
status,
)
def _build_grade_html(result, label: str) -> str:
grade = result.grade
score = result.overall_score
color = GRADE_COLORS.get(grade, "#e4e6ef")
dr = result.eval_date_range
return f"""
{grade}
{score:.1f} / 100
{label}
{result.n_samples:,} samples · {result.n_positives} positives ({result.n_positives/max(result.n_samples,1):.1%} rate)
· {dr[0]} → {dr[1]}
"""
def _build_dim_table(result) -> str:
rows = ""
for dim in result.dimensions:
bar_color = "#00d4aa" if dim.score >= 70 else "#f5a623" if dim.score >= 50 else "#ff4d6a"
bar_w = int(dim.score)
label = dim.name.replace("_", " ").title()
rows += f"""
| {label} |
|
{dim.score:.1f} |
{int(dim.weight*100)}% weight |
"""
return f"""
| Dimension |
Score Bar |
Score |
Weight |
{rows}
"""
def _build_flags_html(flags: list) -> str:
if not flags:
return "✅ No critical issues detected.
"
items = "".join(
f"{f}"
for f in flags
)
return f""
def _eval_empty(msg):
empty_fig = _empty_fig(msg)
empty_df = pd.DataFrame()
err_html = f"{msg}
"
return err_html, err_html, empty_fig, empty_fig, empty_fig, empty_fig, empty_fig, empty_df, err_html, msg
# ---------------------------------------------------------------------------
# CSS
# ---------------------------------------------------------------------------
CUSTOM_CSS = """
@import url('https://fonts.googleapis.com/css2?family=DM+Sans:wght@400;500;600;700&family=JetBrains+Mono:wght@400;600;700&display=swap');
:root {
--bg: #0d0f14;
--surface: #14171f;
--surface2: #1c202c;
--border: #2a2f3d;
--text: #e4e6ef;
--muted: #7a7f94;
--teal: #00d4aa;
--red: #ff4d6a;
--amber: #f5a623;
--purple: #a78bfa;
}
body, .gradio-container {
background: var(--bg) !important;
font-family: 'DM Sans', sans-serif !important;
color: var(--text) !important;
}
.gradio-container {
max-width: 1400px !important;
}
/* Header */
.sniper-header {
padding: 28px 0 20px;
border-bottom: 1px solid var(--border);
margin-bottom: 24px;
}
.sniper-header h1 {
font-size: 28px;
font-weight: 700;
color: var(--text);
margin: 0;
display: flex;
align-items: center;
gap: 10px;
}
.sniper-header p {
color: var(--muted);
font-size: 13px;
margin: 6px 0 0;
}
/* Tab styling */
.tab-nav button {
font-family: 'DM Sans', sans-serif !important;
font-size: 14px !important;
font-weight: 600 !important;
color: var(--muted) !important;
background: transparent !important;
border-bottom: 2px solid transparent !important;
padding: 10px 20px !important;
transition: all 0.2s ease !important;
}
.tab-nav button.selected {
color: var(--teal) !important;
border-bottom: 2px solid var(--teal) !important;
}
/* Panel backgrounds */
.panel-box {
background: var(--surface);
border: 1px solid var(--border);
border-radius: 12px;
padding: 20px;
}
/* Sliders */
input[type=range] {
accent-color: var(--teal) !important;
}
/* Labels */
label span {
color: var(--text) !important;
font-size: 13px !important;
font-weight: 500 !important;
}
/* Inputs */
input, select, textarea {
background: var(--surface2) !important;
border: 1px solid var(--border) !important;
color: var(--text) !important;
border-radius: 8px !important;
}
input:focus, select:focus, textarea:focus {
border-color: var(--teal) !important;
outline: none !important;
box-shadow: 0 0 0 2px rgba(0, 212, 170, 0.15) !important;
}
/* Buttons */
.btn-primary {
background: var(--teal) !important;
color: #0d0f14 !important;
font-weight: 700 !important;
font-family: 'DM Sans', sans-serif !important;
border: none !important;
border-radius: 8px !important;
padding: 12px 28px !important;
font-size: 14px !important;
cursor: pointer !important;
transition: opacity 0.2s ease !important;
width: 100% !important;
}
.btn-primary:hover { opacity: 0.88 !important; }
.btn-secondary {
background: var(--surface2) !important;
color: var(--text) !important;
font-weight: 600 !important;
font-family: 'DM Sans', sans-serif !important;
border: 1px solid var(--border) !important;
border-radius: 8px !important;
padding: 10px 20px !important;
font-size: 13px !important;
cursor: pointer !important;
width: 100% !important;
}
/* Progress bar */
.progress-bar {
background: var(--teal) !important;
}
/* Status box */
.status-box {
background: var(--surface);
border: 1px solid var(--border);
border-radius: 8px;
padding: 12px 16px;
font-family: 'JetBrains Mono', monospace;
font-size: 12px;
color: var(--muted);
min-height: 48px;
}
/* Dataframe */
.dataframe {
background: var(--surface) !important;
border: 1px solid var(--border) !important;
border-radius: 10px !important;
overflow: hidden !important;
}
.dataframe th {
background: var(--surface2) !important;
color: var(--muted) !important;
font-size: 11px !important;
text-transform: uppercase !important;
letter-spacing: 0.06em !important;
border-bottom: 1px solid var(--border) !important;
padding: 10px 12px !important;
}
.dataframe td {
color: var(--text) !important;
font-family: 'JetBrains Mono', monospace !important;
font-size: 12px !important;
padding: 9px 12px !important;
border-bottom: 1px solid var(--border) !important;
}
/* Section labels */
.section-label {
font-size: 11px;
font-weight: 600;
text-transform: uppercase;
letter-spacing: 0.08em;
color: var(--muted);
margin-bottom: 8px;
}
/* Preset radio */
.preset-radio label {
display: inline-flex !important;
align-items: center !important;
gap: 6px !important;
padding: 7px 14px !important;
border: 1px solid var(--border) !important;
border-radius: 6px !important;
cursor: pointer !important;
font-size: 13px !important;
transition: all 0.15s ease !important;
}
.preset-radio label:has(input:checked) {
border-color: var(--teal) !important;
color: var(--teal) !important;
background: rgba(0,212,170,0.07) !important;
}
/* Accordion */
.accordion {
border: 1px solid var(--border) !important;
border-radius: 10px !important;
background: var(--surface) !important;
}
/* Dividers */
hr { border-color: var(--border) !important; }
/* Refresh button */
.refresh-row { margin-bottom: 20px; }
"""
# ---------------------------------------------------------------------------
# Build Gradio UI
# ---------------------------------------------------------------------------
def build_ui():
with gr.Blocks(css=CUSTOM_CSS, title="🎯 Sniper Model Evaluator") as demo:
# ---- Header ----
gr.HTML("""
""")
# ---- Registry row ----
with gr.Row(elem_classes="refresh-row"):
with gr.Column(scale=5):
registry_status = gr.Textbox(
label="Registry Status",
value="Click 'Refresh Models' to discover available model runs.",
interactive=False,
elem_classes="status-box",
lines=2,
)
with gr.Column(scale=1):
refresh_btn = gr.Button("🔄 Refresh Models", elem_classes="btn-secondary")
# ---- Tabs ----
with gr.Tabs():
# ==============================================================
# TAB 1 — BACKTESTER
# ==============================================================
with gr.Tab("📈 Backtester"):
with gr.Row():
# ---- LEFT: Parameter panel ----
with gr.Column(scale=1, min_width=310):
gr.HTML('Model
')
bt_model = gr.Dropdown(
choices=[], value=None, label="Model / Run",
info="Auto-discovered from Arkm20/sniper-model-*",
)
gr.HTML('Preset
')
preset_radio = gr.Radio(
choices=["Conservative", "Balanced", "Aggressive", "Paper test"],
value="Balanced",
label="",
elem_classes="preset-radio",
)
gr.HTML('Date Range
')
with gr.Row():
bt_start = gr.Textbox(value="2021-01-01", label="Start date", max_lines=1)
bt_end = gr.Textbox(value="2024-12-31", label="End date", max_lines=1)
gr.HTML('Capital & Signal
')
initial_cash = gr.Number(value=10000, label="Initial capital ($)", minimum=1000, maximum=1_000_000, step=1000)
conviction_thr = gr.Slider(0.20, 0.90, value=0.50, step=0.01,
label="Conviction threshold",
info="Minimum calibrated probability to enter a trade")
gr.HTML('Position Controls
')
max_positions = gr.Slider(1, 20, value=5, step=1, label="Max concurrent positions")
sl_mult = gr.Slider(0.25, 3.0, value=0.5, step=0.05,
label="Stop loss (× ATR)",
info="Mirrors trainer's SL_MULTIPLIER")
pt_mult = gr.Slider(1.0, 6.0, value=3.0, step=0.1,
label="Profit target (× ATR)",
info="Mirrors trainer's PT_MULTIPLIER")
horizon_days = gr.Slider(5, 60, value=15, step=1,
label="Max hold horizon (days)")
cooldown_days = gr.Slider(0, 10, value=2, step=1,
label="Cooldown after exit (days)")
gr.HTML('Sizing
')
sizing_mode = gr.Radio(
["Volatility-adjusted", "Equal-weight"],
value="Volatility-adjusted",
label="Position sizing mode",
)
risk_fraction = gr.Slider(0.5, 5.0, value=2.0, step=0.1,
label="Risk % of NAV per trade",
info="Used in volatility-adjusted mode only")
with gr.Accordion("Advanced", open=False):
transaction_pct = gr.Slider(0.0, 0.5, value=0.1, step=0.01,
label="Transaction cost (%)",
info="Applied on entry and exit")
account_mode = gr.Radio(
["Compound", "Realistic"],
value="Compound",
label="Account mode",
info="Realistic withdraws a fraction of profits to personal cash",
)
withdrawal_fraction = gr.Slider(5, 50, value=20, step=5,
label="Profit withdrawal % (Realistic mode)")
use_regime = gr.Checkbox(value=True, label="Auto-route to regime models")
benchmark = gr.Dropdown(["SPY", "QQQ", "None"], value="SPY", label="Benchmark overlay")
use_confluence = gr.Checkbox(value=False, label="Confluence filter")
min_confluence = gr.Slider(1, 10, value=3, step=1, label="Min confluence score")
gr.HTML('Ticker Universe
')
ticker_preset = gr.Dropdown(
choices=list(TICKER_UNIVERSE_PRESETS.keys()),
value="All (~800 tickers)",
label="Universe",
)
custom_tickers = gr.Textbox(
label="Custom tickers (space or comma separated)",
placeholder="AAPL MSFT NVDA TSLA ...",
visible=False, lines=3,
)
bt_run_btn = gr.Button("▶ Run Backtest", elem_classes="btn-primary", variant="primary")
bt_status = gr.Textbox(label="Status", value="", interactive=False,
elem_classes="status-box", lines=1)
# ---- RIGHT: Results panel ----
with gr.Column(scale=3):
bt_metrics_html = gr.HTML(
""
"Configure parameters and run a backtest to see results.
"
)
bt_nav_plot = gr.Plot(label="")
with gr.Row():
bt_monthly_plot = gr.Plot(label="")
bt_exit_plot = gr.Plot(label="")
bt_dist_plot = gr.Plot(label="")
with gr.Accordion("Trade Log", open=False):
bt_trades_df = gr.Dataframe(
headers=["Ticker","Entry Date","Exit Date","Exit Reason",
"Entry Price","Exit Price","Return %","Profit $","Days Held"],
label="", wrap=False,
)
with gr.Accordion("Export", open=False):
with gr.Row():
bt_csv_out = gr.Textbox(label="Trades CSV", lines=6, interactive=False)
bt_json_out = gr.Textbox(label="Metrics JSON", lines=6, interactive=False)
# ==============================================================
# TAB 2 — EVALUATOR
# ==============================================================
with gr.Tab("🔬 Model Evaluator"):
with gr.Row():
# ---- LEFT: Config ----
with gr.Column(scale=1, min_width=310):
gr.HTML('Models to Evaluate
')
eval_models = gr.Dropdown(
choices=[], value=None, multiselect=True,
label="Model runs (select one or more for comparison)",
)
gr.HTML('Evaluation Range
')
with gr.Row():
eval_start = gr.Textbox(value="2019-01-01", label="Start date", max_lines=1)
eval_end = gr.Textbox(value="2020-12-31", label="End date", max_lines=1)
gr.HTML('Label Parameters
')
gr.HTML('Match these to the model\'s training config for accurate scoring.
')
eval_pt = gr.Slider(1.0, 6.0, value=3.0, step=0.1,
label="Profit target (× ATR)",
info="Should match PT_MULTIPLIER in training")
eval_sl = gr.Slider(0.1, 3.0, value=0.5, step=0.05,
label="Stop loss (× ATR)",
info="Should match SL_MULTIPLIER in training")
eval_horizon = gr.Slider(5, 60, value=15, step=1,
label="Horizon (days)",
info="Should match HORIZON in training")
gr.HTML('Ticker Universe
')
eval_ticker_preset = gr.Dropdown(
choices=list(TICKER_UNIVERSE_PRESETS.keys()),
value="S&P 500 subset (200)",
label="Universe",
info="Smaller universe = faster evaluation",
)
eval_custom_tickers = gr.Textbox(
label="Custom tickers",
placeholder="AAPL MSFT NVDA ...",
visible=False, lines=3,
)
with gr.Accordion("Scoring Weights", open=False):
gr.HTML('Adjust how each dimension contributes to the overall score. Values are relative.
')
w_disc = gr.Slider(0, 100, value=20, step=5, label="Discrimination")
w_feat = gr.Slider(0, 100, value=20, step=5, label="Feature Health")
w_stab = gr.Slider(0, 100, value=15, step=5, label="Signal Stability")
w_cal = gr.Slider(0, 100, value=15, step=5, label="Calibration")
w_reg = gr.Slider(0, 100, value=15, step=5, label="Regime Robustness")
w_asym = gr.Slider(0, 100, value=15, step=5, label="Asymmetry Capture")
eval_run_btn = gr.Button("▶ Run Evaluation", elem_classes="btn-primary", variant="primary")
eval_status = gr.Textbox(label="Status", value="", interactive=False,
elem_classes="status-box", lines=1)
# ---- RIGHT: Results ----
with gr.Column(scale=3):
eval_grade_html = gr.HTML(
""
"Select models and run evaluation to see the scorecard.
"
)
eval_dim_table = gr.HTML("")
with gr.Row():
eval_radar = gr.Plot(label="")
eval_reliability = gr.Plot(label="")
with gr.Row():
eval_regime = gr.Plot(label="")
eval_psi_chart = gr.Plot(label="")
eval_compare = gr.Plot(label="")
with gr.Accordion("Feature PSI Table", open=False):
eval_psi_table = gr.Dataframe(
headers=["Feature", "NaN Rate", "Inf Rate", "PSI", "Status"],
label="", wrap=False,
)
with gr.Accordion("Warnings & Flags", open=False):
eval_flags_html = gr.HTML("")
# ---------------------------------------------------------------------------
# Event wiring
# ---------------------------------------------------------------------------
# Refresh registry → update both dropdowns
refresh_btn.click(
fn=_refresh_registry,
outputs=[bt_model, eval_models, registry_status],
)
# Preset → fill sliders
preset_radio.change(
fn=apply_preset,
inputs=[preset_radio],
outputs=[conviction_thr, max_positions, sl_mult, pt_mult,
horizon_days, cooldown_days, sizing_mode, risk_fraction, account_mode],
)
# Show custom ticker box when "Custom" selected
def _toggle_custom_bt(val):
return gr.Textbox(visible=(val == "Custom (paste tickers)"))
def _toggle_custom_eval(val):
return gr.Textbox(visible=(val == "Custom (paste tickers)"))
ticker_preset.change(fn=_toggle_custom_bt, inputs=[ticker_preset], outputs=[custom_tickers])
eval_ticker_preset.change(fn=_toggle_custom_eval, inputs=[eval_ticker_preset], outputs=[eval_custom_tickers])
# Run backtest
bt_run_btn.click(
fn=run_backtester,
inputs=[
bt_model, bt_start, bt_end, initial_cash,
conviction_thr, max_positions,
sl_mult, pt_mult, horizon_days, cooldown_days,
sizing_mode, risk_fraction, transaction_pct,
account_mode, withdrawal_fraction,
use_regime, benchmark,
use_confluence, min_confluence,
ticker_preset, custom_tickers,
],
outputs=[
bt_metrics_html,
bt_nav_plot, bt_monthly_plot, bt_exit_plot, bt_dist_plot,
bt_trades_df,
bt_csv_out, bt_json_out,
bt_status,
],
)
# Run evaluator
eval_run_btn.click(
fn=run_evaluator,
inputs=[
eval_models, eval_start, eval_end,
eval_ticker_preset, eval_custom_tickers,
eval_pt, eval_sl, eval_horizon,
w_disc, w_feat, w_stab, w_cal, w_reg, w_asym,
],
outputs=[
eval_grade_html, eval_dim_table,
eval_radar, eval_reliability,
eval_regime, eval_psi_chart,
eval_compare,
eval_psi_table,
eval_flags_html,
eval_status,
],
)
# Auto-refresh on load
demo.load(fn=_refresh_registry, outputs=[bt_model, eval_models, registry_status])
return demo
# ---------------------------------------------------------------------------
# Entry point
# ---------------------------------------------------------------------------
if __name__ == "__main__":
app = build_ui()
app.launch(
server_name="0.0.0.0",
server_port=7860,
show_error=True,
)