Spaces:
Sleeping
Sleeping
| import gradio as gr | |
| import torch | |
| import torch.nn as nn | |
| import torch.nn.functional as F | |
| import pandas as pd | |
| import numpy as np | |
| import pickle | |
| import os | |
| import json | |
| import math | |
| from typing import Union | |
| from deployment.config import load_model_config, get_input_size | |
| from fastapi import FastAPI | |
| from gradio.themes.base import Base | |
| # --- Helper function to get model device --- | |
| def get_model_device(model): | |
| return next(iter(model.parameters())).device | |
| # --- CausalConv1d (common to Hawk, Mamba2, xLSTM) --- | |
| class CausalConv1d(nn.Module): | |
| def __init__(self, hidden_size, kernel_size): | |
| super().__init__() | |
| self.hidden_size = hidden_size | |
| self.kernel_size = kernel_size | |
| self.conv = nn.Conv1d( | |
| hidden_size, hidden_size, kernel_size, groups=hidden_size, bias=True | |
| ) | |
| def init_state(self, batch_size: int, device: Union[torch.device, None] = None): | |
| if device is None: | |
| device = get_model_device(self) | |
| return torch.zeros( | |
| batch_size, self.hidden_size, self.kernel_size - 1, device=device | |
| ) | |
| def forward(self, x: torch.Tensor, state: torch.Tensor): | |
| x_with_state = torch.concat([state, x[:, :, None]], dim=-1) | |
| out = self.conv(x_with_state) | |
| new_state = x_with_state[:, :, 1:] | |
| return out.squeeze(-1), new_state | |
| # --- Hawk Model Definitions --- | |
| class RGLRU(nn.Module): | |
| def __init__(self, hidden_size: int, c: float = 8.0): | |
| super().__init__() | |
| self.hidden_size = hidden_size | |
| self.c = c | |
| self.input_gate = nn.Linear(hidden_size, hidden_size, bias=False) | |
| self.recurrence_gate = nn.Linear(hidden_size, hidden_size, bias=False) | |
| self._base_param = nn.Parameter(torch.empty(hidden_size)) | |
| nn.init.normal_(self._base_param, mean=0.0, std=1.0) # ok to be any real | |
| def forward(self, x_t: torch.Tensor, state: torch.Tensor) -> torch.Tensor: | |
| batch_size, hidden_size = x_t.shape | |
| assert hidden_size == self.hidden_size | |
| assert state.shape[0] == batch_size | |
| i_t = torch.sigmoid(self.input_gate(x_t)) | |
| r_t = torch.sigmoid(self.recurrence_gate(x_t)) # in (0,1) | |
| eps = 1e-4 | |
| base = torch.sigmoid(self._base_param).unsqueeze(0) # shape (1, hidden) | |
| base = base.clamp(min=eps, max=1.0 - eps) | |
| # exponent = c * r_t (positive) | |
| a_t = base ** ( | |
| self.c * r_t | |
| ) # shape (batch, hidden), safe because base in (0,1) | |
| # ensure numerical stability for sqrt | |
| one_minus_sq = 1.0 - a_t * a_t | |
| one_minus_sq = torch.clamp(one_minus_sq, min=0.0) | |
| multiplier = torch.sqrt(one_minus_sq) | |
| new_state = (state * a_t) + (multiplier * (i_t * x_t)) | |
| return new_state | |
| def init_state(self, batch_size: int, device: Union[torch.device, None] = None): | |
| if device is None: | |
| device = get_model_device(self) | |
| return torch.zeros(batch_size, self.hidden_size, device=device) | |
| class Hawk(nn.Module): | |
| def __init__(self, hidden_size: int, conv_kernel_size: int = 4): | |
| super().__init__() | |
| self.conv_kernel_size = conv_kernel_size | |
| self.hidden_size = hidden_size | |
| self.gate_proj = nn.Linear(hidden_size, hidden_size, bias=False) | |
| self.recurrent_proj = nn.Linear(hidden_size, hidden_size, bias=False) | |
| self.conv = CausalConv1d(hidden_size, conv_kernel_size) | |
| self.rglru = RGLRU(hidden_size) | |
| self.out_proj = nn.Linear(hidden_size, hidden_size, bias=False) | |
| def forward( | |
| self, x: torch.Tensor, state: tuple[torch.Tensor, torch.Tensor] | |
| ) -> tuple[torch.Tensor, list[torch.Tensor]]: | |
| conv_state, rglru_state = state | |
| batch_size, hidden_size = x.shape | |
| assert batch_size == conv_state.shape[0] == rglru_state.shape[0] | |
| assert self.hidden_size == hidden_size == rglru_state.shape[1] | |
| gate = F.gelu(self.gate_proj(x)) | |
| x = self.recurrent_proj(x) | |
| x, new_conv_state = self.conv(x, conv_state) | |
| new_rglru_state = self.rglru(x, rglru_state) | |
| gated = gate * new_rglru_state | |
| out = self.out_proj(gated) | |
| new_state = [new_conv_state, new_rglru_state] | |
| return out, new_state | |
| def init_state( | |
| self, batch_size: int, device: Union[torch.device, None] = None | |
| ) -> list[torch.Tensor]: | |
| return [ | |
| self.conv.init_state(batch_size, device), | |
| self.rglru.init_state(batch_size, device), | |
| ] | |
| class HawkPredictor(nn.Module): | |
| def __init__( | |
| self, | |
| input_size: int, | |
| hidden_size: int, | |
| num_layers: int = 2, | |
| conv_kernel_size: int = 4, | |
| dropout: float = 0.1, | |
| ): | |
| super().__init__() | |
| self.input_size = input_size | |
| self.hidden_size = hidden_size | |
| self.num_layers = num_layers | |
| self.input_proj = nn.Linear(input_size, hidden_size) | |
| self.input_norm = nn.LayerNorm(hidden_size) | |
| self.hawk_layers = nn.ModuleList( | |
| [Hawk(hidden_size, conv_kernel_size) for _ in range(num_layers)] | |
| ) | |
| self.layer_norms = nn.ModuleList( | |
| [nn.LayerNorm(hidden_size) for _ in range(num_layers)] | |
| ) | |
| self.dropout = nn.Dropout(dropout) | |
| self.output_head = nn.Sequential( | |
| nn.Linear(hidden_size, hidden_size // 2), | |
| nn.GELU(), | |
| nn.Dropout(dropout), | |
| nn.Linear(hidden_size // 2, 1), | |
| ) | |
| def forward(self, x: torch.Tensor, states=None): | |
| batch_size, seq_len, _ = x.shape | |
| device = x.device | |
| if states is None: | |
| states = [ | |
| layer.init_state(batch_size, device) for layer in self.hawk_layers | |
| ] | |
| x = self.input_proj(x) | |
| x = self.input_norm(x) | |
| outputs = [] | |
| for t in range(seq_len): | |
| x_t = x[:, t, :] | |
| new_states = [] | |
| for i, (hawk_layer, layer_norm) in enumerate( | |
| zip(self.hawk_layers, self.layer_norms) | |
| ): | |
| residual = x_t | |
| x_t, state = hawk_layer(x_t, states[i]) | |
| x_t = layer_norm(x_t + residual) | |
| x_t = self.dropout(x_t) | |
| new_states.append(state) | |
| states = new_states | |
| outputs.append(x_t) | |
| outputs = torch.stack(outputs, dim=1) | |
| predictions = self.output_head(outputs) | |
| return predictions, states | |
| # --- Mamba2 Model Definitions --- | |
| class Mamba2(nn.Module): | |
| def __init__( | |
| self, | |
| hidden_size: int, | |
| inner_size: Union[int, None] = None, | |
| head_size: int = 64, | |
| bc_head_size: int = 128, | |
| conv_kernel_size: int = 4, | |
| ): | |
| super().__init__() | |
| self.head_size = head_size | |
| self.bc_head_size = bc_head_size | |
| if inner_size is None: | |
| inner_size = 2 * hidden_size | |
| assert inner_size % head_size == 0 | |
| self.inner_size = inner_size | |
| self.num_heads = inner_size // head_size | |
| self.input_proj = nn.Linear(hidden_size, inner_size, bias=False) | |
| self.z_proj = nn.Linear(hidden_size, inner_size, bias=False) | |
| self.b_proj = nn.Linear(hidden_size, bc_head_size, bias=False) | |
| self.c_proj = nn.Linear(hidden_size, bc_head_size, bias=False) | |
| self.dt_proj = nn.Linear(hidden_size, self.num_heads, bias=True) | |
| self.input_conv = CausalConv1d(inner_size, conv_kernel_size) | |
| self.b_conv = CausalConv1d(bc_head_size, conv_kernel_size) | |
| self.c_conv = CausalConv1d(bc_head_size, conv_kernel_size) | |
| self.a = nn.Parameter(-torch.empty(self.num_heads).uniform_(1, 16)) | |
| self.d = nn.Parameter(torch.ones(self.num_heads)) | |
| self.norm = nn.RMSNorm(inner_size, eps=1e-5) | |
| self.out_proj = nn.Linear(inner_size, hidden_size, bias=False) | |
| def init_state(self, batch_size: int, device: Union[torch.device, None] = None): | |
| if device is None: | |
| device = get_model_device(self) | |
| conv_states = [ | |
| conv.init_state(batch_size, device) | |
| for conv in [self.input_conv, self.b_conv, self.c_conv] | |
| ] | |
| ssm_state = torch.zeros( | |
| batch_size, self.num_heads, self.head_size, self.bc_head_size, device=device | |
| ) | |
| return conv_states + [ssm_state] | |
| def forward(self, t, state): | |
| batch_size = t.shape[0] | |
| x = self.input_proj(t) | |
| z = self.z_proj(t) | |
| b = self.b_proj(t) | |
| c = self.c_proj(t) | |
| dt = self.dt_proj(t) | |
| x_conv_state, b_conv_state, c_conv_state, ssm_state = state | |
| x, x_conv_state = self.input_conv(x, x_conv_state) | |
| b, b_conv_state = self.b_conv(b, b_conv_state) | |
| c, c_conv_state = self.c_conv(c, c_conv_state) | |
| x = F.silu(x) | |
| b = F.silu(b) | |
| c = F.silu(c) | |
| x = x.view(batch_size, self.num_heads, self.head_size) | |
| dt = F.softplus(dt) | |
| decay = torch.exp(self.a[None] * dt) | |
| new_state_contrib = dt[:, :, None, None] * b[:, None, None] * x[:, :, :, None] | |
| ssm_state = decay[:, :, None, None] * ssm_state + new_state_contrib | |
| state_contrib = torch.einsum("bc,bnhc->bnh", c, ssm_state) | |
| y = state_contrib + self.d[None, :, None] * x | |
| y = y.view(batch_size, self.inner_size) | |
| y = y * F.silu(z) | |
| y = self.norm(y) | |
| output = self.out_proj(y) | |
| new_state = [x_conv_state, b_conv_state, c_conv_state, ssm_state] | |
| return output, new_state | |
| class Mamba2Predictor(nn.Module): | |
| def __init__( | |
| self, | |
| input_size: int, | |
| hidden_size: int, | |
| num_layers: int = 2, | |
| inner_size: Union[int, None] = None, | |
| head_size: int = 64, | |
| bc_head_size: int = 128, | |
| conv_kernel_size: int = 4, | |
| dropout: float = 0.1, | |
| ): | |
| super().__init__() | |
| self.input_size = input_size | |
| self.hidden_size = hidden_size | |
| self.num_layers = num_layers | |
| self.input_proj = nn.Linear(input_size, hidden_size) | |
| self.input_norm = nn.LayerNorm(hidden_size) | |
| self.mamba_layers = nn.ModuleList( | |
| [ | |
| Mamba2( | |
| hidden_size, | |
| inner_size=inner_size, | |
| head_size=head_size, | |
| bc_head_size=bc_head_size, | |
| conv_kernel_size=conv_kernel_size, | |
| ) | |
| for _ in range(num_layers) | |
| ] | |
| ) | |
| self.layer_norms = nn.ModuleList( | |
| [nn.LayerNorm(hidden_size) for _ in range(num_layers)] | |
| ) | |
| self.dropout = nn.Dropout(dropout) | |
| self.output_head = nn.Sequential( | |
| nn.Linear(hidden_size, hidden_size // 2), | |
| nn.GELU(), | |
| nn.Dropout(dropout), | |
| nn.Linear(hidden_size // 2, 1), | |
| ) | |
| def forward(self, x: torch.Tensor, states=None): | |
| batch_size, seq_len, _ = x.shape | |
| device = x.device | |
| if states is None: | |
| states = [ | |
| layer.init_state(batch_size, device) for layer in self.mamba_layers | |
| ] | |
| x = self.input_proj(x) | |
| x = self.input_norm(x) | |
| outputs = [] | |
| for t in range(seq_len): | |
| x_t = x[:, t, :] | |
| new_states = [] | |
| for i, (mamba_layer, layer_norm) in enumerate( | |
| zip(self.mamba_layers, self.layer_norms) | |
| ): | |
| residual = x_t | |
| x_t, state = mamba_layer(x_t, states[i]) | |
| x_t = layer_norm(x_t + residual) | |
| x_t = self.dropout(x_t) | |
| new_states.append(state) | |
| states = new_states | |
| outputs.append(x_t) | |
| outputs = torch.stack(outputs, dim=1) | |
| predictions = self.output_head(outputs) | |
| return predictions, states | |
| # --- xLSTM Model Definitions --- | |
| class MLSTMCell(nn.Module): | |
| def __init__(self, hidden_size: int, num_heads: int = 8): | |
| super().__init__() | |
| self.hidden_size = hidden_size | |
| self.num_heads = num_heads | |
| self.head_size = hidden_size // num_heads | |
| self.eps = 1e-6 | |
| self.igate_proj = nn.Linear(3 * hidden_size, num_heads, bias=True) | |
| self.fgate_proj = nn.Linear(3 * hidden_size, num_heads, bias=True) | |
| self.outnorm = nn.GroupNorm(num_groups=num_heads, num_channels=hidden_size) | |
| def forward(self, q: torch.Tensor, k: torch.Tensor, v: torch.Tensor, state): | |
| batch_size, hidden_size = q.shape | |
| cell_state, norm_state, max_state = state | |
| qkv_cat = torch.cat([q, k, v], dim=-1) | |
| igate_preact = self.igate_proj(qkv_cat) | |
| fgate_preact = self.fgate_proj(qkv_cat) | |
| q = q.view(batch_size, self.num_heads, self.head_size) | |
| k = k.view(batch_size, self.num_heads, self.head_size) | |
| v = v.view(batch_size, self.num_heads, self.head_size) | |
| log_f = torch.nn.functional.logsigmoid(fgate_preact) | |
| max_new = torch.maximum(igate_preact, max_state + log_f) | |
| i_gate = torch.exp(igate_preact - max_new) | |
| f_gate = torch.exp(log_f + max_state - max_new) | |
| k = k / math.sqrt(self.head_size) | |
| cell_new = ( | |
| f_gate[:, :, None, None] * cell_state | |
| + i_gate[:, :, None, None] * k[:, :, :, None] * v[:, :, None] | |
| ) | |
| norm_new = f_gate[:, :, None] * norm_state + i_gate[:, :, None] * k | |
| numerator = torch.einsum("bnh,bnhk->bnk", q, cell_new) | |
| qn_dotproduct = torch.einsum("bnh,bnh->bn", q, norm_new) | |
| max_val = torch.exp(-max_new) | |
| denominator = torch.maximum(qn_dotproduct.abs(), max_val) + self.eps | |
| out = numerator / denominator[:, :, None] | |
| out = self.outnorm(out.view(batch_size, self.hidden_size)) | |
| out = out.reshape(batch_size, self.hidden_size) | |
| return out, (cell_new, norm_new, max_new) | |
| def init_state(self, batch_size: int, device: torch.device): | |
| return ( | |
| torch.zeros( | |
| batch_size, | |
| self.num_heads, | |
| self.head_size, | |
| self.head_size, | |
| device=device, | |
| ), | |
| torch.zeros(batch_size, self.num_heads, self.head_size, device=device), | |
| torch.zeros(batch_size, self.num_heads, device=device), | |
| ) | |
| class BlockLinear(nn.Module): | |
| def __init__(self, num_blocks: int, hidden_size: int, bias: bool = True): | |
| super().__init__() | |
| self.num_blocks = num_blocks | |
| self.block_size = hidden_size // num_blocks | |
| self.hidden_size = hidden_size | |
| self.weight = nn.Parameter( | |
| torch.empty(num_blocks, self.block_size, self.block_size) | |
| ) | |
| nn.init.xavier_uniform_(self.weight) | |
| if bias: | |
| self.bias = nn.Parameter(torch.empty(self.hidden_size)) | |
| nn.init.zeros_(self.bias) | |
| else: | |
| self.bias = None | |
| def forward(self, x): | |
| batch_size = x.shape[0] | |
| assert x.shape[1] == self.hidden_size | |
| x = x.view(batch_size, self.num_blocks, self.block_size) | |
| out = torch.einsum("bnh,nkh->bnk", x, self.weight) | |
| out = out.reshape(batch_size, self.hidden_size) | |
| if self.bias is not None: | |
| out += self.bias | |
| return out | |
| class MLSTMBlock(nn.Module): | |
| def __init__( | |
| self, | |
| hidden_size: int, | |
| num_heads: int = 8, | |
| conv_kernel_size: int = 4, | |
| qkv_proj_block_size: int = 4, | |
| expand_factor: int = 2, | |
| ): | |
| super().__init__() | |
| self.hidden_size = hidden_size | |
| self.num_heads = num_heads | |
| self.inner_size = expand_factor * hidden_size | |
| self.norm = nn.LayerNorm(hidden_size, bias=False) | |
| self.x_proj = nn.Linear(hidden_size, self.inner_size, bias=False) | |
| self.gate_proj = nn.Linear(hidden_size, self.inner_size, bias=False) | |
| num_blocks = self.inner_size // qkv_proj_block_size | |
| self.q_proj = BlockLinear(num_blocks, self.inner_size, bias=False) | |
| self.k_proj = BlockLinear(num_blocks, self.inner_size, bias=False) | |
| self.v_proj = BlockLinear(num_blocks, self.inner_size, bias=False) | |
| self.conv1d = CausalConv1d(self.inner_size, kernel_size=conv_kernel_size) | |
| self.mlstm_cell = MLSTMCell(self.inner_size, num_heads) | |
| self.proj_down = nn.Linear(self.inner_size, hidden_size, bias=False) | |
| self.learnable_skip = nn.Parameter(torch.ones(self.inner_size)) | |
| def forward(self, x: torch.Tensor, state): | |
| conv_state, recurrent_state = state | |
| skip = x | |
| x = self.norm(x) | |
| x_mlstm = self.x_proj(x) | |
| x_gate = self.gate_proj(x) | |
| x_conv, new_conv_state = self.conv1d(x_mlstm, conv_state) | |
| x_mlstm_conv = F.silu(x_conv) | |
| q = self.q_proj(x_mlstm_conv) | |
| k = self.k_proj(x_mlstm_conv) | |
| v = self.v_proj(x_mlstm) | |
| mlstm_out, new_recurrent_state = self.mlstm_cell(q, k, v, recurrent_state) | |
| mlstm_out_skip = mlstm_out + (self.learnable_skip * x_mlstm_conv) | |
| h_state = mlstm_out_skip * F.silu(x_gate) | |
| y = self.proj_down(h_state) | |
| return y + skip, (new_conv_state, new_recurrent_state) | |
| def init_state(self, batch_size: int, device: torch.device): | |
| return ( | |
| self.conv1d.init_state(batch_size, device), | |
| self.mlstm_cell.init_state(batch_size, device), | |
| ) | |
| class SLSTMCell(nn.Module): | |
| def __init__(self, hidden_size: int, num_heads: int = 4): | |
| super().__init__() | |
| self.hidden_size = hidden_size | |
| self.num_heads = num_heads | |
| self.head_size = hidden_size // num_heads | |
| self.eps = 1e-6 | |
| def forward( | |
| self, i: torch.Tensor, f: torch.Tensor, z: torch.Tensor, o: torch.Tensor, state | |
| ): | |
| cell_state, norm_state, max_state = state | |
| log_f_plus_m = max_state + torch.nn.functional.logsigmoid(f) | |
| max_new = torch.maximum(i, log_f_plus_m) | |
| o_gate = torch.sigmoid(o) | |
| i_gate = torch.exp(i - max_new) | |
| f_gate = torch.exp(log_f_plus_m - max_new) | |
| cell_new = f_gate * cell_state + i_gate * torch.tanh(z) | |
| norm_new = f_gate * norm_state + i_gate | |
| y_new = o_gate * cell_new / (norm_new + self.eps) | |
| return y_new, (cell_new, norm_new, max_new) | |
| def init_state(self, batch_size: int, device: torch.device): | |
| return ( | |
| torch.zeros(batch_size, self.hidden_size, device=device), | |
| torch.zeros(batch_size, self.hidden_size, device=device), | |
| torch.zeros(batch_size, self.hidden_size, device=device) - float("inf"), | |
| ) | |
| class SLSTMBlock(nn.Module): | |
| def __init__(self, hidden_size: int, num_heads: int = 4, conv_kernel_size: int = 4): | |
| super().__init__() | |
| self.hidden_size = hidden_size | |
| self.num_heads = num_heads | |
| self.norm = nn.LayerNorm(hidden_size, bias=False) | |
| self.conv1d = CausalConv1d(hidden_size, kernel_size=conv_kernel_size) | |
| self.igate_input = BlockLinear(num_heads, hidden_size, bias=False) | |
| self.fgate_input = BlockLinear(num_heads, hidden_size, bias=False) | |
| self.zgate_input = BlockLinear(num_heads, hidden_size, bias=False) | |
| self.ogate_input = BlockLinear(num_heads, hidden_size, bias=False) | |
| self.igate_state = BlockLinear(num_heads, hidden_size) | |
| self.fgate_state = BlockLinear(num_heads, hidden_size) | |
| self.zgate_state = BlockLinear(num_heads, hidden_size) | |
| self.ogate_state = BlockLinear(num_heads, hidden_size) | |
| self.slstm_cell = SLSTMCell(hidden_size, num_heads) | |
| self.group_norm = nn.GroupNorm(num_groups=num_heads, num_channels=hidden_size) | |
| def forward(self, x: torch.Tensor, state): | |
| conv_state, recurrent_state, slstm_state = state | |
| skip = x | |
| x = self.norm(x) | |
| x_conv, new_conv_state = self.conv1d(x, conv_state) | |
| x_conv_act = F.silu(x_conv) | |
| i = self.igate_input(x_conv_act) + self.igate_state(recurrent_state) | |
| f = self.fgate_input(x_conv_act) + self.fgate_state(recurrent_state) | |
| z = self.zgate_input(x) + self.zgate_state(recurrent_state) | |
| o = self.ogate_input(x) + self.ogate_state(recurrent_state) | |
| new_recurrent_state, new_slstm_state = self.slstm_cell(i, f, z, o, slstm_state) | |
| slstm_out = self.group_norm(new_recurrent_state) | |
| return slstm_out + skip, (new_conv_state, new_recurrent_state, new_slstm_state) | |
| def init_state(self, batch_size: int, device: torch.device): | |
| return ( | |
| self.conv1d.init_state(batch_size, device), | |
| torch.zeros(batch_size, self.hidden_size, device=device), | |
| self.slstm_cell.init_state(batch_size, device), | |
| ) | |
| class xLSTMPredictor(nn.Module): | |
| def __init__( | |
| self, | |
| input_size: int, | |
| hidden_size: int, | |
| num_layers: int = 2, | |
| block_type: str = "mlstm", | |
| num_heads: int = 8, | |
| conv_kernel_size: int = 4, | |
| dropout: float = 0.1, | |
| expand_factor: int = 2, | |
| ): | |
| super().__init__() | |
| self.input_size = input_size | |
| self.hidden_size = hidden_size | |
| self.num_layers = num_layers | |
| self.block_type = block_type | |
| self.input_proj = nn.Linear(input_size, hidden_size) | |
| self.input_norm = nn.LayerNorm(hidden_size) | |
| self.xlstm_layers = nn.ModuleList() | |
| for _ in range(num_layers): | |
| if block_type == "mlstm": | |
| self.xlstm_layers.append( | |
| MLSTMBlock( | |
| hidden_size=hidden_size, | |
| num_heads=num_heads, | |
| conv_kernel_size=conv_kernel_size, | |
| expand_factor=expand_factor, | |
| ) | |
| ) | |
| elif block_type == "slstm": | |
| self.xlstm_layers.append( | |
| SLSTMBlock( | |
| hidden_size=hidden_size, | |
| num_heads=num_heads, | |
| conv_kernel_size=conv_kernel_size, | |
| ) | |
| ) | |
| else: | |
| raise ValueError(f"Unknown block type: {block_type}") | |
| self.dropout = nn.Dropout(dropout) | |
| self.output_head = nn.Sequential( | |
| nn.Linear(hidden_size, hidden_size // 2), | |
| nn.GELU(), | |
| nn.Dropout(dropout), | |
| nn.Linear(hidden_size // 2, 1), | |
| ) | |
| def forward(self, x: torch.Tensor, states=None): | |
| batch_size, seq_len, _ = x.shape | |
| device = x.device | |
| if states is None: | |
| states = [ | |
| layer.init_state(batch_size, device) for layer in self.xlstm_layers | |
| ] | |
| x = self.input_proj(x) | |
| x = self.input_norm(x) | |
| outputs = [] | |
| for t in range(seq_len): | |
| x_t = x[:, t, :] | |
| new_states = [] | |
| for i, xlstm_layer in enumerate(self.xlstm_layers): | |
| x_t, state = xlstm_layer(x_t, states[i]) | |
| x_t = self.dropout(x_t) | |
| new_states.append(state) | |
| states = new_states | |
| outputs.append(x_t) | |
| outputs = torch.stack(outputs, dim=1) | |
| predictions = self.output_head(outputs) | |
| return predictions, states | |
| # --- Load Models --- | |
| MODELS_DIR = "deployment/models" | |
| models = {} | |
| # Load PyTorch models | |
| device = torch.device("cuda" if torch.cuda.is_available() else "cpu") | |
| # Load Hawk model | |
| hawk_config = load_model_config("hawk", models_dir="deployment/models") | |
| input_size_hawk = get_input_size(hawk_config) | |
| hawk_model = HawkPredictor( | |
| input_size=input_size_hawk, | |
| hidden_size=hawk_config["hidden_size"], | |
| num_layers=hawk_config["num_layers"], | |
| conv_kernel_size=hawk_config["conv_kernel_size"], | |
| dropout=hawk_config["dropout"] | |
| ) | |
| hawk_model.load_state_dict(torch.load(os.path.join(MODELS_DIR, "hawk_best_model.pt"), map_location=device, weights_only=False)['model_state_dict']) | |
| hawk_model.to(device) | |
| hawk_model.eval() | |
| models["hawk"] = hawk_model | |
| # Load Mamba2 model | |
| mamba_config = load_model_config("mamba", models_dir="deployment/models") | |
| input_size_mamba = get_input_size(mamba_config) | |
| mamba_model = Mamba2Predictor( | |
| input_size=input_size_mamba, | |
| hidden_size=mamba_config["hidden_size"], | |
| num_layers=mamba_config["num_layers"], | |
| inner_size=mamba_config["inner_size"], | |
| head_size=mamba_config["head_size"], | |
| bc_head_size=mamba_config["bc_head_size"], | |
| conv_kernel_size=mamba_config["conv_kernel_size"], | |
| dropout=mamba_config["dropout"] | |
| ) | |
| mamba_model.load_state_dict(torch.load(os.path.join(MODELS_DIR, "mamba_best_model.pt"), map_location=device, weights_only=False)['model_state_dict']) | |
| mamba_model.to(device) | |
| mamba_model.eval() | |
| models["mamba"] = mamba_model | |
| # Load xLSTM model | |
| xlstm_config = load_model_config("xlstm", models_dir="deployment/models") | |
| input_size_xlstm = get_input_size(xlstm_config) | |
| xlstm_model = xLSTMPredictor( | |
| input_size=input_size_xlstm, | |
| hidden_size=xlstm_config["hidden_size"], | |
| num_layers=xlstm_config["num_layers"], | |
| block_type=xlstm_config["block_type"], | |
| num_heads=xlstm_config["num_heads"], | |
| conv_kernel_size=xlstm_config["conv_kernel_size"], | |
| dropout=xlstm_config["dropout"], | |
| expand_factor=xlstm_config["expand_factor"] | |
| ) | |
| xlstm_model.load_state_dict(torch.load(os.path.join(MODELS_DIR, "xlstm_best_model.pt"), map_location=device, weights_only=False)['model_state_dict']) | |
| xlstm_model.to(device) | |
| xlstm_model.eval() | |
| models["xlstm"] = xlstm_model | |
| # Load Scikit-learn models | |
| with open(os.path.join(MODELS_DIR, "RandomForest_model.pkl"), "rb") as f: | |
| rf_model = pickle.load(f) | |
| models["random_forest"] = rf_model | |
| from sklearn.preprocessing import StandardScaler | |
| from sklearn.metrics import mean_squared_error, mean_absolute_error, r2_score | |
| import matplotlib.pyplot as plt | |
| def predict(model_name, file): | |
| model = models.get(model_name) | |
| if not model: | |
| return "Model not found", None, None | |
| df = pd.read_csv(file.name) | |
| config = load_model_config(model_name, models_dir="deployment/models") | |
| feature_cols = config["feature_cols"] | |
| target_col = config["target_col"] | |
| seq_length = config["seq_length"] | |
| # Data preparation (assuming the uploaded file is the test set) | |
| scaler = StandardScaler() | |
| # Fit on a dummy array to avoid errors, in a real scenario you would load a fitted scaler | |
| scaler.fit(np.random.rand(100, len(feature_cols))) | |
| features = scaler.transform(df[feature_cols].values) | |
| targets = df[target_col].values | |
| X_test = [] | |
| y_test = [] | |
| for i in range(len(features) - seq_length): | |
| X_test.append(features[i : i + seq_length]) | |
| y_test.append(targets[i : i + seq_length]) | |
| X_test = torch.FloatTensor(np.array(X_test)) | |
| y_test = np.array(y_test) | |
| # Prediction | |
| if model_name in ["hawk", "mamba", "xlstm"]: | |
| X_test = X_test.to(device) | |
| with torch.no_grad(): | |
| predictions, _ = model(X_test) | |
| predictions = predictions.cpu().numpy() | |
| else: # scikit-learn models | |
| # For sklearn models, you might need to flatten the sequences | |
| X_test_reshaped = X_test.reshape(len(X_test), -1) | |
| predictions = model.predict(X_test_reshaped) | |
| # The output shape of sklearn models might differ, you might need to adjust this | |
| # For this example, let's assume it's a 1D array and we need to make it match the y_test shape | |
| predictions = np.repeat(predictions[:, np.newaxis], y_test.shape[1], axis=1) | |
| # For PyTorch models, predictions have an extra dimension | |
| if model_name in ["hawk", "mamba", "xlstm"]: | |
| y_pred_for_metrics = predictions[:, -1, 0] | |
| else: | |
| y_pred_for_metrics = predictions[:, -1] | |
| # Calculate metrics | |
| y_true_for_metrics = y_test[:, -1] | |
| metrics = { | |
| "MSE": mean_squared_error(y_true_for_metrics, y_pred_for_metrics), | |
| "RMSE": np.sqrt(mean_squared_error(y_true_for_metrics, y_pred_for_metrics)), | |
| "MAE": mean_absolute_error(y_true_for_metrics, y_pred_for_metrics), | |
| "R2": r2_score(y_true_for_metrics, y_pred_for_metrics), | |
| } | |
| metrics_str = json.dumps(metrics, indent=4) | |
| # Create plot | |
| fig, ax = plt.subplots(figsize=(10, 6)) | |
| ax.plot(y_true_for_metrics, label="Actual") | |
| ax.plot(y_pred_for_metrics, label="Predicted") | |
| ax.set_title("Predictions vs Actual") | |
| ax.set_xlabel("Time Step") | |
| ax.set_ylabel("Value") | |
| ax.legend() | |
| ax.grid(True) | |
| # For this example, we'll just return the last prediction of the last sequence | |
| last_prediction = predictions[-1, -1, 0] if model_name in ["hawk", "mamba", "xlstm"] else predictions[-1, -1] | |
| return f"{last_prediction:.4f}", metrics_str, fig | |
| # --- Gradio Interface --- | |
| with gr.Blocks(theme=Base(), title="Stock Predictor") as demo: | |
| gr.Markdown( | |
| """ | |
| # Stock Price Predictor | |
| Select a model and upload a CSV file with the required features to get a prediction. | |
| """ | |
| ) | |
| with gr.Row(): | |
| with gr.Column(): | |
| model_name = gr.Dropdown( | |
| label="Select Model", choices=list(models.keys()) | |
| ) | |
| feature_input = gr.File( | |
| label="Upload CSV with features", | |
| ) | |
| predict_btn = gr.Button("Predict") | |
| with gr.Column(): | |
| prediction_output = gr.Textbox(label="Prediction") | |
| metrics_output = gr.Textbox(label="Metrics") | |
| plot_output = gr.Plot(label="Plots") | |
| predict_btn.click( | |
| fn=predict, | |
| inputs=[model_name, feature_input], | |
| outputs=[prediction_output, metrics_output, plot_output], | |
| ) | |
| # --- FastAPI App --- | |
| app = FastAPI() | |
| from fastapi.responses import RedirectResponse | |
| def read_root(): | |
| return RedirectResponse(url="/gradio") | |
| app = gr.mount_gradio_app(app, demo, path="/gradio") | |