Spaces:
Sleeping
Sleeping
| """ | |
| app.py — Sniper Model Evaluator Space | |
| Two modes: Backtester and Model Evaluator. | |
| """ | |
| import sys | |
| import json | |
| import time | |
| import logging | |
| from datetime import datetime, date | |
| from pathlib import Path | |
| import numpy as np | |
| import pandas as pd | |
| import gradio as gr | |
| sys.path.insert(0, str(Path(__file__).parent)) | |
| # Must run before any joblib.load — injects SniperModel into __main__ | |
| # so pickles trained with __main__.SniperModel deserialize correctly. | |
| from src.sniper_model import patch_main as _patch_main | |
| _patch_main() | |
| from src.registry import ( | |
| discover_repos, load_bundle, get_all_bundle_labels, | |
| find_bundle_by_label, ModelRepo, ArtifactBundle, | |
| ) | |
| from src.data_loader import ( | |
| download_ticker_batch, filter_ticker_data, extract_market_series, | |
| ) | |
| from src.backtester import run_backtest, BacktestConfig, PRESETS | |
| from src.evaluator import run_evaluation, DIMENSION_WEIGHTS | |
| from src.charts import ( | |
| nav_chart, monthly_returns_heatmap, exit_reasons_chart, | |
| trade_return_distribution, radar_chart, reliability_diagram, | |
| regime_heatmap, feature_psi_chart, multi_model_comparison, | |
| ) | |
| logging.basicConfig(level=logging.INFO, format="%(asctime)s [%(levelname)s] %(message)s") | |
| logger = logging.getLogger("SniperApp") | |
| # --------------------------------------------------------------------------- | |
| # Global state | |
| # --------------------------------------------------------------------------- | |
| _repos: list[ModelRepo] = [] | |
| _loaded_bundles: dict[str, ArtifactBundle] = {} # label -> loaded bundle | |
| _eval_results: list = [] # for multi-model comparison | |
| _eval_labels: list[str] = [] | |
| TICKER_UNIVERSE_PRESETS = { | |
| "All (~800 tickers)": None, # None = use bundle's full universe | |
| "S&P 500 subset (200)": [ # representative 200-ticker subset | |
| "AAPL","MSFT","AMZN","GOOGL","META","NVDA","TSLA","BRK-B","UNH","JNJ", | |
| "XOM","JPM","V","PG","MA","CVX","HD","LLY","ABBV","MRK","PEP","KO", | |
| "BAC","PFE","COST","TMO","AVGO","CSCO","ABT","ACN","WMT","CRM","NEE", | |
| "TXN","DHR","VZ","AMGN","NKE","MDT","UPS","LIN","PM","RTX","QCOM", | |
| "HON","T","SBUX","INTU","AMAT","ISRG","GS","MS","SCHW","AXP","BLK", | |
| "C","USB","TGT","GILD","CVS","ZTS","SYK","ELV","ADP","NOW","SPGI", | |
| "CL","MO","DUK","SO","D","EXC","AEP","SRE","PCG","WEC","ES","XEL", | |
| "LMT","NOC","GD","BA","RTX","HII","TDG","HWM","HEI","LDOS", | |
| "CAT","DE","EMR","ETN","ITW","PH","ROK","DOV","CARR","OTIS", | |
| "F","GM","APTV","BWA","LEA","MGA","AZO","ORLY","GPC","AAP", | |
| "AMZN","EBAY","ETSY","ROST","TJX","LOW","HD","WST","SHW","RPM", | |
| "ECL","PPG","EMN","LYB","APD","CF","MOS","NUE","FCX","FMC", | |
| "BIIB","REGN","VRTX","ILMN","IDXX","MRNA","HOLX","BIO","ABMD", | |
| "MCO","ICE","CME","CBOE","NDAQ","MSCI","FDS","MKTX","BR","AMP", | |
| "TROW","BEN","IVZ","FHI","AMG","SEIC","APAM","VRTS","HLNE","PIPR", | |
| "PLD","EQIX","AMT","CCI","SBAC","ARE","AVB","EQR","MAA","UDR", | |
| "WM","RSG","CWST","CLH","GFL","WCN","SRCL","ECOL","ARIS","US", | |
| "FDX","UPS","XPO","ODFL","SAIA","WERN","JBHT","ARCB","CHRW","EXPD", | |
| ], | |
| "Custom (paste tickers)": "custom", | |
| } | |
| GRADE_COLORS = { | |
| "A+": "#00d4aa", "A": "#00d4aa", "A-": "#34d399", | |
| "B+": "#60a5fa", "B": "#60a5fa", "B-": "#93c5fd", | |
| "C+": "#f5a623", "C": "#f5a623", "C-": "#fbbf24", | |
| "D+": "#ff4d6a", "D": "#ff4d6a", "F": "#dc2626", | |
| } | |
| # --------------------------------------------------------------------------- | |
| # Shared helpers | |
| # --------------------------------------------------------------------------- | |
| def _refresh_registry(progress=gr.Progress()): | |
| global _repos | |
| messages = [] | |
| def _cb(msg): | |
| messages.append(msg) | |
| progress(0, desc=msg) | |
| progress(0, desc="Scanning HuggingFace for Sniper repositories...") | |
| _repos = discover_repos(progress_cb=_cb) | |
| labels = get_all_bundle_labels(_repos) | |
| if not labels: | |
| return ( | |
| gr.Dropdown(choices=[], value=None, label="Model / Run"), | |
| gr.Dropdown(choices=[], value=None, label="Model / Run (Evaluator)"), | |
| "⚠️ No models found. Check that Arkm20/sniper-model-* repos are public.", | |
| ) | |
| status = f"✅ Found {len(_repos)} repo(s), {len(labels)} run(s) total.\n" + "\n".join(f" · {l}" for l in labels) | |
| return ( | |
| gr.Dropdown(choices=labels, value=labels[0], label="Model / Run"), | |
| gr.Dropdown(choices=labels, value=labels[0], label="Model / Run (Evaluator)"), | |
| status, | |
| ) | |
| def _get_or_load_bundle(label: str, progress_cb=None) -> ArtifactBundle | None: | |
| global _loaded_bundles | |
| if label in _loaded_bundles: | |
| return _loaded_bundles[label] | |
| bundle = find_bundle_by_label(_repos, label) | |
| if bundle is None: | |
| return None | |
| bundle = load_bundle(bundle, progress_cb=progress_cb) | |
| _loaded_bundles[label] = bundle | |
| return bundle | |
| def _universe_for_preset(preset_name: str, custom_text: str, bundle) -> list[str]: | |
| if preset_name == "Custom (paste tickers)": | |
| tickers = [t.strip().upper() for t in custom_text.replace(",", " ").split() if t.strip()] | |
| return tickers if tickers else (bundle.metadata.get("configuration", {}).get("TICKER_UNIVERSE", []) or []) | |
| if preset_name == "S&P 500 subset (200)": | |
| return TICKER_UNIVERSE_PRESETS["S&P 500 subset (200)"] | |
| # "All" - use training universe from bundle metadata if available, else hardcoded | |
| return bundle.metadata.get("configuration", {}).get("TICKER_UNIVERSE", []) or [] | |
| # --------------------------------------------------------------------------- | |
| # Backtester runner | |
| # --------------------------------------------------------------------------- | |
| def run_backtester( | |
| model_label, bt_start, bt_end, initial_cash, | |
| conviction_threshold, max_positions, | |
| sl_multiplier, pt_multiplier, horizon_days, cooldown_days, | |
| sizing_mode, risk_fraction, transaction_pct, | |
| account_mode, withdrawal_fraction, | |
| use_regime_routing, benchmark, | |
| use_confluence, min_confluence_score, | |
| ticker_preset, custom_tickers, | |
| progress=gr.Progress(), | |
| ): | |
| def _cb(msg, frac=None): | |
| progress(frac or 0, desc=msg) | |
| try: | |
| # ---- Load model ---- | |
| _cb("Resolving model artifacts from HuggingFace...", 0.01) | |
| bundle = _get_or_load_bundle(model_label, progress_cb=_cb) | |
| if bundle is None or bundle.main_model is None: | |
| return _bt_empty("❌ Could not load model. Check model label and connectivity.") | |
| # ---- Determine universe ---- | |
| tickers = _universe_for_preset(ticker_preset, custom_tickers, bundle) | |
| if not tickers: | |
| return _bt_empty("❌ No tickers found for selected universe.") | |
| # Also add benchmark ticker for download | |
| bm_sym = benchmark if benchmark != "None" else None | |
| extra = [bm_sym] if bm_sym else [] | |
| # ---- Download data ---- | |
| _cb("Downloading market data — batch 1/N...", 0.05) | |
| ticker_data = download_ticker_batch( | |
| tickers + extra, | |
| start=str(bt_start), end=str(bt_end), | |
| progress_cb=_cb, | |
| ) | |
| ticker_data = filter_ticker_data(ticker_data, progress_cb=_cb) | |
| # ---- Build config ---- | |
| config = BacktestConfig( | |
| start_date=str(bt_start), | |
| end_date=str(bt_end), | |
| initial_cash=float(initial_cash), | |
| conviction_threshold=float(conviction_threshold), | |
| use_regime_routing=bool(use_regime_routing), | |
| max_positions=int(max_positions), | |
| pt_multiplier=float(pt_multiplier), | |
| sl_multiplier=float(sl_multiplier), | |
| atr_period=14, | |
| horizon_days=int(horizon_days), | |
| cooldown_days=int(cooldown_days), | |
| sizing_mode="volatility_adjusted" if sizing_mode == "Volatility-adjusted" else "equal_weight", | |
| risk_fraction=float(risk_fraction) / 100.0, | |
| transaction_pct=float(transaction_pct) / 100.0, | |
| account_mode="compound" if account_mode == "Compound" else "realistic", | |
| withdrawal_fraction=float(withdrawal_fraction) / 100.0, | |
| use_confluence=bool(use_confluence), | |
| min_confluence_score=int(min_confluence_score), | |
| benchmark=benchmark, | |
| ) | |
| # ---- Run backtest ---- | |
| result = run_backtest(ticker_data, bundle, config, progress_cb=_cb) | |
| # ---- Build outputs ---- | |
| _cb("Rendering results...", 0.97) | |
| m = result.metrics | |
| metrics_html = _build_metrics_html(m) | |
| fig_nav = nav_chart(result.nav_df, result.benchmark_df, config.initial_cash) | |
| fig_monthly = monthly_returns_heatmap(result.nav_df, config.initial_cash) | |
| fig_exit = exit_reasons_chart(result.trades_df) | |
| fig_dist = trade_return_distribution(result.trades_df) | |
| trades_display = result.trades_df if not result.trades_df.empty else pd.DataFrame( | |
| columns=["Ticker","Entry Date","Exit Date","Exit Reason","Entry Price","Exit Price","Return %","Profit $","Days Held"] | |
| ) | |
| # CSV export string | |
| csv_str = result.trades_df.to_csv(index=False) if not result.trades_df.empty else "" | |
| json_str = json.dumps(result.metrics, indent=2, default=str) | |
| _cb("Done.", 1.0) | |
| return ( | |
| metrics_html, | |
| fig_nav, fig_monthly, fig_exit, fig_dist, | |
| trades_display, | |
| csv_str, json_str, | |
| f"✅ Backtest complete — {m.get('Total Trades', 0)} trades over {result.n_tickers_processed} tickers.", | |
| ) | |
| except Exception as e: | |
| logger.exception("Backtest error") | |
| return _bt_empty(f"❌ Error: {e}") | |
| def _bt_empty(msg): | |
| empty_fig = _empty_fig(msg) | |
| empty_df = pd.DataFrame() | |
| return "<div style='color:#ff4d6a;padding:16px'>%s</div>" % msg, \ | |
| empty_fig, empty_fig, empty_fig, empty_fig, empty_df, "", "{}", msg | |
| def _empty_fig(msg="No data"): | |
| import plotly.graph_objects as go | |
| fig = go.Figure() | |
| fig.update_layout( | |
| paper_bgcolor="#0d0f14", plot_bgcolor="#14171f", | |
| font=dict(color="#7a7f94"), | |
| annotations=[dict(text=msg, xref="paper", yref="paper", | |
| x=0.5, y=0.5, showarrow=False, | |
| font=dict(size=13, color="#7a7f94"))], | |
| height=300, | |
| ) | |
| return fig | |
| def _build_metrics_html(m: dict) -> str: | |
| if not m: | |
| return "<p style='color:#7a7f94'>No metrics available.</p>" | |
| def _card(label, value, color="#e4e6ef", sub=""): | |
| return f""" | |
| <div style="background:#14171f;border:1px solid #2a2f3d;border-radius:10px; | |
| padding:16px 20px;min-width:140px;flex:1"> | |
| <div style="color:#7a7f94;font-size:11px;font-family:'DM Sans',sans-serif; | |
| text-transform:uppercase;letter-spacing:0.08em;margin-bottom:6px">{label}</div> | |
| <div style="color:{color};font-size:22px;font-weight:700;font-family:'JetBrains Mono',monospace">{value}</div> | |
| {f'<div style="color:#7a7f94;font-size:10px;margin-top:4px">{sub}</div>' if sub else ''} | |
| </div>""" | |
| total_ret = m.get("Total Return %", 0) | |
| ann_ret = m.get("Annualized Return %", 0) | |
| sharpe = m.get("Sharpe Ratio", 0) | |
| max_dd = m.get("Max Drawdown %", 0) | |
| win_rate = m.get("Win Rate %", 0) | |
| n_trades = m.get("Total Trades", 0) | |
| pf = m.get("Profit Factor", 0) | |
| calmar = m.get("Calmar Ratio", 0) | |
| ret_color = "#00d4aa" if total_ret >= 0 else "#ff4d6a" | |
| dd_color = "#ff4d6a" if max_dd < -10 else "#f5a623" if max_dd < -5 else "#e4e6ef" | |
| sharpe_color = "#00d4aa" if sharpe > 1 else "#f5a623" if sharpe > 0.5 else "#ff4d6a" | |
| cards = "".join([ | |
| _card("Total Return", f"{total_ret:+.1f}%", ret_color), | |
| _card("Ann. Return", f"{ann_ret:+.1f}%", ret_color), | |
| _card("Sharpe", f"{sharpe:.2f}", sharpe_color), | |
| _card("Max Drawdown", f"{max_dd:.1f}%", dd_color), | |
| _card("Win Rate", f"{win_rate:.1f}%"), | |
| _card("Trades", f"{n_trades:,}"), | |
| _card("Profit Factor", f"{pf:.2f}"), | |
| _card("Calmar", f"{calmar:.2f}"), | |
| ]) | |
| return f""" | |
| <div style="display:flex;flex-wrap:wrap;gap:10px;padding:4px 0 12px"> | |
| {cards} | |
| </div>""" | |
| def apply_preset(preset_name): | |
| cfg = PRESETS.get(preset_name, PRESETS["Balanced"]) | |
| sizing_label = "Volatility-adjusted" if cfg.sizing_mode == "volatility_adjusted" else "Equal-weight" | |
| account_label = "Compound" if cfg.account_mode == "compound" else "Realistic" | |
| return ( | |
| cfg.conviction_threshold, | |
| cfg.max_positions, | |
| cfg.sl_multiplier, | |
| cfg.pt_multiplier, | |
| cfg.horizon_days, | |
| cfg.cooldown_days, | |
| sizing_label, | |
| cfg.risk_fraction * 100, | |
| account_label, | |
| ) | |
| # --------------------------------------------------------------------------- | |
| # Evaluator runner | |
| # --------------------------------------------------------------------------- | |
| def run_evaluator( | |
| model_labels_selected, | |
| eval_start, eval_end, | |
| ticker_preset, custom_tickers, | |
| pt_multiplier, sl_multiplier, horizon, | |
| w_discrimination, w_feature_health, w_signal_stability, | |
| w_calibration, w_regime_robustness, w_asymmetry, | |
| progress=gr.Progress(), | |
| ): | |
| global _eval_results, _eval_labels | |
| def _cb(msg, frac=None): | |
| progress(frac or 0, desc=msg) | |
| if not model_labels_selected: | |
| return _eval_empty("❌ Select at least one model.") | |
| if isinstance(model_labels_selected, str): | |
| model_labels_selected = [model_labels_selected] | |
| weights = { | |
| "discrimination": w_discrimination / 100, | |
| "feature_health": w_feature_health / 100, | |
| "signal_stability": w_signal_stability / 100, | |
| "calibration": w_calibration / 100, | |
| "regime_robustness": w_regime_robustness / 100, | |
| "asymmetry": w_asymmetry / 100, | |
| } | |
| all_results = [] | |
| all_labels = [] | |
| for model_label in model_labels_selected: | |
| try: | |
| _cb(f"Loading model: {model_label}...", 0.02) | |
| bundle = _get_or_load_bundle(model_label, progress_cb=_cb) | |
| if bundle is None or bundle.main_model is None: | |
| logger.warning(f"Could not load {model_label}, skipping.") | |
| continue | |
| tickers = _universe_for_preset(ticker_preset, custom_tickers, bundle) | |
| if not tickers: | |
| continue | |
| _cb(f"Downloading data for {model_label}...", 0.05) | |
| ticker_data = download_ticker_batch( | |
| tickers, start=str(eval_start), end=str(eval_end), progress_cb=_cb | |
| ) | |
| ticker_data = filter_ticker_data(ticker_data, progress_cb=_cb) | |
| _cb(f"Evaluating {model_label}...", 0.38) | |
| result = run_evaluation( | |
| ticker_data=ticker_data, | |
| bundle=bundle, | |
| pt_multiplier=float(pt_multiplier), | |
| sl_multiplier=float(sl_multiplier), | |
| horizon=int(horizon), | |
| dimension_weights=weights, | |
| progress_cb=_cb, | |
| ) | |
| all_results.append(result) | |
| all_labels.append(model_label.split("·")[0].strip()) | |
| except Exception as e: | |
| logger.exception(f"Evaluation error for {model_label}") | |
| continue | |
| if not all_results: | |
| return _eval_empty("❌ No results produced. Check model and data.") | |
| _eval_results = all_results | |
| _eval_labels = all_labels | |
| primary = all_results[0] | |
| primary_label = all_labels[0] | |
| # Build grade HTML | |
| grade_html = _build_grade_html(primary, primary_label) | |
| # Build dim scores table | |
| dim_table = _build_dim_table(primary) | |
| # Charts | |
| fig_radar = radar_chart(primary.dimensions) | |
| fig_reliability = reliability_diagram(primary.reliability_bins) | |
| fig_regime = regime_heatmap(primary.regime_scores) | |
| fig_psi = feature_psi_chart(primary.feature_psi) | |
| fig_compare = multi_model_comparison(all_results, all_labels) if len(all_results) > 1 else _empty_fig("Add more models to enable comparison") | |
| # Flags | |
| all_flags = [] | |
| for dim in primary.dimensions: | |
| for flag in dim.flags: | |
| all_flags.append(f"⚠️ [{dim.name.replace('_',' ').title()}] {flag}") | |
| flags_html = _build_flags_html(all_flags) | |
| # PSI table | |
| psi_display = primary.feature_psi if not primary.feature_psi.empty else pd.DataFrame( | |
| columns=["Feature", "NaN Rate", "Inf Rate", "PSI", "Status"] | |
| ) | |
| status = f"✅ Evaluated {len(all_results)} model(s). Primary score: {primary.overall_score:.1f} ({primary.grade})" | |
| _cb("Done.", 1.0) | |
| return ( | |
| grade_html, | |
| dim_table, | |
| fig_radar, | |
| fig_reliability, | |
| fig_regime, | |
| fig_psi, | |
| fig_compare, | |
| psi_display, | |
| flags_html, | |
| status, | |
| ) | |
| def _build_grade_html(result, label: str) -> str: | |
| grade = result.grade | |
| score = result.overall_score | |
| color = GRADE_COLORS.get(grade, "#e4e6ef") | |
| dr = result.eval_date_range | |
| return f""" | |
| <div style="display:flex;align-items:center;gap:28px;padding:20px 24px; | |
| background:#14171f;border:1px solid #2a2f3d;border-radius:14px;margin-bottom:8px"> | |
| <div style="font-size:72px;font-weight:800;color:{color}; | |
| font-family:'JetBrains Mono',monospace;line-height:1">{grade}</div> | |
| <div> | |
| <div style="font-size:36px;font-weight:700;color:{color}; | |
| font-family:'JetBrains Mono',monospace">{score:.1f}<span style="font-size:18px;color:#7a7f94"> / 100</span></div> | |
| <div style="color:#e4e6ef;font-size:14px;margin-top:4px;font-family:'DM Sans',sans-serif">{label}</div> | |
| <div style="color:#7a7f94;font-size:12px;margin-top:2px;font-family:'DM Sans',sans-serif"> | |
| {result.n_samples:,} samples · {result.n_positives} positives ({result.n_positives/max(result.n_samples,1):.1%} rate) | |
| · {dr[0]} → {dr[1]} | |
| </div> | |
| </div> | |
| </div>""" | |
| def _build_dim_table(result) -> str: | |
| rows = "" | |
| for dim in result.dimensions: | |
| bar_color = "#00d4aa" if dim.score >= 70 else "#f5a623" if dim.score >= 50 else "#ff4d6a" | |
| bar_w = int(dim.score) | |
| label = dim.name.replace("_", " ").title() | |
| rows += f""" | |
| <tr> | |
| <td style="padding:10px 12px;color:#e4e6ef;font-family:'DM Sans',sans-serif;white-space:nowrap">{label}</td> | |
| <td style="padding:10px 12px;width:200px"> | |
| <div style="background:#1c202c;border-radius:4px;height:8px;overflow:hidden"> | |
| <div style="background:{bar_color};height:8px;width:{bar_w}%;border-radius:4px; | |
| transition:width 0.4s ease"></div> | |
| </div> | |
| </td> | |
| <td style="padding:10px 12px;font-family:'JetBrains Mono',monospace; | |
| color:{bar_color};font-size:14px;font-weight:600">{dim.score:.1f}</td> | |
| <td style="padding:10px 12px;color:#7a7f94;font-size:11px">{int(dim.weight*100)}% weight</td> | |
| </tr>""" | |
| return f""" | |
| <table style="width:100%;border-collapse:collapse;background:#0d0f14; | |
| border-radius:10px;overflow:hidden;border:1px solid #2a2f3d"> | |
| <thead> | |
| <tr style="background:#14171f;border-bottom:1px solid #2a2f3d"> | |
| <th style="padding:10px 12px;text-align:left;color:#7a7f94;font-size:11px; | |
| font-family:'DM Sans',sans-serif;font-weight:600;text-transform:uppercase; | |
| letter-spacing:0.06em">Dimension</th> | |
| <th style="padding:10px 12px;text-align:left;color:#7a7f94;font-size:11px; | |
| font-family:'DM Sans',sans-serif;font-weight:600;text-transform:uppercase; | |
| letter-spacing:0.06em">Score Bar</th> | |
| <th style="padding:10px 12px;text-align:left;color:#7a7f94;font-size:11px; | |
| font-family:'DM Sans',sans-serif;font-weight:600;text-transform:uppercase; | |
| letter-spacing:0.06em">Score</th> | |
| <th style="padding:10px 12px;text-align:left;color:#7a7f94;font-size:11px; | |
| font-family:'DM Sans',sans-serif;font-weight:600;text-transform:uppercase; | |
| letter-spacing:0.06em">Weight</th> | |
| </tr> | |
| </thead> | |
| <tbody> | |
| {rows} | |
| </tbody> | |
| </table>""" | |
| def _build_flags_html(flags: list) -> str: | |
| if not flags: | |
| return "<div style='color:#34d399;padding:12px;font-family:DM Sans,sans-serif'>✅ No critical issues detected.</div>" | |
| items = "".join( | |
| f"<li style='margin-bottom:6px;color:#e4e6ef;font-family:DM Sans,sans-serif;font-size:13px'>{f}</li>" | |
| for f in flags | |
| ) | |
| return f"<ul style='padding-left:20px;margin:0'>{items}</ul>" | |
| def _eval_empty(msg): | |
| empty_fig = _empty_fig(msg) | |
| empty_df = pd.DataFrame() | |
| err_html = f"<div style='color:#ff4d6a;padding:16px;font-family:DM Sans,sans-serif'>{msg}</div>" | |
| return err_html, err_html, empty_fig, empty_fig, empty_fig, empty_fig, empty_fig, empty_df, err_html, msg | |
| # --------------------------------------------------------------------------- | |
| # CSS | |
| # --------------------------------------------------------------------------- | |
| CUSTOM_CSS = """ | |
| @import url('https://fonts.googleapis.com/css2?family=DM+Sans:wght@400;500;600;700&family=JetBrains+Mono:wght@400;600;700&display=swap'); | |
| :root { | |
| --bg: #0d0f14; | |
| --surface: #14171f; | |
| --surface2: #1c202c; | |
| --border: #2a2f3d; | |
| --text: #e4e6ef; | |
| --muted: #7a7f94; | |
| --teal: #00d4aa; | |
| --red: #ff4d6a; | |
| --amber: #f5a623; | |
| --purple: #a78bfa; | |
| } | |
| body, .gradio-container { | |
| background: var(--bg) !important; | |
| font-family: 'DM Sans', sans-serif !important; | |
| color: var(--text) !important; | |
| } | |
| .gradio-container { | |
| max-width: 1400px !important; | |
| } | |
| /* Header */ | |
| .sniper-header { | |
| padding: 28px 0 20px; | |
| border-bottom: 1px solid var(--border); | |
| margin-bottom: 24px; | |
| } | |
| .sniper-header h1 { | |
| font-size: 28px; | |
| font-weight: 700; | |
| color: var(--text); | |
| margin: 0; | |
| display: flex; | |
| align-items: center; | |
| gap: 10px; | |
| } | |
| .sniper-header p { | |
| color: var(--muted); | |
| font-size: 13px; | |
| margin: 6px 0 0; | |
| } | |
| /* Tab styling */ | |
| .tab-nav button { | |
| font-family: 'DM Sans', sans-serif !important; | |
| font-size: 14px !important; | |
| font-weight: 600 !important; | |
| color: var(--muted) !important; | |
| background: transparent !important; | |
| border-bottom: 2px solid transparent !important; | |
| padding: 10px 20px !important; | |
| transition: all 0.2s ease !important; | |
| } | |
| .tab-nav button.selected { | |
| color: var(--teal) !important; | |
| border-bottom: 2px solid var(--teal) !important; | |
| } | |
| /* Panel backgrounds */ | |
| .panel-box { | |
| background: var(--surface); | |
| border: 1px solid var(--border); | |
| border-radius: 12px; | |
| padding: 20px; | |
| } | |
| /* Sliders */ | |
| input[type=range] { | |
| accent-color: var(--teal) !important; | |
| } | |
| /* Labels */ | |
| label span { | |
| color: var(--text) !important; | |
| font-size: 13px !important; | |
| font-weight: 500 !important; | |
| } | |
| /* Inputs */ | |
| input, select, textarea { | |
| background: var(--surface2) !important; | |
| border: 1px solid var(--border) !important; | |
| color: var(--text) !important; | |
| border-radius: 8px !important; | |
| } | |
| input:focus, select:focus, textarea:focus { | |
| border-color: var(--teal) !important; | |
| outline: none !important; | |
| box-shadow: 0 0 0 2px rgba(0, 212, 170, 0.15) !important; | |
| } | |
| /* Buttons */ | |
| .btn-primary { | |
| background: var(--teal) !important; | |
| color: #0d0f14 !important; | |
| font-weight: 700 !important; | |
| font-family: 'DM Sans', sans-serif !important; | |
| border: none !important; | |
| border-radius: 8px !important; | |
| padding: 12px 28px !important; | |
| font-size: 14px !important; | |
| cursor: pointer !important; | |
| transition: opacity 0.2s ease !important; | |
| width: 100% !important; | |
| } | |
| .btn-primary:hover { opacity: 0.88 !important; } | |
| .btn-secondary { | |
| background: var(--surface2) !important; | |
| color: var(--text) !important; | |
| font-weight: 600 !important; | |
| font-family: 'DM Sans', sans-serif !important; | |
| border: 1px solid var(--border) !important; | |
| border-radius: 8px !important; | |
| padding: 10px 20px !important; | |
| font-size: 13px !important; | |
| cursor: pointer !important; | |
| width: 100% !important; | |
| } | |
| /* Progress bar */ | |
| .progress-bar { | |
| background: var(--teal) !important; | |
| } | |
| /* Status box */ | |
| .status-box { | |
| background: var(--surface); | |
| border: 1px solid var(--border); | |
| border-radius: 8px; | |
| padding: 12px 16px; | |
| font-family: 'JetBrains Mono', monospace; | |
| font-size: 12px; | |
| color: var(--muted); | |
| min-height: 48px; | |
| } | |
| /* Dataframe */ | |
| .dataframe { | |
| background: var(--surface) !important; | |
| border: 1px solid var(--border) !important; | |
| border-radius: 10px !important; | |
| overflow: hidden !important; | |
| } | |
| .dataframe th { | |
| background: var(--surface2) !important; | |
| color: var(--muted) !important; | |
| font-size: 11px !important; | |
| text-transform: uppercase !important; | |
| letter-spacing: 0.06em !important; | |
| border-bottom: 1px solid var(--border) !important; | |
| padding: 10px 12px !important; | |
| } | |
| .dataframe td { | |
| color: var(--text) !important; | |
| font-family: 'JetBrains Mono', monospace !important; | |
| font-size: 12px !important; | |
| padding: 9px 12px !important; | |
| border-bottom: 1px solid var(--border) !important; | |
| } | |
| /* Section labels */ | |
| .section-label { | |
| font-size: 11px; | |
| font-weight: 600; | |
| text-transform: uppercase; | |
| letter-spacing: 0.08em; | |
| color: var(--muted); | |
| margin-bottom: 8px; | |
| } | |
| /* Preset radio */ | |
| .preset-radio label { | |
| display: inline-flex !important; | |
| align-items: center !important; | |
| gap: 6px !important; | |
| padding: 7px 14px !important; | |
| border: 1px solid var(--border) !important; | |
| border-radius: 6px !important; | |
| cursor: pointer !important; | |
| font-size: 13px !important; | |
| transition: all 0.15s ease !important; | |
| } | |
| .preset-radio label:has(input:checked) { | |
| border-color: var(--teal) !important; | |
| color: var(--teal) !important; | |
| background: rgba(0,212,170,0.07) !important; | |
| } | |
| /* Accordion */ | |
| .accordion { | |
| border: 1px solid var(--border) !important; | |
| border-radius: 10px !important; | |
| background: var(--surface) !important; | |
| } | |
| /* Dividers */ | |
| hr { border-color: var(--border) !important; } | |
| /* Refresh button */ | |
| .refresh-row { margin-bottom: 20px; } | |
| """ | |
| # --------------------------------------------------------------------------- | |
| # Build Gradio UI | |
| # --------------------------------------------------------------------------- | |
| def build_ui(): | |
| with gr.Blocks(css=CUSTOM_CSS, title="🎯 Sniper Model Evaluator") as demo: | |
| # ---- Header ---- | |
| gr.HTML(""" | |
| <div class="sniper-header"> | |
| <h1>🎯 Sniper Model Evaluator</h1> | |
| <p>Asymmetric precision stock signal models — backtesting & diagnostic suite</p> | |
| </div>""") | |
| # ---- Registry row ---- | |
| with gr.Row(elem_classes="refresh-row"): | |
| with gr.Column(scale=5): | |
| registry_status = gr.Textbox( | |
| label="Registry Status", | |
| value="Click 'Refresh Models' to discover available model runs.", | |
| interactive=False, | |
| elem_classes="status-box", | |
| lines=2, | |
| ) | |
| with gr.Column(scale=1): | |
| refresh_btn = gr.Button("🔄 Refresh Models", elem_classes="btn-secondary") | |
| # ---- Tabs ---- | |
| with gr.Tabs(): | |
| # ============================================================== | |
| # TAB 1 — BACKTESTER | |
| # ============================================================== | |
| with gr.Tab("📈 Backtester"): | |
| with gr.Row(): | |
| # ---- LEFT: Parameter panel ---- | |
| with gr.Column(scale=1, min_width=310): | |
| gr.HTML('<div class="section-label">Model</div>') | |
| bt_model = gr.Dropdown( | |
| choices=[], value=None, label="Model / Run", | |
| info="Auto-discovered from Arkm20/sniper-model-*", | |
| ) | |
| gr.HTML('<div class="section-label" style="margin-top:16px">Preset</div>') | |
| preset_radio = gr.Radio( | |
| choices=["Conservative", "Balanced", "Aggressive", "Paper test"], | |
| value="Balanced", | |
| label="", | |
| elem_classes="preset-radio", | |
| ) | |
| gr.HTML('<div class="section-label" style="margin-top:16px">Date Range</div>') | |
| with gr.Row(): | |
| bt_start = gr.Textbox(value="2021-01-01", label="Start date", max_lines=1) | |
| bt_end = gr.Textbox(value="2024-12-31", label="End date", max_lines=1) | |
| gr.HTML('<div class="section-label" style="margin-top:16px">Capital & Signal</div>') | |
| initial_cash = gr.Number(value=10000, label="Initial capital ($)", minimum=1000, maximum=1_000_000, step=1000) | |
| conviction_thr = gr.Slider(0.20, 0.90, value=0.50, step=0.01, | |
| label="Conviction threshold", | |
| info="Minimum calibrated probability to enter a trade") | |
| gr.HTML('<div class="section-label" style="margin-top:16px">Position Controls</div>') | |
| max_positions = gr.Slider(1, 20, value=5, step=1, label="Max concurrent positions") | |
| sl_mult = gr.Slider(0.25, 3.0, value=0.5, step=0.05, | |
| label="Stop loss (× ATR)", | |
| info="Mirrors trainer's SL_MULTIPLIER") | |
| pt_mult = gr.Slider(1.0, 6.0, value=3.0, step=0.1, | |
| label="Profit target (× ATR)", | |
| info="Mirrors trainer's PT_MULTIPLIER") | |
| horizon_days = gr.Slider(5, 60, value=15, step=1, | |
| label="Max hold horizon (days)") | |
| cooldown_days = gr.Slider(0, 10, value=2, step=1, | |
| label="Cooldown after exit (days)") | |
| gr.HTML('<div class="section-label" style="margin-top:16px">Sizing</div>') | |
| sizing_mode = gr.Radio( | |
| ["Volatility-adjusted", "Equal-weight"], | |
| value="Volatility-adjusted", | |
| label="Position sizing mode", | |
| ) | |
| risk_fraction = gr.Slider(0.5, 5.0, value=2.0, step=0.1, | |
| label="Risk % of NAV per trade", | |
| info="Used in volatility-adjusted mode only") | |
| with gr.Accordion("Advanced", open=False): | |
| transaction_pct = gr.Slider(0.0, 0.5, value=0.1, step=0.01, | |
| label="Transaction cost (%)", | |
| info="Applied on entry and exit") | |
| account_mode = gr.Radio( | |
| ["Compound", "Realistic"], | |
| value="Compound", | |
| label="Account mode", | |
| info="Realistic withdraws a fraction of profits to personal cash", | |
| ) | |
| withdrawal_fraction = gr.Slider(5, 50, value=20, step=5, | |
| label="Profit withdrawal % (Realistic mode)") | |
| use_regime = gr.Checkbox(value=True, label="Auto-route to regime models") | |
| benchmark = gr.Dropdown(["SPY", "QQQ", "None"], value="SPY", label="Benchmark overlay") | |
| use_confluence = gr.Checkbox(value=False, label="Confluence filter") | |
| min_confluence = gr.Slider(1, 10, value=3, step=1, label="Min confluence score") | |
| gr.HTML('<div class="section-label" style="margin-top:16px">Ticker Universe</div>') | |
| ticker_preset = gr.Dropdown( | |
| choices=list(TICKER_UNIVERSE_PRESETS.keys()), | |
| value="All (~800 tickers)", | |
| label="Universe", | |
| ) | |
| custom_tickers = gr.Textbox( | |
| label="Custom tickers (space or comma separated)", | |
| placeholder="AAPL MSFT NVDA TSLA ...", | |
| visible=False, lines=3, | |
| ) | |
| bt_run_btn = gr.Button("▶ Run Backtest", elem_classes="btn-primary", variant="primary") | |
| bt_status = gr.Textbox(label="Status", value="", interactive=False, | |
| elem_classes="status-box", lines=1) | |
| # ---- RIGHT: Results panel ---- | |
| with gr.Column(scale=3): | |
| bt_metrics_html = gr.HTML( | |
| "<div style='color:#7a7f94;padding:20px;font-family:DM Sans,sans-serif'>" | |
| "Configure parameters and run a backtest to see results.</div>" | |
| ) | |
| bt_nav_plot = gr.Plot(label="") | |
| with gr.Row(): | |
| bt_monthly_plot = gr.Plot(label="") | |
| bt_exit_plot = gr.Plot(label="") | |
| bt_dist_plot = gr.Plot(label="") | |
| with gr.Accordion("Trade Log", open=False): | |
| bt_trades_df = gr.Dataframe( | |
| headers=["Ticker","Entry Date","Exit Date","Exit Reason", | |
| "Entry Price","Exit Price","Return %","Profit $","Days Held"], | |
| label="", wrap=False, | |
| ) | |
| with gr.Accordion("Export", open=False): | |
| with gr.Row(): | |
| bt_csv_out = gr.Textbox(label="Trades CSV", lines=6, interactive=False) | |
| bt_json_out = gr.Textbox(label="Metrics JSON", lines=6, interactive=False) | |
| # ============================================================== | |
| # TAB 2 — EVALUATOR | |
| # ============================================================== | |
| with gr.Tab("🔬 Model Evaluator"): | |
| with gr.Row(): | |
| # ---- LEFT: Config ---- | |
| with gr.Column(scale=1, min_width=310): | |
| gr.HTML('<div class="section-label">Models to Evaluate</div>') | |
| eval_models = gr.Dropdown( | |
| choices=[], value=None, multiselect=True, | |
| label="Model runs (select one or more for comparison)", | |
| ) | |
| gr.HTML('<div class="section-label" style="margin-top:16px">Evaluation Range</div>') | |
| with gr.Row(): | |
| eval_start = gr.Textbox(value="2019-01-01", label="Start date", max_lines=1) | |
| eval_end = gr.Textbox(value="2020-12-31", label="End date", max_lines=1) | |
| gr.HTML('<div class="section-label" style="margin-top:16px">Label Parameters</div>') | |
| gr.HTML('<p style="color:#7a7f94;font-size:12px;margin:0 0 8px">Match these to the model\'s training config for accurate scoring.</p>') | |
| eval_pt = gr.Slider(1.0, 6.0, value=3.0, step=0.1, | |
| label="Profit target (× ATR)", | |
| info="Should match PT_MULTIPLIER in training") | |
| eval_sl = gr.Slider(0.1, 3.0, value=0.5, step=0.05, | |
| label="Stop loss (× ATR)", | |
| info="Should match SL_MULTIPLIER in training") | |
| eval_horizon = gr.Slider(5, 60, value=15, step=1, | |
| label="Horizon (days)", | |
| info="Should match HORIZON in training") | |
| gr.HTML('<div class="section-label" style="margin-top:16px">Ticker Universe</div>') | |
| eval_ticker_preset = gr.Dropdown( | |
| choices=list(TICKER_UNIVERSE_PRESETS.keys()), | |
| value="S&P 500 subset (200)", | |
| label="Universe", | |
| info="Smaller universe = faster evaluation", | |
| ) | |
| eval_custom_tickers = gr.Textbox( | |
| label="Custom tickers", | |
| placeholder="AAPL MSFT NVDA ...", | |
| visible=False, lines=3, | |
| ) | |
| with gr.Accordion("Scoring Weights", open=False): | |
| gr.HTML('<p style="color:#7a7f94;font-size:12px;margin:0 0 10px">Adjust how each dimension contributes to the overall score. Values are relative.</p>') | |
| w_disc = gr.Slider(0, 100, value=20, step=5, label="Discrimination") | |
| w_feat = gr.Slider(0, 100, value=20, step=5, label="Feature Health") | |
| w_stab = gr.Slider(0, 100, value=15, step=5, label="Signal Stability") | |
| w_cal = gr.Slider(0, 100, value=15, step=5, label="Calibration") | |
| w_reg = gr.Slider(0, 100, value=15, step=5, label="Regime Robustness") | |
| w_asym = gr.Slider(0, 100, value=15, step=5, label="Asymmetry Capture") | |
| eval_run_btn = gr.Button("▶ Run Evaluation", elem_classes="btn-primary", variant="primary") | |
| eval_status = gr.Textbox(label="Status", value="", interactive=False, | |
| elem_classes="status-box", lines=1) | |
| # ---- RIGHT: Results ---- | |
| with gr.Column(scale=3): | |
| eval_grade_html = gr.HTML( | |
| "<div style='color:#7a7f94;padding:20px;font-family:DM Sans,sans-serif'>" | |
| "Select models and run evaluation to see the scorecard.</div>" | |
| ) | |
| eval_dim_table = gr.HTML("") | |
| with gr.Row(): | |
| eval_radar = gr.Plot(label="") | |
| eval_reliability = gr.Plot(label="") | |
| with gr.Row(): | |
| eval_regime = gr.Plot(label="") | |
| eval_psi_chart = gr.Plot(label="") | |
| eval_compare = gr.Plot(label="") | |
| with gr.Accordion("Feature PSI Table", open=False): | |
| eval_psi_table = gr.Dataframe( | |
| headers=["Feature", "NaN Rate", "Inf Rate", "PSI", "Status"], | |
| label="", wrap=False, | |
| ) | |
| with gr.Accordion("Warnings & Flags", open=False): | |
| eval_flags_html = gr.HTML("") | |
| # --------------------------------------------------------------------------- | |
| # Event wiring | |
| # --------------------------------------------------------------------------- | |
| # Refresh registry → update both dropdowns | |
| refresh_btn.click( | |
| fn=_refresh_registry, | |
| outputs=[bt_model, eval_models, registry_status], | |
| ) | |
| # Preset → fill sliders | |
| preset_radio.change( | |
| fn=apply_preset, | |
| inputs=[preset_radio], | |
| outputs=[conviction_thr, max_positions, sl_mult, pt_mult, | |
| horizon_days, cooldown_days, sizing_mode, risk_fraction, account_mode], | |
| ) | |
| # Show custom ticker box when "Custom" selected | |
| def _toggle_custom_bt(val): | |
| return gr.Textbox(visible=(val == "Custom (paste tickers)")) | |
| def _toggle_custom_eval(val): | |
| return gr.Textbox(visible=(val == "Custom (paste tickers)")) | |
| ticker_preset.change(fn=_toggle_custom_bt, inputs=[ticker_preset], outputs=[custom_tickers]) | |
| eval_ticker_preset.change(fn=_toggle_custom_eval, inputs=[eval_ticker_preset], outputs=[eval_custom_tickers]) | |
| # Run backtest | |
| bt_run_btn.click( | |
| fn=run_backtester, | |
| inputs=[ | |
| bt_model, bt_start, bt_end, initial_cash, | |
| conviction_thr, max_positions, | |
| sl_mult, pt_mult, horizon_days, cooldown_days, | |
| sizing_mode, risk_fraction, transaction_pct, | |
| account_mode, withdrawal_fraction, | |
| use_regime, benchmark, | |
| use_confluence, min_confluence, | |
| ticker_preset, custom_tickers, | |
| ], | |
| outputs=[ | |
| bt_metrics_html, | |
| bt_nav_plot, bt_monthly_plot, bt_exit_plot, bt_dist_plot, | |
| bt_trades_df, | |
| bt_csv_out, bt_json_out, | |
| bt_status, | |
| ], | |
| ) | |
| # Run evaluator | |
| eval_run_btn.click( | |
| fn=run_evaluator, | |
| inputs=[ | |
| eval_models, eval_start, eval_end, | |
| eval_ticker_preset, eval_custom_tickers, | |
| eval_pt, eval_sl, eval_horizon, | |
| w_disc, w_feat, w_stab, w_cal, w_reg, w_asym, | |
| ], | |
| outputs=[ | |
| eval_grade_html, eval_dim_table, | |
| eval_radar, eval_reliability, | |
| eval_regime, eval_psi_chart, | |
| eval_compare, | |
| eval_psi_table, | |
| eval_flags_html, | |
| eval_status, | |
| ], | |
| ) | |
| # Auto-refresh on load | |
| demo.load(fn=_refresh_registry, outputs=[bt_model, eval_models, registry_status]) | |
| return demo | |
| # --------------------------------------------------------------------------- | |
| # Entry point | |
| # --------------------------------------------------------------------------- | |
| if __name__ == "__main__": | |
| app = build_ui() | |
| app.launch( | |
| server_name="0.0.0.0", | |
| server_port=7860, | |
| show_error=True, | |
| ) |