ChromFound / src /models /chromfd_block.py
YifengJiao's picture
Upload folder using huggingface_hub
534e5a3 verified
import torch
from mamba_ssm.ops.triton.layer_norm import RMSNorm
from mamba_ssm.ops.triton.layer_norm import layer_norm_fn
from torch import nn
from chromfd_flashatt import ChromFoundTransformerBlock
class Block(nn.Module):
def __init__(
self, dim, mixer_cls, mlp_cls, norm_cls=nn.LayerNorm, fused_add_norm=False, residual_in_fp32=False,
wpsa_window_size=0, shift_size=0, wpsa_heads=0, seq_length=0
):
"""
Simple block wrapping a mixer class with LayerNorm/RMSNorm, residual connection, and wpsa heads.
Args:
wpsa_window_size: Window size for wpsa.
shift_size: Shift size for wpsa.
wpsa_heads: Number of attention heads for wpsa.
seq_length: Input sequence length (for wpsa initialization).
"""
super().__init__()
self.residual_in_fp32 = residual_in_fp32
self.fused_add_norm = fused_add_norm
self.norm = norm_cls(dim)
self.mixer = mixer_cls(32)
self.reduction_layer = nn.Linear(dim, 32)
self.expansion_layer = nn.Linear(32, dim)
if mlp_cls is not nn.Identity:
self.norm2 = norm_cls(dim)
self.mlp = mlp_cls(dim)
else:
self.mlp = None
if self.fused_add_norm:
assert RMSNorm is not None, "RMSNorm import fails"
assert isinstance(
self.norm, (nn.LayerNorm, RMSNorm)
), "Only LayerNorm and RMSNorm are supported for fused_add_norm"
if wpsa_window_size > 0:
self.chromfound_block = ChromFoundTransformerBlock(
dim=dim,
input_resolution=seq_length,
num_heads=wpsa_heads,
window_size=wpsa_window_size,
shift_size=shift_size
)
else:
self.chromfound_block = None
def forward(
self, hidden_states, residual=None, inference_params=None, **mixer_kwargs
):
"""
Pass the input through the encoder layer.
Args:
hidden_states: the sequence to the encoder layer (required).
residual: hidden_states = Mixer(LN(residual)).
inference_params: Inference parameters for mamba(required).
"""
# Step 1: Residual connection and normalization
if not self.fused_add_norm:
residual = (hidden_states + residual) if residual is not None else hidden_states
hidden_states = self.norm(residual.to(dtype=self.norm.weight.dtype))
if self.residual_in_fp32:
residual = residual.to(torch.float32)
else:
hidden_states, residual = layer_norm_fn(
hidden_states,
self.norm.weight,
self.norm.bias,
residual=residual,
prenorm=True,
residual_in_fp32=self.residual_in_fp32,
eps=self.norm.eps,
is_rms_norm=isinstance(self.norm, RMSNorm)
)
# Step 2: Pass through TransformerBlock if present
if self.chromfound_block is not None:
hidden_states = self.chromfound_block(hidden_states)
hidden_states = self.reduction_layer(hidden_states)
# Step 3: Pass through mamba mixer
hidden_states = self.mixer(hidden_states, inference_params=inference_params, **mixer_kwargs)
hidden_states = self.expansion_layer(hidden_states)
# Step 4: Optional MLP block
if self.mlp is not None:
if not self.fused_add_norm:
residual = hidden_states + residual
residual = self.norm2(residual.to(dtype=self.norm2.weight.dtype))
if self.residual_in_fp32:
residual = residual.to(torch.float32)
else:
hidden_states, residual = layer_norm_fn(
hidden_states,
self.norm2.weight,
self.norm2.bias,
residual=residual,
prenorm=True,
residual_in_fp32=self.residual_in_fp32,
eps=self.norm2.eps,
is_rms_norm=isinstance(self.norm2, RMSNorm)
)
hidden_states = self.mlp(hidden_states)
return hidden_states, residual
def allocate_inference_cache(self, batch_size, max_seqlen, dtype=None, **kwargs):
return self.mixer.allocate_inference_cache(batch_size, max_seqlen, dtype=dtype, **kwargs)