CSU-EP / modular_csuep.py
Tingxie's picture
Rename modular_curei.py to modular_csuep.py
d28e605 verified
import math
from contextlib import nullcontext
from typing import Dict, Literal, Optional, Tuple, Union
from einops import rearrange, repeat
import torch
import torch.nn.functional as F
import torch.utils.checkpoint
from torch import nn
from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss
from transformers.activations import ACT2FN
from transformers.configuration_utils import PretrainedConfig
from transformers.modeling_attn_mask_utils import _prepare_4d_attention_mask
from transformers.modeling_outputs import (
BaseModelOutput,
MaskedLMOutput,
QuestionAnsweringModelOutput,
SequenceClassifierOutput,
TokenClassifierOutput,
)
from transformers.modeling_utils import PreTrainedModel
from transformers.utils import (
add_code_sample_docstrings,
add_start_docstrings,
add_start_docstrings_to_model_forward,
is_flash_attn_2_available,
logging,
)
from transformers.models.gemma.modeling_gemma import GemmaRotaryEmbedding, apply_rotary_pos_emb
if is_flash_attn_2_available():
from flash_attn.flash_attn_interface import flash_attn_varlen_qkvpacked_func
from flash_attn.layers.rotary import RotaryEmbedding
from flash_attn.ops.triton.rotary import apply_rotary
else:
RotaryEmbedding = object
_CHECKPOINT_FOR_DOC = "Csuep"
_CONFIG_FOR_DOC = "CsuepConfig"
logger = logging.get_logger(__name__)
class CsuepConfig(PretrainedConfig):
model_type = "CSU-EP"
keys_to_ignore_at_inference = ["past_key_values"]
def __init__(
self,
vocab_size=1004,
hidden_size=768,
intermediate_size=1152,
num_hidden_layers=12,
num_attention_heads=12,
hidden_activation="gelu",
max_position_embeddings=400,
mlm_probability=0.3,
mask_token_id=1003,
initializer_range=0.02,
initializer_cutoff_factor=2.0,
norm_eps=1e-5,
norm_bias=False,
pad_token_id=1002,
global_rope_theta=1000.0,
attention_bias=False,
attention_dropout=0.1,
global_attn_every_n_layers=12,
local_attention=128,
local_rope_theta=1000.0,
embedding_dropout=0.0,
mlp_bias=False,
mlp_dropout=0.0,
decoder_bias=True,
intensity_encoding_method = "concatenate",
classifier_pooling: Literal["cls", "mean"] = "cls",
classifier_dropout=0.0,
classifier_bias=False,
classifier_activation="gelu",
deterministic_flash_attn=False,
sparse_prediction=False,
sparse_pred_ignore_index=-100,
reference_compile=False,
repad_logits_with_grad=False,
_attn_implementation = "eager",
**kwargs,
):
super().__init__(
pad_token_id=pad_token_id,
**kwargs,
)
self.vocab_size = vocab_size
self.max_position_embeddings = max_position_embeddings
self.hidden_size = hidden_size
self.intermediate_size = intermediate_size
self.num_hidden_layers = num_hidden_layers
self.num_attention_heads = num_attention_heads
self.mlm_probability = mlm_probability
self.mask_token_id = mask_token_id
self.initializer_range = initializer_range
self.initializer_cutoff_factor = initializer_cutoff_factor
self.norm_eps = norm_eps
self.norm_bias = norm_bias
self.global_rope_theta = global_rope_theta
self.attention_bias = attention_bias
self.attention_dropout = attention_dropout
self.hidden_activation = hidden_activation
self.global_attn_every_n_layers = global_attn_every_n_layers
self.local_attention = local_attention
self.local_rope_theta = local_rope_theta
self.embedding_dropout = embedding_dropout
self.mlp_bias = mlp_bias
self.mlp_dropout = mlp_dropout
self.decoder_bias = decoder_bias
self.intensity_encoding_method = intensity_encoding_method
self.classifier_pooling = classifier_pooling
self.classifier_dropout = classifier_dropout
self.classifier_bias = classifier_bias
self.classifier_activation = classifier_activation
self.deterministic_flash_attn = deterministic_flash_attn
self.sparse_prediction = sparse_prediction
self.sparse_pred_ignore_index = sparse_pred_ignore_index
self.reference_compile = reference_compile
self.repad_logits_with_grad = repad_logits_with_grad
self._attn_implementation = _attn_implementation
if self.classifier_pooling not in ["cls", "mean"]:
raise ValueError(
f'Invalid value for `classifier_pooling`, should be either "cls" or "mean", but is {self.classifier_pooling}.'
)
def to_dict(self):
output = super().to_dict()
output.pop("reference_compile", None)
return output
def _unpad_csuep_input(
inputs: torch.Tensor,
intensities: torch.Tensor,
attention_mask: torch.Tensor,
position_ids: Optional[torch.Tensor] = None,
labels: Optional[torch.Tensor] = None,
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, int, Optional[torch.Tensor], Optional[torch.Tensor]]:
"""
Remove padding from input sequences.
Args:
inputs: (batch, seqlen, ...) or (batch, seqlen)
intensities: (batch, seqlen, ...) or (batch, seqlen)
attention_mask: (batch, seqlen), bool / int, 1 means valid and 0 means not valid.
position_ids: (batch, seqlen), int, position ids
labels: (batch, seqlen), int, labels
Returns:
unpadded_inputs: (total_nnz, ...), where total_nnz = number of tokens selected in attention_mask.
unpadded_intens: (total_nnz, ...), where total_nnz = number of tokens selected in attention_mask.
indices: (total_nnz)
cu_seqlens: (batch + 1), the cumulative sequence lengths
max_seqlen_in_batch: int
unpadded_position_ids: (total_nnz) or None
unpadded_labels: (total_nnz) or None
"""
seqlens_in_batch = attention_mask.sum(dim=-1, dtype=torch.int32)
indices = torch.nonzero(attention_mask.flatten(), as_tuple=False).flatten()
max_seqlen_in_batch = int(seqlens_in_batch.max().item())
cu_seqlens = torch.nn.functional.pad(torch.cumsum(seqlens_in_batch, dim=0, dtype=torch.int32), (1, 0))
if inputs.dim() == 2:
unpadded_inputs = inputs.flatten()[indices]
unpadded_intens = intensities.flatten()[indices]
else:
batch, seqlen, *rest = inputs.shape
shape = batch * seqlen
unpadded_inputs = inputs.view(shape, *rest)[indices]
unpadded_intens = intensities.flatten()[indices]
unpadded_position_ids = position_ids.flatten()[indices] if position_ids is not None else None
unpadded_labels = labels.flatten()[indices] if labels is not None else None
return unpadded_inputs, unpadded_intens,indices, cu_seqlens, max_seqlen_in_batch, unpadded_position_ids, unpadded_labels
def _pad_csuep_output(
inputs: torch.Tensor,
indices: torch.Tensor,
batch: int,
seqlen: int,
) -> torch.Tensor:
"""
Add padding to sequences.
Args:
inputs: (total_nnz, ...) or (total_nnz,), where total_nnz = number of tokens selected in attention_mask.
indices: (total_nnz)
batch: int, batch size
seqlen: int, max sequence length
Returns:
padded_inputs: (batch, seqlen, ...) or (batch, seqlen)
"""
if inputs.dim() == 1:
output = torch.zeros(batch * seqlen, dtype=inputs.dtype, device=inputs.device)
output[indices] = inputs
padded_inputs = output.view(batch, seqlen)
else:
_, *rest = inputs.shape
output = torch.zeros(batch * seqlen, *rest, dtype=inputs.dtype, device=inputs.device)
output[indices] = inputs
padded_inputs = output.view(batch, seqlen, *rest)
return padded_inputs
class ApplyRotaryEmbUnpad(torch.autograd.Function):
@staticmethod
def forward(
ctx,
qkv,
cos,
sin,
cu_seqlens: Optional[torch.Tensor] = None,
max_seqlen: Optional[int] = None,
):
# (total_nnz, 3, nheads, headdim)
qkv = qkv.contiguous()
total_nnz, _three, _nheads, headdim = qkv.shape
# We need qkv to be contiguous so that when we reshape to combine (3, nheads) dimensions,
# we get the same tensor
# qk = rearrange(qkv[:, :2], "b_s t h d -> b_s (t h) d")
qk = qkv[:, :2].view(total_nnz, -1, headdim)
apply_rotary(
qk,
cos,
sin,
seqlen_offsets=0,
cu_seqlens=cu_seqlens,
max_seqlen=max_seqlen,
interleaved=False,
inplace=True,
)
ctx.save_for_backward(cos, sin, cu_seqlens)
ctx.max_seqlen = max_seqlen
return qkv
@staticmethod
def backward(ctx, do):
cos, sin, cu_seqlens = ctx.saved_tensors
do = do.contiguous()
total_nnz, _three, _nheads, headdim = do.shape
# We need dqkv to be contiguous so that when we reshape to combine (3, nheads) dimensions,
# we get the same tensor
dqk = do[:, :2].view(total_nnz, -1, headdim)
apply_rotary(
dqk,
cos,
sin,
seqlen_offsets=0,
cu_seqlens=cu_seqlens,
max_seqlen=ctx.max_seqlen,
interleaved=False,
inplace=True,
conjugate=True,
)
return do, None, None, None, None, None, None
def apply_rotary_unpadded(
qkv,
cos,
sin,
cu_seqlens: Optional[torch.Tensor] = None,
max_seqlen: Optional[int] = None,
):
"""
Arguments:
qkv: (total_nnz, 3, nheads, headdim) - input tensor for packed QKV.
cos, sin: (seqlen_rotary, rotary_dim / 2)
interleaved: if True, rotate pairs of even and odd dimensions (GPT-J style) instead
of 1st half and 2nd half (GPT-NeoX style).
inplace: if True, apply rotary embedding in-place.
seqlen_offsets: (batch_size,) or int. Each sequence in x is shifted by this amount.
Most commonly used in inference when we have KV cache.
cu_seqlens: (batch + 1,) or None
max_seqlen: int
Return:
out: (total_nnz, dim)
rotary_dim must be <= headdim
Apply rotary embedding to the first rotary_dim of x.
"""
return ApplyRotaryEmbUnpad.apply(qkv, cos, sin, cu_seqlens, max_seqlen)
class CsuepUnpaddedRotaryEmbedding(RotaryEmbedding):
"""
The rotary position embeddings applied directly to unpadded sequences.
"""
def __init__(
self,
dim: int,
base: float = 1000.0,
max_seqlen: Optional[int] = None,
device: Optional[torch.device] = None,
dtype: Optional[torch.dtype] = None,
):
"""
max_seqlen: if max_seqlen, device, and dtype are provided, we precompute the cos_sin_cache
up to max_seqlen. If the max_seqlen, device, or dtype during training/inference differ,
the cos_sin_cache will be recomputed during the forward pass.
"""
super().__init__(dim=dim, base=base, pos_idx_in_fp32=True, device=device, interleaved=False)
self.max_seqlen = max_seqlen
if max_seqlen is not None and device is not None and dtype is not None:
self._update_cos_sin_cache(max_seqlen, device=device, dtype=dtype)
def forward(
self,
qkv: torch.Tensor,
cu_seqlens: torch.Tensor,
max_seqlen: Optional[int] = None,
) -> Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]:
"""
Apply rotary embedding *inplace* to qkv.
qkv: (total_nnz, 3, nheads, headdim)
cu_seqlens: (batch + 1,) cumulative sequence lengths
max_seqlen: int max seq length in the batch
"""
if max_seqlen is not None:
self._update_cos_sin_cache(max_seqlen, device=qkv.device, dtype=qkv.dtype)
qkv = apply_rotary_unpadded(
qkv,
self._cos_cached,
self._sin_cached,
cu_seqlens=cu_seqlens,
max_seqlen=max_seqlen,
)
return qkv
def extra_repr(self) -> str:
return f"dim={self.dim}, base={self.base}, scale_base={self.scale_base}"
class CsuepEmbeddings(nn.Module):
"""
Embeddings layer adapted to handle unpadded, 1D inputs for both m/z tokens
and intensities. Supports both concatenation and multiplication for intensity encoding.
"""
def __init__(self, config: CsuepConfig, intensity_encoding_method: str = "concatenate"):
super().__init__()
self.config = config
self.tok_embeddings = nn.Embedding(
config.vocab_size, config.hidden_size, padding_idx=config.pad_token_id
)
self.intensity_encoding_method = config.intensity_encoding_method
if self.intensity_encoding_method not in ["concatenate", "multiply"]:
raise ValueError("intensity_encoding_method must be either 'concatenate' or 'multiply'")
if self.intensity_encoding_method == "concatenate":
self.intensity_compress = nn.Linear(config.hidden_size + 1, config.hidden_size, bias=config.mlp_bias)
else:
self.intensity_compress = None
self.norm = nn.LayerNorm(config.hidden_size, eps=config.norm_eps, bias=config.norm_bias)
self.drop = nn.Dropout(config.embedding_dropout)
def _combine_with_intensities_unpadded(
self, token_embeddings: torch.Tensor, intensities: torch.Tensor
) -> torch.Tensor:
"""
Combines token embeddings with intensity information for UNPADDED inputs.
Args:
token_embeddings: (total_nnz, hidden_size)
intensities: (total_nnz,)
Returns:
Combined embeddings of shape (total_nnz, hidden_size).
"""
if self.intensity_encoding_method == "concatenate":
intensity_reshaped = intensities.unsqueeze(-1)
combined = torch.cat([token_embeddings, intensity_reshaped], dim=-1)
return self.intensity_compress(combined)
elif self.intensity_encoding_method == "multiply":
intensity_reshaped = intensities.unsqueeze(-1)
return token_embeddings * intensity_reshaped
@torch.compile(dynamic=True)
def compiled_embeddings(self, input_ids: torch.LongTensor, intensities: torch.FloatTensor) -> torch.Tensor:
token_embeddings = self.tok_embeddings(input_ids)
combined_embeddings = self._combine_with_intensities_unpadded(token_embeddings, intensities)
return self.drop(self.norm(combined_embeddings))
def forward(
self,
input_ids: torch.LongTensor, # Expected shape: (total_nnz,)
intensities: torch.FloatTensor, # Expected shape: (total_nnz,)
inputs_embeds: Optional[torch.Tensor] = None,
) -> torch.Tensor:
if self.config.reference_compile:
return self.compiled_embeddings(input_ids, intensities)
if inputs_embeds is not None:
token_embeddings = inputs_embeds
else:
if input_ids is None:
raise ValueError("You must provide either input_ids or inputs_embeds")
token_embeddings = self.tok_embeddings(input_ids)
combined_embeddings = self._combine_with_intensities_unpadded(token_embeddings, intensities)
final_embeddings = self.drop(self.norm(combined_embeddings))
return final_embeddings
class CsuepMLP(nn.Module):
"""Applies the GLU at the end of each CSU-EP layer.
Compared to the default BERT architecture, this block replaces :class:`~transformers.model.bert.modeling_bert.BertIntermediate`
and :class:`~transformers.model.bert.modeling_bert.SelfOutput` with a single module that has similar functionality.
"""
def __init__(self, config: CsuepConfig):
super().__init__()
self.config = config
self.Wi = nn.Linear(config.hidden_size, int(config.intermediate_size) * 2, bias=config.mlp_bias)
self.act = ACT2FN[config.hidden_activation]
self.drop = nn.Dropout(config.mlp_dropout)
self.Wo = nn.Linear(config.intermediate_size, config.hidden_size, bias=config.mlp_bias)
def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
input, gate = self.Wi(hidden_states).chunk(2, dim=-1)
return self.Wo(self.drop(self.act(input) * gate))
class CsuepRotaryEmbedding(GemmaRotaryEmbedding):
def __init__(self, config: CsuepConfig, dim: int, base: float, device: Optional[torch.device] = None):
super().__init__(
dim=dim,
max_position_embeddings=config.max_position_embeddings,
base=base,
device=device
)
def eager_attention_forward(
module: "CsuepAttention",
qkv: torch. Tensor,
intensities: torch.Tensor,
attention_mask: torch.Tensor,
sliding_window_mask: torch.Tensor,
position_ids: Optional[torch.LongTensor],
local_attention: Tuple[int, int],
bs: int,
dim: int,
output_attentions: Optional[bool] = False,
**_kwargs,
) -> Union[Tuple[torch.Tensor, torch.Tensor], Tuple[torch.Tensor]]:
# qkv: [batch_size, seqlen, 3, nheads, headdim]
cos, sin = module.rotary_emb(qkv,position_ids=position_ids)
query, key, value = qkv.transpose(3, 1).unbind(dim=2)
# query, key, value: [batch_size, heads, seq_len, head_dim]
query, key = apply_rotary_pos_emb(query, key, cos, sin)
scale = module.head_dim**-0.5
attn_weights = torch.matmul(query, key.transpose(2, 3)) * scale
if local_attention != (-1, -1):
attention_mask = sliding_window_mask
attn_weights = attn_weights + attention_mask
# upcast attention to fp32
attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query.dtype)
attn_weights = nn.functional.dropout(attn_weights, p=module.attention_dropout, training=module.training)
attn_output = torch.matmul(attn_weights, value)
attn_output = attn_output.transpose(1, 2).contiguous()
attn_output = attn_output.view(bs, -1, dim)
if output_attentions:
return (attn_output, attn_weights)
return (attn_output,)
def flash_attention_forward(
module: "CsuepAttention",
qkv: torch.Tensor,
intensities: torch.Tensor,
rotary_emb: CsuepUnpaddedRotaryEmbedding,
cu_seqlens: torch.Tensor,
max_seqlen: int,
local_attention: Tuple[int, int],
bs: int,
dim: int,
target_dtype: torch.dtype = torch.bfloat16,
**_kwargs,
) -> Tuple[torch.Tensor]:
# (total_seqlen, 3, nheads, headdim)
qkv = rotary_emb(qkv, cu_seqlens=cu_seqlens, max_seqlen=max_seqlen)
convert_dtype = qkv.dtype not in (torch.float16, torch.bfloat16)
if convert_dtype:
# FA2 implementation only supports fp16 and bf16. If FA2 is supported,
# bfloat16 must be supported as of FA2 2.5.7. (Turing GPUs not supported)
orig_dtype = qkv.dtype
qkv = qkv.to(target_dtype)
attn = flash_attn_varlen_qkvpacked_func(
qkv,
cu_seqlens=cu_seqlens,
max_seqlen=max_seqlen,
dropout_p=module.attention_dropout if module.training else 0.0,
deterministic=module.deterministic_flash_attn,
window_size=local_attention,
)
attn = attn.to(orig_dtype) # type: ignore
else:
attn = flash_attn_varlen_qkvpacked_func(
qkv,
cu_seqlens=cu_seqlens,
max_seqlen=max_seqlen,
dropout_p=module.attention_dropout if module.training else 0.0,
deterministic=module.deterministic_flash_attn,
window_size=local_attention,
)
return (attn.view(bs, dim),)
def sdpa_attention_forward(
module: "CsuepAttention",
qkv: torch.Tensor,
attention_mask: torch.Tensor,
sliding_window_mask: torch.Tensor,
position_ids: Optional[torch.LongTensor],
local_attention: Tuple[int, int],
bs: int,
dim: int,
**_kwargs,
) -> Tuple[torch.Tensor]:
# qkv: [batch_size, seqlen, 3, nheads, headdim]
cos, sin = module.rotary_emb(qkv, position_ids=position_ids)
query, key, value = qkv.transpose(3, 1).unbind(dim=2)
# query, key, value: [batch_size, heads, seq_len, head_dim]
query, key = apply_rotary_pos_emb(query, key, cos, sin)
if local_attention != (-1, -1):
attention_mask = sliding_window_mask
attn_output = (
F.scaled_dot_product_attention(
query,
key,
value,
dropout_p=module.attention_dropout if module.training else 0.0,
attn_mask=attention_mask,
)
.transpose(1, 2)
.contiguous()
)
attn_output = attn_output.view(bs, -1, dim)
return (attn_output,)
CSUEP_ATTENTION_FUNCTION = {
"flash_attention_2": flash_attention_forward,
"eager": eager_attention_forward,
"sdpa": sdpa_attention_forward,
}
class CsuepAttention(nn.Module):
"""Performs multi-headed self attention on a batch of unpadded sequences.
If Flash Attention 2 is installed, this module uses Flash Attention to improve throughput.
If Flash Attention 2 is not installed, the implementation will use PyTorch's SDPA kernel,
which requires padding and unpadding inputs, adding some overhead.
See `forward` method for additional details.
"""
def __init__(self, config: CsuepConfig, layer_id: Optional[int] = None):
super().__init__()
self.config = config
self.layer_id = layer_id
if config.hidden_size % config.num_attention_heads != 0:
raise ValueError(
f"The hidden size ({config.hidden_size}) is not a multiple of the number of attention heads ({config.num_attention_heads})"
)
self.attention_dropout = config.attention_dropout
self.deterministic_flash_attn = config.deterministic_flash_attn
self.num_heads = config.num_attention_heads
self.head_dim = config.hidden_size // config.num_attention_heads
self.all_head_size = self.head_dim * self.num_heads
self.Wqkv = nn.Linear(config.hidden_size, 3 * self.all_head_size, bias=config.attention_bias)
if layer_id % config.global_attn_every_n_layers != 0:
self.local_attention = (config.local_attention // 2, config.local_attention // 2)
else:
self.local_attention = (-1, -1)
rope_theta = config.global_rope_theta
max_position_embeddings = config.max_position_embeddings
if self.local_attention != (-1, -1):
if config.local_rope_theta is not None:
rope_theta = config.local_rope_theta
max_position_embeddings = config.local_attention
if config._attn_implementation == "flash_attention_2":
self.rotary_emb = CsuepUnpaddedRotaryEmbedding(
dim=self.head_dim, base=rope_theta
)
else:
import copy
config_copy = copy.deepcopy(config)
config_copy.rope_theta = rope_theta
self.rotary_emb = CsuepRotaryEmbedding(config=config_copy,dim=self.head_dim, base=rope_theta)
self.Wo = nn.Linear(config.hidden_size, config.hidden_size, bias=config.attention_bias)
self.out_drop = nn.Dropout(config.attention_dropout) if config.attention_dropout > 0.0 else nn.Identity()
self.pruned_heads = set()
def forward(
self,
hidden_states: torch.Tensor,
intensities:torch.Tensor,
output_attentions: Optional[bool] = False,
**kwargs,
) -> torch.Tensor:
qkv = self.Wqkv(hidden_states)
bs = hidden_states.shape[0]
if self.config._attn_implementation == "flash_attention_2":
qkv = qkv.view(-1, 3, self.num_heads, self.head_dim)
else:
qkv = qkv.view(bs, -1, 3, self.num_heads, self.head_dim)
attn_outputs = CSUEP_ATTENTION_FUNCTION[self.config._attn_implementation](
self,
qkv=qkv,
intensities=intensities,
rotary_emb=self.rotary_emb,
local_attention=self.local_attention,
bs=bs,
dim=self.all_head_size,
output_attentions=output_attentions,
**kwargs,
)
hidden_states = attn_outputs[0]
hidden_states = self.out_drop(self.Wo(hidden_states))
return (hidden_states,) + attn_outputs[1:] # add attentions if outputted
class CsuepEncoderLayer(nn.Module):
def __init__(self, config: CsuepConfig, layer_id: Optional[int] = None):
super().__init__()
self.config = config
if layer_id == 0:
self.attn_norm = nn.Identity()
else:
self.attn_norm = nn.LayerNorm(config.hidden_size, eps=config.norm_eps, bias=config.norm_bias)
self.attn = CsuepAttention(config=config, layer_id=layer_id)
self.mlp_norm = nn.LayerNorm(config.hidden_size, eps=config.norm_eps, bias=config.norm_bias)
self.mlp = CsuepMLP(config)
@torch.compile(dynamic=True)
def compiled_mlp(self, hidden_states: torch.Tensor) -> torch.Tensor:
return self.mlp(self.mlp_norm(hidden_states))
def forward(
self,
hidden_states: torch.Tensor,
intensities: torch.Tensor,
attention_mask: Optional[torch.Tensor] = None,
sliding_window_mask: Optional[torch.Tensor] = None,
position_ids: Optional[torch.LongTensor] = None,
cu_seqlens: Optional[torch.Tensor] = None,
max_seqlen: Optional[int] = None,
output_attentions: Optional[bool] = False,
) -> torch.Tensor:
attn_outputs = self.attn(
self.attn_norm(hidden_states),
intensities,
attention_mask=attention_mask,
sliding_window_mask=sliding_window_mask,
position_ids=position_ids,
cu_seqlens=cu_seqlens,
max_seqlen=max_seqlen,
output_attentions=output_attentions,
)
hidden_states = hidden_states + attn_outputs[0]
mlp_output = (
self.compiled_mlp(hidden_states)
if self.config.reference_compile
else self.mlp(self.mlp_norm(hidden_states))
)
hidden_states = hidden_states + mlp_output
return (hidden_states,) + attn_outputs[1:] # add attentions if outputted
CSUEP_START_DOCSTRING = r"""
This model inherits from [`PreTrainedModel`]. Check the superclass documentation for the generic methods the
library implements for all its model (such as downloading or saving, resizing the input embeddings, pruning heads
etc.)
This model is also a PyTorch [torch.nn.Module](https://pytorch.org/docs/stable/nn.html#torch.nn.Module) subclass.
Use it as a regular PyTorch Module and refer to the PyTorch documentation for all matter related to general usage
and behavior.
Parameters:
config ([`CsuepConfig`]):
Model configuration class with all the parameters of the model. Initializing with a config file does not
load the weights associated with the model, only the configuration. Check out the
[`~PreTrainedModel.from_pretrained`] method to load the model weights.
"""
@add_start_docstrings(
"The bare CSU-EP Model outputting raw hidden-states without any specific head on top.",
CSUEP_START_DOCSTRING,
)
class CsuepPreTrainedModel(PreTrainedModel):
config_class = CsuepConfig
base_model_prefix = "model"
supports_gradient_checkpointing = True
_no_split_modules = ["CsuepEmbeddings", "CsuepEncoderLayer"]
_supports_flash_attn_2 = True
_supports_sdpa = True
_supports_flex_attn = False
def _init_weights(self, module: nn.Module):
cutoff_factor = self.config.initializer_cutoff_factor
if cutoff_factor is None:
cutoff_factor = 3
def init_weight(module: nn.Module, std: float):
nn.init.trunc_normal_(
module.weight,
mean=0.0,
std=std,
a=-cutoff_factor * std,
b=cutoff_factor * std,
)
if isinstance(module, nn.Linear):
if module.bias is not None:
nn.init.zeros_(module.bias)
stds = {
"in": self.config.initializer_range,
"out": self.config.initializer_range / math.sqrt(2.0 * self.config.num_hidden_layers),
"embedding": self.config.initializer_range,
"final_out": self.config.hidden_size**-0.5,
}
if isinstance(module, CsuepEmbeddings):
init_weight(module.tok_embeddings, stds["embedding"])
elif isinstance(module, CsuepMLP):
init_weight(module.Wi, stds["in"])
init_weight(module.Wo, stds["out"])
elif isinstance(module, CsuepAttention):
init_weight(module.Wqkv, stds["in"])
init_weight(module.Wo, stds["out"])
elif isinstance(module, CsuepPredictionHead):
init_weight(module.dense, stds["out"])
elif isinstance(module, CsuepForMaskedLM):
init_weight(module.decoder, stds["out"])
elif isinstance(module, nn.LayerNorm):
module.weight.data.fill_(1.0)
if module.bias is not None:
module.bias.data.zero_()
@classmethod
def _autoset_attn_implementation(
cls,
config,
use_flash_attention_2: bool = False,
torch_dtype: Optional[torch.dtype] = None,
device_map: Optional[Union[str, Dict[str, int]]] = None,
check_device_map: bool = True,
):
# If the user didn't specify anything, try to use flash_attention_2 if available.
# Otherwise we fall back to the default SDPA -> Eager from the super() method.
# Csu-ep's FA2 implementation correctly handles non-fp16/bf16 dtypes, we don't
# need the FA2 warning for non-fp16/bf16 dtypes so we set fp16 for the FA2 check.
if config._attn_implementation_internal is None:
config._attn_implementation_internal = "flash_attention_2"
try:
return cls._check_and_enable_flash_attn_2(
config,
torch_dtype=torch.float16,
device_map=device_map,
hard_check_only=False,
check_device_map=check_device_map,
)
except (ValueError, ImportError):
config._attn_implementation_internal = None
return super()._autoset_attn_implementation(
config,
use_flash_attention_2=use_flash_attention_2,
torch_dtype=torch.float16,
device_map=device_map,
check_device_map=check_device_map,
)
def _maybe_set_compile(self):
if self.config.reference_compile is False:
return
if hasattr(self, "hf_device_map") and len(self.hf_device_map) > 1:
if self.config.reference_compile:
logger.warning_once(
"If `accelerate` split the model across devices, `torch.compile` will not work. "
"Falling back to non-compiled mode."
)
self.config.reference_compile = False
if self.device.type == "mps":
if self.config.reference_compile:
logger.warning_once(
"Compiling the model with `torch.compile` and using a `torch.mps` device is not supported. "
"Falling back to non-compiled mode."
)
self.config.reference_compile = False
if self.device.type == "cpu":
if self.config.reference_compile:
logger.warning_once(
"Compiling the model with `torch.compile` and using a `torch.cpu` device is not supported. "
"Falling back to non-compiled mode."
)
self.config.reference_compile = False
def resize_token_embeddings(self, *args, **kwargs):
model_embeds = super().resize_token_embeddings(*args, **kwargs)
if self.config.reference_compile in {True, None}:
if self.config.reference_compile:
logger.warning_once(
"Resizing token embeddings with `torch.compile` is not supported. Falling back to non-compiled mode."
)
self.config.reference_compile = False
return model_embeds
CSUEP_INPUTS_DOCSTRING = r"""
Args:
input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`):
Indices of input sequence tokens in the vocabulary. With Flash Attention 2.0, padding will be ignored
by default should you provide it.
Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and
[`PreTrainedTokenizer.__call__`] for details.
[What are input IDs?](../glossary#input-ids)
attention_mask (`torch.Tensor` of shape `(batch_size, sequence_length)`, *optional*):
Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`:
- 1 for tokens that are **not masked**,
- 0 for tokens that are **masked**.
[What are attention masks?](../glossary#attention-mask)
Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and
[`PreTrainedTokenizer.__call__`] for details.
If you want to change padding behavior, you should read [`modeling_opt._prepare_decoder_attention_mask`]
and modify to your needs. See diagram 1 in [the paper](https://arxiv.org/abs/1910.13461) for more
information on the default strategy.
- 1 indicates the head is **not masked**,
- 0 indicates the head is **masked**.
sliding_window_mask (`torch.Tensor` of shape `(batch_size, sequence_length)`, *optional*):
Mask to avoid performing attention on padding or far-away tokens. In CSUEP, only every few layers
perform global attention, while the rest perform local attention. This mask is used to avoid attending to
far-away tokens in the local attention layers when not using Flash Attention.
position_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
Indices of positions of each input sequence tokens in the position embeddings. Selected in the range `[0,
config.n_positions - 1]`.
[What are position IDs?](../glossary#position-ids)
inputs_embeds (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`, *optional*):
Optionally, instead of passing `input_ids` you can choose to directly pass an embedded representation. This
is useful if you want more control over how to convert `input_ids` indices into associated vectors than the
model's internal embedding lookup matrix.
indices (`torch.Tensor` of shape `(total_unpadded_tokens,)`, *optional*):
Indices of the non-padding tokens in the input sequence. Used for unpadding the output.
cu_seqlens (`torch.Tensor` of shape `(batch + 1,)`, *optional*):
Cumulative sequence lengths of the input sequences. Used to index the unpadded tensors.
max_seqlen (`int`, *optional*):
Maximum sequence length in the batch excluding padding tokens. Used to unpad input_ids and pad output tensors.
batch_size (`int`, *optional*):
Batch size of the input sequences. Used to pad the output tensors.
seq_len (`int`, *optional*):
Sequence length of the input sequences including padding tokens. Used to pad the output tensors.
output_attentions (`bool`, *optional*):
Whether or not to return the attentions tensors of all attention layers. See `attentions` under returned
tensors for more detail.
output_hidden_states (`bool`, *optional*):
Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors for
more detail.
return_dict (`bool`, *optional*):
Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple.
"""
@add_start_docstrings(
"The bare CSU-EP Model outputting raw hidden-states without any specific head on top.",
CSUEP_START_DOCSTRING,
)
class CsuepModel(CsuepPreTrainedModel):
def __init__(self, config: CsuepConfig):
super().__init__(config)
self.config = config
self.embeddings = CsuepEmbeddings(config)
self.layers = nn.ModuleList(
[CsuepEncoderLayer(config, layer_id) for layer_id in range(config.num_hidden_layers)]
)
self.final_norm = nn.LayerNorm(config.hidden_size, eps=config.norm_eps, bias=config.norm_bias)
self.gradient_checkpointing = False
self.post_init()
def get_input_embeddings(self):
return self.embeddings.tok_embeddings
def set_input_embeddings(self, value):
self.embeddings.tok_embeddings = value
@add_start_docstrings_to_model_forward(CSUEP_INPUTS_DOCSTRING)
@add_code_sample_docstrings(
checkpoint=_CHECKPOINT_FOR_DOC,
output_type=BaseModelOutput,
config_class=_CONFIG_FOR_DOC,
)
def forward(
self,
input_ids: Optional[torch.LongTensor] = None,
intensities: torch.FloatTensor = None,
attention_mask: Optional[torch.Tensor] = None,
sliding_window_mask: Optional[torch.Tensor] = None,
position_ids: Optional[torch.LongTensor] = None,
inputs_embeds: Optional[torch.Tensor] = None,
indices: Optional[torch.Tensor] = None,
cu_seqlens: Optional[torch.Tensor] = None,
max_seqlen: Optional[int] = None,
batch_size: Optional[int] = None,
seq_len: Optional[int] = None,
output_attentions: Optional[bool] = None,
output_hidden_states: Optional[bool] = None,
return_dict: Optional[bool] = None,
) -> Union[Tuple[torch.Tensor, ...], BaseModelOutput]:
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
output_hidden_states = (
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
)
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
if (input_ids is None) ^ (inputs_embeds is not None):
raise ValueError("You must specify exactly one of input_ids or inputs_embeds")
all_hidden_states = () if output_hidden_states else None
all_self_attentions = () if output_attentions else None
self._maybe_set_compile()
if input_ids is not None:
self.warn_if_padding_and_no_attention_mask(input_ids, attention_mask)
if batch_size is None and seq_len is None:
if inputs_embeds is not None:
batch_size, seq_len = inputs_embeds.shape[:2]
else:
batch_size, seq_len = input_ids.shape[:2]
device = input_ids.device if input_ids is not None else inputs_embeds.device
if attention_mask is None:
attention_mask = torch.ones((batch_size, seq_len), device=device, dtype=torch.bool)
repad = False
if self.config._attn_implementation == "flash_attention_2":
if indices is None and cu_seqlens is None and max_seqlen is None:
repad = True
if inputs_embeds is None:
with torch.no_grad():
input_ids, intensities, indices, cu_seqlens, max_seqlen, *_ = _unpad_csuep_input(
inputs=input_ids, intensities=intensities, attention_mask=attention_mask
)
else:
inputs_embeds, intensities, indices, cu_seqlens, max_seqlen, *_ = _unpad_csuep_input(
inputs=inputs_embeds, intensities=intensities, attention_mask=attention_mask
)
else:
if position_ids is None:
position_ids = torch.arange(seq_len, device=device).unsqueeze(0)
attention_mask, sliding_window_mask = self._update_attention_mask(
attention_mask, output_attentions=output_attentions
)
hidden_states = self.embeddings(input_ids=input_ids, intensities=intensities,inputs_embeds=inputs_embeds)
for encoder_layer in self.layers:
if output_hidden_states:
all_hidden_states = all_hidden_states + (hidden_states,)
if self.gradient_checkpointing and self.training:
layer_outputs = self._gradient_checkpointing_func(
encoder_layer.__call__,
hidden_states,
intensities,
attention_mask,
sliding_window_mask,
position_ids,
cu_seqlens,
max_seqlen,
output_attentions,
)
else:
layer_outputs = encoder_layer(
hidden_states,
intensities=intensities,
attention_mask=attention_mask,
sliding_window_mask=sliding_window_mask,
position_ids=position_ids,
cu_seqlens=cu_seqlens,
max_seqlen=max_seqlen,
output_attentions=output_attentions,
)
hidden_states = layer_outputs[0]
if output_attentions and len(layer_outputs) > 1:
all_self_attentions = all_self_attentions + (layer_outputs[1],)
if output_hidden_states:
all_hidden_states = all_hidden_states + (hidden_states,)
hidden_states = self.final_norm(hidden_states)
if repad:
hidden_states = _pad_csuep_output(
inputs=hidden_states, indices=indices, batch=batch_size, seqlen=seq_len
)
if all_hidden_states is not None:
all_hidden_states = tuple(
_pad_csuep_output(inputs=hs, indices=indices, batch=batch_size, seqlen=seq_len)
for hs in all_hidden_states
)
if not return_dict:
return tuple(v for v in [hidden_states, all_hidden_states, all_self_attentions] if v is not None)
return BaseModelOutput(
last_hidden_state=hidden_states,
hidden_states=all_hidden_states,
attentions=all_self_attentions,
)
def _update_attention_mask(self, attention_mask: torch.Tensor, output_attentions: bool) -> torch.Tensor:
if output_attentions:
if self.config._attn_implementation == "sdpa":
logger.warning_once(
"Outputting attentions is only supported with the 'eager' attention implementation, "
'not with "sdpa". Falling back to `attn_implementation="eager"`.'
)
self.config._attn_implementation = "eager"
elif self.config._attn_implementation != "eager":
logger.warning_once(
"Outputting attentions is only supported with the eager attention implementation, "
f'not with {self.config._attn_implementation}. Consider setting `attn_implementation="eager"`.'
" Setting `output_attentions=False`."
)
global_attention_mask = _prepare_4d_attention_mask(attention_mask, self.dtype)
# Create position indices
rows = torch.arange(global_attention_mask.shape[2]).unsqueeze(0)
# Calculate distance between positions
distance = torch.abs(rows - rows.T)
# Create sliding window mask (1 for positions within window, 0 outside)
window_mask = (
(distance <= self.config.local_attention // 2).unsqueeze(0).unsqueeze(0).to(attention_mask.device)
)
# Combine with existing mask
sliding_window_mask = global_attention_mask.masked_fill(window_mask.logical_not(), torch.finfo(self.dtype).min)
return global_attention_mask, sliding_window_mask
class CsuepPredictionHead(nn.Module):
def __init__(self, config: CsuepConfig):
super().__init__()
self.config = config
self.dense = nn.Linear(config.hidden_size, config.hidden_size, config.classifier_bias)
self.act = ACT2FN[config.classifier_activation]
self.norm = nn.LayerNorm(config.hidden_size, eps=config.norm_eps, bias=config.norm_bias)
def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
return self.norm(self.act(self.dense(hidden_states)))
def fixed_cross_entropy(source, target, num_items_in_batch: int = None, ignore_index: int = -100, **kwargs):
reduction = "sum" if num_items_in_batch is not None else "mean"
loss = nn.functional.cross_entropy(source, target, ignore_index=ignore_index, reduction=reduction)
if reduction == "sum":
loss = loss / num_items_in_batch
return loss
def ForMaskedLMLoss(
logits, labels, vocab_size: int, num_items_in_batch: int = None, ignore_index: int = -100, **kwargs
):
# Upcast to float if we need to compute the loss to avoid potential precision issues
logits = logits.float()
# Flatten the tokens
logits = logits.view(-1, vocab_size)
labels = labels.view(-1)
# Enable model parallelism
labels = labels.to(logits.device)
loss = fixed_cross_entropy(logits, labels, num_items_in_batch, ignore_index, **kwargs)
return loss
@add_start_docstrings(
"The CSU-EP Model with a decoder head on top that is used for masked language modeling.",
CSUEP_START_DOCSTRING,
)
class CsuepForMaskedLM(CsuepPreTrainedModel):
_tied_weights_keys = ["decoder.weight"]
def __init__(self, config: CsuepConfig):
super().__init__(config)
self.config = config
self.model = CsuepModel(config)
self.head = CsuepPredictionHead(config)
self.decoder = nn.Linear(config.hidden_size, config.vocab_size, bias=config.decoder_bias)
self.sparse_prediction = self.config.sparse_prediction
self.sparse_pred_ignore_index = self.config.sparse_pred_ignore_index
# Initialize weights and apply final processing
self.post_init()
def get_output_embeddings(self):
return self.decoder
def set_output_embeddings(self, new_embeddings: nn.Linear):
self.decoder = new_embeddings
@torch.compile(dynamic=True)
def compiled_head(self, output: torch.Tensor) -> torch.Tensor:
return self.decoder(self.head(output))
@add_start_docstrings_to_model_forward(CSUEP_INPUTS_DOCSTRING)
@add_code_sample_docstrings(
checkpoint=_CHECKPOINT_FOR_DOC,
output_type=MaskedLMOutput,
config_class=_CONFIG_FOR_DOC,
)
def forward(
self,
input_ids: Optional[torch.LongTensor] = None,
intensities: Optional[torch.Tensor] = None,
attention_mask: Optional[torch.Tensor] = None,
sliding_window_mask: Optional[torch.Tensor] = None,
position_ids: Optional[torch.Tensor] = None,
inputs_embeds: Optional[torch.Tensor] = None,
labels: Optional[torch.Tensor] = None,
indices: Optional[torch.Tensor] = None,
cu_seqlens: Optional[torch.Tensor] = None,
max_seqlen: Optional[int] = None,
batch_size: Optional[int] = None,
seq_len: Optional[int] = None,
output_attentions: Optional[bool] = None,
output_hidden_states: Optional[bool] = None,
return_dict: Optional[bool] = None,
**kwargs,
) -> Union[Tuple[torch.Tensor], MaskedLMOutput]:
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
self._maybe_set_compile()
if self.config._attn_implementation == "flash_attention_2":
if indices is None and cu_seqlens is None and max_seqlen is None:
if batch_size is None and seq_len is None:
if inputs_embeds is not None:
batch_size, seq_len = inputs_embeds.shape[:2]
else:
batch_size, seq_len = input_ids.shape[:2]
device = input_ids.device if input_ids is not None else inputs_embeds.device
if attention_mask is None:
attention_mask = torch.ones((batch_size, seq_len), device=device, dtype=torch.bool)
if inputs_embeds is None:
with torch.no_grad():
input_ids, intensities, indices, cu_seqlens, max_seqlen, position_ids, labels = _unpad_csuep_input(
inputs=input_ids, intensities=intensities,attention_mask=attention_mask, position_ids=position_ids, labels=labels
)
else:
inputs_embeds, intensities, indices, cu_seqlens, max_seqlen, position_ids, labels = _unpad_csuep_input(
inputs=inputs_embeds, intensities=intensities, attention_mask=attention_mask, position_ids=position_ids, labels=labels
)
outputs = self.model(
input_ids=input_ids,
intensities=intensities,
attention_mask=attention_mask,
sliding_window_mask=sliding_window_mask,
position_ids=position_ids,
inputs_embeds=inputs_embeds,
indices=indices,
cu_seqlens=cu_seqlens,
max_seqlen=max_seqlen,
batch_size=batch_size,
seq_len=seq_len,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
return_dict=return_dict,
)
last_hidden_state = outputs.last_hidden_state
if self.sparse_prediction and labels is not None:
# flatten labels and output first
labels = labels.view(-1)
last_hidden_state = last_hidden_state.view(labels.shape[0], -1)
# then filter out the non-masked tokens
mask_tokens = labels != self.sparse_pred_ignore_index
last_hidden_state = last_hidden_state[mask_tokens]
labels = labels[mask_tokens]
logits = (
self.compiled_head(last_hidden_state)
if self.config.reference_compile
else self.decoder(self.head(last_hidden_state))
)
loss = None
if labels is not None:
loss = ForMaskedLMLoss(logits, labels, vocab_size=self.config.vocab_size)
if self.config._attn_implementation == "flash_attention_2":
with nullcontext() if self.config.repad_logits_with_grad or labels is None else torch.no_grad():
logits = _pad_csuep_output(inputs=logits, indices=indices, batch=batch_size, seqlen=seq_len)
if not return_dict:
output = (logits,)
return ((loss,) + output) if loss is not None else output
return MaskedLMOutput(
loss=loss,
logits=logits,
hidden_states=outputs.hidden_states,
attentions=outputs.attentions,
)
__all__ = [
"CsuepConfig",
"CsuepModel",
"CsuepPreTrainedModel",
"CsuepForMaskedLM"
]