| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
|
|
| import math |
| from typing import List, Optional, Tuple, Union |
|
|
| import torch |
| import torch.nn.functional as F |
| import torch.utils.checkpoint |
| from torch import nn |
| from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss |
| from transformers import GenerationConfig |
| from transformers.generation.utils import NEED_SETUP_CACHE_CLASSES_MAPPING, GenerationMixin, GenerateOutput |
| from transformers.modeling_utils import PreTrainedModel |
| from transformers.models.auto.modeling_auto import MODEL_FOR_CAUSAL_LM_MAPPING_NAMES |
| from transformers.utils import ( |
| add_start_docstrings, |
| add_start_docstrings_to_model_forward, |
| is_flash_attn_greater_or_equal_2_10, |
| logging, |
| replace_return_docstrings, |
| ) |
|
|
| from .block_config import AttentionConfig, FFNConfig |
| from .configuration_decilm import DeciLMConfig |
| from .transformers_4_44_2__activations import ACT2FN |
| from .transformers_4_44_2__cache_utils import Cache, StaticCache |
| from .transformers_4_44_2__modeling_attn_mask_utils import AttentionMaskConverter |
| from .transformers_4_44_2__modeling_flash_attention_utils_backward_compat import _flash_attention_forward |
| from .transformers_4_44_2__modeling_outputs import ( |
| BaseModelOutputWithPast, |
| CausalLMOutputWithPast, |
| QuestionAnsweringModelOutput, |
| SequenceClassifierOutputWithPast, |
| TokenClassifierOutput, |
| ) |
| from .transformers_4_44_2__modeling_rope_utils import ROPE_INIT_FUNCTIONS |
| from .transformers_4_44_2__pytorch_utils import ALL_LAYERNORM_LAYERS |
| from .variable_cache import VariableCache |
|
|
| MODEL_FOR_CAUSAL_LM_MAPPING_NAMES[DeciLMConfig.model_type] = "DeciLMForCausalLM" |
| logger = logging.get_logger(__name__) |
|
|
| _CONFIG_FOR_DOC = "DeciLMConfig" |
|
|
|
|
| def _prepare_4d_causal_attention_mask_with_cache_position( |
| attention_mask: torch.Tensor, |
| sequence_length: int, |
| target_length: int, |
| dtype: torch.dtype, |
| device: torch.device, |
| min_dtype: float, |
| cache_position: torch.Tensor, |
| batch_size: int, |
| ): |
| """ |
| Creates a causal 4D mask of shape `(batch_size, 1, query_length, key_value_length)` from a 2D mask of shape |
| `(batch_size, key_value_length)`, or if the input `attention_mask` is already 4D, do nothing. |
| |
| Args: |
| attention_mask (`torch.Tensor`): |
| A 2D attention mask of shape `(batch_size, key_value_length)` or a 4D attention mask of shape `(batch_size, 1, query_length, key_value_length)`. |
| sequence_length (`int`): |
| The sequence length being processed. |
| target_length (`int`): |
| The target length: when generating with static cache, the mask should be as long as the static cache, to account for the 0 padding, the part of the cache that is not filled yet. |
| dtype (`torch.dtype`): |
| The dtype to use for the 4D attention mask. |
| device (`torch.device`): |
| The device to place the 4D attention mask on. |
| min_dtype (`float`): |
| The minimum value representable with the dtype `dtype`. |
| cache_position (`torch.Tensor`): |
| Indices depicting the position of the input sequence tokens in the sequence. |
| batch_size (`torch.Tensor`): |
| Batch size. |
| """ |
| if attention_mask is not None and attention_mask.dim() == 4: |
| |
| causal_mask = attention_mask |
| else: |
| causal_mask = torch.full((sequence_length, target_length), fill_value=min_dtype, dtype=dtype, device=device) |
| if sequence_length != 1: |
| causal_mask = torch.triu(causal_mask, diagonal=1) |
| causal_mask *= torch.arange(target_length, device=device) > cache_position.reshape(-1, 1) |
| causal_mask = causal_mask[None, None, :, :].expand(batch_size, 1, -1, -1) |
| if attention_mask is not None: |
| causal_mask = causal_mask.clone() |
| mask_length = attention_mask.shape[-1] |
| padding_mask = causal_mask[:, :, :, :mask_length] + attention_mask[:, None, None, :] |
| padding_mask = padding_mask == 0 |
| causal_mask[:, :, :, :mask_length] = causal_mask[:, :, :, :mask_length].masked_fill( |
| padding_mask, min_dtype |
| ) |
|
|
| return causal_mask |
|
|
|
|
| class DeciLMRMSNorm(nn.Module): |
| def __init__(self, hidden_size, eps=1e-6): |
| """ |
| DeciLMRMSNorm is equivalent to T5LayerNorm |
| """ |
| super().__init__() |
| self.weight = nn.Parameter(torch.ones(hidden_size)) |
| self.variance_epsilon = eps |
|
|
| def forward(self, hidden_states): |
| input_dtype = hidden_states.dtype |
| hidden_states = hidden_states.to(torch.float32) |
| variance = hidden_states.pow(2).mean(-1, keepdim=True) |
| hidden_states = hidden_states * torch.rsqrt(variance + self.variance_epsilon) |
| return self.weight * hidden_states.to(input_dtype) |
|
|
| def extra_repr(self): |
| return f"{tuple(self.weight.shape)}, eps={self.variance_epsilon}" |
|
|
|
|
| ALL_LAYERNORM_LAYERS.append(DeciLMRMSNorm) |
|
|
|
|
| class DeciLMRotaryEmbedding(nn.Module): |
| def __init__( |
| self, |
| dim=None, |
| max_position_embeddings=2048, |
| base=10000, |
| device=None, |
| scaling_factor=1.0, |
| rope_type="default", |
| config: Optional[DeciLMConfig] = None, |
| ): |
| super().__init__() |
| |
| self.rope_kwargs = {} |
| if config is None: |
| logger.warning_once( |
| "`DeciLMRotaryEmbedding` can now be fully parameterized by passing the model config through the " |
| "`config` argument. All other arguments will be removed in v4.45" |
| ) |
| self.rope_kwargs = { |
| "rope_type": rope_type, |
| "factor": scaling_factor, |
| "dim": dim, |
| "base": base, |
| "max_position_embeddings": max_position_embeddings, |
| } |
| self.rope_type = rope_type |
| self.max_seq_len_cached = max_position_embeddings |
| self.original_max_seq_len = max_position_embeddings |
| else: |
| |
| if config.rope_scaling is not None: |
| self.rope_type = config.rope_scaling.get("rope_type", config.rope_scaling.get("type")) |
| else: |
| self.rope_type = "default" |
| self.max_seq_len_cached = config.max_position_embeddings |
| self.original_max_seq_len = config.max_position_embeddings |
|
|
| self.config = config |
| self.rope_init_fn = ROPE_INIT_FUNCTIONS[self.rope_type] |
|
|
| inv_freq, self.attention_scaling = self.rope_init_fn(self.config, device, **self.rope_kwargs) |
| self.register_buffer("inv_freq", inv_freq, persistent=False) |
| self.original_inv_freq = self.inv_freq |
|
|
| def _dynamic_frequency_update(self, position_ids, device): |
| """ |
| dynamic RoPE layers should recompute `inv_freq` in the following situations: |
| 1 - growing beyond the cached sequence length (allow scaling) |
| 2 - the current sequence length is in the original scale (avoid losing precision with small sequences) |
| """ |
| seq_len = torch.max(position_ids) + 1 |
| if seq_len > self.max_seq_len_cached: |
| inv_freq, self.attention_scaling = self.rope_init_fn( |
| self.config, device, seq_len=seq_len, **self.rope_kwargs |
| ) |
| self.register_buffer("inv_freq", inv_freq, persistent=False) |
| self.max_seq_len_cached = seq_len |
|
|
| if seq_len < self.original_max_seq_len and self.max_seq_len_cached > self.original_max_seq_len: |
| self.register_buffer("inv_freq", self.original_inv_freq, persistent=False) |
| self.max_seq_len_cached = self.original_max_seq_len |
|
|
| @torch.no_grad() |
| def forward(self, x, position_ids): |
| if "dynamic" in self.rope_type: |
| self._dynamic_frequency_update(position_ids, device=x.device) |
|
|
| |
| inv_freq_expanded = self.inv_freq[None, :, None].float().expand(position_ids.shape[0], -1, 1) |
| position_ids_expanded = position_ids[:, None, :].float() |
| |
| device_type = x.device.type |
| device_type = device_type if isinstance(device_type, str) and device_type != "mps" else "cpu" |
| with torch.autocast(device_type=device_type, enabled=False): |
| freqs = (inv_freq_expanded.float() @ position_ids_expanded.float()).transpose(1, 2) |
| emb = torch.cat((freqs, freqs), dim=-1) |
| cos = emb.cos() |
| sin = emb.sin() |
|
|
| |
| cos = cos * self.attention_scaling |
| sin = sin * self.attention_scaling |
|
|
| return cos.to(dtype=x.dtype), sin.to(dtype=x.dtype) |
|
|
|
|
| class DeciLMLinearScalingRotaryEmbedding(DeciLMRotaryEmbedding): |
| """DeciLMRotaryEmbedding extended with linear scaling. Credits to the Reddit user /u/kaiokendev""" |
|
|
| def __init__(self, *args, **kwargs): |
| logger.warning_once( |
| "`DeciLMLinearScalingRotaryEmbedding` is deprecated an will be removed in v4.45. Please use " |
| "`DeciLMRotaryEmbedding`, which now also does linear scaling (simply pass the model config to __init__)." |
| ) |
| kwargs["rope_type"] = "linear" |
| super().__init__(*args, **kwargs) |
|
|
|
|
| class DeciLMDynamicNTKScalingRotaryEmbedding(DeciLMRotaryEmbedding): |
| """DeciLMRotaryEmbedding extended with Dynamic NTK scaling. Credits to the Reddit users /u/bloc97 and /u/emozilla""" |
|
|
| def __init__(self, *args, **kwargs): |
| logger.warning_once( |
| "`DeciLMDynamicNTKScalingRotaryEmbedding` is deprecated an will be removed in v4.45. Please use " |
| "`DeciLMRotaryEmbedding`, which now also does dynamic ntk scaling (simply pass the model config to " |
| "__init__)." |
| ) |
| kwargs["rope_type"] = "dynamic" |
| super().__init__(*args, **kwargs) |
|
|
|
|
| def rotate_half(x): |
| """Rotates half the hidden dims of the input.""" |
| x1 = x[..., : x.shape[-1] // 2] |
| x2 = x[..., x.shape[-1] // 2:] |
| return torch.cat((-x2, x1), dim=-1) |
|
|
|
|
| def apply_rotary_pos_emb(q, k, cos, sin, position_ids=None, unsqueeze_dim=1): |
| """Applies Rotary Position Embedding to the query and key tensors. |
| |
| Args: |
| q (`torch.Tensor`): The query tensor. |
| k (`torch.Tensor`): The key tensor. |
| cos (`torch.Tensor`): The cosine part of the rotary embedding. |
| sin (`torch.Tensor`): The sine part of the rotary embedding. |
| position_ids (`torch.Tensor`, *optional*): |
| Deprecated and unused. |
| unsqueeze_dim (`int`, *optional*, defaults to 1): |
| The 'unsqueeze_dim' argument specifies the dimension along which to unsqueeze cos[position_ids] and |
| sin[position_ids] so that they can be properly broadcasted to the dimensions of q and k. For example, note |
| that cos[position_ids] and sin[position_ids] have the shape [batch_size, seq_len, head_dim]. Then, if q and |
| k have the shape [batch_size, heads, seq_len, head_dim], then setting unsqueeze_dim=1 makes |
| cos[position_ids] and sin[position_ids] broadcastable to the shapes of q and k. Similarly, if q and k have |
| the shape [batch_size, seq_len, heads, head_dim], then set unsqueeze_dim=2. |
| Returns: |
| `tuple(torch.Tensor)` comprising of the query and key tensors rotated using the Rotary Position Embedding. |
| """ |
| cos = cos.unsqueeze(unsqueeze_dim) |
| sin = sin.unsqueeze(unsqueeze_dim) |
| q_embed = (q * cos) + (rotate_half(q) * sin) |
| k_embed = (k * cos) + (rotate_half(k) * sin) |
| return q_embed, k_embed |
|
|
|
|
| class DeciLMMLP(nn.Module): |
| def __init__(self, |
| config: DeciLMConfig, |
| ffn_config: FFNConfig, |
| ): |
| super().__init__() |
| self.config = config |
| self.ffn_config = ffn_config |
| self.hidden_size = config.hidden_size |
| self.intermediate_size = _ffn_mult_to_intermediate_size( |
| ffn_config.ffn_mult, config.hidden_size) |
| self.gate_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=config.mlp_bias) |
| self.up_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=config.mlp_bias) |
| self.down_proj = nn.Linear(self.intermediate_size, self.hidden_size, bias=config.mlp_bias) |
| self.act_fn = ACT2FN[config.hidden_act] |
|
|
| if ffn_config.sparsify is not None: |
| self.register_full_backward_hook(sparsity_backward_hook) |
|
|
| def forward(self, x): |
| if self.config.pretraining_tp > 1: |
| slice = self.intermediate_size // self.config.pretraining_tp |
| gate_proj_slices = self.gate_proj.weight.split(slice, dim=0) |
| up_proj_slices = self.up_proj.weight.split(slice, dim=0) |
| down_proj_slices = self.down_proj.weight.split(slice, dim=1) |
|
|
| gate_proj = torch.cat( |
| [F.linear(x, gate_proj_slices[i]) for i in range(self.config.pretraining_tp)], dim=-1 |
| ) |
| up_proj = torch.cat([F.linear(x, up_proj_slices[i]) for i in range(self.config.pretraining_tp)], dim=-1) |
|
|
| intermediate_states = (self.act_fn(gate_proj) * up_proj).split(slice, dim=2) |
| down_proj = [ |
| F.linear(intermediate_states[i], down_proj_slices[i]) for i in range(self.config.pretraining_tp) |
| ] |
| down_proj = sum(down_proj) |
| else: |
| down_proj = self.down_proj(self.act_fn(self.gate_proj(x)) * self.up_proj(x)) |
|
|
| return down_proj |
|
|
|
|
| def repeat_kv(hidden_states: torch.Tensor, n_rep: int) -> torch.Tensor: |
| """ |
| This is the equivalent of torch.repeat_interleave(x, dim=1, repeats=n_rep). The hidden states go from (batch, |
| num_key_value_heads, seqlen, head_dim) to (batch, num_attention_heads, seqlen, head_dim) |
| """ |
| batch, num_key_value_heads, slen, head_dim = hidden_states.shape |
| if n_rep == 1: |
| return hidden_states |
| hidden_states = hidden_states[:, :, None, :, :].expand(batch, num_key_value_heads, n_rep, slen, head_dim) |
| return hidden_states.reshape(batch, num_key_value_heads * n_rep, slen, head_dim) |
|
|
|
|
| class DeciLMAttention(nn.Module): |
| """Multi-headed attention from 'Attention Is All You Need' paper""" |
|
|
| def __init__(self, |
| config: DeciLMConfig, |
| attention_config: AttentionConfig, |
| layer_idx: Optional[int] = None, |
| ): |
| super().__init__() |
| self.config = config |
| self.attention_config = attention_config |
| self.layer_idx = layer_idx |
| if layer_idx is None: |
| logger.warning_once( |
| f"Instantiating {self.__class__.__name__} without passing a `layer_idx` is not recommended and will " |
| "lead to errors during the forward call if caching is used. Please make sure to provide a `layer_idx` " |
| "when creating this class." |
| ) |
|
|
| self.attention_dropout = config.attention_dropout |
| self.hidden_size = config.hidden_size |
| self.num_heads = config.num_attention_heads |
| self.head_dim = self.hidden_size // self.num_heads |
| self.num_key_value_groups = attention_config.n_heads_in_group |
| self.num_key_value_heads = self.num_heads // self.num_key_value_groups |
| self.max_position_embeddings = config.max_position_embeddings |
| self.rope_theta = config.rope_theta |
| self.is_causal = True |
|
|
| if (self.head_dim * self.num_heads) != self.hidden_size: |
| raise ValueError( |
| f"hidden_size must be divisible by num_heads (got `hidden_size`: {self.hidden_size}" |
| f" and `num_heads`: {self.num_heads})." |
| ) |
|
|
| self.q_proj = nn.Linear(self.hidden_size, self.num_heads * self.head_dim, bias=config.attention_bias) |
| self.k_proj = nn.Linear(self.hidden_size, self.num_key_value_heads * self.head_dim, bias=config.attention_bias) |
| self.v_proj = nn.Linear(self.hidden_size, self.num_key_value_heads * self.head_dim, bias=config.attention_bias) |
| self.o_proj = nn.Linear(self.hidden_size, self.hidden_size, bias=config.attention_bias) |
|
|
| |
| self.rotary_emb = DeciLMRotaryEmbedding(config=self.config) |
|
|
| if attention_config.sparsify is not None: |
| self.register_full_backward_hook(sparsity_backward_hook) |
|
|
| def forward( |
| self, |
| hidden_states: torch.Tensor, |
| attention_mask: Optional[torch.Tensor] = None, |
| position_ids: Optional[torch.LongTensor] = None, |
| past_key_value: Optional[Cache] = None, |
| output_attentions: bool = False, |
| use_cache: bool = False, |
| cache_position: Optional[torch.LongTensor] = None, |
| position_embeddings: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, |
| **kwargs, |
| ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]: |
| bsz, q_len, _ = hidden_states.size() |
| if self.config.pretraining_tp > 1: |
| key_value_slicing = (self.num_key_value_heads * self.head_dim) // self.config.pretraining_tp |
| query_slices = self.q_proj.weight.split( |
| (self.num_heads * self.head_dim) // self.config.pretraining_tp, dim=0 |
| ) |
| key_slices = self.k_proj.weight.split(key_value_slicing, dim=0) |
| value_slices = self.v_proj.weight.split(key_value_slicing, dim=0) |
|
|
| query_states = [F.linear(hidden_states, query_slices[i]) for i in range(self.config.pretraining_tp)] |
| query_states = torch.cat(query_states, dim=-1) |
|
|
| key_states = [F.linear(hidden_states, key_slices[i]) for i in range(self.config.pretraining_tp)] |
| key_states = torch.cat(key_states, dim=-1) |
|
|
| value_states = [F.linear(hidden_states, value_slices[i]) for i in range(self.config.pretraining_tp)] |
| value_states = torch.cat(value_states, dim=-1) |
|
|
| else: |
| query_states = self.q_proj(hidden_states) |
| key_states = self.k_proj(hidden_states) |
| value_states = self.v_proj(hidden_states) |
|
|
| query_states = query_states.view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2) |
| key_states = key_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2) |
| value_states = value_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2) |
|
|
| if position_embeddings is None: |
| logger.warning_once( |
| "The attention layers in this model are transitioning from computing the RoPE embeddings internally " |
| "through `position_ids` (2D tensor with the indexes of the tokens), to using externally computed " |
| "`position_embeddings` (Tuple of tensors, containing cos and sin). In v4.45 `position_ids` will be " |
| "removed and `position_embeddings` will be mandatory." |
| ) |
| cos, sin = self.rotary_emb(value_states, position_ids) |
| else: |
| cos, sin = position_embeddings |
| query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin) |
|
|
| if past_key_value is not None: |
| |
| cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position} |
| key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs) |
|
|
| key_states = repeat_kv(key_states, self.num_key_value_groups) |
| value_states = repeat_kv(value_states, self.num_key_value_groups) |
|
|
| attn_weights = torch.matmul(query_states, key_states.transpose(2, 3)) / math.sqrt(self.head_dim) |
|
|
| if attention_mask is not None: |
| causal_mask = attention_mask[:, :, :, : key_states.shape[-2]] |
| attn_weights = attn_weights + causal_mask |
|
|
| |
| attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query_states.dtype) |
| attn_weights = nn.functional.dropout(attn_weights, p=self.attention_dropout, training=self.training) |
| attn_output = torch.matmul(attn_weights, value_states) |
|
|
| if attn_output.size() != (bsz, self.num_heads, q_len, self.head_dim): |
| raise ValueError( |
| f"`attn_output` should be of size {(bsz, self.num_heads, q_len, self.head_dim)}, but is" |
| f" {attn_output.size()}" |
| ) |
|
|
| attn_output = attn_output.transpose(1, 2).contiguous() |
|
|
| attn_output = attn_output.reshape(bsz, q_len, -1) |
|
|
| if self.config.pretraining_tp > 1: |
| attn_output = attn_output.split(self.hidden_size // self.config.pretraining_tp, dim=2) |
| o_proj_slices = self.o_proj.weight.split(self.hidden_size // self.config.pretraining_tp, dim=1) |
| attn_output = sum([F.linear(attn_output[i], o_proj_slices[i]) for i in range(self.config.pretraining_tp)]) |
| else: |
| attn_output = self.o_proj(attn_output) |
|
|
| if not output_attentions: |
| attn_weights = None |
|
|
| return attn_output, attn_weights, past_key_value |
|
|
|
|
| class DeciLMFlashAttention2(DeciLMAttention): |
| """ |
| DeciLM flash attention module. This module inherits from `DeciLMAttention` as the weights of the module stays |
| untouched. The only required change would be on the forward pass where it needs to correctly call the public API of |
| flash attention and deal with padding tokens in case the input contains any of them. |
| """ |
|
|
| def __init__(self, *args, **kwargs): |
| super().__init__(*args, **kwargs) |
|
|
| |
| |
| |
| self._flash_attn_uses_top_left_mask = not is_flash_attn_greater_or_equal_2_10() |
|
|
| self.sliding_window = self.attention_config.prefill_sliding_window |
|
|
| def forward( |
| self, |
| hidden_states: torch.Tensor, |
| attention_mask: Optional[torch.LongTensor] = None, |
| position_ids: Optional[torch.LongTensor] = None, |
| past_key_value: Optional[Cache] = None, |
| output_attentions: bool = False, |
| use_cache: bool = False, |
| cache_position: Optional[torch.LongTensor] = None, |
| position_embeddings: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, |
| ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]: |
| output_attentions = False |
|
|
| bsz, q_len, _ = hidden_states.size() |
|
|
| query_states = self.q_proj(hidden_states) |
| key_states = self.k_proj(hidden_states) |
| value_states = self.v_proj(hidden_states) |
|
|
| |
| |
| |
| query_states = query_states.view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2) |
| key_states = key_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2) |
| value_states = value_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2) |
|
|
| if position_embeddings is None: |
| logger.warning_once( |
| "The attention layers in this model are transitioning from computing the RoPE embeddings internally " |
| "through `position_ids` (2D tensor with the indexes of the tokens), to using externally computed " |
| "`position_embeddings` (Tuple of tensors, containing cos and sin). In v4.45 `position_ids` will be " |
| "removed and `position_embeddings` will be mandatory." |
| ) |
| cos, sin = self.rotary_emb(value_states, position_ids) |
| else: |
| cos, sin = position_embeddings |
| query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin) |
|
|
| if past_key_value is not None: |
| |
| cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position} |
| key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs) |
|
|
| |
| |
| query_states = query_states.transpose(1, 2) |
| key_states = key_states.transpose(1, 2) |
| value_states = value_states.transpose(1, 2) |
|
|
| dropout_rate = self.attention_dropout if self.training else 0.0 |
|
|
| |
| |
| |
| |
| |
|
|
| input_dtype = query_states.dtype |
| if input_dtype == torch.float32: |
| if torch.is_autocast_enabled(): |
| target_dtype = torch.get_autocast_gpu_dtype() |
| |
| elif hasattr(self.config, "_pre_quantization_dtype"): |
| target_dtype = self.config._pre_quantization_dtype |
| else: |
| target_dtype = self.q_proj.weight.dtype |
|
|
| logger.warning_once( |
| f"The input hidden states seems to be silently casted in float32, this might be related to" |
| f" the fact you have upcasted embedding or layer norm layers in float32. We will cast back the input in" |
| f" {target_dtype}." |
| ) |
|
|
| query_states = query_states.to(target_dtype) |
| key_states = key_states.to(target_dtype) |
| value_states = value_states.to(target_dtype) |
|
|
| attn_output = _flash_attention_forward( |
| query_states, |
| key_states, |
| value_states, |
| attention_mask, |
| q_len, |
| position_ids=position_ids, |
| dropout=dropout_rate, |
| sliding_window=self.sliding_window, |
| use_top_left_mask=self._flash_attn_uses_top_left_mask, |
| is_causal=self.is_causal, |
| ) |
|
|
| attn_output = attn_output.reshape(bsz, q_len, -1).contiguous() |
| attn_output = self.o_proj(attn_output) |
|
|
| if not output_attentions: |
| attn_weights = None |
|
|
| return attn_output, attn_weights, past_key_value |
|
|
|
|
| DECILM_ATTENTION_CLASSES = { |
| "eager": DeciLMAttention, |
| "flash_attention_2": DeciLMFlashAttention2, |
| } |
|
|
|
|
| class DeciLMDecoderLayer(nn.Module): |
| |
| def __init__(self, config: DeciLMConfig, layer_idx: int): |
| super().__init__() |
| self.config = config |
| self.hidden_size = config.hidden_size |
| self.block_config = config.block_configs[layer_idx] |
| self.attention_config = self.block_config.attention |
| self.ffn_config = self.block_config.ffn |
|
|
| if not self.attention_config.no_op: |
| self.input_layernorm = DeciLMRMSNorm(config.hidden_size, eps=config.rms_norm_eps) |
| if not self.attention_config.replace_with_linear: |
| self.self_attn = DECILM_ATTENTION_CLASSES[config._attn_implementation]( |
| config=config, attention_config=self.attention_config, layer_idx=layer_idx) |
| else: |
| self.self_attn = DeciLMLinearAttention(config) |
|
|
| if not self.ffn_config.no_op: |
| self.post_attention_layernorm = DeciLMRMSNorm(config.hidden_size, eps=config.rms_norm_eps) |
| if not self.ffn_config.replace_with_linear: |
| self.mlp = DeciLMMLP(config, self.ffn_config) |
| else: |
| self.mlp = DeciLMLinearMLP(config) |
|
|
| self.is_sliding = self.attention_config.is_sliding |
| self.sliding_window = self.attention_config.prefill_sliding_window |
|
|
| def forward( |
| self, |
| hidden_states: torch.Tensor, |
| attention_mask: Optional[torch.Tensor] = None, |
| position_ids: Optional[torch.LongTensor] = None, |
| past_key_value: Optional[Cache] = None, |
| output_attentions: Optional[bool] = False, |
| use_cache: Optional[bool] = False, |
| cache_position: Optional[torch.LongTensor] = None, |
| position_embeddings: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, |
| **kwargs, |
| ) -> Tuple[torch.FloatTensor, Optional[Tuple[torch.FloatTensor, torch.FloatTensor]]]: |
| """ |
| Args: |
| hidden_states (`torch.FloatTensor`): input to the layer of shape `(batch, seq_len, embed_dim)` |
| attention_mask (`torch.FloatTensor`, *optional*): |
| attention mask of size `(batch_size, sequence_length)` if flash attention is used or `(batch_size, 1, |
| query_sequence_length, key_sequence_length)` if default attention is used. |
| output_attentions (`bool`, *optional*): |
| Whether or not to return the attentions tensors of all attention layers. See `attentions` under |
| returned tensors for more detail. |
| use_cache (`bool`, *optional*): |
| If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding |
| (see `past_key_values`). |
| past_key_value (`Tuple(torch.FloatTensor)`, *optional*): cached past key and value projection states |
| cache_position (`torch.LongTensor` of shape `(sequence_length)`, *optional*): |
| Indices depicting the position of the input sequence tokens in the sequence |
| position_embeddings (`Tuple[torch.FloatTensor, torch.FloatTensor]`, *optional*): |
| Tuple containing the cosine and sine positional embeddings of shape `(batch_size, seq_len, head_dim)`, |
| with `head_dim` being the embedding dimension of each attention head. |
| kwargs (`dict`, *optional*): |
| Arbitrary kwargs to be ignored, used for FSDP and other methods that injects code |
| into the model |
| """ |
| if self.attention_config.unshifted_sink and self.attention_config.is_sink: |
| attention_mask = self._unshifted_sink_mask( |
| attention_mask, hidden_states, |
| self.attention_config.window_length, self.attention_config.num_sink_tokens) |
| else: |
| attention_mask = self._gemma2_window_mask(attention_mask, hidden_states, past_key_value) |
|
|
| self_attn_weights = None |
| present_key_value = past_key_value |
| if self.attention_config.no_op: |
| pass |
| elif self.attention_config.replace_with_linear: |
| residual = hidden_states |
| hidden_states = self.input_layernorm(hidden_states) |
| hidden_states = self.self_attn(hidden_states) |
| hidden_states = residual + hidden_states |
| else: |
| residual = hidden_states |
| hidden_states = self.input_layernorm(hidden_states) |
| hidden_states, self_attn_weights, present_key_value = self.self_attn( |
| hidden_states=hidden_states, |
| attention_mask=attention_mask, |
| position_ids=position_ids, |
| past_key_value=past_key_value, |
| output_attentions=output_attentions, |
| use_cache=use_cache, |
| cache_position=cache_position, |
| position_embeddings=position_embeddings, |
| **kwargs, |
| ) |
| hidden_states = residual + hidden_states |
|
|
| if not self.ffn_config.no_op: |
| residual = hidden_states |
| hidden_states = self.post_attention_layernorm(hidden_states) |
| hidden_states = self.mlp(hidden_states) |
| hidden_states = residual + hidden_states |
|
|
| outputs = (hidden_states,) |
|
|
| if output_attentions: |
| outputs += (self_attn_weights,) |
|
|
| if use_cache: |
| outputs += (present_key_value,) |
|
|
| return outputs |
|
|
| def _gemma2_window_mask(self, |
| attention_mask: Optional[torch.Tensor], |
| hidden_states: torch.Tensor, |
| past_key_value: Optional[VariableCache], |
| ) -> Optional[torch.Tensor]: |
| if self.is_sliding and attention_mask is not None: |
| |
| if self.config._attn_implementation == "flash_attention_2": |
| if past_key_value is not None: |
| attention_mask = attention_mask[:, -self.sliding_window:] |
| else: |
| min_dtype = torch.finfo(hidden_states.dtype).min |
| sliding_window_mask = torch.tril( |
| torch.ones_like(attention_mask, dtype=torch.bool), diagonal=-self.sliding_window |
| ) |
| attention_mask = torch.where(sliding_window_mask, min_dtype, attention_mask) |
| if attention_mask.shape[-1] <= 1: |
| attention_mask = attention_mask[:, :, :, -self.sliding_window:] |
| return attention_mask |
|
|
| def _unshifted_sink_mask(self, |
| attention_mask: torch.Tensor, |
| hidden_states: torch.Tensor, |
| window_length: int, |
| num_sink_tokens: Optional[int], |
| ) -> torch.Tensor: |
| assert self.config._attn_implementation == "eager", "Unshifted sink is only supported in 'eager' mode." |
| assert attention_mask is not None, "The attention mask seems to not be prepared" |
|
|
| attention_mask = attention_mask.clone() |
| min_dtype = torch.finfo(hidden_states.dtype).min |
|
|
| if window_length == 0: |
| attention_mask = torch.full_like(attention_mask, fill_value=min_dtype) |
| else: |
| query_length = attention_mask.shape[-2] |
| is_decode = (query_length == 1) |
| if is_decode: |
| attention_mask[:, :, :, :-window_length] = min_dtype |
| else: |
| sliding_window_mask = torch.tril( |
| torch.ones_like(attention_mask, dtype=torch.bool), diagonal=-window_length |
| ) |
| attention_mask = torch.where(sliding_window_mask, min_dtype, attention_mask) |
|
|
| attention_mask[:, :, :, :num_sink_tokens] = 0 |
| return attention_mask |
|
|
|
|
| DECILM_START_DOCSTRING = r""" |
| This model inherits from [`PreTrainedModel`]. Check the superclass documentation for the generic methods the |
| library implements for all its model (such as downloading or saving, resizing the input embeddings, pruning heads |
| etc.) |
| |
| This model is also a PyTorch [torch.nn.Module](https://pytorch.org/docs/stable/nn.html#torch.nn.Module) subclass. |
| Use it as a regular PyTorch Module and refer to the PyTorch documentation for all matter related to general usage |
| and behavior. |
| |
| Parameters: |
| config ([`DeciLMConfig`]): |
| Model configuration class with all the parameters of the model. Initializing with a config file does not |
| load the weights associated with the model, only the configuration. Check out the |
| [`~PreTrainedModel.from_pretrained`] method to load the model weights. |
| """ |
|
|
|
|
| @add_start_docstrings( |
| "The bare DeciLM Model outputting raw hidden-states without any specific head on top.", |
| DECILM_START_DOCSTRING, |
| ) |
| class DeciLMPreTrainedModel(PreTrainedModel): |
| config_class = DeciLMConfig |
| base_model_prefix = "model" |
| supports_gradient_checkpointing = True |
| _no_split_modules = ["DeciLMDecoderLayer"] |
| _skip_keys_device_placement = ["past_key_values"] |
| _supports_flash_attn_2 = True |
| _supports_sdpa = False |
| _supports_cache_class = True |
| _supports_quantized_cache = False |
| _supports_static_cache = True |
|
|
| def _init_weights(self, module): |
| std = self.config.initializer_range |
| if isinstance(module, nn.Linear): |
| module.weight.data.normal_(mean=0.0, std=std) |
| if module.bias is not None: |
| module.bias.data.zero_() |
| elif isinstance(module, nn.Embedding): |
| module.weight.data.normal_(mean=0.0, std=std) |
| if module.padding_idx is not None: |
| module.weight.data[module.padding_idx].zero_() |
|
|
| def _prepare_generation_config( |
| self, |
| generation_config: Optional[GenerationConfig], |
| *args, |
| **kwargs, |
| ) -> tuple[GenerationConfig, dict]: |
| |
| generation_config, model_kwargs = super()._prepare_generation_config(generation_config, *args, **kwargs) |
| generation_config.cache_implementation = "variable" |
| NEED_SETUP_CACHE_CLASSES_MAPPING["variable"] = VariableCache |
| return generation_config, model_kwargs |
|
|
|
|
| DECILM_INPUTS_DOCSTRING = r""" |
| Args: |
| input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`): |
| Indices of input sequence tokens in the vocabulary. Padding will be ignored by default should you provide |
| it. |
| |
| Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and |
| [`PreTrainedTokenizer.__call__`] for details. |
| |
| [What are input IDs?](../glossary#input-ids) |
| attention_mask (`torch.Tensor` of shape `(batch_size, sequence_length)`, *optional*): |
| Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`: |
| |
| - 1 for tokens that are **not masked**, |
| - 0 for tokens that are **masked**. |
| |
| [What are attention masks?](../glossary#attention-mask) |
| |
| Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and |
| [`PreTrainedTokenizer.__call__`] for details. |
| |
| If `past_key_values` is used, optionally only the last `input_ids` have to be input (see |
| `past_key_values`). |
| |
| If you want to change padding behavior, you should read [`modeling_opt._prepare_decoder_attention_mask`] |
| and modify to your needs. See diagram 1 in [the paper](https://arxiv.org/abs/1910.13461) for more |
| information on the default strategy. |
| |
| - 1 indicates the head is **not masked**, |
| - 0 indicates the head is **masked**. |
| position_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*): |
| Indices of positions of each input sequence tokens in the position embeddings. Selected in the range `[0, |
| config.n_positions - 1]`. |
| |
| [What are position IDs?](../glossary#position-ids) |
| past_key_values (`VariableCache`, *optional*): |
| Pre-computed hidden-states (key and values in the self-attention blocks and in the cross-attention |
| blocks) that can be used to speed up sequential decoding. This typically consists in the `past_key_values` |
| returned by the model at a previous stage of decoding, when `use_cache=True` or `config.use_cache=True`. |
| |
| If passed to the forward function, past_key_values must be a VariableCache object (see imports). |
| For generation purposes, this is already handled inside model.generate(). |
| |
| If `past_key_values` are used, the user can optionally input only the last `input_ids` (those that don't |
| have their past key value states given to this model) of shape `(batch_size, 1)` instead of all `input_ids` |
| of shape `(batch_size, sequence_length)`. |
| inputs_embeds (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`, *optional*): |
| Optionally, instead of passing `input_ids` you can choose to directly pass an embedded representation. This |
| is useful if you want more control over how to convert `input_ids` indices into associated vectors than the |
| model's internal embedding lookup matrix. |
| use_cache (`bool`, *optional*): |
| If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding (see |
| `past_key_values`). |
| output_attentions (`bool`, *optional*): |
| Whether or not to return the attentions tensors of all attention layers. See `attentions` under returned |
| tensors for more detail. |
| output_hidden_states (`bool`, *optional*): |
| Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors for |
| more detail. |
| return_dict (`bool`, *optional*): |
| Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple. |
| cache_position (`torch.LongTensor` of shape `(sequence_length)`, *optional*): |
| Indices depicting the position of the input sequence tokens in the sequence. Contrarily to `position_ids`, |
| this tensor is not affected by padding. It is used to update the cache in the correct position and to infer |
| the complete sequence length. |
| """ |
|
|
|
|
| @add_start_docstrings( |
| "The bare DeciLM Model outputting raw hidden-states without any specific head on top.", |
| DECILM_START_DOCSTRING, |
| ) |
| class DeciLMModel(DeciLMPreTrainedModel): |
| """ |
| Transformer decoder consisting of *config.num_hidden_layers* layers. Each layer is a [`DeciLMDecoderLayer`] |
| |
| Args: |
| config: DeciLMConfig |
| """ |
|
|
| def __init__(self, config: DeciLMConfig): |
| super().__init__(config) |
| self.padding_idx = config.pad_token_id |
| self.vocab_size = config.vocab_size |
|
|
| self.embed_tokens = nn.Embedding(config.vocab_size, config.hidden_size, self.padding_idx) |
| self.layers = nn.ModuleList( |
| [DeciLMDecoderLayer(config, layer_idx) for layer_idx in range(config.num_hidden_layers)] |
| ) |
| self.norm = DeciLMRMSNorm(config.hidden_size, eps=config.rms_norm_eps) |
| self.rotary_emb = DeciLMRotaryEmbedding(config=config) |
| self.gradient_checkpointing = False |
|
|
| |
| self.post_init() |
|
|
| def get_input_embeddings(self): |
| return self.embed_tokens |
|
|
| def set_input_embeddings(self, value): |
| self.embed_tokens = value |
|
|
| @add_start_docstrings_to_model_forward(DECILM_INPUTS_DOCSTRING) |
| def forward( |
| self, |
| input_ids: torch.LongTensor = None, |
| attention_mask: Optional[torch.Tensor] = None, |
| position_ids: Optional[torch.LongTensor] = None, |
| past_key_values: Optional[Union[Cache, List[torch.FloatTensor]]] = None, |
| inputs_embeds: Optional[torch.FloatTensor] = None, |
| use_cache: Optional[bool] = None, |
| output_attentions: Optional[bool] = None, |
| output_hidden_states: Optional[bool] = None, |
| return_dict: Optional[bool] = None, |
| cache_position: Optional[torch.LongTensor] = None, |
| ) -> Union[Tuple, BaseModelOutputWithPast]: |
| output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions |
| output_hidden_states = ( |
| output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states |
| ) |
| use_cache = use_cache if use_cache is not None else self.config.use_cache |
| return_dict = return_dict if return_dict is not None else self.config.use_return_dict |
|
|
| if (input_ids is None) ^ (inputs_embeds is not None): |
| raise ValueError( |
| "You cannot specify both input_ids and inputs_embeds at the same time, and must specify either one" |
| ) |
|
|
| if self.gradient_checkpointing and self.training and use_cache: |
| logger.warning_once( |
| "`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`." |
| ) |
| use_cache = False |
|
|
| if inputs_embeds is None: |
| inputs_embeds = self.embed_tokens(input_ids) |
|
|
| is_legacy_cache_format = (past_key_values is not None) and not isinstance(past_key_values, Cache) |
| if is_legacy_cache_format: |
| raise NotImplementedError("DeciLMModel does not support legacy cache format, please use a newer " |
| "transformers version or use VariableCache explicitly (see import in this file).") |
|
|
| if cache_position is None: |
| past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0 |
| cache_position = torch.arange( |
| past_seen_tokens, past_seen_tokens + inputs_embeds.shape[1], device=inputs_embeds.device |
| ) |
| if position_ids is None: |
| position_ids = cache_position.unsqueeze(0) |
|
|
| causal_mask = self._update_causal_mask( |
| attention_mask, inputs_embeds, cache_position, past_key_values, output_attentions |
| ) |
| hidden_states = inputs_embeds |
|
|
| |
| position_embeddings = self.rotary_emb(hidden_states, position_ids) |
|
|
| |
| all_hidden_states = () if output_hidden_states else None |
| all_self_attns = () if output_attentions else None |
| next_decoder_cache = None |
|
|
| for decoder_layer in self.layers: |
| if output_hidden_states: |
| all_hidden_states += (hidden_states,) |
|
|
| if self.gradient_checkpointing and self.training: |
| layer_outputs = self._gradient_checkpointing_func( |
| decoder_layer.__call__, |
| hidden_states, |
| causal_mask, |
| position_ids, |
| past_key_values, |
| output_attentions, |
| use_cache, |
| cache_position, |
| position_embeddings, |
| ) |
| else: |
| layer_outputs = decoder_layer( |
| hidden_states, |
| attention_mask=causal_mask, |
| position_ids=position_ids, |
| past_key_value=past_key_values, |
| output_attentions=output_attentions, |
| use_cache=use_cache, |
| cache_position=cache_position, |
| position_embeddings=position_embeddings, |
| ) |
|
|
| hidden_states = layer_outputs[0] |
|
|
| if use_cache: |
| next_decoder_cache = layer_outputs[2 if output_attentions else 1] |
|
|
| if output_attentions: |
| all_self_attns += (layer_outputs[1],) |
|
|
| hidden_states = self.norm(hidden_states) |
|
|
| |
| if output_hidden_states: |
| all_hidden_states += (hidden_states,) |
|
|
| next_cache = next_decoder_cache if use_cache else None |
|
|
| if not return_dict: |
| return tuple(v for v in [hidden_states, next_cache, all_hidden_states, all_self_attns] if v is not None) |
| return BaseModelOutputWithPast( |
| last_hidden_state=hidden_states, |
| past_key_values=next_cache, |
| hidden_states=all_hidden_states, |
| attentions=all_self_attns, |
| ) |
|
|
| def _update_causal_mask( |
| self, |
| attention_mask: torch.Tensor, |
| input_tensor: torch.Tensor, |
| cache_position: torch.Tensor, |
| past_key_values: Cache, |
| output_attentions: bool, |
| ): |
| |
| |
| |
| |
|
|
| if self.config._attn_implementation == "flash_attention_2": |
| if attention_mask is not None and 0.0 in attention_mask: |
| return attention_mask |
| return None |
|
|
| |
| |
| |
| past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0 |
| assert not isinstance(past_key_values, StaticCache), "DeciLM does not support StaticCache" |
| using_static_cache = isinstance(past_key_values, StaticCache) |
|
|
| |
| if self.config._attn_implementation == "sdpa" and not using_static_cache and not output_attentions: |
| if AttentionMaskConverter._ignore_causal_mask_sdpa( |
| attention_mask, |
| inputs_embeds=input_tensor, |
| past_key_values_length=past_seen_tokens, |
| is_training=self.training, |
| ) and all([not layer.is_sliding for layer in self.layers]): |
| return None |
|
|
| dtype, device = input_tensor.dtype, input_tensor.device |
| min_dtype = torch.finfo(dtype).min |
| sequence_length = input_tensor.shape[1] |
| if using_static_cache: |
| target_length = past_key_values.get_max_length() |
| else: |
| target_length = ( |
| attention_mask.shape[-1] |
| if isinstance(attention_mask, torch.Tensor) |
| else past_seen_tokens + sequence_length + 1 |
| ) |
|
|
| |
| causal_mask = _prepare_4d_causal_attention_mask_with_cache_position( |
| attention_mask, |
| sequence_length=sequence_length, |
| target_length=target_length, |
| dtype=dtype, |
| device=device, |
| min_dtype=min_dtype, |
| cache_position=cache_position, |
| batch_size=input_tensor.shape[0], |
| ) |
|
|
| if ( |
| self.config._attn_implementation == "sdpa" |
| and attention_mask is not None |
| and attention_mask.device.type == "cuda" |
| and not output_attentions |
| ): |
| |
| |
| |
| causal_mask = AttentionMaskConverter._unmask_unattended(causal_mask, min_dtype) |
|
|
| return causal_mask |
|
|
|
|
| class DeciLMForCausalLM(DeciLMPreTrainedModel, GenerationMixin): |
| _tied_weights_keys = ["lm_head.weight"] |
|
|
| def __init__(self, config): |
| super().__init__(config) |
| self.model = DeciLMModel(config) |
| self.vocab_size = config.vocab_size |
| self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False) |
|
|
| |
| self.post_init() |
|
|
| def get_input_embeddings(self): |
| return self.model.embed_tokens |
|
|
| def set_input_embeddings(self, value): |
| self.model.embed_tokens = value |
|
|
| def get_output_embeddings(self): |
| return self.lm_head |
|
|
| def set_output_embeddings(self, new_embeddings): |
| self.lm_head = new_embeddings |
|
|
| def set_decoder(self, decoder): |
| self.model = decoder |
|
|
| def get_decoder(self): |
| return self.model |
|
|
| @add_start_docstrings_to_model_forward(DECILM_INPUTS_DOCSTRING) |
| @replace_return_docstrings(output_type=CausalLMOutputWithPast, config_class=_CONFIG_FOR_DOC) |
| def forward( |
| self, |
| input_ids: torch.LongTensor = None, |
| attention_mask: Optional[torch.Tensor] = None, |
| position_ids: Optional[torch.LongTensor] = None, |
| past_key_values: Optional[Union[Cache, List[torch.FloatTensor]]] = None, |
| inputs_embeds: Optional[torch.FloatTensor] = None, |
| labels: Optional[torch.LongTensor] = None, |
| use_cache: Optional[bool] = None, |
| output_attentions: Optional[bool] = None, |
| output_hidden_states: Optional[bool] = None, |
| return_dict: Optional[bool] = None, |
| cache_position: Optional[torch.LongTensor] = None, |
| ) -> Union[Tuple, CausalLMOutputWithPast]: |
| r""" |
| Args: |
| labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*): |
| Labels for computing the masked language modeling loss. Indices should either be in `[0, ..., |
| config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored |
| (masked), the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`. |
| |
| Return: |
| """ |
| output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions |
| output_hidden_states = ( |
| output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states |
| ) |
| return_dict = return_dict if return_dict is not None else self.config.use_return_dict |
|
|
| |
| outputs = self.model( |
| input_ids=input_ids, |
| attention_mask=attention_mask, |
| position_ids=position_ids, |
| past_key_values=past_key_values, |
| inputs_embeds=inputs_embeds, |
| use_cache=use_cache, |
| output_attentions=output_attentions, |
| output_hidden_states=output_hidden_states, |
| return_dict=return_dict, |
| cache_position=cache_position, |
| ) |
|
|
| hidden_states = outputs[0] |
| if self.config.pretraining_tp > 1: |
| lm_head_slices = self.lm_head.weight.split(self.vocab_size // self.config.pretraining_tp, dim=0) |
| logits = [F.linear(hidden_states, lm_head_slices[i]) for i in range(self.config.pretraining_tp)] |
| logits = torch.cat(logits, dim=-1) |
| else: |
| logits = self.lm_head(hidden_states) |
| logits = logits.float() |
|
|
| loss = None |
| if labels is not None: |
| |
| shift_logits = logits[..., :-1, :].contiguous() |
| shift_labels = labels[..., 1:].contiguous() |
| |
| loss_fct = CrossEntropyLoss() |
| shift_logits = shift_logits.view(-1, self.config.vocab_size) |
| shift_labels = shift_labels.view(-1) |
| |
| shift_labels = shift_labels.to(shift_logits.device) |
| loss = loss_fct(shift_logits, shift_labels) |
|
|
| if not return_dict: |
| output = (logits,) + outputs[1:] |
| return (loss,) + output if loss is not None else output |
|
|
| return CausalLMOutputWithPast( |
| loss=loss, |
| logits=logits, |
| past_key_values=outputs.past_key_values, |
| hidden_states=outputs.hidden_states, |
| attentions=outputs.attentions, |
| ) |
|
|
| def prepare_inputs_for_generation( |
| self, |
| input_ids, |
| past_key_values=None, |
| attention_mask=None, |
| inputs_embeds=None, |
| cache_position=None, |
| position_ids=None, |
| use_cache=True, |
| **kwargs, |
| ): |
| |
| |
| |
| if past_key_values is not None: |
| if inputs_embeds is not None: |
| input_ids = input_ids[:, -cache_position.shape[0]:] |
| elif input_ids.shape[1] != cache_position.shape[0]: |
| input_ids = input_ids[:, cache_position] |
|
|
| if attention_mask is not None and position_ids is None: |
| |
| position_ids = attention_mask.long().cumsum(-1) - 1 |
| position_ids.masked_fill_(attention_mask == 0, 1) |
| if past_key_values: |
| position_ids = position_ids[:, -input_ids.shape[1]:] |
|
|
| |
| position_ids = position_ids.clone(memory_format=torch.contiguous_format) |
|
|
| |
| if inputs_embeds is not None and cache_position[0] == 0: |
| model_inputs = {"inputs_embeds": inputs_embeds, "input_ids": None} |
| else: |
| |
| model_inputs = {"input_ids": input_ids.clone(memory_format=torch.contiguous_format), "inputs_embeds": None} |
|
|
| assert not isinstance(past_key_values, StaticCache), "DeciLM does not support StaticCache" |
| if isinstance(past_key_values, StaticCache) and attention_mask.ndim == 2: |
| if model_inputs["inputs_embeds"] is not None: |
| batch_size, sequence_length, _ = model_inputs["inputs_embeds"].shape |
| device = model_inputs["inputs_embeds"].device |
| else: |
| batch_size, sequence_length = model_inputs["input_ids"].shape |
| device = model_inputs["input_ids"].device |
|
|
| dtype = self.lm_head.weight.dtype |
| min_dtype = torch.finfo(dtype).min |
|
|
| attention_mask = _prepare_4d_causal_attention_mask_with_cache_position( |
| attention_mask, |
| sequence_length=sequence_length, |
| target_length=past_key_values.get_max_length(), |
| dtype=dtype, |
| device=device, |
| min_dtype=min_dtype, |
| cache_position=cache_position, |
| batch_size=batch_size, |
| ) |
|
|
| model_inputs.update( |
| { |
| "position_ids": position_ids, |
| "cache_position": cache_position, |
| "past_key_values": past_key_values, |
| "use_cache": use_cache, |
| "attention_mask": attention_mask, |
| } |
| ) |
| return model_inputs |
|
|
| def _maybe_initialize_input_ids_for_generation( |
| self, |
| inputs: Optional[torch.Tensor] = None, |
| bos_token_id: Optional[torch.Tensor] = None, |
| model_kwargs: Optional[dict[str, torch.Tensor]] = None, |
| ) -> torch.LongTensor: |
| """ |
| Patching hf bug that creates wrong cache length if only inputs_embeds are passed to the model |
| """ |
| input_ids = super()._maybe_initialize_input_ids_for_generation( |
| inputs=inputs, bos_token_id=bos_token_id, model_kwargs=model_kwargs) |
| if ( |
| "inputs_embeds" in model_kwargs |
| and input_ids is not None |
| and input_ids.shape[1] == 0 |
| ): |
| batch_size, input_sequence_length = model_kwargs["inputs_embeds"].shape[:2] |
| input_ids = torch.zeros((batch_size, input_sequence_length), dtype=torch.long, device=self.device) |
| return input_ids |
|
|
| def generate( |
| self, |
| inputs: Optional[torch.Tensor] = None, |
| *args, |
| **kwargs, |
| ) -> Union[GenerateOutput, torch.LongTensor]: |
| """ |
| Patching hf bug that creates wrong cache length if only inputs_embeds are passed to the model |
| """ |
| only_passed_inputs_embeds = ( |
| "inputs_embeds" in kwargs and |
| "input_ids" not in kwargs and |
| inputs is None |
| ) |
| if only_passed_inputs_embeds: |
| input_sequence_length = kwargs["inputs_embeds"].shape[1] |
|
|
| generation_output = super().generate(inputs=inputs, *args, **kwargs) |
|
|
| if only_passed_inputs_embeds and isinstance(generation_output, torch.Tensor): |
| generation_output = generation_output[:, input_sequence_length:] |
|
|
| return generation_output |
|
|
|
|
| @add_start_docstrings( |
| """ |
| The DeciLM Model transformer with a sequence classification head on top (linear layer). |
| |
| [`DeciLMForSequenceClassification`] uses the last token in order to do the classification, as other causal models |
| (e.g. GPT-2) do. |
| |
| Since it does classification on the last token, it requires to know the position of the last token. If a |
| `pad_token_id` is defined in the configuration, it finds the last token that is not a padding token in each row. If |
| no `pad_token_id` is defined, it simply takes the last value in each row of the batch. Since it cannot guess the |
| padding tokens when `inputs_embeds` are passed instead of `input_ids`, it does the same (take the last value in |
| each row of the batch). |
| """, |
| DECILM_START_DOCSTRING, |
| ) |
| class DeciLMForSequenceClassification(DeciLMPreTrainedModel): |
| def __init__(self, config): |
| super().__init__(config) |
| self.num_labels = config.num_labels |
| self.model = DeciLMModel(config) |
| self.score = nn.Linear(config.hidden_size, self.num_labels, bias=False) |
|
|
| |
| self.post_init() |
|
|
| def get_input_embeddings(self): |
| return self.model.embed_tokens |
|
|
| def set_input_embeddings(self, value): |
| self.model.embed_tokens = value |
|
|
| @add_start_docstrings_to_model_forward(DECILM_INPUTS_DOCSTRING) |
| def forward( |
| self, |
| input_ids: Optional[torch.LongTensor] = None, |
| attention_mask: Optional[torch.Tensor] = None, |
| position_ids: Optional[torch.LongTensor] = None, |
| past_key_values: Optional[Union[Cache, List[torch.FloatTensor]]] = None, |
| inputs_embeds: Optional[torch.FloatTensor] = None, |
| labels: Optional[torch.LongTensor] = None, |
| use_cache: Optional[bool] = None, |
| output_attentions: Optional[bool] = None, |
| output_hidden_states: Optional[bool] = None, |
| return_dict: Optional[bool] = None, |
| ) -> Union[Tuple, SequenceClassifierOutputWithPast]: |
| r""" |
| labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*): |
| Labels for computing the sequence classification/regression loss. Indices should be in `[0, ..., |
| config.num_labels - 1]`. If `config.num_labels == 1` a regression loss is computed (Mean-Square loss), If |
| `config.num_labels > 1` a classification loss is computed (Cross-Entropy). |
| """ |
| return_dict = return_dict if return_dict is not None else self.config.use_return_dict |
|
|
| transformer_outputs = self.model( |
| input_ids, |
| attention_mask=attention_mask, |
| position_ids=position_ids, |
| past_key_values=past_key_values, |
| inputs_embeds=inputs_embeds, |
| use_cache=use_cache, |
| output_attentions=output_attentions, |
| output_hidden_states=output_hidden_states, |
| return_dict=return_dict, |
| ) |
| hidden_states = transformer_outputs[0] |
| logits = self.score(hidden_states) |
|
|
| if input_ids is not None: |
| batch_size = input_ids.shape[0] |
| else: |
| batch_size = inputs_embeds.shape[0] |
|
|
| if self.config.pad_token_id is None and batch_size != 1: |
| raise ValueError("Cannot handle batch sizes > 1 if no padding token is defined.") |
| if self.config.pad_token_id is None: |
| sequence_lengths = -1 |
| else: |
| if input_ids is not None: |
| |
| sequence_lengths = torch.eq(input_ids, self.config.pad_token_id).int().argmax(-1) - 1 |
| sequence_lengths = sequence_lengths % input_ids.shape[-1] |
| sequence_lengths = sequence_lengths.to(logits.device) |
| else: |
| sequence_lengths = -1 |
|
|
| pooled_logits = logits[torch.arange(batch_size, device=logits.device), sequence_lengths] |
|
|
| loss = None |
| if labels is not None: |
| labels = labels.to(logits.device) |
| if self.config.problem_type is None: |
| if self.num_labels == 1: |
| self.config.problem_type = "regression" |
| elif self.num_labels > 1 and (labels.dtype == torch.long or labels.dtype == torch.int): |
| self.config.problem_type = "single_label_classification" |
| else: |
| self.config.problem_type = "multi_label_classification" |
|
|
| if self.config.problem_type == "regression": |
| loss_fct = MSELoss() |
| if self.num_labels == 1: |
| loss = loss_fct(pooled_logits.squeeze(), labels.squeeze()) |
| else: |
| loss = loss_fct(pooled_logits, labels) |
| elif self.config.problem_type == "single_label_classification": |
| loss_fct = CrossEntropyLoss() |
| loss = loss_fct(pooled_logits.view(-1, self.num_labels), labels.view(-1)) |
| elif self.config.problem_type == "multi_label_classification": |
| loss_fct = BCEWithLogitsLoss() |
| loss = loss_fct(pooled_logits, labels) |
| if not return_dict: |
| output = (pooled_logits,) + transformer_outputs[1:] |
| return ((loss,) + output) if loss is not None else output |
|
|
| return SequenceClassifierOutputWithPast( |
| loss=loss, |
| logits=pooled_logits, |
| past_key_values=transformer_outputs.past_key_values, |
| hidden_states=transformer_outputs.hidden_states, |
| attentions=transformer_outputs.attentions, |
| ) |
|
|
|
|
| @add_start_docstrings( |
| """ |
| The DeciLM Model transformer with a span classification head on top for extractive question-answering tasks like |
| SQuAD (a linear layer on top of the hidden-states output to compute `span start logits` and `span end logits`). |
| """, |
| DECILM_START_DOCSTRING, |
| ) |
| class DeciLMForQuestionAnswering(DeciLMPreTrainedModel): |
| base_model_prefix = "transformer" |
|
|
| |
| def __init__(self, config): |
| super().__init__(config) |
| self.transformer = DeciLMModel(config) |
| self.qa_outputs = nn.Linear(config.hidden_size, 2) |
|
|
| |
| self.post_init() |
|
|
| def get_input_embeddings(self): |
| return self.transformer.embed_tokens |
|
|
| def set_input_embeddings(self, value): |
| self.transformer.embed_tokens = value |
|
|
| @add_start_docstrings_to_model_forward(DECILM_INPUTS_DOCSTRING) |
| def forward( |
| self, |
| input_ids: Optional[torch.LongTensor] = None, |
| attention_mask: Optional[torch.FloatTensor] = None, |
| position_ids: Optional[torch.LongTensor] = None, |
| past_key_values: Optional[Union[Cache, List[torch.FloatTensor]]] = None, |
| inputs_embeds: Optional[torch.FloatTensor] = None, |
| start_positions: Optional[torch.LongTensor] = None, |
| end_positions: Optional[torch.LongTensor] = None, |
| output_attentions: Optional[bool] = None, |
| output_hidden_states: Optional[bool] = None, |
| return_dict: Optional[bool] = None, |
| ) -> Union[Tuple, QuestionAnsweringModelOutput]: |
| r""" |
| start_positions (`torch.LongTensor` of shape `(batch_size,)`, *optional*): |
| Labels for position (index) of the start of the labelled span for computing the token classification loss. |
| Positions are clamped to the length of the sequence (`sequence_length`). Position outside of the sequence |
| are not taken into account for computing the loss. |
| end_positions (`torch.LongTensor` of shape `(batch_size,)`, *optional*): |
| Labels for position (index) of the end of the labelled span for computing the token classification loss. |
| Positions are clamped to the length of the sequence (`sequence_length`). Position outside of the sequence |
| are not taken into account for computing the loss. |
| """ |
| return_dict = return_dict if return_dict is not None else self.config.use_return_dict |
|
|
| outputs = self.transformer( |
| input_ids, |
| attention_mask=attention_mask, |
| position_ids=position_ids, |
| past_key_values=past_key_values, |
| inputs_embeds=inputs_embeds, |
| output_attentions=output_attentions, |
| output_hidden_states=output_hidden_states, |
| return_dict=return_dict, |
| ) |
|
|
| sequence_output = outputs[0] |
|
|
| logits = self.qa_outputs(sequence_output) |
| start_logits, end_logits = logits.split(1, dim=-1) |
| start_logits = start_logits.squeeze(-1).contiguous() |
| end_logits = end_logits.squeeze(-1).contiguous() |
|
|
| total_loss = None |
| if start_positions is not None and end_positions is not None: |
| |
| if len(start_positions.size()) > 1: |
| start_positions = start_positions.squeeze(-1).to(start_logits.device) |
| if len(end_positions.size()) > 1: |
| end_positions = end_positions.squeeze(-1).to(end_logits.device) |
| |
| ignored_index = start_logits.size(1) |
| start_positions = start_positions.clamp(0, ignored_index) |
| end_positions = end_positions.clamp(0, ignored_index) |
|
|
| loss_fct = CrossEntropyLoss(ignore_index=ignored_index) |
| start_loss = loss_fct(start_logits, start_positions) |
| end_loss = loss_fct(end_logits, end_positions) |
| total_loss = (start_loss + end_loss) / 2 |
|
|
| if not return_dict: |
| output = (start_logits, end_logits) + outputs[2:] |
| return ((total_loss,) + output) if total_loss is not None else output |
|
|
| return QuestionAnsweringModelOutput( |
| loss=total_loss, |
| start_logits=start_logits, |
| end_logits=end_logits, |
| hidden_states=outputs.hidden_states, |
| attentions=outputs.attentions, |
| ) |
|
|
|
|
| @add_start_docstrings( |
| """ |
| The DeciLM Model transformer with a token classification head on top (a linear layer on top of the hidden-states |
| output) e.g. for Named-Entity-Recognition (NER) tasks. |
| """, |
| DECILM_START_DOCSTRING, |
| ) |
| class DeciLMForTokenClassification(DeciLMPreTrainedModel): |
| def __init__(self, config): |
| super().__init__(config) |
| self.num_labels = config.num_labels |
| self.model = DeciLMModel(config) |
| if getattr(config, "classifier_dropout", None) is not None: |
| classifier_dropout = config.classifier_dropout |
| elif getattr(config, "hidden_dropout", None) is not None: |
| classifier_dropout = config.hidden_dropout |
| else: |
| classifier_dropout = 0.1 |
| self.dropout = nn.Dropout(classifier_dropout) |
| self.score = nn.Linear(config.hidden_size, config.num_labels) |
|
|
| |
| self.post_init() |
|
|
| def get_input_embeddings(self): |
| return self.model.embed_tokens |
|
|
| def set_input_embeddings(self, value): |
| self.model.embed_tokens = value |
|
|
| @add_start_docstrings_to_model_forward(DECILM_INPUTS_DOCSTRING) |
| def forward( |
| self, |
| input_ids: Optional[torch.LongTensor] = None, |
| attention_mask: Optional[torch.Tensor] = None, |
| position_ids: Optional[torch.LongTensor] = None, |
| past_key_values: Optional[List[torch.FloatTensor]] = None, |
| inputs_embeds: Optional[torch.FloatTensor] = None, |
| labels: Optional[torch.LongTensor] = None, |
| use_cache: Optional[bool] = None, |
| output_attentions: Optional[bool] = None, |
| output_hidden_states: Optional[bool] = None, |
| return_dict: Optional[bool] = None, |
| ) -> Union[Tuple, TokenClassifierOutput]: |
| r""" |
| labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*): |
| Labels for computing the sequence classification/regression loss. Indices should be in `[0, ..., |
| config.num_labels - 1]`. If `config.num_labels == 1` a regression loss is computed (Mean-Square loss), If |
| `config.num_labels > 1` a classification loss is computed (Cross-Entropy). |
| """ |
| return_dict = return_dict if return_dict is not None else self.config.use_return_dict |
|
|
| outputs = self.model( |
| input_ids, |
| attention_mask=attention_mask, |
| position_ids=position_ids, |
| past_key_values=past_key_values, |
| inputs_embeds=inputs_embeds, |
| use_cache=use_cache, |
| output_attentions=output_attentions, |
| output_hidden_states=output_hidden_states, |
| return_dict=return_dict, |
| ) |
| sequence_output = outputs[0] |
| sequence_output = self.dropout(sequence_output) |
| logits = self.score(sequence_output) |
|
|
| loss = None |
| if labels is not None: |
| loss_fct = CrossEntropyLoss() |
| loss = loss_fct(logits.view(-1, self.num_labels), labels.view(-1)) |
|
|
| if not return_dict: |
| output = (logits,) + outputs[2:] |
| return ((loss,) + output) if loss is not None else output |
|
|
| return TokenClassifierOutput( |
| loss=loss, |
| logits=logits, |
| hidden_states=outputs.hidden_states, |
| attentions=outputs.attentions, |
| ) |
|
|
|
|
| |
| |
| |
|
|
|
|
| def _ffn_mult_to_intermediate_size(ffn_mult: float, n_embd: int) -> int: |
| |
| intermediate_size = int(2 * ffn_mult * n_embd / 3) |
| return _find_multiple(intermediate_size, 256) |
|
|
|
|
| def _find_multiple(n: int, k: int) -> int: |
| |
| if n % k == 0: |
| return n |
| return n + k - (n % k) |
|
|
|
|
| class DeciLMLinearMLP(nn.Module): |
| |
| def __init__(self, |
| config: DeciLMConfig, |
| ): |
| super().__init__() |
| self.linear_mlp = nn.Linear(in_features=config.hidden_size, |
| out_features=config.hidden_size, |
| bias=False) |
|
|
| def forward(self, x: torch.Tensor) -> torch.Tensor: |
| return self.linear_mlp.forward(x) |
|
|
|
|
| class DeciLMLinearAttention(nn.Module): |
| |
| def __init__(self, |
| config: DeciLMConfig, |
| ): |
| super().__init__() |
| self.linear_attn = nn.Linear(in_features=config.hidden_size, |
| out_features=config.hidden_size, |
| bias=False) |
|
|
| def forward(self, x: torch.Tensor) -> torch.Tensor: |
| return self.linear_attn.forward(x) |
|
|
|
|
| def sparsity_backward_hook(*args, **kwargs): |
| raise NotImplementedError("No support for sparsity when training HF DeciLM (inference is ok though)") |
|
|