Spaces:
Sleeping
Sleeping
Deploy app code
Browse files- .DS_Store +0 -0
- app.py +112 -0
- artifacts/.DS_Store +0 -0
- artifacts/best_model.pt +3 -0
- requirements.txt +9 -0
- src/.DS_Store +0 -0
- src/__init__.py +0 -0
- src/__pycache__/__init__.cpython-311.pyc +0 -0
- src/__pycache__/engine.cpython-311.pyc +0 -0
- src/__pycache__/models.cpython-311.pyc +0 -0
- src/engine.py +104 -0
- src/models.py +88 -0
- src/utils.py +0 -0
.DS_Store
ADDED
|
Binary file (6.15 kB). View file
|
|
|
app.py
ADDED
|
@@ -0,0 +1,112 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import gradio as gr
|
| 2 |
+
import torch
|
| 3 |
+
import numpy as np
|
| 4 |
+
import pandas as pd
|
| 5 |
+
import plotly.graph_objects as go
|
| 6 |
+
from datasets import load_dataset
|
| 7 |
+
from src.models import get_model
|
| 8 |
+
from src.engine import quantize_model
|
| 9 |
+
|
| 10 |
+
# --- LOAD DATASET FROM HUGGING FACE ---
|
| 11 |
+
# Replace with your actual username/dataset name
|
| 12 |
+
DATASET_NAME = "aayushkrm/wunder-fund-hft-data"
|
| 13 |
+
try:
|
| 14 |
+
print("Loading dataset from Hugging Face...")
|
| 15 |
+
# Load first 1% just for demo speed
|
| 16 |
+
dataset = load_dataset(DATASET_NAME, split="train[:1%]")
|
| 17 |
+
df = dataset.to_pandas()
|
| 18 |
+
# Ensure parquet structure matches
|
| 19 |
+
SEQ_IDS = df['seq_ix'].unique()
|
| 20 |
+
except Exception as e:
|
| 21 |
+
print(f"Could not load HF dataset: {e}")
|
| 22 |
+
# Fallback to dummy data
|
| 23 |
+
df = None
|
| 24 |
+
SEQ_IDS = [0, 1, 2]
|
| 25 |
+
|
| 26 |
+
def load_cached_model():
|
| 27 |
+
# Strategy ED Configuration: 32 Input, 256 Hidden, 6 Layers
|
| 28 |
+
# Note: Check if Strategy ED used 256 or 240.
|
| 29 |
+
# Your logs said Strategy ED was "SE-MISH-SWARM (Best-of-Best Fusion)".
|
| 30 |
+
# Usually this was Hidden=256.
|
| 31 |
+
model = get_model("winner", input_size=32, hidden_size=256, layers=6)
|
| 32 |
+
|
| 33 |
+
model_path = "artifacts/best_model.pt"
|
| 34 |
+
|
| 35 |
+
if os.path.exists(model_path):
|
| 36 |
+
try:
|
| 37 |
+
# Load the FP16/FP32 weights
|
| 38 |
+
state = torch.load(model_path, map_location='cpu')
|
| 39 |
+
|
| 40 |
+
# Strategy ED saved state_dict directly.
|
| 41 |
+
# We need to cast to float32 before quantization if they were saved as Half
|
| 42 |
+
state = {k: v.float() for k, v in state.items()}
|
| 43 |
+
|
| 44 |
+
model.load_state_dict(state)
|
| 45 |
+
print("✅ Loaded best_model.pt")
|
| 46 |
+
except Exception as e:
|
| 47 |
+
print(f"⚠️ Error loading model: {e}")
|
| 48 |
+
else:
|
| 49 |
+
print("⚠️ Model file not found, using random weights.")
|
| 50 |
+
|
| 51 |
+
# Quantize for inference (just to show off the capability)
|
| 52 |
+
model = quantize_model(model)
|
| 53 |
+
return model
|
| 54 |
+
|
| 55 |
+
MODEL = load_cached_model()
|
| 56 |
+
|
| 57 |
+
def inference(seq_id, steps_to_plot):
|
| 58 |
+
if df is not None:
|
| 59 |
+
# Extract sequence
|
| 60 |
+
seq_data = df[df['seq_ix'] == seq_id].sort_values('step_in_seq')
|
| 61 |
+
# Get raw features (0-31)
|
| 62 |
+
raw_values = seq_data[[str(i) for i in range(32)]].values.astype(np.float32)
|
| 63 |
+
# Normalize simple
|
| 64 |
+
mean = raw_values.mean(axis=0)
|
| 65 |
+
std = raw_values.std(axis=0) + 1e-6
|
| 66 |
+
norm_values = (raw_values - mean) / std
|
| 67 |
+
else:
|
| 68 |
+
# Dummy data
|
| 69 |
+
norm_values = np.random.randn(1000, 32).astype(np.float32)
|
| 70 |
+
|
| 71 |
+
# Run Inference
|
| 72 |
+
x = torch.tensor(norm_values).unsqueeze(0) # (1, 1000, 32)
|
| 73 |
+
|
| 74 |
+
with torch.no_grad():
|
| 75 |
+
# Get hidden states / predictions
|
| 76 |
+
# Note: The model returns (prediction, hidden), but prediction is next step
|
| 77 |
+
# We want to visualize the flow.
|
| 78 |
+
preds = []
|
| 79 |
+
h = None
|
| 80 |
+
# Slow loop for demo visualization
|
| 81 |
+
for t in range(min(len(x[0]), steps_to_plot)):
|
| 82 |
+
xt = x[:, t:t+1, :]
|
| 83 |
+
o, h = MODEL(xt, h)
|
| 84 |
+
preds.append(o.numpy()[0,0,0]) # Plot 1st feature dim
|
| 85 |
+
|
| 86 |
+
# Plotting
|
| 87 |
+
fig = go.Figure()
|
| 88 |
+
# Plot actual Feature 0
|
| 89 |
+
fig.add_trace(go.Scatter(y=norm_values[:steps_to_plot, 0], mode='lines', name='Actual Feature 0', line=dict(color='gray')))
|
| 90 |
+
# Plot predicted Feature 0
|
| 91 |
+
fig.add_trace(go.Scatter(y=preds, mode='lines', name='Predicted Feature 0', line=dict(color='green')))
|
| 92 |
+
|
| 93 |
+
fig.update_layout(title=f"Forecasting Sequence {seq_id}", xaxis_title="Time Step", yaxis_title="Normalized Value")
|
| 94 |
+
|
| 95 |
+
return fig
|
| 96 |
+
|
| 97 |
+
# --- GRADIO UI ---
|
| 98 |
+
with gr.Blocks(theme=gr.themes.Monochrome()) as demo:
|
| 99 |
+
gr.Markdown("# ⚡ Quant-Lab: HFT Sequence Modeling")
|
| 100 |
+
gr.Markdown(f"**Strategy ED (Rank 28):** SE-Mish-DeepResGRU (INT8 Quantized)")
|
| 101 |
+
|
| 102 |
+
with gr.Row():
|
| 103 |
+
seq_selector = gr.Dropdown(choices=list(SEQ_IDS[:20]), label="Select Market Sequence", value=SEQ_IDS[0])
|
| 104 |
+
step_slider = gr.Slider(minimum=50, maximum=1000, value=200, label="Steps to Visualize")
|
| 105 |
+
|
| 106 |
+
plot = gr.Plot(label="Forecast Visualization")
|
| 107 |
+
|
| 108 |
+
btn = gr.Button("Run Inference")
|
| 109 |
+
btn.click(inference, inputs=[seq_selector, step_slider], outputs=plot)
|
| 110 |
+
|
| 111 |
+
if __name__ == "__main__":
|
| 112 |
+
demo.launch()
|
artifacts/.DS_Store
ADDED
|
Binary file (6.15 kB). View file
|
|
|
artifacts/best_model.pt
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:5a5d0b30d0dd025f271aed3de1afaa7a968270428de27b0af288daa00e812e56
|
| 3 |
+
size 2480994
|
requirements.txt
ADDED
|
@@ -0,0 +1,9 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
torch>=2.0.0
|
| 2 |
+
numpy
|
| 3 |
+
pandas
|
| 4 |
+
scikit-learn
|
| 5 |
+
streamlit
|
| 6 |
+
gradio
|
| 7 |
+
plotly
|
| 8 |
+
tqdm
|
| 9 |
+
datasets
|
src/.DS_Store
ADDED
|
Binary file (6.15 kB). View file
|
|
|
src/__init__.py
ADDED
|
File without changes
|
src/__pycache__/__init__.cpython-311.pyc
ADDED
|
Binary file (156 Bytes). View file
|
|
|
src/__pycache__/engine.cpython-311.pyc
ADDED
|
Binary file (6.99 kB). View file
|
|
|
src/__pycache__/models.cpython-311.pyc
ADDED
|
Binary file (7.29 kB). View file
|
|
|
src/engine.py
ADDED
|
@@ -0,0 +1,104 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import torch
|
| 2 |
+
import torch.nn as nn
|
| 3 |
+
from torch.optim import Optimizer
|
| 4 |
+
import math
|
| 5 |
+
import os
|
| 6 |
+
|
| 7 |
+
# --- 1. RANGER OPTIMIZER (Full Implementation) ---
|
| 8 |
+
class Ranger(Optimizer):
|
| 9 |
+
def __init__(self, params, lr=1e-3, alpha=0.5, k=6, N_sma_threshhold=5, betas=(.95, 0.999), eps=1e-5, weight_decay=0):
|
| 10 |
+
defaults = dict(lr=lr, alpha=alpha, k=k, step_counter=0, betas=betas, N_sma_threshhold=N_sma_threshhold, eps=eps, weight_decay=weight_decay)
|
| 11 |
+
super().__init__(params, defaults)
|
| 12 |
+
self.N_sma_threshhold = N_sma_threshhold
|
| 13 |
+
self.alpha = alpha
|
| 14 |
+
self.k = k
|
| 15 |
+
self.radam_buffer = [[None,None,None] for ind in range(10)]
|
| 16 |
+
|
| 17 |
+
def __setstate__(self, state):
|
| 18 |
+
super().__setstate__(state)
|
| 19 |
+
|
| 20 |
+
def step(self, closure=None):
|
| 21 |
+
loss = None
|
| 22 |
+
if closure is not None: loss = closure()
|
| 23 |
+
for group in self.param_groups:
|
| 24 |
+
for p in group['params']:
|
| 25 |
+
if p.grad is None: continue
|
| 26 |
+
grad = p.grad.data.float()
|
| 27 |
+
if p.grad.is_sparse: raise RuntimeError('Ranger does not support sparse gradients')
|
| 28 |
+
p_data_fp32 = p.data.float()
|
| 29 |
+
state = self.state[p]
|
| 30 |
+
if len(state) == 0:
|
| 31 |
+
state['step'] = 0
|
| 32 |
+
state['exp_avg'] = torch.zeros_like(p_data_fp32)
|
| 33 |
+
state['exp_avg_sq'] = torch.zeros_like(p_data_fp32)
|
| 34 |
+
state['slow_buffer'] = torch.empty_like(p.data)
|
| 35 |
+
state['slow_buffer'].copy_(p.data)
|
| 36 |
+
else:
|
| 37 |
+
state['exp_avg'] = state['exp_avg'].type_as(p_data_fp32)
|
| 38 |
+
state['exp_avg_sq'] = state['exp_avg_sq'].type_as(p_data_fp32)
|
| 39 |
+
|
| 40 |
+
exp_avg, exp_avg_sq = state['exp_avg'], state['exp_avg_sq']
|
| 41 |
+
beta1, beta2 = group['betas']
|
| 42 |
+
exp_avg.mul_(beta1).add_(grad, alpha=1 - beta1)
|
| 43 |
+
exp_avg_sq.mul_(beta2).addcmul_(grad, grad, value=1 - beta2)
|
| 44 |
+
state['step'] += 1
|
| 45 |
+
|
| 46 |
+
buffered = self.radam_buffer[int(state['step'] % 10)]
|
| 47 |
+
if state['step'] == buffered[0]:
|
| 48 |
+
N_sma, step_size = buffered[1], buffered[2]
|
| 49 |
+
else:
|
| 50 |
+
buffered[0] = state['step']
|
| 51 |
+
beta2_t = beta2 ** state['step']
|
| 52 |
+
N_sma_max = 2 / (1 - beta2) - 1
|
| 53 |
+
N_sma = N_sma_max - 2 * state['step'] * beta2_t / (1 - beta2_t)
|
| 54 |
+
buffered[1] = N_sma
|
| 55 |
+
if N_sma >= self.N_sma_threshhold:
|
| 56 |
+
step_size = math.sqrt((1 - beta2_t) * (N_sma - 4) / (N_sma_max - 4) * (N_sma - 2) / N_sma * N_sma_max / (N_sma_max - 2)) / (1 - beta1 ** state['step'])
|
| 57 |
+
else:
|
| 58 |
+
step_size = 1.0 / (1 - beta1 ** state['step'])
|
| 59 |
+
buffered[2] = step_size
|
| 60 |
+
|
| 61 |
+
if group['weight_decay'] != 0:
|
| 62 |
+
p_data_fp32.add_(p_data_fp32, alpha=-group['weight_decay'] * group['lr'])
|
| 63 |
+
|
| 64 |
+
if N_sma >= self.N_sma_threshhold:
|
| 65 |
+
denom = exp_avg_sq.sqrt().add_(group['eps'])
|
| 66 |
+
p_data_fp32.addcdiv_(exp_avg, denom, value=-step_size * group['lr'])
|
| 67 |
+
else:
|
| 68 |
+
p_data_fp32.add_(exp_avg, alpha=-step_size * group['lr'])
|
| 69 |
+
|
| 70 |
+
p.data.copy_(p_data_fp32)
|
| 71 |
+
|
| 72 |
+
if state['step'] % group['k'] == 0:
|
| 73 |
+
slow_p = state['slow_buffer']
|
| 74 |
+
slow_p.add_(p.data - slow_p, alpha=self.alpha)
|
| 75 |
+
p.data.copy_(slow_p)
|
| 76 |
+
return loss
|
| 77 |
+
|
| 78 |
+
# --- 2. QUANTIZATION PIPELINE ---
|
| 79 |
+
def quantize_model(model):
|
| 80 |
+
"""
|
| 81 |
+
Applies PyTorch Dynamic INT8 Quantization.
|
| 82 |
+
"""
|
| 83 |
+
model.cpu().eval()
|
| 84 |
+
q_model = torch.quantization.quantize_dynamic(
|
| 85 |
+
model,
|
| 86 |
+
{torch.nn.Linear, torch.nn.GRU, torch.nn.LSTM},
|
| 87 |
+
dtype=torch.qint8
|
| 88 |
+
)
|
| 89 |
+
return q_model
|
| 90 |
+
|
| 91 |
+
def save_model(model, path):
|
| 92 |
+
torch.save(model.state_dict(), path)
|
| 93 |
+
|
| 94 |
+
def load_model(model_class, path, quantized=False):
|
| 95 |
+
model = model_class()
|
| 96 |
+
if quantized:
|
| 97 |
+
model = quantize_model(model)
|
| 98 |
+
# Weights_only=False is needed for quantized state dicts
|
| 99 |
+
state = torch.load(path, map_location='cpu', weights_only=False)
|
| 100 |
+
else:
|
| 101 |
+
state = torch.load(path, map_location='cpu')
|
| 102 |
+
|
| 103 |
+
model.load_state_dict(state)
|
| 104 |
+
return model
|
src/models.py
ADDED
|
@@ -0,0 +1,88 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import torch
|
| 2 |
+
import torch.nn as nn
|
| 3 |
+
import torch.nn.functional as F
|
| 4 |
+
import math
|
| 5 |
+
|
| 6 |
+
# --- ACTIVATIONS & BLOCKS ---
|
| 7 |
+
class Mish(nn.Module):
|
| 8 |
+
"""
|
| 9 |
+
Mish Activation: x * tanh(softplus(x)).
|
| 10 |
+
Proved superior to ReLU for deep RNNs in low-signal regimes.
|
| 11 |
+
"""
|
| 12 |
+
def forward(self, x):
|
| 13 |
+
return x * torch.tanh(F.softplus(x))
|
| 14 |
+
|
| 15 |
+
class SEBlock(nn.Module):
|
| 16 |
+
"""
|
| 17 |
+
Squeeze-and-Excitation Block for 1D sequences.
|
| 18 |
+
Acts as dynamic feature selection/gating.
|
| 19 |
+
"""
|
| 20 |
+
def __init__(self, channel, reduction=4):
|
| 21 |
+
super().__init__()
|
| 22 |
+
self.fc = nn.Sequential(
|
| 23 |
+
nn.Linear(channel, channel // reduction, bias=False),
|
| 24 |
+
nn.ReLU(inplace=True),
|
| 25 |
+
nn.Linear(channel // reduction, channel, bias=False),
|
| 26 |
+
nn.Sigmoid()
|
| 27 |
+
)
|
| 28 |
+
def forward(self, x):
|
| 29 |
+
y = self.fc(x)
|
| 30 |
+
return x * y
|
| 31 |
+
|
| 32 |
+
# --- 1. THE WINNER: SE-Mish-DeepResGRU ---
|
| 33 |
+
class PreNormGRUCell(nn.Module):
|
| 34 |
+
def __init__(self, dim, dropout):
|
| 35 |
+
super().__init__()
|
| 36 |
+
self.gru = nn.GRU(dim, dim, 1, batch_first=True)
|
| 37 |
+
self.drop = nn.Dropout(dropout)
|
| 38 |
+
self.ln = nn.LayerNorm(dim)
|
| 39 |
+
self.act = Mish()
|
| 40 |
+
def forward(self, x, h):
|
| 41 |
+
x_norm = self.ln(x)
|
| 42 |
+
o, h_new = self.gru(x_norm, h)
|
| 43 |
+
o = self.drop(self.act(o))
|
| 44 |
+
return x + o, h_new # Residual Connection
|
| 45 |
+
|
| 46 |
+
class SEMishGRU(nn.Module):
|
| 47 |
+
def __init__(self, input_size=32, hidden_size=240, layers=6):
|
| 48 |
+
super().__init__()
|
| 49 |
+
self.embed = nn.Linear(input_size, hidden_size)
|
| 50 |
+
self.se = SEBlock(hidden_size)
|
| 51 |
+
self.layers = nn.ModuleList([PreNormGRUCell(hidden_size, 0.15) for _ in range(layers)])
|
| 52 |
+
self.head = nn.Linear(hidden_size, 32)
|
| 53 |
+
self.final_ln = nn.LayerNorm(hidden_size)
|
| 54 |
+
self.layers_count = layers
|
| 55 |
+
|
| 56 |
+
def forward(self, x, h_list=None):
|
| 57 |
+
if h_list is None: h_list = [None] * self.layers_count
|
| 58 |
+
x = self.embed(x)
|
| 59 |
+
x = self.se(x)
|
| 60 |
+
new_h = []
|
| 61 |
+
for i, layer in enumerate(self.layers):
|
| 62 |
+
x, h = layer(x, h_list[i])
|
| 63 |
+
new_h.append(h)
|
| 64 |
+
x = self.final_ln(x)
|
| 65 |
+
return self.head(x), new_h
|
| 66 |
+
|
| 67 |
+
# --- 2. THE CHALLENGER: Transformer-Encoder (Failed due to overfitting) ---
|
| 68 |
+
class TransformerModel(nn.Module):
|
| 69 |
+
def __init__(self, input_size=32, hidden_size=256, layers=4):
|
| 70 |
+
super().__init__()
|
| 71 |
+
self.embed = nn.Linear(input_size, hidden_size)
|
| 72 |
+
enc_layer = nn.TransformerEncoderLayer(d_model=hidden_size, nhead=4, dim_feedforward=512, dropout=0.1, batch_first=True)
|
| 73 |
+
self.transformer = nn.TransformerEncoder(enc_layer, num_layers=layers)
|
| 74 |
+
self.head = nn.Linear(hidden_size, 32)
|
| 75 |
+
|
| 76 |
+
def forward(self, x):
|
| 77 |
+
x = self.embed(x)
|
| 78 |
+
x = self.transformer(x)
|
| 79 |
+
return self.head(x[:, -1, :]) # Predict on last step
|
| 80 |
+
|
| 81 |
+
# --- FACTORY ---
|
| 82 |
+
def get_model(name, **kwargs):
|
| 83 |
+
if name == "winner":
|
| 84 |
+
return SEMishGRU(**kwargs)
|
| 85 |
+
elif name == "transformer":
|
| 86 |
+
return TransformerModel(**kwargs)
|
| 87 |
+
else:
|
| 88 |
+
raise ValueError(f"Unknown model: {name}")
|
src/utils.py
ADDED
|
File without changes
|