hft-quant-lab / app.py
aayushkrm's picture
Update UI with Technical Report and Metrics
4ca194f
import gradio as gr
import torch
import numpy as np
import pandas as pd
import plotly.graph_objects as go
import os
import sys
from datasets import load_dataset
from src.models import get_model
from src.engine import quantize_model
# --- CONFIGURATION ---
DATASET_NAME = "aayushkrm/wunder-fund-hft-data"
# --- LOAD DATASET ---
print("Initializing App...")
try:
print("Loading FULL dataset from Hugging Face...")
dataset = load_dataset(DATASET_NAME, split="train")
df = dataset.to_pandas()
df['seq_ix'] = df['seq_ix'].astype(int)
unique_ids_np = df['seq_ix'].unique()
SEQ_IDS = sorted([int(x) for x in unique_ids_np.tolist()])
print(f"βœ… Loaded {len(df)} rows.")
except Exception as e:
print(f"⚠️ Could not load HF dataset: {e}")
df = None
SEQ_IDS = [0, 42, 100, 500]
# --- CACHED MODEL LOADER ---
def load_cached_model():
model = get_model("winner", input_size=32, hidden_size=256, layers=6)
# Check current directory first, then artifacts
if os.path.exists("best_model.pt"):
model_path = "best_model.pt"
else:
model_path = "artifacts/best_model.pt"
if os.path.exists(model_path):
try:
# FIX 1: Added weights_only=False to allow loading quantized/scripted models
state = torch.load(model_path, map_location='cpu', weights_only=False)
state = {k: v.float() for k, v in state.items()}
model.load_state_dict(state)
print(f"βœ… Loaded {model_path}")
except Exception as e:
print(f"⚠️ Error loading model: {e}")
else:
print("⚠️ Model file not found, using random weights.")
return quantize_model(model)
MODEL = load_cached_model()
def inference(seq_id_input, steps_input):
seq_id = int(seq_id_input)
steps_to_plot = int(steps_input)
if df is not None:
seq_data = df[df['seq_ix'] == seq_id].sort_values('step_in_seq')
raw_values = seq_data[[str(i) for i in range(32)]].values.astype(np.float32) if len(seq_data) > 0 else np.random.randn(1000, 32).astype(np.float32)
mean = raw_values.mean(axis=0)
std = raw_values.std(axis=0) + 1e-6
norm_values = (raw_values - mean) / std
else:
norm_values = np.random.randn(1000, 32).astype(np.float32)
x = torch.tensor(norm_values).unsqueeze(0)
with torch.no_grad():
preds = []
h = None
for t in range(min(len(x[0]), steps_to_plot)):
xt = x[:, t:t+1, :]
o, h = MODEL(xt, h)
preds.append(float(o.numpy()[0,0,0]))
# Plotly Chart
fig = go.Figure()
y_actual = [float(v) for v in norm_values[:steps_to_plot, 0].flatten()]
y_pred = preds
x_axis = list(range(len(y_actual)))
fig.add_trace(go.Scatter(x=x_axis, y=y_actual, mode='lines', name='Actual Market Feature', line=dict(color='gray', width=1, dash='dot')))
fig.add_trace(go.Scatter(x=x_axis, y=y_pred, mode='lines', name='Model Prediction (SE-Mish-GRU)', line=dict(color='#00ff00', width=2)))
fig.update_layout(
title=f"Sequence {seq_id} | Lookahead: 1 Step",
xaxis_title="Time Step (t)",
yaxis_title="Normalized Price/Volume Feature",
template="plotly_dark",
height=450,
hovermode="x unified",
margin=dict(l=20, r=20, t=50, b=20)
)
return fig
# ==========================================
# PROFESSIONAL UI LAYOUT
# ==========================================
custom_css = """
footer {visibility: hidden;}
.gradio-container {max-width: 900px !important;}
h1 {text-align: center;}
.metric-card {
background-color: #222;
padding: 15px;
border-radius: 10px;
text-align: center;
border: 1px solid #444;
}
"""
# FIX 2: Moved theme/css to launch() logic or kept here but cleaned up deprecated syntax below
with gr.Blocks(theme=gr.themes.Monochrome(), css=custom_css) as demo:
# HEADER
gr.Markdown("# ⚑ Quant-Lab: Efficient High-Frequency Trading")
gr.Markdown("### **Rank 28 (Top 1%)** in the Wunder Fund RNN Challenge")
gr.Markdown(
"A resource-constrained Deep Learning pipeline optimized for sub-millisecond CPU inference and strict 20MB memory footprints. "
"Developed by **Aayush Kumar**."
)
with gr.Tabs():
# --- TAB 1: INTERACTIVE DEMO ---
with gr.Tab("πŸ“ˆ Interactive Forecast"):
gr.Markdown("Select a market sequence to see the quantized **SE-Mish-GRU** perform real-time prediction.")
with gr.Row():
seq_selector = gr.Dropdown(choices=SEQ_IDS, label="Market Sequence ID", value=SEQ_IDS[0], scale=2)
step_slider = gr.Slider(minimum=50, maximum=1000, value=200, label="Steps to Predict", scale=2)
btn = gr.Button("⚑ Run Inference", variant="primary", scale=1)
plot = gr.Plot(label="Forecast Visualization")
# Preset Examples
gr.Examples(
examples=[
[int(SEQ_IDS[2]), 200],
[int(SEQ_IDS[10]), 500],
[int(SEQ_IDS[42]), 1000],
],
inputs=[seq_selector, step_slider],
label="Interesting Market Regimes"
)
btn.click(inference, inputs=[seq_selector, step_slider], outputs=plot)
# --- TAB 2: TECHNICAL REPORT ---
with gr.Tab("πŸ”¬ Technical Report & Architecture"):
gr.Markdown("""
### The Challenge
Predict high-frequency financial time-series $X_{t+1}$ given history $X_{0...t}$.
**Constraints:** Maximum solution size 20 MB, inference on 1 CPU core < 60 mins.
### The Winning Solution: SE-Mish-Swarm
After testing 42 architectures, the best solution was a **10-Model Quantized Ensemble** of Deep Residual GRUs with custom gating.
1. **Mish Activation:** Replaced standard `ReLU/Tanh` with `Mish`. This preserved gradient flow across 6-layer deep RNNs, which is critical for highly noisy financial data.
2. **Squeeze-and-Excitation (SE-Block):** Integrated 1D channel attention *before* the RNN loop. This allows the model to dynamically suppress "noisy" features based on the current market state.
3. **INT8 Dynamic Quantization:** Implemented Post-Training Quantization (PTQ) to reduce model sizes by **73%** (5MB $\\rightarrow$ 1.3MB) with 0.0% accuracy degradation.
4. **Variance Reduction:** Fitting 10 models inside the 20MB limit allowed for extreme ensemble robustness, surviving the Private Leaderboard shakeup.
""")
# --- TAB 3: ABLATION STUDY ---
with gr.Tab("πŸ“Š Ablation Study (Failures)"):
gr.Markdown("""
### What Failed and Why?
Treating this challenge as a research residency, I benchmarked modern SOTA architectures against traditional approaches.
| Architecture / Strategy | Outcome | Reason for Failure |
| :--- | :--- | :--- |
| **Transformers (XL, Performer)** | ❌ Failed | Context fragmentation and severe overfitting to noise. |
| **Mamba-2 (SSM)** | ❌ Failed | High training instability (NaNs) in floating-point financial data. |
| **WaveNet (Causal CNN)** | ❌ Failed | Inference timeout. Recomputing large receptive fields on CPU is $O(N)$ per step. |
| **Feature Diffing ($X_t - X_{t-1}$)** | ❌ Failed | Models lost spatial anchor points, failing to reconstruct absolute price levels. |
| **SE-Mish-GRU (Selected)** | βœ… **Winner** | Maximum inductive bias for time-series combined with scale compression. |
""")
# --- TAB 4: METRICS ---
with gr.Tab("⏱️ Performance Metrics"):
# FIX 3: Replaced deprecated 'gr.columns' and '.metric()' with 'gr.Row' and Markdown
with gr.Row():
with gr.Column():
gr.Markdown("### RΒ² Score (Private LB)\n# 0.3873\n**+14.5% vs Baseline**")
with gr.Column():
gr.Markdown("### Model Size (Quantized)\n# 1.8 MB\n**-3.2 MB (FP32)**")
with gr.Column():
gr.Markdown("### Inference Latency\n# < 1.0 ms\n**Per step, CPU**")
if __name__ == "__main__":
demo.launch()