"""Gamma SSM Block with residual connections and normalization.""" import torch import torch.nn as nn from typing import Optional, Tuple from .ssm_gamma import SSMGamma from .normalization import LayerNorm class GammaSingleBlock(nn.Module): """ Single Gamma SSM Block with residual connection and layer normalization. Performs: y = Block(LayerNorm(x)) + x (if prenorm=True) or y = LayerNorm(Block(x) + x) (if prenorm=False) Args: d_model: Model dimension hidden_dim: Hidden dimension for the SSM state delta_t: Time discretization step (default: 0.1) kernel_length: Convolution kernel length for future use (default: 4) A_type: Type of A matrix initialization (default: "tridiagonal") prenorm: Use prenorm (LayerNorm -> Block) vs postnorm (Block -> LayerNorm) (default: True) residual_scale: Scaling factor for residual connection (default: 1.0) dropout: Dropout rate after block (default: 0.0) Shape: - Input: (batch, seq_len, d_model) - Output: (batch, seq_len, d_model) """ def __init__( self, d_model: int, hidden_dim: int, delta_t: float = 0.1, kernel_length: int = 4, A_type: str = "tridiagonal", prenorm: bool = True, residual_scale: float = 1.0, dropout: float = 0.0, ): super().__init__() self.d_model = d_model self.prenorm = prenorm self.dropout_p = dropout self.residual_scale = residual_scale # Normalization self.norm = LayerNorm(d_model) # SSM block self.ssm = SSMGamma( state_dim=d_model, hidden_dim=hidden_dim, delta_t=delta_t, kernel_length=kernel_length, A_type=A_type, ) # Dropout if dropout > 0: self.dropout = nn.Dropout(dropout) else: self.dropout = None def forward( self, x: torch.Tensor, state: Optional[torch.Tensor] = None, mask: Optional[torch.Tensor] = None, return_state: bool = True, ) -> Tuple[torch.Tensor, Optional[torch.Tensor]]: """ Forward pass through block. Args: x: Input tensor (batch, seq_len, d_model) state: Optional initial hidden state (batch, hidden_dim) mask: Optional mask (batch, seq_len) for valid positions Returns: output: (batch, seq_len, d_model) final_state: Final hidden state from SSM (batch, hidden_dim) """ if self.prenorm: # Apply norm before SSM x_norm = self.norm(x) ssm_out, final_state = self.ssm(x_norm, mask=mask, state=state) else: # Apply SSM first, then norm ssm_out, final_state = self.ssm(x, mask=mask, state=state) ssm_out = self.norm(ssm_out) # Apply dropout if present if self.dropout is not None: ssm_out = self.dropout(ssm_out) # Residual connection with optional scaling output = x * self.residual_scale + ssm_out # Apply final norm if postnorm if not self.prenorm: output = self.norm(output) if not return_state: final_state = None return output, final_state def step(self, u: torch.Tensor, h: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]: """ Single step inference through block (RNN style). Args: u: Input tensor (batch, d_model) - single timestep h: Hidden state (batch, hidden_dim) Returns: output: (batch, d_model) - block output h_new: (batch, hidden_dim) - new hidden state """ if self.prenorm: # Apply norm before SSM u_norm = self.norm(u) ssm_out, h_new = self.ssm.step(u_norm, h) else: # Apply SSM first, then norm ssm_out, h_new = self.ssm.step(u, h) ssm_out = self.norm(ssm_out) # Apply dropout if present if self.dropout is not None: ssm_out = self.dropout(ssm_out) # Residual connection with optional scaling output = u * self.residual_scale + ssm_out return output, h_new def allocate_inference_cache( self, batch_size: int, seq_len: int, device: torch.device, dtype: torch.dtype, ): """Allocate cache for efficient inference.""" return self.ssm.allocate_inference_cache(batch_size, seq_len, device, dtype) def allocate_deployment_cache( self, batch_size: int, seq_len: int, device: torch.device, dtype: torch.dtype, ): return self.allocate_inference_cache(batch_size, seq_len, device, dtype) def allocate_balanced_deployment_cache( self, batch_size: int, seq_len: int, device: torch.device, dtype: torch.dtype, ): return self.allocate_inference_cache(batch_size, seq_len, device, dtype)