Spaces:
Sleeping
Sleeping
| import gradio as gr | |
| import torch | |
| import numpy as np | |
| import pandas as pd | |
| import plotly.graph_objects as go | |
| import os | |
| import sys | |
| from datasets import load_dataset | |
| from src.models import get_model | |
| from src.engine import quantize_model | |
| # --- CONFIGURATION --- | |
| DATASET_NAME = "aayushkrm/wunder-fund-hft-data" | |
| # --- LOAD DATASET --- | |
| print("Initializing App...") | |
| try: | |
| print("Loading FULL dataset from Hugging Face...") | |
| dataset = load_dataset(DATASET_NAME, split="train") | |
| df = dataset.to_pandas() | |
| df['seq_ix'] = df['seq_ix'].astype(int) | |
| unique_ids_np = df['seq_ix'].unique() | |
| SEQ_IDS = sorted([int(x) for x in unique_ids_np.tolist()]) | |
| print(f"β Loaded {len(df)} rows.") | |
| except Exception as e: | |
| print(f"β οΈ Could not load HF dataset: {e}") | |
| df = None | |
| SEQ_IDS = [0, 42, 100, 500] | |
| # --- CACHED MODEL LOADER --- | |
| def load_cached_model(): | |
| model = get_model("winner", input_size=32, hidden_size=256, layers=6) | |
| # Check current directory first, then artifacts | |
| if os.path.exists("best_model.pt"): | |
| model_path = "best_model.pt" | |
| else: | |
| model_path = "artifacts/best_model.pt" | |
| if os.path.exists(model_path): | |
| try: | |
| # FIX 1: Added weights_only=False to allow loading quantized/scripted models | |
| state = torch.load(model_path, map_location='cpu', weights_only=False) | |
| state = {k: v.float() for k, v in state.items()} | |
| model.load_state_dict(state) | |
| print(f"β Loaded {model_path}") | |
| except Exception as e: | |
| print(f"β οΈ Error loading model: {e}") | |
| else: | |
| print("β οΈ Model file not found, using random weights.") | |
| return quantize_model(model) | |
| MODEL = load_cached_model() | |
| def inference(seq_id_input, steps_input): | |
| seq_id = int(seq_id_input) | |
| steps_to_plot = int(steps_input) | |
| if df is not None: | |
| seq_data = df[df['seq_ix'] == seq_id].sort_values('step_in_seq') | |
| raw_values = seq_data[[str(i) for i in range(32)]].values.astype(np.float32) if len(seq_data) > 0 else np.random.randn(1000, 32).astype(np.float32) | |
| mean = raw_values.mean(axis=0) | |
| std = raw_values.std(axis=0) + 1e-6 | |
| norm_values = (raw_values - mean) / std | |
| else: | |
| norm_values = np.random.randn(1000, 32).astype(np.float32) | |
| x = torch.tensor(norm_values).unsqueeze(0) | |
| with torch.no_grad(): | |
| preds = [] | |
| h = None | |
| for t in range(min(len(x[0]), steps_to_plot)): | |
| xt = x[:, t:t+1, :] | |
| o, h = MODEL(xt, h) | |
| preds.append(float(o.numpy()[0,0,0])) | |
| # Plotly Chart | |
| fig = go.Figure() | |
| y_actual = [float(v) for v in norm_values[:steps_to_plot, 0].flatten()] | |
| y_pred = preds | |
| x_axis = list(range(len(y_actual))) | |
| fig.add_trace(go.Scatter(x=x_axis, y=y_actual, mode='lines', name='Actual Market Feature', line=dict(color='gray', width=1, dash='dot'))) | |
| fig.add_trace(go.Scatter(x=x_axis, y=y_pred, mode='lines', name='Model Prediction (SE-Mish-GRU)', line=dict(color='#00ff00', width=2))) | |
| fig.update_layout( | |
| title=f"Sequence {seq_id} | Lookahead: 1 Step", | |
| xaxis_title="Time Step (t)", | |
| yaxis_title="Normalized Price/Volume Feature", | |
| template="plotly_dark", | |
| height=450, | |
| hovermode="x unified", | |
| margin=dict(l=20, r=20, t=50, b=20) | |
| ) | |
| return fig | |
| # ========================================== | |
| # PROFESSIONAL UI LAYOUT | |
| # ========================================== | |
| custom_css = """ | |
| footer {visibility: hidden;} | |
| .gradio-container {max-width: 900px !important;} | |
| h1 {text-align: center;} | |
| .metric-card { | |
| background-color: #222; | |
| padding: 15px; | |
| border-radius: 10px; | |
| text-align: center; | |
| border: 1px solid #444; | |
| } | |
| """ | |
| # FIX 2: Moved theme/css to launch() logic or kept here but cleaned up deprecated syntax below | |
| with gr.Blocks(theme=gr.themes.Monochrome(), css=custom_css) as demo: | |
| # HEADER | |
| gr.Markdown("# β‘ Quant-Lab: Efficient High-Frequency Trading") | |
| gr.Markdown("### **Rank 28 (Top 1%)** in the Wunder Fund RNN Challenge") | |
| gr.Markdown( | |
| "A resource-constrained Deep Learning pipeline optimized for sub-millisecond CPU inference and strict 20MB memory footprints. " | |
| "Developed by **Aayush Kumar**." | |
| ) | |
| with gr.Tabs(): | |
| # --- TAB 1: INTERACTIVE DEMO --- | |
| with gr.Tab("π Interactive Forecast"): | |
| gr.Markdown("Select a market sequence to see the quantized **SE-Mish-GRU** perform real-time prediction.") | |
| with gr.Row(): | |
| seq_selector = gr.Dropdown(choices=SEQ_IDS, label="Market Sequence ID", value=SEQ_IDS[0], scale=2) | |
| step_slider = gr.Slider(minimum=50, maximum=1000, value=200, label="Steps to Predict", scale=2) | |
| btn = gr.Button("β‘ Run Inference", variant="primary", scale=1) | |
| plot = gr.Plot(label="Forecast Visualization") | |
| # Preset Examples | |
| gr.Examples( | |
| examples=[ | |
| [int(SEQ_IDS[2]), 200], | |
| [int(SEQ_IDS[10]), 500], | |
| [int(SEQ_IDS[42]), 1000], | |
| ], | |
| inputs=[seq_selector, step_slider], | |
| label="Interesting Market Regimes" | |
| ) | |
| btn.click(inference, inputs=[seq_selector, step_slider], outputs=plot) | |
| # --- TAB 2: TECHNICAL REPORT --- | |
| with gr.Tab("π¬ Technical Report & Architecture"): | |
| gr.Markdown(""" | |
| ### The Challenge | |
| Predict high-frequency financial time-series $X_{t+1}$ given history $X_{0...t}$. | |
| **Constraints:** Maximum solution size 20 MB, inference on 1 CPU core < 60 mins. | |
| ### The Winning Solution: SE-Mish-Swarm | |
| After testing 42 architectures, the best solution was a **10-Model Quantized Ensemble** of Deep Residual GRUs with custom gating. | |
| 1. **Mish Activation:** Replaced standard `ReLU/Tanh` with `Mish`. This preserved gradient flow across 6-layer deep RNNs, which is critical for highly noisy financial data. | |
| 2. **Squeeze-and-Excitation (SE-Block):** Integrated 1D channel attention *before* the RNN loop. This allows the model to dynamically suppress "noisy" features based on the current market state. | |
| 3. **INT8 Dynamic Quantization:** Implemented Post-Training Quantization (PTQ) to reduce model sizes by **73%** (5MB $\\rightarrow$ 1.3MB) with 0.0% accuracy degradation. | |
| 4. **Variance Reduction:** Fitting 10 models inside the 20MB limit allowed for extreme ensemble robustness, surviving the Private Leaderboard shakeup. | |
| """) | |
| # --- TAB 3: ABLATION STUDY --- | |
| with gr.Tab("π Ablation Study (Failures)"): | |
| gr.Markdown(""" | |
| ### What Failed and Why? | |
| Treating this challenge as a research residency, I benchmarked modern SOTA architectures against traditional approaches. | |
| | Architecture / Strategy | Outcome | Reason for Failure | | |
| | :--- | :--- | :--- | | |
| | **Transformers (XL, Performer)** | β Failed | Context fragmentation and severe overfitting to noise. | | |
| | **Mamba-2 (SSM)** | β Failed | High training instability (NaNs) in floating-point financial data. | | |
| | **WaveNet (Causal CNN)** | β Failed | Inference timeout. Recomputing large receptive fields on CPU is $O(N)$ per step. | | |
| | **Feature Diffing ($X_t - X_{t-1}$)** | β Failed | Models lost spatial anchor points, failing to reconstruct absolute price levels. | | |
| | **SE-Mish-GRU (Selected)** | β **Winner** | Maximum inductive bias for time-series combined with scale compression. | | |
| """) | |
| # --- TAB 4: METRICS --- | |
| with gr.Tab("β±οΈ Performance Metrics"): | |
| # FIX 3: Replaced deprecated 'gr.columns' and '.metric()' with 'gr.Row' and Markdown | |
| with gr.Row(): | |
| with gr.Column(): | |
| gr.Markdown("### RΒ² Score (Private LB)\n# 0.3873\n**+14.5% vs Baseline**") | |
| with gr.Column(): | |
| gr.Markdown("### Model Size (Quantized)\n# 1.8 MB\n**-3.2 MB (FP32)**") | |
| with gr.Column(): | |
| gr.Markdown("### Inference Latency\n# < 1.0 ms\n**Per step, CPU**") | |
| if __name__ == "__main__": | |
| demo.launch() |