Pomilon
Deploy Aetheris to HF Space
1df0e33
import torch
import torch.nn as nn
import torch.nn.functional as F
from ..config import AetherisConfig
def selective_scan_native(u: torch.Tensor, delta: torch.Tensor, A: torch.Tensor,
B: torch.Tensor, C: torch.Tensor, D: torch.Tensor) -> torch.Tensor:
"""Memory-efficient scan with reduced intermediate tensors."""
B_size, L, D_inner = u.shape
D_state = A.shape[-1]
# Use in-place operations where possible
h = torch.zeros(B_size, D_inner, D_state, device=u.device, dtype=u.dtype)
ys = []
for l in range(L):
dt = delta[:, l, :].unsqueeze(-1)
dA = torch.exp(dt * A)
B_l = B[:, l, :].unsqueeze(1)
dB = dt * B_l
u_t = u[:, l, :].unsqueeze(-1)
h = dA * h + dB * u_t
C_l = C[:, l, :].unsqueeze(1)
y_t = torch.sum(h * C_l, dim=-1)
ys.append(y_t)
y = torch.stack(ys, dim=1)
return y + u * D
class SSMBlock(nn.Module):
"""Memory-optimized State Space Model with stability improvements."""
def __init__(self, config: AetherisConfig):
super().__init__()
self.d_model = config.d_model
self.d_state = config.ssm_d_state
self.d_inner = config.d_inner
self.in_proj = nn.Linear(self.d_model, self.d_inner * 2, bias=False)
self.out_proj = nn.Linear(self.d_inner, self.d_model, bias=False)
self.conv_d = nn.Conv1d(self.d_inner, self.d_inner, kernel_size=3,
padding=2, groups=self.d_inner, bias=False)
self.gate_proj = nn.Linear(self.d_model, self.d_inner, bias=False)
self.B_proj = nn.Linear(self.d_inner, self.d_state, bias=False)
self.C_proj = nn.Linear(self.d_inner, self.d_state, bias=False)
self.delta_proj = nn.Linear(self.d_inner, self.d_inner, bias=False)
# Initialize A to be more stable (closer to -1)
self.A_log = nn.Parameter(torch.randn(self.d_inner, self.d_state) * 0.1 - 4.0)
self.D = nn.Parameter(torch.ones(self.d_inner) * 0.1)
self.act = nn.SiLU()
self.norm = nn.LayerNorm(config.d_model)
# Proper initialization
nn.init.xavier_uniform_(self.in_proj.weight, gain=0.5)
nn.init.xavier_uniform_(self.out_proj.weight, gain=0.5)
nn.init.xavier_uniform_(self.gate_proj.weight, gain=0.5)
nn.init.xavier_uniform_(self.B_proj.weight, gain=0.5)
nn.init.xavier_uniform_(self.C_proj.weight, gain=0.5)
nn.init.xavier_uniform_(self.delta_proj.weight, gain=0.5)
def forward(self, x: torch.Tensor) -> torch.Tensor:
B, L, D = x.shape
x_norm = self.norm(x)
xz = self.in_proj(x_norm)
x_in, z_gate = xz.chunk(2, dim=-1)
x_conv = self.conv_d(x_in.transpose(1, 2))
# Slice off the last 2 elements (the "future" leakage)
x_conv = x_conv[:, :, :-2].transpose(1, 2)
x_conv = self.act(x_conv)
# Add small epsilon to prevent numerical issues and clamp max value
delta = torch.clamp(F.softplus(self.delta_proj(x_conv)), max=5.0) + 1e-4
B_ssm = self.B_proj(x_conv)
C_ssm = self.C_proj(x_conv)
# Clamp A to prevent extreme values
A_fixed = -torch.exp(torch.clamp(self.A_log, min=-10.0, max=2.0))
A_batched = A_fixed.unsqueeze(0).expand(B, -1, -1)
y_ssm = selective_scan_native(x_conv, delta, A_batched, B_ssm, C_ssm, self.D)
y_gate = F.silu(self.gate_proj(x_norm)) * y_ssm
output = self.out_proj(y_gate)
return x + output