StarMist0012's picture
Add files using upload-large-folder tool
e2bfccc verified
"""Gamma SSM Block with residual connections and normalization."""
import torch
import torch.nn as nn
from typing import Optional, Tuple
from .ssm_gamma import SSMGamma
from .normalization import LayerNorm
class GammaSingleBlock(nn.Module):
"""
Single Gamma SSM Block with residual connection and layer normalization.
Performs: y = Block(LayerNorm(x)) + x (if prenorm=True)
or y = LayerNorm(Block(x) + x) (if prenorm=False)
Args:
d_model: Model dimension
hidden_dim: Hidden dimension for the SSM state
delta_t: Time discretization step (default: 0.1)
kernel_length: Convolution kernel length for future use (default: 4)
A_type: Type of A matrix initialization (default: "tridiagonal")
prenorm: Use prenorm (LayerNorm -> Block) vs postnorm (Block -> LayerNorm) (default: True)
residual_scale: Scaling factor for residual connection (default: 1.0)
dropout: Dropout rate after block (default: 0.0)
Shape:
- Input: (batch, seq_len, d_model)
- Output: (batch, seq_len, d_model)
"""
def __init__(
self,
d_model: int,
hidden_dim: int,
delta_t: float = 0.1,
kernel_length: int = 4,
A_type: str = "tridiagonal",
prenorm: bool = True,
residual_scale: float = 1.0,
dropout: float = 0.0,
):
super().__init__()
self.d_model = d_model
self.prenorm = prenorm
self.dropout_p = dropout
self.residual_scale = residual_scale
# Normalization
self.norm = LayerNorm(d_model)
# SSM block
self.ssm = SSMGamma(
state_dim=d_model,
hidden_dim=hidden_dim,
delta_t=delta_t,
kernel_length=kernel_length,
A_type=A_type,
)
# Dropout
if dropout > 0:
self.dropout = nn.Dropout(dropout)
else:
self.dropout = None
def forward(
self,
x: torch.Tensor,
state: Optional[torch.Tensor] = None,
mask: Optional[torch.Tensor] = None,
return_state: bool = True,
) -> Tuple[torch.Tensor, Optional[torch.Tensor]]:
"""
Forward pass through block.
Args:
x: Input tensor (batch, seq_len, d_model)
state: Optional initial hidden state (batch, hidden_dim)
mask: Optional mask (batch, seq_len) for valid positions
Returns:
output: (batch, seq_len, d_model)
final_state: Final hidden state from SSM (batch, hidden_dim)
"""
if self.prenorm:
# Apply norm before SSM
x_norm = self.norm(x)
ssm_out, final_state = self.ssm(x_norm, mask=mask, state=state)
else:
# Apply SSM first, then norm
ssm_out, final_state = self.ssm(x, mask=mask, state=state)
ssm_out = self.norm(ssm_out)
# Apply dropout if present
if self.dropout is not None:
ssm_out = self.dropout(ssm_out)
# Residual connection with optional scaling
output = x * self.residual_scale + ssm_out
# Apply final norm if postnorm
if not self.prenorm:
output = self.norm(output)
if not return_state:
final_state = None
return output, final_state
def step(self, u: torch.Tensor, h: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
"""
Single step inference through block (RNN style).
Args:
u: Input tensor (batch, d_model) - single timestep
h: Hidden state (batch, hidden_dim)
Returns:
output: (batch, d_model) - block output
h_new: (batch, hidden_dim) - new hidden state
"""
if self.prenorm:
# Apply norm before SSM
u_norm = self.norm(u)
ssm_out, h_new = self.ssm.step(u_norm, h)
else:
# Apply SSM first, then norm
ssm_out, h_new = self.ssm.step(u, h)
ssm_out = self.norm(ssm_out)
# Apply dropout if present
if self.dropout is not None:
ssm_out = self.dropout(ssm_out)
# Residual connection with optional scaling
output = u * self.residual_scale + ssm_out
return output, h_new
def allocate_inference_cache(
self,
batch_size: int,
seq_len: int,
device: torch.device,
dtype: torch.dtype,
):
"""Allocate cache for efficient inference."""
return self.ssm.allocate_inference_cache(batch_size, seq_len, device, dtype)
def allocate_deployment_cache(
self,
batch_size: int,
seq_len: int,
device: torch.device,
dtype: torch.dtype,
):
return self.allocate_inference_cache(batch_size, seq_len, device, dtype)
def allocate_balanced_deployment_cache(
self,
batch_size: int,
seq_len: int,
device: torch.device,
dtype: torch.dtype,
):
return self.allocate_inference_cache(batch_size, seq_len, device, dtype)