|
|
| import math |
| from contextlib import nullcontext |
| from typing import Dict, Literal, Optional, Tuple, Union |
| from einops import rearrange, repeat |
| import torch |
| import torch.nn.functional as F |
| import torch.utils.checkpoint |
| from torch import nn |
| from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss |
|
|
| from transformers.activations import ACT2FN |
| from transformers.configuration_utils import PretrainedConfig |
| from transformers.modeling_attn_mask_utils import _prepare_4d_attention_mask |
| from transformers.modeling_outputs import ( |
| BaseModelOutput, |
| MaskedLMOutput, |
| QuestionAnsweringModelOutput, |
| SequenceClassifierOutput, |
| TokenClassifierOutput, |
| ) |
| from transformers.modeling_utils import PreTrainedModel |
| from transformers.utils import ( |
| add_code_sample_docstrings, |
| add_start_docstrings, |
| add_start_docstrings_to_model_forward, |
| is_flash_attn_2_available, |
| logging, |
| ) |
| from transformers.models.gemma.modeling_gemma import GemmaRotaryEmbedding, apply_rotary_pos_emb |
|
|
|
|
| if is_flash_attn_2_available(): |
| from flash_attn.flash_attn_interface import flash_attn_varlen_qkvpacked_func |
| from flash_attn.layers.rotary import RotaryEmbedding |
| from flash_attn.ops.triton.rotary import apply_rotary |
| else: |
| RotaryEmbedding = object |
|
|
| _CHECKPOINT_FOR_DOC = "Csuep" |
| _CONFIG_FOR_DOC = "CsuepConfig" |
|
|
| logger = logging.get_logger(__name__) |
|
|
|
|
| class CsuepConfig(PretrainedConfig): |
| |
| model_type = "CSU-EP" |
| keys_to_ignore_at_inference = ["past_key_values"] |
|
|
| def __init__( |
| self, |
| vocab_size=1004, |
| hidden_size=768, |
| intermediate_size=1152, |
| num_hidden_layers=12, |
| num_attention_heads=12, |
| hidden_activation="gelu", |
| max_position_embeddings=400, |
| mlm_probability=0.3, |
| mask_token_id=1003, |
| initializer_range=0.02, |
| initializer_cutoff_factor=2.0, |
| norm_eps=1e-5, |
| norm_bias=False, |
| pad_token_id=1002, |
| global_rope_theta=1000.0, |
| attention_bias=False, |
| attention_dropout=0.1, |
| global_attn_every_n_layers=12, |
| local_attention=128, |
| local_rope_theta=1000.0, |
| embedding_dropout=0.0, |
| mlp_bias=False, |
| mlp_dropout=0.0, |
| decoder_bias=True, |
| intensity_encoding_method = "concatenate", |
| classifier_pooling: Literal["cls", "mean"] = "cls", |
| classifier_dropout=0.0, |
| classifier_bias=False, |
| classifier_activation="gelu", |
| deterministic_flash_attn=False, |
| sparse_prediction=False, |
| sparse_pred_ignore_index=-100, |
| reference_compile=False, |
| repad_logits_with_grad=False, |
| _attn_implementation = "eager", |
| **kwargs, |
| ): |
| super().__init__( |
| pad_token_id=pad_token_id, |
| **kwargs, |
| ) |
| self.vocab_size = vocab_size |
| self.max_position_embeddings = max_position_embeddings |
| self.hidden_size = hidden_size |
| self.intermediate_size = intermediate_size |
| self.num_hidden_layers = num_hidden_layers |
| self.num_attention_heads = num_attention_heads |
| self.mlm_probability = mlm_probability |
| self.mask_token_id = mask_token_id |
| self.initializer_range = initializer_range |
| self.initializer_cutoff_factor = initializer_cutoff_factor |
| self.norm_eps = norm_eps |
| self.norm_bias = norm_bias |
| self.global_rope_theta = global_rope_theta |
| self.attention_bias = attention_bias |
| self.attention_dropout = attention_dropout |
| self.hidden_activation = hidden_activation |
| self.global_attn_every_n_layers = global_attn_every_n_layers |
| self.local_attention = local_attention |
| self.local_rope_theta = local_rope_theta |
| self.embedding_dropout = embedding_dropout |
| self.mlp_bias = mlp_bias |
| self.mlp_dropout = mlp_dropout |
| self.decoder_bias = decoder_bias |
| self.intensity_encoding_method = intensity_encoding_method |
| self.classifier_pooling = classifier_pooling |
| self.classifier_dropout = classifier_dropout |
| self.classifier_bias = classifier_bias |
| self.classifier_activation = classifier_activation |
| self.deterministic_flash_attn = deterministic_flash_attn |
| self.sparse_prediction = sparse_prediction |
| self.sparse_pred_ignore_index = sparse_pred_ignore_index |
| self.reference_compile = reference_compile |
| self.repad_logits_with_grad = repad_logits_with_grad |
| self._attn_implementation = _attn_implementation |
| if self.classifier_pooling not in ["cls", "mean"]: |
| raise ValueError( |
| f'Invalid value for `classifier_pooling`, should be either "cls" or "mean", but is {self.classifier_pooling}.' |
| ) |
|
|
| def to_dict(self): |
| output = super().to_dict() |
| output.pop("reference_compile", None) |
| return output |
|
|
|
|
| def _unpad_csuep_input( |
| inputs: torch.Tensor, |
| intensities: torch.Tensor, |
| attention_mask: torch.Tensor, |
| position_ids: Optional[torch.Tensor] = None, |
| labels: Optional[torch.Tensor] = None, |
| ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, int, Optional[torch.Tensor], Optional[torch.Tensor]]: |
| """ |
| Remove padding from input sequences. |
| |
| Args: |
| inputs: (batch, seqlen, ...) or (batch, seqlen) |
| intensities: (batch, seqlen, ...) or (batch, seqlen) |
| attention_mask: (batch, seqlen), bool / int, 1 means valid and 0 means not valid. |
| position_ids: (batch, seqlen), int, position ids |
| labels: (batch, seqlen), int, labels |
| |
| Returns: |
| unpadded_inputs: (total_nnz, ...), where total_nnz = number of tokens selected in attention_mask. |
| unpadded_intens: (total_nnz, ...), where total_nnz = number of tokens selected in attention_mask. |
| indices: (total_nnz) |
| cu_seqlens: (batch + 1), the cumulative sequence lengths |
| max_seqlen_in_batch: int |
| unpadded_position_ids: (total_nnz) or None |
| unpadded_labels: (total_nnz) or None |
| """ |
| seqlens_in_batch = attention_mask.sum(dim=-1, dtype=torch.int32) |
| indices = torch.nonzero(attention_mask.flatten(), as_tuple=False).flatten() |
| max_seqlen_in_batch = int(seqlens_in_batch.max().item()) |
| cu_seqlens = torch.nn.functional.pad(torch.cumsum(seqlens_in_batch, dim=0, dtype=torch.int32), (1, 0)) |
|
|
| if inputs.dim() == 2: |
| unpadded_inputs = inputs.flatten()[indices] |
| unpadded_intens = intensities.flatten()[indices] |
| else: |
| batch, seqlen, *rest = inputs.shape |
| shape = batch * seqlen |
| unpadded_inputs = inputs.view(shape, *rest)[indices] |
| unpadded_intens = intensities.flatten()[indices] |
|
|
| unpadded_position_ids = position_ids.flatten()[indices] if position_ids is not None else None |
| unpadded_labels = labels.flatten()[indices] if labels is not None else None |
|
|
| return unpadded_inputs, unpadded_intens,indices, cu_seqlens, max_seqlen_in_batch, unpadded_position_ids, unpadded_labels |
|
|
|
|
| def _pad_csuep_output( |
| inputs: torch.Tensor, |
| indices: torch.Tensor, |
| batch: int, |
| seqlen: int, |
| ) -> torch.Tensor: |
| """ |
| Add padding to sequences. |
| |
| Args: |
| inputs: (total_nnz, ...) or (total_nnz,), where total_nnz = number of tokens selected in attention_mask. |
| indices: (total_nnz) |
| batch: int, batch size |
| seqlen: int, max sequence length |
| |
| Returns: |
| padded_inputs: (batch, seqlen, ...) or (batch, seqlen) |
| """ |
| if inputs.dim() == 1: |
| output = torch.zeros(batch * seqlen, dtype=inputs.dtype, device=inputs.device) |
| output[indices] = inputs |
| padded_inputs = output.view(batch, seqlen) |
| else: |
| _, *rest = inputs.shape |
| output = torch.zeros(batch * seqlen, *rest, dtype=inputs.dtype, device=inputs.device) |
| output[indices] = inputs |
| padded_inputs = output.view(batch, seqlen, *rest) |
|
|
| return padded_inputs |
|
|
|
|
| class ApplyRotaryEmbUnpad(torch.autograd.Function): |
| @staticmethod |
| def forward( |
| ctx, |
| qkv, |
| cos, |
| sin, |
| cu_seqlens: Optional[torch.Tensor] = None, |
| max_seqlen: Optional[int] = None, |
| ): |
| |
| qkv = qkv.contiguous() |
| total_nnz, _three, _nheads, headdim = qkv.shape |
| |
| |
| |
| qk = qkv[:, :2].view(total_nnz, -1, headdim) |
| apply_rotary( |
| qk, |
| cos, |
| sin, |
| seqlen_offsets=0, |
| cu_seqlens=cu_seqlens, |
| max_seqlen=max_seqlen, |
| interleaved=False, |
| inplace=True, |
| ) |
|
|
| ctx.save_for_backward(cos, sin, cu_seqlens) |
| ctx.max_seqlen = max_seqlen |
| return qkv |
|
|
| @staticmethod |
| def backward(ctx, do): |
| cos, sin, cu_seqlens = ctx.saved_tensors |
| do = do.contiguous() |
| total_nnz, _three, _nheads, headdim = do.shape |
| |
| |
| dqk = do[:, :2].view(total_nnz, -1, headdim) |
| apply_rotary( |
| dqk, |
| cos, |
| sin, |
| seqlen_offsets=0, |
| cu_seqlens=cu_seqlens, |
| max_seqlen=ctx.max_seqlen, |
| interleaved=False, |
| inplace=True, |
| conjugate=True, |
| ) |
|
|
| return do, None, None, None, None, None, None |
|
|
|
|
| def apply_rotary_unpadded( |
| qkv, |
| cos, |
| sin, |
| cu_seqlens: Optional[torch.Tensor] = None, |
| max_seqlen: Optional[int] = None, |
| ): |
| """ |
| Arguments: |
| qkv: (total_nnz, 3, nheads, headdim) - input tensor for packed QKV. |
| cos, sin: (seqlen_rotary, rotary_dim / 2) |
| interleaved: if True, rotate pairs of even and odd dimensions (GPT-J style) instead |
| of 1st half and 2nd half (GPT-NeoX style). |
| inplace: if True, apply rotary embedding in-place. |
| seqlen_offsets: (batch_size,) or int. Each sequence in x is shifted by this amount. |
| Most commonly used in inference when we have KV cache. |
| cu_seqlens: (batch + 1,) or None |
| max_seqlen: int |
| Return: |
| out: (total_nnz, dim) |
| rotary_dim must be <= headdim |
| Apply rotary embedding to the first rotary_dim of x. |
| """ |
| return ApplyRotaryEmbUnpad.apply(qkv, cos, sin, cu_seqlens, max_seqlen) |
|
|
| |
| class CsuepUnpaddedRotaryEmbedding(RotaryEmbedding): |
| """ |
| The rotary position embeddings applied directly to unpadded sequences. |
| """ |
|
|
| def __init__( |
| self, |
| dim: int, |
| base: float = 1000.0, |
| max_seqlen: Optional[int] = None, |
| device: Optional[torch.device] = None, |
| dtype: Optional[torch.dtype] = None, |
| ): |
| """ |
| max_seqlen: if max_seqlen, device, and dtype are provided, we precompute the cos_sin_cache |
| up to max_seqlen. If the max_seqlen, device, or dtype during training/inference differ, |
| the cos_sin_cache will be recomputed during the forward pass. |
| """ |
| super().__init__(dim=dim, base=base, pos_idx_in_fp32=True, device=device, interleaved=False) |
| self.max_seqlen = max_seqlen |
|
|
| if max_seqlen is not None and device is not None and dtype is not None: |
| self._update_cos_sin_cache(max_seqlen, device=device, dtype=dtype) |
|
|
| def forward( |
| self, |
| qkv: torch.Tensor, |
| cu_seqlens: torch.Tensor, |
| max_seqlen: Optional[int] = None, |
| ) -> Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]: |
| """ |
| Apply rotary embedding *inplace* to qkv. |
| qkv: (total_nnz, 3, nheads, headdim) |
| cu_seqlens: (batch + 1,) cumulative sequence lengths |
| max_seqlen: int max seq length in the batch |
| """ |
| if max_seqlen is not None: |
| self._update_cos_sin_cache(max_seqlen, device=qkv.device, dtype=qkv.dtype) |
|
|
| qkv = apply_rotary_unpadded( |
| qkv, |
| self._cos_cached, |
| self._sin_cached, |
| cu_seqlens=cu_seqlens, |
| max_seqlen=max_seqlen, |
| ) |
|
|
| return qkv |
|
|
| def extra_repr(self) -> str: |
| return f"dim={self.dim}, base={self.base}, scale_base={self.scale_base}" |
|
|
|
|
| class CsuepEmbeddings(nn.Module): |
| """ |
| Embeddings layer adapted to handle unpadded, 1D inputs for both m/z tokens |
| and intensities. Supports both concatenation and multiplication for intensity encoding. |
| """ |
|
|
| def __init__(self, config: CsuepConfig, intensity_encoding_method: str = "concatenate"): |
| super().__init__() |
| self.config = config |
| self.tok_embeddings = nn.Embedding( |
| config.vocab_size, config.hidden_size, padding_idx=config.pad_token_id |
| ) |
| |
| self.intensity_encoding_method = config.intensity_encoding_method |
| if self.intensity_encoding_method not in ["concatenate", "multiply"]: |
| raise ValueError("intensity_encoding_method must be either 'concatenate' or 'multiply'") |
|
|
| if self.intensity_encoding_method == "concatenate": |
| self.intensity_compress = nn.Linear(config.hidden_size + 1, config.hidden_size, bias=config.mlp_bias) |
| else: |
| self.intensity_compress = None |
|
|
| self.norm = nn.LayerNorm(config.hidden_size, eps=config.norm_eps, bias=config.norm_bias) |
| self.drop = nn.Dropout(config.embedding_dropout) |
|
|
| def _combine_with_intensities_unpadded( |
| self, token_embeddings: torch.Tensor, intensities: torch.Tensor |
| ) -> torch.Tensor: |
| """ |
| Combines token embeddings with intensity information for UNPADDED inputs. |
| |
| Args: |
| token_embeddings: (total_nnz, hidden_size) |
| intensities: (total_nnz,) |
| |
| Returns: |
| Combined embeddings of shape (total_nnz, hidden_size). |
| """ |
| if self.intensity_encoding_method == "concatenate": |
| intensity_reshaped = intensities.unsqueeze(-1) |
| combined = torch.cat([token_embeddings, intensity_reshaped], dim=-1) |
| return self.intensity_compress(combined) |
| |
| elif self.intensity_encoding_method == "multiply": |
| intensity_reshaped = intensities.unsqueeze(-1) |
| return token_embeddings * intensity_reshaped |
|
|
| @torch.compile(dynamic=True) |
| def compiled_embeddings(self, input_ids: torch.LongTensor, intensities: torch.FloatTensor) -> torch.Tensor: |
| token_embeddings = self.tok_embeddings(input_ids) |
| combined_embeddings = self._combine_with_intensities_unpadded(token_embeddings, intensities) |
| return self.drop(self.norm(combined_embeddings)) |
|
|
| def forward( |
| self, |
| input_ids: torch.LongTensor, |
| intensities: torch.FloatTensor, |
| inputs_embeds: Optional[torch.Tensor] = None, |
| ) -> torch.Tensor: |
|
|
| if self.config.reference_compile: |
| return self.compiled_embeddings(input_ids, intensities) |
| |
| if inputs_embeds is not None: |
| token_embeddings = inputs_embeds |
| else: |
| if input_ids is None: |
| raise ValueError("You must provide either input_ids or inputs_embeds") |
| token_embeddings = self.tok_embeddings(input_ids) |
| combined_embeddings = self._combine_with_intensities_unpadded(token_embeddings, intensities) |
| final_embeddings = self.drop(self.norm(combined_embeddings)) |
| return final_embeddings |
|
|
| class CsuepMLP(nn.Module): |
| """Applies the GLU at the end of each CSU-EP layer. |
| |
| Compared to the default BERT architecture, this block replaces :class:`~transformers.model.bert.modeling_bert.BertIntermediate` |
| and :class:`~transformers.model.bert.modeling_bert.SelfOutput` with a single module that has similar functionality. |
| """ |
|
|
| def __init__(self, config: CsuepConfig): |
| super().__init__() |
| self.config = config |
| self.Wi = nn.Linear(config.hidden_size, int(config.intermediate_size) * 2, bias=config.mlp_bias) |
| self.act = ACT2FN[config.hidden_activation] |
| self.drop = nn.Dropout(config.mlp_dropout) |
| self.Wo = nn.Linear(config.intermediate_size, config.hidden_size, bias=config.mlp_bias) |
|
|
| def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: |
| input, gate = self.Wi(hidden_states).chunk(2, dim=-1) |
| return self.Wo(self.drop(self.act(input) * gate)) |
|
|
| class CsuepRotaryEmbedding(GemmaRotaryEmbedding): |
| def __init__(self, config: CsuepConfig, dim: int, base: float, device: Optional[torch.device] = None): |
| super().__init__( |
| dim=dim, |
| max_position_embeddings=config.max_position_embeddings, |
| base=base, |
| device=device |
| ) |
|
|
|
|
| def eager_attention_forward( |
| module: "CsuepAttention", |
| qkv: torch. Tensor, |
| intensities: torch.Tensor, |
| attention_mask: torch.Tensor, |
| sliding_window_mask: torch.Tensor, |
| position_ids: Optional[torch.LongTensor], |
| local_attention: Tuple[int, int], |
| bs: int, |
| dim: int, |
| output_attentions: Optional[bool] = False, |
| **_kwargs, |
| ) -> Union[Tuple[torch.Tensor, torch.Tensor], Tuple[torch.Tensor]]: |
| |
| cos, sin = module.rotary_emb(qkv,position_ids=position_ids) |
| query, key, value = qkv.transpose(3, 1).unbind(dim=2) |
| |
| query, key = apply_rotary_pos_emb(query, key, cos, sin) |
|
|
| scale = module.head_dim**-0.5 |
| attn_weights = torch.matmul(query, key.transpose(2, 3)) * scale |
|
|
| if local_attention != (-1, -1): |
| attention_mask = sliding_window_mask |
|
|
| attn_weights = attn_weights + attention_mask |
|
|
| |
| attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query.dtype) |
| attn_weights = nn.functional.dropout(attn_weights, p=module.attention_dropout, training=module.training) |
| attn_output = torch.matmul(attn_weights, value) |
| attn_output = attn_output.transpose(1, 2).contiguous() |
| attn_output = attn_output.view(bs, -1, dim) |
| if output_attentions: |
| return (attn_output, attn_weights) |
| return (attn_output,) |
|
|
|
|
| def flash_attention_forward( |
| module: "CsuepAttention", |
| qkv: torch.Tensor, |
| intensities: torch.Tensor, |
| rotary_emb: CsuepUnpaddedRotaryEmbedding, |
| cu_seqlens: torch.Tensor, |
| max_seqlen: int, |
| local_attention: Tuple[int, int], |
| bs: int, |
| dim: int, |
| target_dtype: torch.dtype = torch.bfloat16, |
| **_kwargs, |
| ) -> Tuple[torch.Tensor]: |
| |
| qkv = rotary_emb(qkv, cu_seqlens=cu_seqlens, max_seqlen=max_seqlen) |
|
|
| convert_dtype = qkv.dtype not in (torch.float16, torch.bfloat16) |
| if convert_dtype: |
| |
| |
| orig_dtype = qkv.dtype |
| qkv = qkv.to(target_dtype) |
|
|
| attn = flash_attn_varlen_qkvpacked_func( |
| qkv, |
| cu_seqlens=cu_seqlens, |
| max_seqlen=max_seqlen, |
| dropout_p=module.attention_dropout if module.training else 0.0, |
| deterministic=module.deterministic_flash_attn, |
| window_size=local_attention, |
| ) |
| attn = attn.to(orig_dtype) |
| else: |
| attn = flash_attn_varlen_qkvpacked_func( |
| qkv, |
| cu_seqlens=cu_seqlens, |
| max_seqlen=max_seqlen, |
| dropout_p=module.attention_dropout if module.training else 0.0, |
| deterministic=module.deterministic_flash_attn, |
| window_size=local_attention, |
| ) |
| return (attn.view(bs, dim),) |
|
|
|
|
| def sdpa_attention_forward( |
| module: "CsuepAttention", |
| qkv: torch.Tensor, |
| attention_mask: torch.Tensor, |
| sliding_window_mask: torch.Tensor, |
| position_ids: Optional[torch.LongTensor], |
| local_attention: Tuple[int, int], |
| bs: int, |
| dim: int, |
| **_kwargs, |
| ) -> Tuple[torch.Tensor]: |
| |
| cos, sin = module.rotary_emb(qkv, position_ids=position_ids) |
| query, key, value = qkv.transpose(3, 1).unbind(dim=2) |
| |
| query, key = apply_rotary_pos_emb(query, key, cos, sin) |
|
|
| if local_attention != (-1, -1): |
| attention_mask = sliding_window_mask |
|
|
| attn_output = ( |
| F.scaled_dot_product_attention( |
| query, |
| key, |
| value, |
| dropout_p=module.attention_dropout if module.training else 0.0, |
| attn_mask=attention_mask, |
| ) |
| .transpose(1, 2) |
| .contiguous() |
| ) |
| attn_output = attn_output.view(bs, -1, dim) |
| return (attn_output,) |
|
|
|
|
| CSUEP_ATTENTION_FUNCTION = { |
| "flash_attention_2": flash_attention_forward, |
| "eager": eager_attention_forward, |
| "sdpa": sdpa_attention_forward, |
| } |
|
|
|
|
| class CsuepAttention(nn.Module): |
| """Performs multi-headed self attention on a batch of unpadded sequences. |
| |
| If Flash Attention 2 is installed, this module uses Flash Attention to improve throughput. |
| If Flash Attention 2 is not installed, the implementation will use PyTorch's SDPA kernel, |
| which requires padding and unpadding inputs, adding some overhead. |
| |
| See `forward` method for additional details. |
| """ |
|
|
| def __init__(self, config: CsuepConfig, layer_id: Optional[int] = None): |
| super().__init__() |
| self.config = config |
| self.layer_id = layer_id |
|
|
| if config.hidden_size % config.num_attention_heads != 0: |
| raise ValueError( |
| f"The hidden size ({config.hidden_size}) is not a multiple of the number of attention heads ({config.num_attention_heads})" |
| ) |
|
|
| self.attention_dropout = config.attention_dropout |
| self.deterministic_flash_attn = config.deterministic_flash_attn |
| self.num_heads = config.num_attention_heads |
| self.head_dim = config.hidden_size // config.num_attention_heads |
| self.all_head_size = self.head_dim * self.num_heads |
| self.Wqkv = nn.Linear(config.hidden_size, 3 * self.all_head_size, bias=config.attention_bias) |
|
|
| if layer_id % config.global_attn_every_n_layers != 0: |
| self.local_attention = (config.local_attention // 2, config.local_attention // 2) |
| else: |
| self.local_attention = (-1, -1) |
|
|
| rope_theta = config.global_rope_theta |
| max_position_embeddings = config.max_position_embeddings |
| if self.local_attention != (-1, -1): |
| if config.local_rope_theta is not None: |
| rope_theta = config.local_rope_theta |
| max_position_embeddings = config.local_attention |
| |
| if config._attn_implementation == "flash_attention_2": |
| self.rotary_emb = CsuepUnpaddedRotaryEmbedding( |
| dim=self.head_dim, base=rope_theta |
| ) |
| else: |
| import copy |
| config_copy = copy.deepcopy(config) |
| config_copy.rope_theta = rope_theta |
| self.rotary_emb = CsuepRotaryEmbedding(config=config_copy,dim=self.head_dim, base=rope_theta) |
|
|
| self.Wo = nn.Linear(config.hidden_size, config.hidden_size, bias=config.attention_bias) |
| self.out_drop = nn.Dropout(config.attention_dropout) if config.attention_dropout > 0.0 else nn.Identity() |
| self.pruned_heads = set() |
|
|
| def forward( |
| self, |
| hidden_states: torch.Tensor, |
| intensities:torch.Tensor, |
| output_attentions: Optional[bool] = False, |
| **kwargs, |
| ) -> torch.Tensor: |
| qkv = self.Wqkv(hidden_states) |
|
|
| bs = hidden_states.shape[0] |
| if self.config._attn_implementation == "flash_attention_2": |
| qkv = qkv.view(-1, 3, self.num_heads, self.head_dim) |
| else: |
| qkv = qkv.view(bs, -1, 3, self.num_heads, self.head_dim) |
|
|
| attn_outputs = CSUEP_ATTENTION_FUNCTION[self.config._attn_implementation]( |
| self, |
| qkv=qkv, |
| intensities=intensities, |
| rotary_emb=self.rotary_emb, |
| local_attention=self.local_attention, |
| bs=bs, |
| dim=self.all_head_size, |
| output_attentions=output_attentions, |
| **kwargs, |
| ) |
| hidden_states = attn_outputs[0] |
| hidden_states = self.out_drop(self.Wo(hidden_states)) |
|
|
| return (hidden_states,) + attn_outputs[1:] |
|
|
|
|
| class CsuepEncoderLayer(nn.Module): |
| def __init__(self, config: CsuepConfig, layer_id: Optional[int] = None): |
| super().__init__() |
| self.config = config |
| if layer_id == 0: |
| self.attn_norm = nn.Identity() |
| else: |
| self.attn_norm = nn.LayerNorm(config.hidden_size, eps=config.norm_eps, bias=config.norm_bias) |
| self.attn = CsuepAttention(config=config, layer_id=layer_id) |
| self.mlp_norm = nn.LayerNorm(config.hidden_size, eps=config.norm_eps, bias=config.norm_bias) |
| self.mlp = CsuepMLP(config) |
|
|
| @torch.compile(dynamic=True) |
| def compiled_mlp(self, hidden_states: torch.Tensor) -> torch.Tensor: |
| return self.mlp(self.mlp_norm(hidden_states)) |
|
|
| def forward( |
| self, |
| hidden_states: torch.Tensor, |
| intensities: torch.Tensor, |
| attention_mask: Optional[torch.Tensor] = None, |
| sliding_window_mask: Optional[torch.Tensor] = None, |
| position_ids: Optional[torch.LongTensor] = None, |
| cu_seqlens: Optional[torch.Tensor] = None, |
| max_seqlen: Optional[int] = None, |
| output_attentions: Optional[bool] = False, |
| ) -> torch.Tensor: |
| attn_outputs = self.attn( |
| self.attn_norm(hidden_states), |
| intensities, |
| attention_mask=attention_mask, |
| sliding_window_mask=sliding_window_mask, |
| position_ids=position_ids, |
| cu_seqlens=cu_seqlens, |
| max_seqlen=max_seqlen, |
| output_attentions=output_attentions, |
| ) |
| hidden_states = hidden_states + attn_outputs[0] |
| mlp_output = ( |
| self.compiled_mlp(hidden_states) |
| if self.config.reference_compile |
| else self.mlp(self.mlp_norm(hidden_states)) |
| ) |
| hidden_states = hidden_states + mlp_output |
|
|
| return (hidden_states,) + attn_outputs[1:] |
|
|
|
|
| CSUEP_START_DOCSTRING = r""" |
| This model inherits from [`PreTrainedModel`]. Check the superclass documentation for the generic methods the |
| library implements for all its model (such as downloading or saving, resizing the input embeddings, pruning heads |
| etc.) |
| |
| This model is also a PyTorch [torch.nn.Module](https://pytorch.org/docs/stable/nn.html#torch.nn.Module) subclass. |
| Use it as a regular PyTorch Module and refer to the PyTorch documentation for all matter related to general usage |
| and behavior. |
| |
| Parameters: |
| config ([`CsuepConfig`]): |
| Model configuration class with all the parameters of the model. Initializing with a config file does not |
| load the weights associated with the model, only the configuration. Check out the |
| [`~PreTrainedModel.from_pretrained`] method to load the model weights. |
| """ |
|
|
|
|
| @add_start_docstrings( |
| "The bare CSU-EP Model outputting raw hidden-states without any specific head on top.", |
| CSUEP_START_DOCSTRING, |
| ) |
| class CsuepPreTrainedModel(PreTrainedModel): |
| config_class = CsuepConfig |
| base_model_prefix = "model" |
| supports_gradient_checkpointing = True |
| _no_split_modules = ["CsuepEmbeddings", "CsuepEncoderLayer"] |
| _supports_flash_attn_2 = True |
| _supports_sdpa = True |
| _supports_flex_attn = False |
|
|
| def _init_weights(self, module: nn.Module): |
| cutoff_factor = self.config.initializer_cutoff_factor |
| if cutoff_factor is None: |
| cutoff_factor = 3 |
|
|
| def init_weight(module: nn.Module, std: float): |
| nn.init.trunc_normal_( |
| module.weight, |
| mean=0.0, |
| std=std, |
| a=-cutoff_factor * std, |
| b=cutoff_factor * std, |
| ) |
|
|
| if isinstance(module, nn.Linear): |
| if module.bias is not None: |
| nn.init.zeros_(module.bias) |
|
|
| stds = { |
| "in": self.config.initializer_range, |
| "out": self.config.initializer_range / math.sqrt(2.0 * self.config.num_hidden_layers), |
| "embedding": self.config.initializer_range, |
| "final_out": self.config.hidden_size**-0.5, |
| } |
|
|
| if isinstance(module, CsuepEmbeddings): |
| init_weight(module.tok_embeddings, stds["embedding"]) |
| elif isinstance(module, CsuepMLP): |
| init_weight(module.Wi, stds["in"]) |
| init_weight(module.Wo, stds["out"]) |
| elif isinstance(module, CsuepAttention): |
| init_weight(module.Wqkv, stds["in"]) |
| init_weight(module.Wo, stds["out"]) |
| elif isinstance(module, CsuepPredictionHead): |
| init_weight(module.dense, stds["out"]) |
| elif isinstance(module, CsuepForMaskedLM): |
| init_weight(module.decoder, stds["out"]) |
| elif isinstance(module, nn.LayerNorm): |
| module.weight.data.fill_(1.0) |
| if module.bias is not None: |
| module.bias.data.zero_() |
|
|
| @classmethod |
| def _autoset_attn_implementation( |
| cls, |
| config, |
| use_flash_attention_2: bool = False, |
| torch_dtype: Optional[torch.dtype] = None, |
| device_map: Optional[Union[str, Dict[str, int]]] = None, |
| check_device_map: bool = True, |
| ): |
| |
| |
| |
| |
| if config._attn_implementation_internal is None: |
| config._attn_implementation_internal = "flash_attention_2" |
| try: |
| return cls._check_and_enable_flash_attn_2( |
| config, |
| torch_dtype=torch.float16, |
| device_map=device_map, |
| hard_check_only=False, |
| check_device_map=check_device_map, |
| ) |
| except (ValueError, ImportError): |
| config._attn_implementation_internal = None |
| return super()._autoset_attn_implementation( |
| config, |
| use_flash_attention_2=use_flash_attention_2, |
| torch_dtype=torch.float16, |
| device_map=device_map, |
| check_device_map=check_device_map, |
| ) |
|
|
| def _maybe_set_compile(self): |
| if self.config.reference_compile is False: |
| return |
|
|
| if hasattr(self, "hf_device_map") and len(self.hf_device_map) > 1: |
| if self.config.reference_compile: |
| logger.warning_once( |
| "If `accelerate` split the model across devices, `torch.compile` will not work. " |
| "Falling back to non-compiled mode." |
| ) |
| self.config.reference_compile = False |
|
|
| if self.device.type == "mps": |
| if self.config.reference_compile: |
| logger.warning_once( |
| "Compiling the model with `torch.compile` and using a `torch.mps` device is not supported. " |
| "Falling back to non-compiled mode." |
| ) |
| self.config.reference_compile = False |
|
|
| if self.device.type == "cpu": |
| if self.config.reference_compile: |
| logger.warning_once( |
| "Compiling the model with `torch.compile` and using a `torch.cpu` device is not supported. " |
| "Falling back to non-compiled mode." |
| ) |
| self.config.reference_compile = False |
|
|
|
|
| def resize_token_embeddings(self, *args, **kwargs): |
| model_embeds = super().resize_token_embeddings(*args, **kwargs) |
|
|
| if self.config.reference_compile in {True, None}: |
| if self.config.reference_compile: |
| logger.warning_once( |
| "Resizing token embeddings with `torch.compile` is not supported. Falling back to non-compiled mode." |
| ) |
| self.config.reference_compile = False |
|
|
| return model_embeds |
|
|
|
|
| CSUEP_INPUTS_DOCSTRING = r""" |
| Args: |
| input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`): |
| Indices of input sequence tokens in the vocabulary. With Flash Attention 2.0, padding will be ignored |
| by default should you provide it. |
| |
| Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and |
| [`PreTrainedTokenizer.__call__`] for details. |
| |
| [What are input IDs?](../glossary#input-ids) |
| attention_mask (`torch.Tensor` of shape `(batch_size, sequence_length)`, *optional*): |
| Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`: |
| |
| - 1 for tokens that are **not masked**, |
| - 0 for tokens that are **masked**. |
| |
| [What are attention masks?](../glossary#attention-mask) |
| |
| Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and |
| [`PreTrainedTokenizer.__call__`] for details. |
| |
| If you want to change padding behavior, you should read [`modeling_opt._prepare_decoder_attention_mask`] |
| and modify to your needs. See diagram 1 in [the paper](https://arxiv.org/abs/1910.13461) for more |
| information on the default strategy. |
| |
| - 1 indicates the head is **not masked**, |
| - 0 indicates the head is **masked**. |
| sliding_window_mask (`torch.Tensor` of shape `(batch_size, sequence_length)`, *optional*): |
| Mask to avoid performing attention on padding or far-away tokens. In CSUEP, only every few layers |
| perform global attention, while the rest perform local attention. This mask is used to avoid attending to |
| far-away tokens in the local attention layers when not using Flash Attention. |
| position_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*): |
| Indices of positions of each input sequence tokens in the position embeddings. Selected in the range `[0, |
| config.n_positions - 1]`. |
| |
| [What are position IDs?](../glossary#position-ids) |
| inputs_embeds (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`, *optional*): |
| Optionally, instead of passing `input_ids` you can choose to directly pass an embedded representation. This |
| is useful if you want more control over how to convert `input_ids` indices into associated vectors than the |
| model's internal embedding lookup matrix. |
| indices (`torch.Tensor` of shape `(total_unpadded_tokens,)`, *optional*): |
| Indices of the non-padding tokens in the input sequence. Used for unpadding the output. |
| cu_seqlens (`torch.Tensor` of shape `(batch + 1,)`, *optional*): |
| Cumulative sequence lengths of the input sequences. Used to index the unpadded tensors. |
| max_seqlen (`int`, *optional*): |
| Maximum sequence length in the batch excluding padding tokens. Used to unpad input_ids and pad output tensors. |
| batch_size (`int`, *optional*): |
| Batch size of the input sequences. Used to pad the output tensors. |
| seq_len (`int`, *optional*): |
| Sequence length of the input sequences including padding tokens. Used to pad the output tensors. |
| output_attentions (`bool`, *optional*): |
| Whether or not to return the attentions tensors of all attention layers. See `attentions` under returned |
| tensors for more detail. |
| output_hidden_states (`bool`, *optional*): |
| Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors for |
| more detail. |
| return_dict (`bool`, *optional*): |
| Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple. |
| """ |
|
|
|
|
| @add_start_docstrings( |
| "The bare CSU-EP Model outputting raw hidden-states without any specific head on top.", |
| CSUEP_START_DOCSTRING, |
| ) |
| class CsuepModel(CsuepPreTrainedModel): |
| def __init__(self, config: CsuepConfig): |
| super().__init__(config) |
| self.config = config |
| self.embeddings = CsuepEmbeddings(config) |
| self.layers = nn.ModuleList( |
| [CsuepEncoderLayer(config, layer_id) for layer_id in range(config.num_hidden_layers)] |
| ) |
| self.final_norm = nn.LayerNorm(config.hidden_size, eps=config.norm_eps, bias=config.norm_bias) |
| self.gradient_checkpointing = False |
| self.post_init() |
|
|
| def get_input_embeddings(self): |
| return self.embeddings.tok_embeddings |
|
|
| def set_input_embeddings(self, value): |
| self.embeddings.tok_embeddings = value |
|
|
| @add_start_docstrings_to_model_forward(CSUEP_INPUTS_DOCSTRING) |
| @add_code_sample_docstrings( |
| checkpoint=_CHECKPOINT_FOR_DOC, |
| output_type=BaseModelOutput, |
| config_class=_CONFIG_FOR_DOC, |
| ) |
| def forward( |
| self, |
| input_ids: Optional[torch.LongTensor] = None, |
| intensities: torch.FloatTensor = None, |
| attention_mask: Optional[torch.Tensor] = None, |
| sliding_window_mask: Optional[torch.Tensor] = None, |
| position_ids: Optional[torch.LongTensor] = None, |
| inputs_embeds: Optional[torch.Tensor] = None, |
| indices: Optional[torch.Tensor] = None, |
| cu_seqlens: Optional[torch.Tensor] = None, |
| max_seqlen: Optional[int] = None, |
| batch_size: Optional[int] = None, |
| seq_len: Optional[int] = None, |
| output_attentions: Optional[bool] = None, |
| output_hidden_states: Optional[bool] = None, |
| return_dict: Optional[bool] = None, |
| ) -> Union[Tuple[torch.Tensor, ...], BaseModelOutput]: |
| output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions |
| output_hidden_states = ( |
| output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states |
| ) |
| return_dict = return_dict if return_dict is not None else self.config.use_return_dict |
|
|
| if (input_ids is None) ^ (inputs_embeds is not None): |
| raise ValueError("You must specify exactly one of input_ids or inputs_embeds") |
|
|
| all_hidden_states = () if output_hidden_states else None |
| all_self_attentions = () if output_attentions else None |
|
|
| self._maybe_set_compile() |
|
|
| if input_ids is not None: |
| self.warn_if_padding_and_no_attention_mask(input_ids, attention_mask) |
|
|
| if batch_size is None and seq_len is None: |
| if inputs_embeds is not None: |
| batch_size, seq_len = inputs_embeds.shape[:2] |
| else: |
| batch_size, seq_len = input_ids.shape[:2] |
| device = input_ids.device if input_ids is not None else inputs_embeds.device |
|
|
| if attention_mask is None: |
| attention_mask = torch.ones((batch_size, seq_len), device=device, dtype=torch.bool) |
|
|
| repad = False |
| if self.config._attn_implementation == "flash_attention_2": |
| if indices is None and cu_seqlens is None and max_seqlen is None: |
| repad = True |
| if inputs_embeds is None: |
| with torch.no_grad(): |
| input_ids, intensities, indices, cu_seqlens, max_seqlen, *_ = _unpad_csuep_input( |
| inputs=input_ids, intensities=intensities, attention_mask=attention_mask |
| ) |
| else: |
| inputs_embeds, intensities, indices, cu_seqlens, max_seqlen, *_ = _unpad_csuep_input( |
| inputs=inputs_embeds, intensities=intensities, attention_mask=attention_mask |
| ) |
| else: |
| if position_ids is None: |
| position_ids = torch.arange(seq_len, device=device).unsqueeze(0) |
|
|
| attention_mask, sliding_window_mask = self._update_attention_mask( |
| attention_mask, output_attentions=output_attentions |
| ) |
|
|
| hidden_states = self.embeddings(input_ids=input_ids, intensities=intensities,inputs_embeds=inputs_embeds) |
|
|
| for encoder_layer in self.layers: |
| if output_hidden_states: |
| all_hidden_states = all_hidden_states + (hidden_states,) |
|
|
| if self.gradient_checkpointing and self.training: |
| layer_outputs = self._gradient_checkpointing_func( |
| encoder_layer.__call__, |
| hidden_states, |
| intensities, |
| attention_mask, |
| sliding_window_mask, |
| position_ids, |
| cu_seqlens, |
| max_seqlen, |
| output_attentions, |
| ) |
| else: |
| layer_outputs = encoder_layer( |
| hidden_states, |
| intensities=intensities, |
| attention_mask=attention_mask, |
| sliding_window_mask=sliding_window_mask, |
| position_ids=position_ids, |
| cu_seqlens=cu_seqlens, |
| max_seqlen=max_seqlen, |
| output_attentions=output_attentions, |
| ) |
| hidden_states = layer_outputs[0] |
| if output_attentions and len(layer_outputs) > 1: |
| all_self_attentions = all_self_attentions + (layer_outputs[1],) |
|
|
| if output_hidden_states: |
| all_hidden_states = all_hidden_states + (hidden_states,) |
|
|
| hidden_states = self.final_norm(hidden_states) |
|
|
| if repad: |
| hidden_states = _pad_csuep_output( |
| inputs=hidden_states, indices=indices, batch=batch_size, seqlen=seq_len |
| ) |
| if all_hidden_states is not None: |
| all_hidden_states = tuple( |
| _pad_csuep_output(inputs=hs, indices=indices, batch=batch_size, seqlen=seq_len) |
| for hs in all_hidden_states |
| ) |
|
|
| if not return_dict: |
| return tuple(v for v in [hidden_states, all_hidden_states, all_self_attentions] if v is not None) |
| return BaseModelOutput( |
| last_hidden_state=hidden_states, |
| hidden_states=all_hidden_states, |
| attentions=all_self_attentions, |
| ) |
|
|
| def _update_attention_mask(self, attention_mask: torch.Tensor, output_attentions: bool) -> torch.Tensor: |
| if output_attentions: |
| if self.config._attn_implementation == "sdpa": |
| logger.warning_once( |
| "Outputting attentions is only supported with the 'eager' attention implementation, " |
| 'not with "sdpa". Falling back to `attn_implementation="eager"`.' |
| ) |
| self.config._attn_implementation = "eager" |
| elif self.config._attn_implementation != "eager": |
| logger.warning_once( |
| "Outputting attentions is only supported with the eager attention implementation, " |
| f'not with {self.config._attn_implementation}. Consider setting `attn_implementation="eager"`.' |
| " Setting `output_attentions=False`." |
| ) |
|
|
| global_attention_mask = _prepare_4d_attention_mask(attention_mask, self.dtype) |
|
|
| |
| rows = torch.arange(global_attention_mask.shape[2]).unsqueeze(0) |
| |
| distance = torch.abs(rows - rows.T) |
|
|
| |
| window_mask = ( |
| (distance <= self.config.local_attention // 2).unsqueeze(0).unsqueeze(0).to(attention_mask.device) |
| ) |
| |
| sliding_window_mask = global_attention_mask.masked_fill(window_mask.logical_not(), torch.finfo(self.dtype).min) |
|
|
| return global_attention_mask, sliding_window_mask |
|
|
|
|
| class CsuepPredictionHead(nn.Module): |
| def __init__(self, config: CsuepConfig): |
| super().__init__() |
| self.config = config |
| self.dense = nn.Linear(config.hidden_size, config.hidden_size, config.classifier_bias) |
| self.act = ACT2FN[config.classifier_activation] |
| self.norm = nn.LayerNorm(config.hidden_size, eps=config.norm_eps, bias=config.norm_bias) |
|
|
| def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: |
| return self.norm(self.act(self.dense(hidden_states))) |
|
|
| def fixed_cross_entropy(source, target, num_items_in_batch: int = None, ignore_index: int = -100, **kwargs): |
| reduction = "sum" if num_items_in_batch is not None else "mean" |
| loss = nn.functional.cross_entropy(source, target, ignore_index=ignore_index, reduction=reduction) |
| if reduction == "sum": |
| loss = loss / num_items_in_batch |
| return loss |
|
|
| def ForMaskedLMLoss( |
| logits, labels, vocab_size: int, num_items_in_batch: int = None, ignore_index: int = -100, **kwargs |
| ): |
| |
| logits = logits.float() |
|
|
| |
| logits = logits.view(-1, vocab_size) |
| labels = labels.view(-1) |
| |
|
|
| labels = labels.to(logits.device) |
| loss = fixed_cross_entropy(logits, labels, num_items_in_batch, ignore_index, **kwargs) |
| return loss |
|
|
| @add_start_docstrings( |
| "The CSU-EP Model with a decoder head on top that is used for masked language modeling.", |
| CSUEP_START_DOCSTRING, |
| ) |
| class CsuepForMaskedLM(CsuepPreTrainedModel): |
| _tied_weights_keys = ["decoder.weight"] |
|
|
| def __init__(self, config: CsuepConfig): |
| super().__init__(config) |
| self.config = config |
| self.model = CsuepModel(config) |
| self.head = CsuepPredictionHead(config) |
| self.decoder = nn.Linear(config.hidden_size, config.vocab_size, bias=config.decoder_bias) |
|
|
| self.sparse_prediction = self.config.sparse_prediction |
| self.sparse_pred_ignore_index = self.config.sparse_pred_ignore_index |
|
|
| |
| self.post_init() |
|
|
| def get_output_embeddings(self): |
| return self.decoder |
|
|
| def set_output_embeddings(self, new_embeddings: nn.Linear): |
| self.decoder = new_embeddings |
|
|
| @torch.compile(dynamic=True) |
| def compiled_head(self, output: torch.Tensor) -> torch.Tensor: |
| return self.decoder(self.head(output)) |
|
|
| @add_start_docstrings_to_model_forward(CSUEP_INPUTS_DOCSTRING) |
| @add_code_sample_docstrings( |
| checkpoint=_CHECKPOINT_FOR_DOC, |
| output_type=MaskedLMOutput, |
| config_class=_CONFIG_FOR_DOC, |
| ) |
| def forward( |
| self, |
| input_ids: Optional[torch.LongTensor] = None, |
| intensities: Optional[torch.Tensor] = None, |
| attention_mask: Optional[torch.Tensor] = None, |
| sliding_window_mask: Optional[torch.Tensor] = None, |
| position_ids: Optional[torch.Tensor] = None, |
| inputs_embeds: Optional[torch.Tensor] = None, |
| labels: Optional[torch.Tensor] = None, |
| indices: Optional[torch.Tensor] = None, |
| cu_seqlens: Optional[torch.Tensor] = None, |
| max_seqlen: Optional[int] = None, |
| batch_size: Optional[int] = None, |
| seq_len: Optional[int] = None, |
| output_attentions: Optional[bool] = None, |
| output_hidden_states: Optional[bool] = None, |
| return_dict: Optional[bool] = None, |
| **kwargs, |
| ) -> Union[Tuple[torch.Tensor], MaskedLMOutput]: |
| return_dict = return_dict if return_dict is not None else self.config.use_return_dict |
| self._maybe_set_compile() |
|
|
| if self.config._attn_implementation == "flash_attention_2": |
| if indices is None and cu_seqlens is None and max_seqlen is None: |
| if batch_size is None and seq_len is None: |
| if inputs_embeds is not None: |
| batch_size, seq_len = inputs_embeds.shape[:2] |
| else: |
| batch_size, seq_len = input_ids.shape[:2] |
| device = input_ids.device if input_ids is not None else inputs_embeds.device |
|
|
| if attention_mask is None: |
| attention_mask = torch.ones((batch_size, seq_len), device=device, dtype=torch.bool) |
|
|
| if inputs_embeds is None: |
| with torch.no_grad(): |
| input_ids, intensities, indices, cu_seqlens, max_seqlen, position_ids, labels = _unpad_csuep_input( |
| inputs=input_ids, intensities=intensities,attention_mask=attention_mask, position_ids=position_ids, labels=labels |
| ) |
| else: |
| inputs_embeds, intensities, indices, cu_seqlens, max_seqlen, position_ids, labels = _unpad_csuep_input( |
| inputs=inputs_embeds, intensities=intensities, attention_mask=attention_mask, position_ids=position_ids, labels=labels |
| ) |
|
|
| outputs = self.model( |
| input_ids=input_ids, |
| intensities=intensities, |
| attention_mask=attention_mask, |
| sliding_window_mask=sliding_window_mask, |
| position_ids=position_ids, |
| inputs_embeds=inputs_embeds, |
| indices=indices, |
| cu_seqlens=cu_seqlens, |
| max_seqlen=max_seqlen, |
| batch_size=batch_size, |
| seq_len=seq_len, |
| output_attentions=output_attentions, |
| output_hidden_states=output_hidden_states, |
| return_dict=return_dict, |
| ) |
| last_hidden_state = outputs.last_hidden_state |
|
|
| if self.sparse_prediction and labels is not None: |
| |
| labels = labels.view(-1) |
| last_hidden_state = last_hidden_state.view(labels.shape[0], -1) |
|
|
| |
| mask_tokens = labels != self.sparse_pred_ignore_index |
| last_hidden_state = last_hidden_state[mask_tokens] |
| labels = labels[mask_tokens] |
|
|
| logits = ( |
| self.compiled_head(last_hidden_state) |
| if self.config.reference_compile |
| else self.decoder(self.head(last_hidden_state)) |
| ) |
|
|
| loss = None |
| if labels is not None: |
| loss = ForMaskedLMLoss(logits, labels, vocab_size=self.config.vocab_size) |
|
|
| if self.config._attn_implementation == "flash_attention_2": |
| with nullcontext() if self.config.repad_logits_with_grad or labels is None else torch.no_grad(): |
| logits = _pad_csuep_output(inputs=logits, indices=indices, batch=batch_size, seqlen=seq_len) |
|
|
| if not return_dict: |
| output = (logits,) |
| return ((loss,) + output) if loss is not None else output |
|
|
| return MaskedLMOutput( |
| loss=loss, |
| logits=logits, |
| hidden_states=outputs.hidden_states, |
| attentions=outputs.attentions, |
| ) |
|
|
|
|
|
|
| __all__ = [ |
| "CsuepConfig", |
| "CsuepModel", |
| "CsuepPreTrainedModel", |
| "CsuepForMaskedLM" |
| ] |
|
|