| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| """ PyTorch HybriDNA model.""" |
| import inspect |
| import math |
| from typing import Any, Dict, List, Optional, Tuple, Union |
|
|
| import torch |
| import torch.nn.functional as F |
| import torch.utils.checkpoint |
| from torch import nn |
| from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss |
| from einops import rearrange, repeat |
|
|
| from transformers.activations import ACT2FN |
| from transformers.cache_utils import DynamicCache |
| from transformers.modeling_attn_mask_utils import ( |
| AttentionMaskConverter, |
| ) |
| from transformers.modeling_outputs import ( |
| BaseModelOutputWithPast, |
| CausalLMOutputWithPast, |
| SequenceClassifierOutputWithPast, |
| ) |
| from transformers.modeling_utils import PreTrainedModel |
| from transformers.generation.utils import GenerationMixin |
| from transformers.utils import ( |
| add_start_docstrings, |
| add_start_docstrings_to_model_forward, |
| is_flash_attn_greater_or_equal_2_10, |
| logging, |
| replace_return_docstrings, |
| ) |
| from transformers.utils.import_utils import ( |
| is_causal_conv1d_available, |
| is_flash_attn_2_available, |
| is_mamba_ssm_available, |
| ) |
| from hf.configuration_hybridna import HybriDNAConfig |
| from mamba_ssm.ops.triton.layernorm_gated import RMSNorm as RMSNormGated |
|
|
| |
| try: |
| from flash_attn import flash_attn_func, flash_attn_varlen_func |
| from flash_attn.bert_padding import index_first_axis, pad_input, unpad_input |
|
|
| _flash_supports_window_size = "window_size" in list(inspect.signature(flash_attn_func).parameters) |
| except ImportError: |
| pass |
|
|
|
|
| |
| try: |
| from mamba_ssm.ops.selective_scan_interface import mamba_inner_fn, selective_scan_fn |
| from mamba_ssm.ops.triton.selective_state_update import selective_state_update |
| from mamba_ssm.ops.triton.ssd_combined import mamba_chunk_scan_combined, mamba_split_conv1d_scan_combined |
| except ImportError: |
| selective_state_update, selective_scan_fn, mamba_inner_fn = None, None, None |
|
|
| |
| try: |
| from causal_conv1d import causal_conv1d_fn, causal_conv1d_update |
| except ImportError: |
| causal_conv1d_update, causal_conv1d_fn = None, None |
|
|
| is_fast_path_available = all( |
| (selective_state_update, selective_scan_fn, causal_conv1d_fn, causal_conv1d_update, mamba_inner_fn) |
| ) |
|
|
|
|
| logger = logging.get_logger(__name__) |
|
|
| _CONFIG_FOR_DOC = "HybriDNAConfig" |
|
|
|
|
| |
| def _get_unpad_data(attention_mask): |
| seqlens_in_batch = attention_mask.sum(dim=-1, dtype=torch.int32) |
| indices = torch.nonzero(attention_mask.flatten(), as_tuple=False).flatten() |
| max_seqlen_in_batch = seqlens_in_batch.max().item() |
| cu_seqlens = F.pad(torch.cumsum(seqlens_in_batch, dim=0, dtype=torch.int32), (1, 0)) |
| return ( |
| indices, |
| cu_seqlens, |
| max_seqlen_in_batch, |
| ) |
|
|
|
|
| |
| class HybriDNARMSNorm(nn.Module): |
| def __init__(self, hidden_size, eps=1e-6): |
| """ |
| HybriDNARMSNorm is equivalent to T5LayerNorm |
| """ |
| super().__init__() |
| self.weight = nn.Parameter(torch.ones(hidden_size)) |
| self.variance_epsilon = eps |
|
|
| def forward(self, hidden_states): |
| input_dtype = hidden_states.dtype |
| hidden_states = hidden_states.to(torch.float32) |
| variance = hidden_states.pow(2).mean(-1, keepdim=True) |
| hidden_states = hidden_states * torch.rsqrt(variance + self.variance_epsilon) |
| return self.weight * hidden_states.to(input_dtype) |
|
|
|
|
| |
| def repeat_kv(hidden_states: torch.Tensor, n_rep: int) -> torch.Tensor: |
| """ |
| This is the equivalent of torch.repeat_interleave(x, dim=1, repeats=n_rep). The hidden states go from (batch, |
| num_key_value_heads, seqlen, head_dim) to (batch, num_attention_heads, seqlen, head_dim) |
| """ |
| batch, num_key_value_heads, slen, head_dim = hidden_states.shape |
| if n_rep == 1: |
| return hidden_states |
| hidden_states = hidden_states[:, :, None, :, :].expand(batch, num_key_value_heads, n_rep, slen, head_dim) |
| return hidden_states.reshape(batch, num_key_value_heads * n_rep, slen, head_dim) |
|
|
|
|
| class HybridMambaAttentionDynamicCache(DynamicCache): |
| """ |
| A dynamic cache that can handle both the attention cache (which has a seq_len dimension) and the mamba cache |
| (which has a constant shape regardless of seq_len). |
| This cache has two sets of lists of tensors: `key_cache` and `value_cache` for attention cache and `conv_states` |
| and `ssm_states` for mamba cache. Each of these lists has `num_layers` tensors. The expected shape for each tensor |
| For attention layers, `key_cache` and `value_cache` have a shape of `(batch_size, num_heads, seq_len, head_dim)`, |
| while `conv_states` and `ssm_states` have a shape of `(batch_size, 0)` (empty tensors). |
| For mamba layers, `key_cache` and `value_cache` have a shape of `(batch_size, 0)` (empty tensors), |
| while `conv_states` represents the convolution state and has a shape of `(batch_size, d_inner, d_conv)`, |
| and `ssm_states` represents the ssm state and has a shape of `(batch_size, d_inner, d_state)`. |
| """ |
|
|
| def __init__(self, config, batch_size, dtype=torch.float16, device=None): |
| self.dtype = dtype |
| self.layers_block_type = config.layers_block_type |
| self.has_previous_state = False |
| self.seq_offset = 0 |
| intermediate_size = config.mamba_expand * config.hidden_size |
| ssm_state_size = config.mamba_d_state |
| conv_kernel_size = config.mamba_d_conv |
| |
| |
| |
| num_heads = config.intermediate_size // config.head_dim |
| head_dim = config.head_dim |
| conv1d_dim = intermediate_size + 2 * ssm_state_size |
| self.conv_states = [] |
| self.ssm_states = [] |
| for i in range(config.num_hidden_layers): |
| if self.layers_block_type[i] == "mamba": |
| self.conv_states += [ |
| torch.zeros(batch_size, conv1d_dim, conv_kernel_size, device=device, dtype=dtype) |
| ] |
| self.ssm_states += [ |
| torch.zeros(batch_size, num_heads, head_dim, ssm_state_size, device=device, dtype=dtype) |
| ] |
| else: |
| self.conv_states += [torch.tensor([[]] * batch_size, device=device)] |
| self.ssm_states += [torch.tensor([[]] * batch_size, device=device)] |
|
|
| self.key_cache = [torch.tensor([[]] * batch_size, device=device) for _ in range(config.num_hidden_layers)] |
| self.value_cache = [torch.tensor([[]] * batch_size, device=device) for _ in range(config.num_hidden_layers)] |
|
|
| def update( |
| self, |
| key_states: torch.Tensor, |
| value_states: torch.Tensor, |
| layer_idx: int, |
| cache_kwargs: Optional[Dict[str, Any]] = None, |
| ) -> Tuple[torch.Tensor, torch.Tensor]: |
| |
| |
| if self.key_cache[layer_idx].shape[-1] == 0: |
| self.key_cache[layer_idx] = key_states.contiguous() |
| self.value_cache[layer_idx] = value_states.contiguous() |
| else: |
| self.key_cache[layer_idx] = torch.cat([self.key_cache[layer_idx], key_states.contiguous()], dim=2) |
| self.value_cache[layer_idx] = torch.cat([self.value_cache[layer_idx], value_states.contiguous()], dim=2) |
|
|
| return self.key_cache[layer_idx], self.value_cache[layer_idx] |
|
|
| def reorder_cache(self, beam_idx: torch.LongTensor): |
| """Reorders the cache for beam search, given the selected beam indices.""" |
| for layer_idx in range(len(self.key_cache)): |
| device = self.key_cache[layer_idx].device |
| self.key_cache[layer_idx] = self.key_cache[layer_idx].index_select(0, beam_idx.to(device)) |
| device = self.value_cache[layer_idx].device |
| self.value_cache[layer_idx] = self.value_cache[layer_idx].index_select(0, beam_idx.to(device)) |
|
|
| device = self.conv_states[layer_idx].device |
| self.conv_states[layer_idx] = self.conv_states[layer_idx].index_select(0, beam_idx.to(device)) |
| device = self.ssm_states[layer_idx].device |
| self.ssm_states[layer_idx] = self.ssm_states[layer_idx].index_select(0, beam_idx.to(device)) |
|
|
| def __len__(self): |
| """Return the number of layers in the cache.""" |
| return len(self.key_cache) |
|
|
| def to_legacy_cache(self) -> Tuple[Tuple[torch.Tensor], Tuple[torch.Tensor]]: |
| raise NotImplementedError("HybridMambaAttentionDynamicCache does not have a legacy cache equivalent.") |
|
|
| @classmethod |
| def from_legacy_cache(cls, past_key_values: Optional[Tuple[Tuple[torch.Tensor]]] = None) -> "DynamicCache": |
| raise NotImplementedError("HybridMambaAttentionDynamicCache does not have a legacy cache equivalent.") |
|
|
|
|
| |
| class HybriDNAAttention(nn.Module): |
| """ |
| Multi-headed attention from 'Attention Is All You Need' paper. Modified to use sliding window attention: Longformer |
| and "Generating Long Sequences with Sparse Transformers". |
| """ |
|
|
| def __init__(self, config: HybriDNAConfig, layer_idx: Optional[int] = None): |
| super().__init__() |
| self.config = config |
| self.layer_idx = layer_idx |
| if layer_idx is None: |
| logger.warning_once( |
| f"Instantiating {self.__class__.__name__} without passing a `layer_idx` is not recommended and will " |
| "lead to errors during the forward call if caching is used. Please make sure to provide a `layer_idx` " |
| "when creating this class." |
| ) |
|
|
| self.hidden_size = config.hidden_size |
| self.num_heads = config.num_attention_heads |
| self.head_dim = self.hidden_size // self.num_heads |
| self.num_key_value_heads = config.num_key_value_heads |
| self.num_key_value_groups = self.num_heads // self.num_key_value_heads |
| self.is_causal = True |
| self.attention_dropout = config.attention_dropout |
|
|
| if (self.head_dim * self.num_heads) != self.hidden_size: |
| raise ValueError( |
| f"hidden_size must be divisible by num_heads (got `hidden_size`: {self.hidden_size}" |
| f" and `num_heads`: {self.num_heads})." |
| ) |
| self.q_proj = nn.Linear(self.hidden_size, self.num_heads * self.head_dim, bias=False) |
| self.k_proj = nn.Linear(self.hidden_size, self.num_key_value_heads * self.head_dim, bias=False) |
| self.v_proj = nn.Linear(self.hidden_size, self.num_key_value_heads * self.head_dim, bias=False) |
| self.o_proj = nn.Linear(self.num_heads * self.head_dim, self.hidden_size, bias=False) |
|
|
| def forward( |
| self, |
| hidden_states: torch.Tensor, |
| attention_mask: Optional[torch.Tensor] = None, |
| position_ids: Optional[torch.LongTensor] = None, |
| past_key_value: Optional[HybridMambaAttentionDynamicCache] = None, |
| output_attentions: bool = False, |
| use_cache: bool = False, |
| cache_position: Optional[torch.LongTensor] = None, |
| ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]: |
| bsz, q_len, _ = hidden_states.size() |
|
|
| query_states = self.q_proj(hidden_states) |
| key_states = self.k_proj(hidden_states) |
| value_states = self.v_proj(hidden_states) |
|
|
| query_states = query_states.view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2) |
| key_states = key_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2) |
| value_states = value_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2) |
|
|
| if past_key_value is not None: |
| key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx) |
|
|
| |
| key_states = repeat_kv(key_states, self.num_key_value_groups) |
| value_states = repeat_kv(value_states, self.num_key_value_groups) |
|
|
| attn_weights = torch.matmul(query_states, key_states.transpose(2, 3)) / math.sqrt(self.head_dim) |
|
|
| if attention_mask is not None: |
| causal_mask = attention_mask[:, :, :, : key_states.shape[-2]] |
| attn_weights = attn_weights + causal_mask |
|
|
| |
| attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query_states.dtype) |
| attn_weights = nn.functional.dropout(attn_weights, p=self.attention_dropout, training=self.training) |
| attn_output = torch.matmul(attn_weights, value_states) |
|
|
| if attn_output.size() != (bsz, self.num_heads, q_len, self.head_dim): |
| raise ValueError( |
| f"`attn_output` should be of size {(bsz, self.num_heads, q_len, self.head_dim)}, but is" |
| f" {attn_output.size()}" |
| ) |
|
|
| attn_output = attn_output.transpose(1, 2).contiguous() |
| attn_output = attn_output.reshape(bsz, q_len, self.hidden_size) |
|
|
| attn_output = self.o_proj(attn_output) |
|
|
| if not output_attentions: |
| attn_weights = None |
|
|
| return attn_output, attn_weights, past_key_value |
|
|
|
|
| |
| class HybriDNAFlashAttention2(HybriDNAAttention): |
| """ |
| HybriDNA flash attention module. This module inherits from `HybriDNAAttention` as the weights of the module stays |
| untouched. The only required change would be on the forward pass where it needs to correctly call the public API of |
| flash attention and deal with padding tokens in case the input contains any of them. |
| """ |
|
|
| |
| def __init__(self, *args, **kwargs): |
| super().__init__(*args, **kwargs) |
|
|
| |
| |
| |
| self._flash_attn_uses_top_left_mask = not is_flash_attn_greater_or_equal_2_10() |
|
|
| def forward( |
| self, |
| hidden_states: torch.Tensor, |
| attention_mask: Optional[torch.Tensor] = None, |
| position_ids: Optional[torch.LongTensor] = None, |
| past_key_value: Optional[HybridMambaAttentionDynamicCache] = None, |
| output_attentions: bool = False, |
| use_cache: bool = False, |
| cache_position: Optional[torch.LongTensor] = None, |
| **kwargs, |
| ): |
| bsz, q_len, _ = hidden_states.size() |
|
|
| query_states = self.q_proj(hidden_states) |
| key_states = self.k_proj(hidden_states) |
| value_states = self.v_proj(hidden_states) |
|
|
| |
| |
| |
| query_states = query_states.view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2) |
| key_states = key_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2) |
| value_states = value_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2) |
|
|
| kv_seq_len = cache_position[-1] |
|
|
| use_sliding_windows = ( |
| _flash_supports_window_size |
| and getattr(self.config, "sliding_window", None) is not None |
| and kv_seq_len > self.config.sliding_window |
| ) |
|
|
| if not _flash_supports_window_size: |
| logger.warning_once( |
| "The current flash attention version does not support sliding window attention, for a more memory efficient implementation" |
| " make sure to upgrade flash-attn library." |
| ) |
|
|
| if past_key_value is not None: |
| |
| cache_has_contents = cache_position[0] > 0 |
| if ( |
| getattr(self.config, "sliding_window", None) is not None |
| and kv_seq_len > self.config.sliding_window |
| and cache_has_contents |
| ): |
| slicing_tokens = 1 - self.config.sliding_window |
|
|
| past_key = past_key_value[self.layer_idx][0] |
| past_value = past_key_value[self.layer_idx][1] |
|
|
| past_key = past_key[:, :, slicing_tokens:, :].contiguous() |
| past_value = past_value[:, :, slicing_tokens:, :].contiguous() |
|
|
| if past_key.shape[-2] != self.config.sliding_window - 1: |
| raise ValueError( |
| f"past key must have a shape of (`batch_size, num_heads, self.config.sliding_window-1, head_dim`), got" |
| f" {past_key.shape}" |
| ) |
|
|
| if attention_mask is not None: |
| attention_mask = attention_mask[:, slicing_tokens:] |
| attention_mask = torch.cat([attention_mask, torch.ones_like(attention_mask[:, -1:])], dim=-1) |
|
|
| key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx) |
|
|
| |
| key_states = repeat_kv(key_states, self.num_key_value_groups) |
| value_states = repeat_kv(value_states, self.num_key_value_groups) |
| dropout_rate = 0.0 if not self.training else self.attention_dropout |
|
|
| |
| |
| |
| input_dtype = query_states.dtype |
| if input_dtype == torch.float32: |
| if torch.is_autocast_enabled(): |
| target_dtype = torch.get_autocast_gpu_dtype() |
| |
| elif hasattr(self.config, "_pre_quantization_dtype"): |
| target_dtype = self.config._pre_quantization_dtype |
| else: |
| target_dtype = self.q_proj.weight.dtype |
|
|
| logger.warning_once( |
| f"The input hidden states seems to be silently casted in float32, this might be related to" |
| f" the fact you have upcasted embedding or layer norm layers in float32. We will cast back the input in" |
| f" {target_dtype}." |
| ) |
|
|
| query_states = query_states.to(target_dtype) |
| key_states = key_states.to(target_dtype) |
| value_states = value_states.to(target_dtype) |
|
|
| |
| query_states = query_states.transpose(1, 2) |
| key_states = key_states.transpose(1, 2) |
| value_states = value_states.transpose(1, 2) |
|
|
| attn_output = self._flash_attention_forward( |
| query_states, |
| key_states, |
| value_states, |
| attention_mask, |
| q_len, |
| dropout=dropout_rate, |
| use_sliding_windows=use_sliding_windows, |
| ) |
|
|
| attn_output = attn_output.reshape(bsz, q_len, self.hidden_size).contiguous() |
| attn_output = self.o_proj(attn_output) |
|
|
| if not output_attentions: |
| attn_weights = None |
|
|
| return attn_output, attn_weights, past_key_value |
|
|
| def _flash_attention_forward( |
| self, |
| query_states, |
| key_states, |
| value_states, |
| attention_mask, |
| query_length, |
| dropout=0.0, |
| softmax_scale=None, |
| use_sliding_windows=False, |
| ): |
| """ |
| Calls the forward method of Flash Attention - if the input hidden states contain at least one padding token |
| first unpad the input, then computes the attention scores and pad the final attention scores. |
| Args: |
| query_states (`torch.Tensor`): |
| Input query states to be passed to Flash Attention API |
| key_states (`torch.Tensor`): |
| Input key states to be passed to Flash Attention API |
| value_states (`torch.Tensor`): |
| Input value states to be passed to Flash Attention API |
| attention_mask (`torch.Tensor`): |
| The padding mask - corresponds to a tensor of size `(batch_size, seq_len)` where 0 stands for the |
| position of padding tokens and 1 for the position of non-padding tokens. |
| dropout (`float`, *optional*): |
| Attention dropout |
| softmax_scale (`float`, *optional*): |
| The scaling of QK^T before applying softmax. Default to 1 / sqrt(head_dim) |
| use_sliding_windows (`bool`, *optional*): |
| Whether to activate sliding window attention. |
| """ |
| if not self._flash_attn_uses_top_left_mask: |
| causal = self.is_causal |
| else: |
| |
| causal = self.is_causal and query_length != 1 |
|
|
| |
| if attention_mask is not None: |
| batch_size = query_states.shape[0] |
| query_states, key_states, value_states, indices_q, cu_seq_lens, max_seq_lens = self._upad_input( |
| query_states, key_states, value_states, attention_mask, query_length |
| ) |
|
|
| cu_seqlens_q, cu_seqlens_k = cu_seq_lens |
| max_seqlen_in_batch_q, max_seqlen_in_batch_k = max_seq_lens |
|
|
| if not use_sliding_windows: |
| attn_output_unpad = flash_attn_varlen_func( |
| query_states, |
| key_states, |
| value_states, |
| cu_seqlens_q=cu_seqlens_q, |
| cu_seqlens_k=cu_seqlens_k, |
| max_seqlen_q=max_seqlen_in_batch_q, |
| max_seqlen_k=max_seqlen_in_batch_k, |
| dropout_p=dropout, |
| softmax_scale=softmax_scale, |
| causal=causal, |
| ) |
| else: |
| attn_output_unpad = flash_attn_varlen_func( |
| query_states, |
| key_states, |
| value_states, |
| cu_seqlens_q=cu_seqlens_q, |
| cu_seqlens_k=cu_seqlens_k, |
| max_seqlen_q=max_seqlen_in_batch_q, |
| max_seqlen_k=max_seqlen_in_batch_k, |
| dropout_p=dropout, |
| softmax_scale=softmax_scale, |
| causal=causal, |
| window_size=(self.config.sliding_window, self.config.sliding_window), |
| ) |
|
|
| attn_output = pad_input(attn_output_unpad, indices_q, batch_size, query_length) |
| else: |
| if not use_sliding_windows: |
| attn_output = flash_attn_func( |
| query_states, |
| key_states, |
| value_states, |
| dropout, |
| softmax_scale=softmax_scale, |
| causal=causal, |
| ) |
| else: |
| attn_output = flash_attn_func( |
| query_states, |
| key_states, |
| value_states, |
| dropout, |
| softmax_scale=softmax_scale, |
| causal=causal, |
| window_size=(self.config.sliding_window, self.config.sliding_window), |
| ) |
|
|
| return attn_output |
|
|
| |
| def _upad_input(self, query_layer, key_layer, value_layer, attention_mask, query_length): |
| batch_size, kv_seq_len, num_heads, head_dim = key_layer.shape |
|
|
| |
| |
| if kv_seq_len != attention_mask.shape[-1]: |
| attention_mask_num_tokens = attention_mask.shape[-1] |
| attention_mask = attention_mask[:, attention_mask_num_tokens - kv_seq_len :] |
|
|
| indices_k, cu_seqlens_k, max_seqlen_in_batch_k = _get_unpad_data(attention_mask) |
|
|
| key_layer = index_first_axis(key_layer.reshape(batch_size * kv_seq_len, num_heads, head_dim), indices_k) |
| value_layer = index_first_axis(value_layer.reshape(batch_size * kv_seq_len, num_heads, head_dim), indices_k) |
|
|
| if query_length == kv_seq_len: |
| query_layer = index_first_axis( |
| query_layer.reshape(batch_size * kv_seq_len, num_heads, head_dim), indices_k |
| ) |
| cu_seqlens_q = cu_seqlens_k |
| max_seqlen_in_batch_q = max_seqlen_in_batch_k |
| indices_q = indices_k |
| elif query_length == 1: |
| max_seqlen_in_batch_q = 1 |
| cu_seqlens_q = torch.arange( |
| batch_size + 1, dtype=torch.int32, device=query_layer.device |
| ) |
| indices_q = cu_seqlens_q[:-1] |
| query_layer = query_layer.squeeze(1) |
| else: |
| |
| attention_mask = attention_mask[:, -query_length:] |
| query_layer, indices_q, cu_seqlens_q, max_seqlen_in_batch_q = unpad_input(query_layer, attention_mask) |
|
|
| return ( |
| query_layer, |
| key_layer, |
| value_layer, |
| indices_q, |
| (cu_seqlens_q, cu_seqlens_k), |
| (max_seqlen_in_batch_q, max_seqlen_in_batch_k), |
| ) |
|
|
|
|
| |
| class HybriDNASdpaAttention(HybriDNAAttention): |
| """ |
| HybriDNA attention module using torch.nn.functional.scaled_dot_product_attention. This module inherits from |
| `HybriDNAAttention` as the weights of the module stays untouched. The only changes are on the forward pass to adapt to |
| SDPA API. |
| """ |
|
|
| |
| def forward( |
| self, |
| hidden_states: torch.Tensor, |
| attention_mask: Optional[torch.Tensor] = None, |
| position_ids: Optional[torch.LongTensor] = None, |
| past_key_value: Optional[HybridMambaAttentionDynamicCache] = None, |
| output_attentions: bool = False, |
| use_cache: bool = False, |
| cache_position: Optional[torch.LongTensor] = None, |
| ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]: |
| if output_attentions: |
| |
| logger.warning_once( |
| "HybriDNAModel is using HybriDNASdpaAttention, but `torch.nn.functional.scaled_dot_product_attention` does not support `output_attentions=True`. Falling back to the manual attention implementation, " |
| 'but specifying the manual implementation will be required from Transformers version v5.0.0 onwards. This warning can be removed using the argument `attn_implementation="eager"` when loading the model.' |
| ) |
| return super().forward( |
| hidden_states=hidden_states, |
| attention_mask=attention_mask, |
| position_ids=position_ids, |
| past_key_value=past_key_value, |
| output_attentions=output_attentions, |
| use_cache=use_cache, |
| ) |
|
|
| bsz, q_len, _ = hidden_states.size() |
|
|
| query_states = self.q_proj(hidden_states) |
| key_states = self.k_proj(hidden_states) |
| value_states = self.v_proj(hidden_states) |
|
|
| query_states = query_states.view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2) |
| key_states = key_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2) |
| value_states = value_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2) |
|
|
| if past_key_value is not None: |
| key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx) |
|
|
| key_states = repeat_kv(key_states, self.num_key_value_groups) |
| value_states = repeat_kv(value_states, self.num_key_value_groups) |
|
|
| causal_mask = attention_mask |
| if attention_mask is not None: |
| causal_mask = causal_mask[:, :, :, : key_states.shape[-2]] |
|
|
| |
| |
| if query_states.device.type == "cuda" and attention_mask is not None: |
| query_states = query_states.contiguous() |
| key_states = key_states.contiguous() |
| value_states = value_states.contiguous() |
| causal_mask = causal_mask.contiguous() |
|
|
| attn_output = torch.nn.functional.scaled_dot_product_attention( |
| query_states, |
| key_states, |
| value_states, |
| attn_mask=causal_mask, |
| dropout_p=self.attention_dropout if self.training else 0.0, |
| |
| is_causal=self.is_causal and attention_mask is None and q_len > 1, |
| ) |
|
|
| attn_output = attn_output.transpose(1, 2).contiguous() |
| attn_output = attn_output.view(bsz, q_len, self.hidden_size) |
|
|
| attn_output = self.o_proj(attn_output) |
|
|
| return attn_output, None, past_key_value |
|
|
|
|
| HYBRIDNA_ATTENTION_CLASSES = { |
| "eager": HybriDNAAttention, |
| "flash_attention_2": HybriDNAFlashAttention2, |
| "sdpa": HybriDNASdpaAttention, |
| } |
|
|
| class HybriDNAMamba2RMSNorm(nn.Module): |
| def __init__(self, hidden_size, eps=1e-6, normalize=False): |
| """ |
| HybriDNAMamba2RMSNorm is equivalent to T5LayerNorm and LlamaRMSNorm but with optional residual normalizing |
| """ |
| super().__init__() |
| self.weight = nn.Parameter(torch.ones(hidden_size)) |
| self.eps = eps |
| self.normalize = normalize |
|
|
| def forward(self, hidden_states, residual=None): |
| input_dtype = hidden_states.dtype |
| hidden_states = hidden_states.to(torch.float32) |
|
|
| |
| if residual is not None and self.normalize: |
| hidden_states = hidden_states * nn.functional.silu(residual.to(torch.float32)) |
|
|
| variance = hidden_states.pow(2).mean(-1, keepdim=True) |
| hidden_states = hidden_states * torch.rsqrt(variance + self.eps) |
| hidden_states = hidden_states * self.weight |
|
|
| return hidden_states.to(input_dtype) |
|
|
| class HybriDNAMamba2Mixer(nn.Module): |
| """ |
| Using the found relation to the attention mechanism under certain conditions (State-Space-Duality SSD), |
| we use the Multi-input SSM which can be seen as a counterpart to the Multi-value Attention with analogues: |
| - X ~= V |
| - B ~= Q |
| - C ~= K |
| - A (1-SS(a)) ~= Attention Mask |
| |
| For an overview, see the mamba2 paper, section 6, figure 4. |
| """ |
|
|
| def __init__(self, config: HybriDNAConfig, layer_idx: int): |
| super().__init__() |
| self.hidden_size = config.hidden_size |
| self.ssm_state_size = config.mamba_d_state |
| self.conv_kernel_size = config.mamba_d_conv |
| self.intermediate_size = config.mamba_expand * config.hidden_size |
| self.head_dim = config.head_dim |
| self.num_heads = config.intermediate_size // self.head_dim |
| self.chunk_size = config.chunk_size |
| self.dt_min = 0 |
| self.dt_max = float("inf") |
| self.layer_idx = layer_idx |
| self.use_bias = config.mamba_proj_bias |
| self.use_conv_bias = config.mamba_conv_bias |
| self.use_triton_kernels = config.use_mamba_kernels |
|
|
| |
| self.in_proj = nn.Linear( |
| in_features=self.hidden_size, |
| out_features=2 * (self.intermediate_size + self.ssm_state_size) + self.num_heads, |
| bias=self.use_bias, |
| ) |
|
|
| conv1d_dim = self.intermediate_size + 2 * self.ssm_state_size |
| self.conv1d = nn.Conv1d( |
| in_channels=conv1d_dim, |
| out_channels=conv1d_dim, |
| bias=self.use_conv_bias, |
| kernel_size=config.mamba_d_conv, |
| groups=conv1d_dim, |
| padding=config.mamba_d_conv - 1, |
| ) |
|
|
| self.activation = config.hidden_act |
| self.act = ACT2FN[config.hidden_act] |
|
|
| |
| self.dt_bias = nn.Parameter(torch.rand(size=(self.num_heads,))) |
|
|
| |
| A = torch.empty(self.num_heads, dtype=torch.float32).uniform_(1,16) |
| self.A_log = nn.Parameter(torch.log(A)) |
|
|
| |
| self.D = nn.Parameter(torch.ones(self.num_heads)) |
|
|
| |
| self.norm = HybriDNAMamba2RMSNorm(self.intermediate_size, eps=1e-5, normalize=True) |
|
|
| self.out_proj = nn.Linear(self.intermediate_size, self.hidden_size, bias=self.use_bias) |
|
|
| def _conv1d(self, xBC, seq_len, use_triton_kernels, cache, cached_start, cached_forward): |
| |
| if cached_start: |
| xBC_t = rearrange(xBC, "b l d -> b d l") |
| cache.conv_states[self.layer_idx].copy_( |
| nn.functional.pad(xBC_t, (self.conv_kernel_size - xBC_t.shape[-1], 0)) |
| ) |
|
|
| if is_fast_path_available and use_triton_kernels: |
| if cached_forward: |
| |
| xBC = causal_conv1d_update( |
| xBC.squeeze(1), |
| cache.conv_states[self.layer_idx], |
| rearrange(self.conv1d.weight, "d 1 w -> d w"), |
| self.conv1d.bias, |
| self.activation, |
| ) |
| xBC = xBC.unsqueeze(1) |
| else: |
| xBC = causal_conv1d_fn( |
| xBC.transpose(1, 2), |
| rearrange(self.conv1d.weight, "d 1 w -> d w"), |
| bias=self.conv1d.bias, |
| activation=self.activation, |
| ).transpose(1, 2) |
| else: |
| if cached_forward: |
| cache.conv_states[self.layer_idx].copy_( |
| torch.roll(cache.conv_states[self.layer_idx], shifts=-1, dims=-1) |
| ) |
| cache.conv_states[self.layer_idx][:, :, -1] = xBC |
| xBC = torch.sum( |
| cache.conv_states[self.layer_idx] * rearrange(self.conv1d.weight, "d 1 w -> d w"), dim=-1 |
| ) |
| if self.conv1d.bias is not None: |
| xBC = xBC + self.conv1d.bias |
| xBC = self.act(xBC) |
| else: |
| xBC = self.act(self.conv1d(xBC.transpose(1, 2))[..., :seq_len].transpose(1, 2)) |
|
|
| return xBC |
|
|
| def _ssd_naive(self, x, dt, A, B, C, chunk_size, dt_min, dt_max, initial_states=None, return_final_states=False): |
| """ |
| Arguments: |
| x: (batch_size, seq_len, num_heads, head_dim) |
| dt: (batch_size, seq_len, num_heads) |
| A: (num_heads) |
| B: (batch_size, seq_len, num_heads, ssm_state_size) |
| C: (batch_size, seq_len, num_heads, ssm_state_size) |
| Return: |
| y: (batch_size, seq_len, num_heads, head_dim) |
| """ |
|
|
| def pad_by_size(x, pad_size): |
| """ |
| Padding x tensor with `pad_size` on the seq_len dim (dim=1) |
| |
| Assumes that we only have tensors of either size 4 or 3 |
| """ |
| assert 2 < len(x.shape) < 5 |
|
|
| pad_shape = (0, 0, 0, 0, 0, pad_size, 0, 0) if len(x.shape) == 4 else (0, 0, 0, pad_size, 0, 0) |
|
|
| return nn.functional.pad(x, pad_shape, mode="constant", value=0) |
|
|
| def segsum(x): |
| """ |
| More stable segment sum calculation |
| """ |
| T = x.size(-1) |
| x = repeat(x, "... d -> ... d e", e=T) |
| mask = torch.tril(torch.ones(T, T, device=x.device, dtype=torch.bool), diagonal=-1) |
| x = x.masked_fill(~mask, 0) |
| x_segsum = torch.cumsum(x, dim=-2) |
| mask = torch.tril(torch.ones(T, T, device=x.device, dtype=torch.bool), diagonal=0) |
| x_segsum = x_segsum.masked_fill(~mask, -torch.inf) |
| return x_segsum |
|
|
| |
| seq_len = x.shape[1] |
| pad_size = chunk_size - (seq_len % chunk_size) |
|
|
| |
| dt = nn.functional.softplus(dt + self.dt_bias) |
| dt = torch.clamp(dt, dt_min, dt_max) |
|
|
| D_residual = self.D.unsqueeze(-1) * pad_by_size(x, pad_size) |
|
|
| |
| x = x * dt.unsqueeze(-1) |
| A = A.to(x.dtype) * dt |
|
|
| |
| x, A, B, C = [ |
| rearrange(pad_by_size(t, pad_size), "b (c l) ... -> b c l ...", l=chunk_size) for t in (x, A, B, C) |
| ] |
|
|
| A = rearrange(A, "b c l h -> b h c l") |
| A_cumsum = torch.cumsum(A, dim=-1) |
|
|
| |
| L = torch.exp(segsum(A)) |
| Y_diag = torch.einsum("bclhn,bcshn,bhcls,bcshp->bclhp", C, B, L, x) |
|
|
| |
| |
| decay_states = torch.exp((A_cumsum[:, :, :, -1:] - A_cumsum)) |
| states = torch.einsum("bclhn,bhcl,bclhp->bchpn", B, decay_states, x) |
|
|
| |
| |
| if initial_states is None: |
| initial_states = torch.zeros_like(states[:, :1]) |
| states = torch.cat([initial_states, states], dim=1) |
| decay_chunk = torch.exp(segsum(nn.functional.pad(A_cumsum[:, :, :, -1], (1, 0)))) |
| new_states = torch.einsum("bhzc,bchpn->bzhpn", decay_chunk, states) |
| states, final_state = new_states[:, :-1], new_states[:, -1] |
|
|
| |
| |
| state_decay_out = torch.exp(A_cumsum) |
| Y_off = torch.einsum("bclhn,bchpn,bhcl->bclhp", C, states, state_decay_out) |
|
|
| |
| y = rearrange(Y_diag + Y_off, "b c l h p -> b (c l) h p") |
|
|
| |
| y = y + D_residual |
|
|
| |
| if pad_size > 0: |
| y = y[:, :seq_len, :, :] |
|
|
| if not return_final_states: |
| return y |
| else: |
| return y, final_state |
|
|
| def _ssd( |
| self, x, B, C, dt, initial_state, return_final_state, use_triton_kernels, cache, cached_start, cached_forward |
| ): |
| |
| A = -torch.exp(self.A_log.float()) |
|
|
| last_state = None |
| if not cached_forward: |
| if use_triton_kernels: |
| y = mamba_chunk_scan_combined( |
| x=rearrange(x, pattern="b l (h p) -> b l h p", p=self.head_dim), |
| dt=dt, |
| A=A, |
| B=rearrange(B, pattern="b l n -> b l 1 n"), |
| C=rearrange(C, pattern="b l n -> b l 1 n"), |
| chunk_size=self.chunk_size, |
| D=self.D, |
| z=None, |
| initial_states=initial_state, |
| dt_bias=self.dt_bias, |
| dt_softplus=True, |
| seq_idx=None, |
| dt_limit=(self.dt_min, self.dt_max), |
| return_final_states=cached_start or return_final_state, |
| ) |
| else: |
| initial_state = rearrange(initial_state, "b n h p -> b 1 n h p") if initial_state is not None else None |
| y = self._ssd_naive( |
| x=rearrange(x, pattern="b l (h p) -> b l h p", p=self.head_dim), |
| dt=dt, |
| A=A, |
| B=rearrange(B, pattern="b l n -> b l 1 n"), |
| C=rearrange(C, pattern="b l n -> b l 1 n"), |
| chunk_size=self.chunk_size, |
| initial_states=initial_state, |
| dt_min=self.dt_min, |
| dt_max=self.dt_max, |
| return_final_states=cached_start or return_final_state, |
| ) |
| if cached_start or return_final_state: |
| y, last_state = y |
| if cached_start: |
| cache.ssm_states[self.layer_idx].copy_(last_state) |
|
|
| y = rearrange(y, "b l h p -> b l (h p)") |
| else: |
| |
| x = x.squeeze(1) |
| B = B.squeeze(1) |
| C = C.squeeze(1) |
| dt = dt.squeeze(1) |
|
|
| if use_triton_kernels: |
| |
| A = repeat(A, "h -> h p n", p=self.head_dim, n=self.ssm_state_size).to(dtype=torch.float32) |
| dt = repeat(dt, "b h -> b h p", p=self.head_dim) |
| dt_bias = repeat(self.dt_bias, "h -> h p", p=self.head_dim) |
| D = repeat(self.D, "h -> h p", p=self.head_dim) |
| x_reshaped = rearrange(x, "b (h p) -> b h p", p=self.head_dim) |
|
|
| |
| y = selective_state_update( |
| state=cache.ssm_states[self.layer_idx], |
| x=x_reshaped, |
| dt=dt, |
| A=A, |
| B=B, |
| C=C, |
| D=D, |
| z=None, |
| dt_bias=dt_bias, |
| dt_softplus=True, |
| ) |
| else: |
| |
| dt = nn.functional.softplus(dt + self.dt_bias.to(dtype=dt.dtype)) |
| |
|
|
| |
| dA = torch.exp(dt * A) |
|
|
| |
| x = rearrange(x, "b (h p) -> b h p", p=self.head_dim) |
| dBx = torch.einsum("bh,bn,bhp->bhpn", dt, B, x) |
|
|
| |
| cache.ssm_states[self.layer_idx].copy_( |
| cache.ssm_states[self.layer_idx] * rearrange(dA, "b h -> b h 1 1") + dBx |
| ) |
|
|
| |
| y = torch.einsum("bhpn,bn->bhp", cache.ssm_states[self.layer_idx], C) |
|
|
| |
| y = y + rearrange(self.D, "h -> h 1") * x |
|
|
| |
| y = rearrange(y, "b h p -> b 1 (h p)") |
|
|
| |
| if return_final_state: |
| last_state = cache.ssm_states[self.layer_idx].clone() |
|
|
| return y, last_state |
|
|
| def _forward( |
| self, |
| hidden_states, |
| use_triton_kernels, |
| initial_state=None, |
| return_final_state=False, |
| cache: Optional[HybridMambaAttentionDynamicCache] = None, |
| ): |
| |
| if cache is not None: |
| cached_start = cache.seq_offset == 0 |
| cached_forward = not cached_start |
| else: |
| cached_start = False |
| cached_forward = False |
|
|
| |
| if initial_state is not None and cached_forward: |
| raise ValueError("Subsequent caching and passing initial states is not possible at the same time!") |
|
|
| |
| zxbcdt = self.in_proj(hidden_states) |
|
|
| |
| if self.training and cache is None and is_fast_path_available and use_triton_kernels: |
| y = mamba_split_conv1d_scan_combined( |
| zxbcdt=zxbcdt, |
| conv1d_weight=rearrange(self.conv1d.weight, "d 1 w -> d w"), |
| conv1d_bias=self.conv1d.bias, |
| dt_bias=self.dt_bias, |
| A=-torch.exp(self.A_log), |
| D=self.D, |
| chunk_size=self.chunk_size, |
| seq_idx=None, |
| activation=self.activation, |
| rmsnorm_weight=self.norm.weight, |
| rmsnorm_eps=self.norm.eps, |
| outproj_weight=self.out_proj.weight, |
| outproj_bias=self.out_proj.bias, |
| headdim=self.head_dim, |
| ngroups=1, |
| norm_before_gate=False, |
| dt_limit=(self.dt_min, self.dt_max), |
| initial_states=initial_state, |
| return_final_states=return_final_state, |
| ) |
| last_state = None |
| if return_final_state: |
| y, last_state = y |
| return y |
|
|
| |
| d_mlp = (zxbcdt.shape[-1] - 2 * self.intermediate_size - 2 * self.ssm_state_size - self.num_heads) // 2 |
| z0, x0, z, xBC, dt = torch.split( |
| zxbcdt, |
| [d_mlp, d_mlp, self.intermediate_size, self.intermediate_size + 2 * self.ssm_state_size, self.num_heads], |
| dim=-1, |
| ) |
|
|
| |
| xBC = self._conv1d( |
| xBC=xBC, |
| seq_len=hidden_states.shape[1], |
| use_triton_kernels=use_triton_kernels, |
| cache=cache, |
| cached_start=cached_start, |
| cached_forward=cached_forward, |
| ) |
|
|
| |
| x, B, C = torch.split(xBC, [self.intermediate_size, self.ssm_state_size, self.ssm_state_size], dim=-1) |
|
|
| |
| y, last_state = self._ssd( |
| x=x, |
| B=B, |
| C=C, |
| dt=dt, |
| initial_state=initial_state, |
| return_final_state=return_final_state, |
| use_triton_kernels=use_triton_kernels, |
| cache=cache, |
| cached_start=cached_start, |
| cached_forward=cached_forward, |
| ) |
|
|
| |
| y = self.norm(y, residual=z) |
| if d_mlp > 0: |
| y = torch.cat([self.act(z0) * x0, y], dim=-1) |
|
|
| |
| y = self.out_proj(y) |
|
|
| return y |
|
|
| def forward( |
| self, hidden_states, initial_state=None, return_final_state=False, cache_params: HybridMambaAttentionDynamicCache = None |
| ): |
| use_triton_kernels = "cuda" in self.in_proj.weight.device.type and self.use_triton_kernels |
|
|
| |
| if use_triton_kernels: |
| if not is_fast_path_available: |
| logger.warning_once( |
| "Faster path is not available because `(causal_conv1d_fn, causal_conv1d_update)` is None. " |
| "Falling back to slower implementation. To install follow https://github.com/Dao-AILab/causal-conv1d" |
| ) |
| else: |
| logger.warning_once( |
| "Fast path is not available because the GPU is not properly utilized. " |
| "Falling back to naive implementation." |
| ) |
| return self._forward(hidden_states, use_triton_kernels, initial_state, return_final_state, cache_params) |
|
|
|
|
| |
| class HybriDNAMLP(nn.Module): |
| def __init__(self, config): |
| super().__init__() |
| self.config = config |
| self.hidden_size = config.hidden_size |
| self.intermediate_size = config.intermediate_size |
| self.gate_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=False) |
| self.up_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=False) |
| self.down_proj = nn.Linear(self.intermediate_size, self.hidden_size, bias=False) |
| self.act_fn = ACT2FN[config.hidden_act] |
|
|
| def forward(self, x): |
| return self.down_proj(self.act_fn(self.gate_proj(x)) * self.up_proj(x)) |
|
|
|
|
| class HybriDNAAttentionDecoderLayer(nn.Module): |
| def __init__(self, config: HybriDNAConfig, layer_idx: int): |
| super().__init__() |
| self.self_attn = HYBRIDNA_ATTENTION_CLASSES[config._attn_implementation](config, layer_idx) |
| self.feed_forward = HybriDNAMLP(config) |
| self.input_layernorm = HybriDNARMSNorm(config.hidden_size, eps=config.rms_norm_eps) |
| self.pre_ff_layernorm = HybriDNARMSNorm(config.hidden_size, eps=config.rms_norm_eps) |
|
|
| def forward(self, hidden_states: torch.Tensor, attention_mask: Optional[torch.Tensor] = None, |
| position_ids: Optional[torch.LongTensor] = None, |
| past_key_value: Optional[HybridMambaAttentionDynamicCache] = None, |
| output_attentions: Optional[bool] = False, |
| use_cache: Optional[bool] = False, cache_position: Optional[torch.LongTensor] = None |
| ) -> Tuple[torch.Tensor, Optional[Tuple[torch.Tensor, torch.Tensor]]]: |
| residual = hidden_states |
| hidden_states = self.input_layernorm(hidden_states) |
|
|
| hidden_states, self_attn_weights, present_key_value = self.self_attn( |
| hidden_states=hidden_states, |
| attention_mask=attention_mask, |
| position_ids=position_ids, |
| past_key_value=past_key_value, |
| output_attentions=output_attentions, |
| use_cache=use_cache, |
| cache_position=cache_position, |
| ) |
|
|
| |
| hidden_states = residual + hidden_states |
|
|
| |
| residual = hidden_states |
| hidden_states = self.pre_ff_layernorm(hidden_states) |
| hidden_states = self.feed_forward(hidden_states) |
| hidden_states = residual + hidden_states |
|
|
| outputs = (hidden_states,) |
| if output_attentions: |
| outputs += (self_attn_weights,) |
| if use_cache: |
| outputs += (present_key_value,) |
| return outputs |
|
|
|
|
| class HybriDNAMambaDecoderLayer(nn.Module): |
| def __init__(self, config: HybriDNAConfig, layer_idx: int): |
| super().__init__() |
| self.mamba = HybriDNAMamba2Mixer(config=config, layer_idx=layer_idx) |
| self.feed_forward = HybriDNAMLP(config) |
| self.input_layernorm = HybriDNARMSNorm(config.hidden_size, eps=config.rms_norm_eps) |
| self.pre_ff_layernorm = HybriDNARMSNorm(config.hidden_size, eps=config.rms_norm_eps) |
|
|
| def forward(self, hidden_states: torch.Tensor, attention_mask: Optional[torch.Tensor] = None, |
| position_ids: Optional[torch.LongTensor] = None, |
| past_key_value: Optional[HybridMambaAttentionDynamicCache] = None, |
| output_attentions: Optional[bool] = False, |
| use_cache: Optional[bool] = False, cache_position: Optional[torch.LongTensor] = None |
| ) -> Tuple[torch.Tensor, Optional[Tuple[torch.Tensor, torch.Tensor]]]: |
| residual = hidden_states |
|
|
| hidden_states = self.input_layernorm(hidden_states) |
|
|
| hidden_states = self.mamba( |
| hidden_states=hidden_states, |
| cache_params=past_key_value, |
| ) |
| self_attn_weights = None |
|
|
| |
| hidden_states = residual + hidden_states |
|
|
| |
| residual = hidden_states |
| hidden_states = self.pre_ff_layernorm(hidden_states) |
| hidden_states = self.feed_forward(hidden_states) |
| hidden_states = residual + hidden_states |
|
|
| outputs = (hidden_states,) |
| if output_attentions: |
| outputs += (self_attn_weights,) |
| if use_cache: |
| outputs += (past_key_value,) |
| return outputs |
|
|
|
|
| HYBRIDNA_START_DOCSTRING = r""" |
| This model inherits from [`PreTrainedModel`]. Check the superclass documentation for the generic methods the |
| library implements for all its model (such as downloading or saving, resizing the input embeddings, pruning heads |
| etc.) |
| This model is also a PyTorch [torch.nn.Module](https://pytorch.org/docs/stable/nn.html#torch.nn.Module) subclass. |
| Use it as a regular PyTorch Module and refer to the PyTorch documentation for all matter related to general usage |
| and behavior. |
| Parameters: |
| config ([`HybriDNAConfig`]): |
| Model configuration class with all the parameters of the model. Initializing with a config file does not |
| load the weights associated with the model, only the configuration. Check out the |
| [`~PreTrainedModel.from_pretrained`] method to load the model weights. |
| """ |
|
|
|
|
| @add_start_docstrings( |
| "The bare HybriDNA Model outputting raw hidden-states without any specific head on top.", |
| HYBRIDNA_START_DOCSTRING, |
| ) |
| class HybriDNAPreTrainedModel(PreTrainedModel): |
| config_class = HybriDNAConfig |
| base_model_prefix = "model" |
| supports_gradient_checkpointing = True |
| _no_split_modules = ["HybriDNAAttentionDecoderLayer", "HybriDNAMambaDecoderLayer"] |
| _skip_keys_device_placement = "past_key_values" |
| _supports_flash_attn_2 = True |
| _supports_sdpa = True |
| _supports_cache_class = True |
|
|
| def _init_weights(self, module): |
| std = self.config.initializer_range |
| if isinstance(module, (nn.Linear, nn.Conv1d)): |
| module.weight.data.normal_(mean=0.0, std=std) |
| if module.bias is not None: |
| module.bias.data.zero_() |
| elif isinstance(module, nn.Embedding): |
| module.weight.data.normal_(mean=0.0, std=std) |
| if module.padding_idx is not None: |
| module.weight.data[module.padding_idx].zero_() |
|
|
|
|
| HYBRIDNA_INPUTS_DOCSTRING = r""" |
| Args: |
| input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`): |
| Indices of input sequence tokens in the vocabulary. Padding will be ignored by default should you provide |
| it. |
| Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and |
| [`PreTrainedTokenizer.__call__`] for details. |
| [What are input IDs?](../glossary#input-ids) |
| attention_mask (`torch.Tensor` of shape `(batch_size, sequence_length)`, *optional*): |
| Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`: |
| - 1 for tokens that are **not masked**, |
| - 0 for tokens that are **masked**. |
| [What are attention masks?](../glossary#attention-mask) |
| Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and |
| [`PreTrainedTokenizer.__call__`] for details. |
| If `past_key_values` is used, optionally only the last `input_ids` have to be input (see |
| `past_key_values`). |
| If you want to change padding behavior, you should read [`modeling_opt._prepare_decoder_attention_mask`] |
| and modify to your needs. See diagram 1 in [the paper](https://arxiv.org/abs/1910.13461) for more |
| information on the default strategy. |
| - 1 indicates the head is **not masked**, |
| - 0 indicates the head is **masked**. |
| position_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*): |
| Indices of positions of each input sequence tokens in the position embeddings. Selected in the range `[0, |
| config.n_positions - 1]`. |
| [What are position IDs?](../glossary#position-ids) |
| past_key_values (`HybridMambaAttentionDynamicCache`, *optional*, returned when `use_cache=True` is passed or when `config.use_cache=True`): |
| A HybridMambaAttentionDynamicCache object containing pre-computed hidden-states (keys and values in the |
| self-attention blocks and convolution and ssm states in the mamba blocks) that can be used (see |
| `past_key_values` input) to speed up sequential decoding. |
| Key and value cache tensors have shape `(batch_size, num_heads, seq_len, head_dim)`. |
| Convolution and ssm states tensors have shape `(batch_size, d_inner, d_conv)` and |
| `(batch_size, d_inner, d_state)` respectively. |
| See the `HybridMambaAttentionDynamicCache` class for more details. |
| If `past_key_values` are used, the user can optionally input only the last `input_ids` (those that |
| don't have their past key value states given to this model) of shape `(batch_size, 1)` instead of all |
| `input_ids` of shape `(batch_size, sequence_length)`. |
| inputs_embeds (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`, *optional*): |
| Optionally, instead of passing `input_ids` you can choose to directly pass an embedded representation. This |
| is useful if you want more control over how to convert `input_ids` indices into associated vectors than the |
| model's internal embedding lookup matrix. |
| use_cache (`bool`, *optional*): |
| If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding (see |
| `past_key_values`). |
| output_attentions (`bool`, *optional*): |
| Whether or not to return the attentions tensors of all attention layers. See `attentions` under returned |
| tensors for more detail. |
| output_hidden_states (`bool`, *optional*): |
| Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors for |
| more detail. |
| return_dict (`bool`, *optional*): |
| Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple. |
| cache_position (`torch.LongTensor` of shape `(sequence_length)`, *optional*): |
| Indices depicting the position of the input sequence tokens in the sequence. Contrarily to `position_ids`, |
| this tensor is not affected by padding. It is used to update the cache in the correct position and to infer |
| the complete sequence length. |
| """ |
|
|
| ALL_DECODER_LAYER_TYPES = {"attention": HybriDNAAttentionDecoderLayer, "mamba": HybriDNAMambaDecoderLayer} |
|
|
|
|
| @add_start_docstrings( |
| "The bare HybriDNA Model outputting raw hidden-states without any specific head on top.", |
| HYBRIDNA_START_DOCSTRING, |
| ) |
| class HybriDNAModel(HybriDNAPreTrainedModel): |
| """ |
| Transformer decoder consisting of *config.num_hidden_layers* layers. Each layer is a [`HybriDNADecoderLayer`] |
| Args: |
| config: HybriDNAConfig |
| """ |
|
|
| def __init__(self, config: HybriDNAConfig): |
| super().__init__(config) |
| self.padding_idx = config.pad_token_id |
| self.vocab_size = config.vocab_size |
|
|
| self.embed_tokens = nn.Embedding(config.vocab_size, config.hidden_size, self.padding_idx) |
| decoder_layers = [] |
| for i in range(config.num_hidden_layers): |
| layer_class = ALL_DECODER_LAYER_TYPES[config.layers_block_type[i]] |
| decoder_layers.append(layer_class(config, layer_idx=i)) |
| self.layers = nn.ModuleList(decoder_layers) |
|
|
| self._attn_implementation = config._attn_implementation |
| self.final_layernorm = HybriDNARMSNorm(config.hidden_size, eps=config.rms_norm_eps) |
|
|
| self.gradient_checkpointing = False |
| |
| self.post_init() |
|
|
| def get_input_embeddings(self): |
| return self.embed_tokens |
|
|
| def set_input_embeddings(self, value): |
| self.embed_tokens = value |
|
|
| @add_start_docstrings_to_model_forward(HYBRIDNA_INPUTS_DOCSTRING) |
| def forward( |
| self, |
| input_ids: torch.LongTensor = None, |
| attention_mask: Optional[torch.Tensor] = None, |
| position_ids: Optional[torch.LongTensor] = None, |
| past_key_values: Optional[HybridMambaAttentionDynamicCache] = None, |
| inputs_embeds: Optional[torch.Tensor] = None, |
| use_cache: Optional[bool] = None, |
| output_attentions: Optional[bool] = None, |
| output_hidden_states: Optional[bool] = None, |
| return_dict: Optional[bool] = None, |
| cache_position: Optional[torch.LongTensor] = None, |
| ) -> Union[Tuple, BaseModelOutputWithPast]: |
| output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions |
| output_hidden_states = ( |
| output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states |
| ) |
| use_cache = use_cache if use_cache is not None else self.config.use_cache |
|
|
| return_dict = return_dict if return_dict is not None else self.config.use_return_dict |
|
|
| if (input_ids is None) ^ (inputs_embeds is not None): |
| raise ValueError( |
| "You cannot specify both input_ids and inputs_embeds at the same time, and must specify either one" |
| ) |
|
|
| if self.gradient_checkpointing and self.training and use_cache: |
| logger.warning_once( |
| "`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`." |
| ) |
| use_cache = False |
|
|
| if inputs_embeds is None: |
| inputs_embeds = self.embed_tokens(input_ids) |
| hidden_states = inputs_embeds |
|
|
| if use_cache and past_key_values is None: |
| logger.warning_once( |
| "HybriDNA requires an initialized `HybridMambaAttentionDynamicCache` to return a cache. None was " |
| "provided, so no cache will be returned." |
| ) |
|
|
| if cache_position is None: |
| past_seen_tokens = past_key_values.seq_offset if past_key_values is not None else 0 |
| cache_position = torch.arange(past_seen_tokens, past_seen_tokens + hidden_states.shape[1], device=hidden_states.device) |
| if position_ids is None: |
| position_ids = cache_position.unsqueeze(0) |
|
|
| causal_mask = self._update_causal_mask(attention_mask, inputs_embeds, cache_position) |
|
|
| all_hidden_states = () if output_hidden_states else None |
| all_self_attns = () if output_attentions else None |
|
|
| for decoder_layer in self.layers: |
| if output_hidden_states: |
| all_hidden_states += (hidden_states,) |
|
|
| if self.gradient_checkpointing and self.training: |
| layer_outputs = self._gradient_checkpointing_func( |
| decoder_layer.__call__, |
| hidden_states, |
| causal_mask, |
| position_ids, |
| past_key_values, |
| output_attentions, |
| use_cache, |
| cache_position, |
| ) |
| else: |
| layer_outputs = decoder_layer( |
| hidden_states, |
| attention_mask=causal_mask, |
| position_ids=position_ids, |
| past_key_value=past_key_values, |
| output_attentions=output_attentions, |
| use_cache=use_cache, |
| cache_position=cache_position, |
| ) |
|
|
| hidden_states = layer_outputs[0] |
|
|
| if output_attentions: |
| if layer_outputs[1] is not None: |
| |
| all_self_attns += (layer_outputs[1],) |
|
|
| hidden_states = self.final_layernorm(hidden_states) |
|
|
| |
| if output_hidden_states: |
| all_hidden_states += (hidden_states,) |
|
|
| if past_key_values and not past_key_values.has_previous_state: |
| past_key_values.has_previous_state = True |
| if past_key_values is not None: |
| past_key_values.seq_offset += hidden_states.shape[1] |
|
|
| next_cache = None if not use_cache else past_key_values |
|
|
| if not return_dict: |
| return tuple( |
| v |
| for v in [hidden_states, next_cache, all_hidden_states, all_self_attns] |
| if v is not None |
| ) |
| return BaseModelOutputWithPast( |
| last_hidden_state=hidden_states, |
| past_key_values=next_cache, |
| hidden_states=all_hidden_states, |
| attentions=all_self_attns, |
| ) |
|
|
| def _update_causal_mask(self, attention_mask, input_tensor, cache_position): |
| if self.config._attn_implementation == "flash_attention_2": |
| if attention_mask is not None and 0.0 in attention_mask: |
| return attention_mask |
| return None |
|
|
| dtype, device = input_tensor.dtype, input_tensor.device |
| min_dtype = torch.finfo(dtype).min |
| sequence_length = input_tensor.shape[1] |
| target_length = cache_position[-1] + 1 |
|
|
| causal_mask = torch.full((sequence_length, target_length), fill_value=min_dtype, dtype=dtype, device=device) |
| if sequence_length != 1: |
| causal_mask = torch.triu(causal_mask, diagonal=1) |
| causal_mask *= torch.arange(target_length, device=device) > cache_position.reshape(-1, 1) |
| causal_mask = causal_mask[None, None, :, :].expand(input_tensor.shape[0], 1, -1, -1) |
| if attention_mask is not None: |
| causal_mask = causal_mask.clone() |
| if attention_mask.dim() == 2: |
| mask_length = attention_mask.shape[-1] |
| padding_mask = causal_mask[..., :mask_length].eq(0.0) * attention_mask[:, None, None, :].eq(0.0) |
| causal_mask[..., :mask_length] = causal_mask[..., :mask_length].masked_fill(padding_mask, min_dtype) |
|
|
| if ( |
| self.config._attn_implementation == "sdpa" |
| and attention_mask is not None |
| and attention_mask.device.type == "cuda" |
| ): |
| |
| |
| |
| causal_mask = AttentionMaskConverter._unmask_unattended(causal_mask, min_dtype) |
|
|
| return causal_mask |
|
|
|
|
| class HybriDNAForCausalLM(HybriDNAPreTrainedModel, GenerationMixin): |
| _tied_weights_keys = ["lm_head.weight"] |
|
|
| def __init__(self, config: HybriDNAConfig): |
| super().__init__(config) |
| self.model = HybriDNAModel(config) |
| self.vocab_size = config.vocab_size |
| self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False) |
| |
| self.post_init() |
|
|
| def get_input_embeddings(self): |
| return self.model.embed_tokens |
|
|
| def set_input_embeddings(self, value): |
| self.model.embed_tokens = value |
|
|
| def get_output_embeddings(self): |
| return self.lm_head |
|
|
| def set_output_embeddings(self, new_embeddings): |
| self.lm_head = new_embeddings |
|
|
| def set_decoder(self, decoder): |
| self.model = decoder |
|
|
| def get_decoder(self): |
| return self.model |
|
|
| @add_start_docstrings_to_model_forward(HYBRIDNA_INPUTS_DOCSTRING) |
| @replace_return_docstrings(output_type=CausalLMOutputWithPast, config_class=_CONFIG_FOR_DOC) |
| def forward( |
| self, |
| input_ids: torch.LongTensor = None, |
| attention_mask: Optional[torch.Tensor] = None, |
| position_ids: Optional[torch.LongTensor] = None, |
| past_key_values: Optional[HybridMambaAttentionDynamicCache] = None, |
| inputs_embeds: Optional[torch.Tensor] = None, |
| labels: Optional[torch.LongTensor] = None, |
| use_cache: Optional[bool] = None, |
| output_attentions: Optional[bool] = None, |
| output_hidden_states: Optional[bool] = None, |
| return_dict: Optional[bool] = None, |
| cache_position: Optional[torch.LongTensor] = None, |
| num_logits_to_keep: Optional[Union[int, None]] = None, |
| ) -> Union[Tuple, CausalLMOutputWithPast]: |
| r""" |
| Args: |
| labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*): |
| Labels for computing the masked language modeling loss. Indices should either be in `[0, ..., |
| config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored |
| (masked), the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`. |
| num_logits_to_keep (`int` or `None`, *optional*): |
| Calculate logits for the last `num_logits_to_keep` tokens. If `None`, calculate logits for all |
| `input_ids`. Only last token logits are needed for generation, and calculating them only for that token |
| can save memory, which becomes pretty significant for long sequences. |
| Returns: |
| Example: |
| ```python |
| >>> from transformers import AutoTokenizer, AutoModelForCausalLM |
| >>> model = AutoModelForCausalLM.from_pretrained("Mishamq/HybriDNA-300M", trust_remote_code=True) |
| >>> tokenizer = AutoTokenizer.from_pretrained("Mishamq/HybriDNA-300M", trust_remote_code=True) |
| >>> prompt = "ACGTACGTACGTACGT" |
| >>> inputs = tokenizer(prompt, return_tensors="pt") |
| >>> # Generate |
| >>> generate_ids = model.generate(inputs.input_ids, max_length=30) |
| >>> tokenizer.batch_decode(generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0] |
| ```""" |
|
|
| output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions |
| output_hidden_states = ( |
| output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states |
| ) |
| return_dict = return_dict if return_dict is not None else self.config.use_return_dict |
|
|
| |
| outputs = self.model( |
| input_ids=input_ids, |
| attention_mask=attention_mask, |
| position_ids=position_ids, |
| past_key_values=past_key_values, |
| inputs_embeds=inputs_embeds, |
| use_cache=use_cache, |
| output_attentions=output_attentions, |
| output_hidden_states=output_hidden_states, |
| cache_position=cache_position, |
| return_dict=return_dict, |
| ) |
|
|
| hidden_states = outputs[0] |
| if num_logits_to_keep is None: |
| logits = self.lm_head(hidden_states) |
| else: |
| logits = self.lm_head(hidden_states[..., -num_logits_to_keep:, :]) |
| logits = logits.float() |
|
|
| loss = None |
| if labels is not None: |
| |
| shift_logits = logits[..., :-1, :].contiguous() |
| shift_labels = labels[..., 1:].contiguous() |
| |
| loss_fct = CrossEntropyLoss() |
| shift_logits = shift_logits.view(-1, self.config.vocab_size) |
| shift_labels = shift_labels.view(-1) |
| |
| shift_labels = shift_labels.to(shift_logits.device) |
| loss = loss_fct(shift_logits, shift_labels) |
|
|
| if not return_dict: |
| output = (logits,) + outputs[1:] |
| return (loss,) + output if loss is not None else output |
|
|
| return CausalLMOutputWithPast( |
| loss=loss, |
| logits=logits, |
| past_key_values=outputs.past_key_values, |
| hidden_states=outputs.hidden_states, |
| attentions=outputs.attentions, |
| ) |
|
|
| def prepare_inputs_for_generation( |
| self, |
| input_ids, |
| past_key_values=None, |
| attention_mask=None, |
| inputs_embeds=None, |
| cache_position=None, |
| **kwargs, |
| ): |
| |
| |
| empty_past_kv = past_key_values is None or not isinstance(past_key_values, HybridMambaAttentionDynamicCache) |
|
|
| |
| if not empty_past_kv: |
| past_length = cache_position[0] if cache_position is not None else attention_mask.shape[1] |
| max_cache_length = self.config.sliding_window |
| |
| |
| |
| |
| if attention_mask is not None and attention_mask.shape[1] > input_ids.shape[1]: |
| input_ids = input_ids[:, -(attention_mask.shape[1] - past_length) :] |
| |
| |
| elif past_length < input_ids.shape[1]: |
| input_ids = input_ids[:, past_length:] |
| |
|
|
| |
| if ( |
| max_cache_length is not None |
| and attention_mask is not None |
| and past_length + input_ids.shape[1] > max_cache_length |
| ): |
| attention_mask = attention_mask[:, -max_cache_length:] |
| else: |
| past_key_values = HybridMambaAttentionDynamicCache( |
| self.config, input_ids.shape[0], self.dtype, device=self.device |
| ) |
|
|
| position_ids = kwargs.get("position_ids", None) |
| if attention_mask is not None and position_ids is None: |
| |
| position_ids = attention_mask.long().cumsum(-1) - 1 |
| position_ids.masked_fill_(attention_mask == 0, 1) |
| if not empty_past_kv: |
| position_ids = position_ids[:, -input_ids.shape[1] :] |
|
|
| |
| if inputs_embeds is not None and empty_past_kv: |
| model_inputs = {"inputs_embeds": inputs_embeds} |
| else: |
| model_inputs = {"input_ids": input_ids} |
|
|
| model_inputs.update( |
| { |
| "position_ids": position_ids, |
| "past_key_values": past_key_values, |
| "use_cache": kwargs.get("use_cache"), |
| "attention_mask": attention_mask, |
| "num_logits_to_keep": self.config.num_logits_to_keep, |
| "cache_position": cache_position, |
| } |
| ) |
| return model_inputs |
|
|
|
|
| @add_start_docstrings( |
| """ |
| The HybriDNA Model with a sequence classification head on top (linear layer). |
| [`HybriDNAForSequenceClassification`] uses the last token in order to do the classification, as other causal models |
| (e.g. GPT-2) do. |
| Since it does classification on the last token, it requires to know the position of the last token. If a |
| `pad_token_id` is defined in the configuration, it finds the last token that is not a padding token in each row. If |
| no `pad_token_id` is defined, it simply takes the last value in each row of the batch. Since it cannot guess the |
| padding tokens when `inputs_embeds` are passed instead of `input_ids`, it does the same (take the last value in |
| each row of the batch). |
| """, |
| HYBRIDNA_START_DOCSTRING, |
| ) |
| class HybriDNAForSequenceClassification(HybriDNAPreTrainedModel): |
| def __init__(self, config): |
| super().__init__(config) |
| self.num_labels = config.num_labels |
| self.model = HybriDNAModel(config) |
| self.score = nn.Linear(config.hidden_size, self.num_labels, bias=False) |
| self.dropout = nn.Dropout(0) |
| |
| self.post_init() |
|
|
| def get_input_embeddings(self): |
| return self.model.embed_tokens |
|
|
| def set_input_embeddings(self, value): |
| self.model.embed_tokens = value |
|
|
| @add_start_docstrings_to_model_forward(HYBRIDNA_INPUTS_DOCSTRING) |
| def forward( |
| self, |
| input_ids: torch.LongTensor = None, |
| attention_mask: Optional[torch.Tensor] = None, |
| position_ids: Optional[torch.LongTensor] = None, |
| past_key_values: Optional[List[torch.Tensor]] = None, |
| inputs_embeds: Optional[torch.Tensor] = None, |
| labels: Optional[torch.LongTensor] = None, |
| use_cache: Optional[bool] = None, |
| output_attentions: Optional[bool] = None, |
| output_hidden_states: Optional[bool] = None, |
| return_dict: Optional[bool] = None, |
| ) -> Union[Tuple, SequenceClassifierOutputWithPast]: |
| r""" |
| labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*): |
| Labels for computing the sequence classification/regression loss. Indices should be in `[0, ..., |
| config.num_labels - 1]`. If `config.num_labels == 1` a regression loss is computed (Mean-Square loss), If |
| `config.num_labels > 1` a classification loss is computed (Cross-Entropy). |
| """ |
| return_dict = return_dict if return_dict is not None else self.config.use_return_dict |
|
|
| transformer_outputs = self.model( |
| input_ids, |
| attention_mask=attention_mask, |
| position_ids=position_ids, |
| past_key_values=past_key_values, |
| inputs_embeds=inputs_embeds, |
| use_cache=use_cache, |
| output_attentions=output_attentions, |
| output_hidden_states=output_hidden_states, |
| return_dict=return_dict, |
| ) |
| hidden_states = transformer_outputs[0] |
| hidden_states = self.dropout(hidden_states) |
| logits = self.score(hidden_states) |
|
|
| if input_ids is not None: |
| batch_size = input_ids.shape[0] |
| else: |
| batch_size = inputs_embeds.shape[0] |
|
|
| if self.config.pad_token_id is None and batch_size != 1: |
| raise ValueError("Cannot handle batch sizes > 1 if no padding token is defined.") |
| if self.config.pad_token_id is None: |
| sequence_lengths = -1 |
| else: |
| if input_ids is not None: |
| |
| sequence_lengths = torch.eq(input_ids, self.config.pad_token_id).int().argmax(-1) - 1 |
| sequence_lengths = sequence_lengths % input_ids.shape[-1] |
| sequence_lengths = sequence_lengths.to(logits.device) |
| else: |
| sequence_lengths = -1 |
|
|
| pooled_logits = logits[torch.arange(batch_size, device=logits.device), sequence_lengths] |
|
|
| loss = None |
| if labels is not None: |
| labels = labels.to(logits.device) |
| if self.config.problem_type is None: |
| if self.num_labels == 1: |
| self.config.problem_type = "regression" |
| elif self.num_labels > 1 and (labels.dtype == torch.long or labels.dtype == torch.int): |
| self.config.problem_type = "single_label_classification" |
| else: |
| self.config.problem_type = "multi_label_classification" |
|
|
| if self.config.problem_type == "regression": |
| loss_fct = MSELoss() |
| if self.num_labels == 1: |
| loss = loss_fct(pooled_logits.squeeze(), labels.squeeze()) |
| else: |
| loss = loss_fct(pooled_logits, labels) |
| elif self.config.problem_type == "single_label_classification": |
| loss_fct = CrossEntropyLoss() |
| loss = loss_fct(pooled_logits.view(-1, self.num_labels), labels.view(-1)) |
| elif self.config.problem_type == "multi_label_classification": |
| loss_fct = BCEWithLogitsLoss() |
| loss = loss_fct(pooled_logits, labels) |
| if not return_dict: |
| output = (pooled_logits,) + transformer_outputs[1:] |
| return ((loss,) + output) if loss is not None else output |
|
|
| return SequenceClassifierOutputWithPast( |
| loss=loss, |
| logits=pooled_logits, |
| past_key_values=transformer_outputs.past_key_values, |
| hidden_states=transformer_outputs.hidden_states, |
| attentions=transformer_outputs.attentions, |
| ) |
|
|
| @add_start_docstrings( |
| """ |
| The HybriDNA Model with a sequence classification head on top (linear layer along with RC Echo Embedding). |
| The input sequence is concatenated with its reverse complement before being processed by the model. |
| [`HybriDNAForSequenceClassificationRCEcho`] |
| """, |
| HYBRIDNA_START_DOCSTRING, |
| ) |
| class HybriDNAForSequenceClassificationRCEcho(HybriDNAPreTrainedModel): |
| def __init__(self, config): |
| super().__init__(config) |
| self.num_labels = config.num_labels |
| self.model = HybriDNAModel(config) |
| self.score = nn.Linear(config.hidden_size, self.num_labels, bias=False) |
|
|
| |
| self.post_init() |
| self.dropout = nn.Dropout(0.05) |
|
|
| def get_input_embeddings(self): |
| return self.model.embed_tokens |
|
|
| def set_input_embeddings(self, value): |
| self.model.embed_tokens = value |
|
|
| def _reverse_complement_tokens(self, input_ids: torch.Tensor) -> torch.Tensor: |
| """ |
| Reverse complement DNA token IDs. |
| Token mapping: A=7, C=8, G=9, T=10, N=11 |
| Complement: A(7)↔T(10), C(8)↔G(9), N(11)→N(11) |
| Special tokens (0-6) are preserved as-is. |
| """ |
| rc = input_ids.clone() |
| |
| is_A = (input_ids == 7) |
| is_T = (input_ids == 10) |
| rc[is_A] = 10 |
| rc[is_T] = 7 |
| |
| is_C = (input_ids == 8) |
| is_G = (input_ids == 9) |
| rc[is_C] = 9 |
| rc[is_G] = 8 |
| |
| |
| rc = torch.flip(rc, dims=[1]) |
| return rc |
|
|
| @add_start_docstrings_to_model_forward(HYBRIDNA_INPUTS_DOCSTRING) |
| def forward( |
| self, |
| input_ids: torch.LongTensor = None, |
| attention_mask: Optional[torch.Tensor] = None, |
| position_ids: Optional[torch.LongTensor] = None, |
| past_key_values: Optional[List[torch.Tensor]] = None, |
| inputs_embeds: Optional[torch.Tensor] = None, |
| labels: Optional[torch.LongTensor] = None, |
| use_cache: Optional[bool] = None, |
| output_attentions: Optional[bool] = None, |
| output_hidden_states: Optional[bool] = None, |
| return_dict: Optional[bool] = None, |
| ) -> Union[Tuple, SequenceClassifierOutputWithPast]: |
| return_dict = return_dict if return_dict is not None else self.config.use_return_dict |
|
|
| |
| if input_ids is not None: |
| |
| non_masked_tokens = (input_ids != self.config.pad_token_id).int() |
|
|
| |
| masked_input_ids = input_ids * non_masked_tokens |
| rc_input_ids = self._reverse_complement_tokens(masked_input_ids) |
|
|
| |
| repeated_input_ids = torch.cat([masked_input_ids, rc_input_ids], dim=1) |
|
|
| |
| if attention_mask is not None: |
| repeated_attention_mask = torch.cat([attention_mask, attention_mask], dim=1) |
| else: |
| repeated_attention_mask = None |
|
|
| input_ids = repeated_input_ids |
| attention_mask = repeated_attention_mask |
|
|
| |
| transformer_outputs = self.model( |
| input_ids, |
| attention_mask=attention_mask, |
| position_ids=position_ids, |
| past_key_values=past_key_values, |
| inputs_embeds=inputs_embeds, |
| use_cache=use_cache, |
| output_attentions=output_attentions, |
| output_hidden_states=output_hidden_states, |
| return_dict=return_dict, |
| ) |
| hidden_states = transformer_outputs[0] |
|
|
| |
| sequence_length = hidden_states.shape[1] // 2 |
| second_half_hidden_states = hidden_states[:, sequence_length:, :] |
| if attention_mask is not None: |
| second_half_attention_mask = attention_mask[:, sequence_length:] |
| sum_hidden_states = (second_half_hidden_states * second_half_attention_mask.unsqueeze(-1)).sum(dim=1) |
| sum_mask = second_half_attention_mask.sum(dim=1, keepdim=True) |
| mean_hidden_states = sum_hidden_states / sum_mask |
| else: |
| mean_hidden_states = second_half_hidden_states.mean(dim=1) |
|
|
| |
| mean_hidden_states = self.dropout(mean_hidden_states) |
| logits = self.score(mean_hidden_states) |
|
|
| loss = None |
| if labels is not None: |
| labels = labels.to(logits.device) |
| if self.config.problem_type is None: |
| if self.num_labels == 1: |
| self.config.problem_type = "regression" |
| elif self.num_labels > 1 and (labels.dtype == torch.long or labels.dtype == torch.int): |
| self.config.problem_type = "single_label_classification" |
| else: |
| self.config.problem_type = "multi_label_classification" |
|
|
| if self.config.problem_type == "regression": |
| loss_fct = MSELoss() |
| if self.num_labels == 1: |
| loss = loss_fct(logits.squeeze(), labels.squeeze()) |
| else: |
| loss = loss_fct(logits, labels) |
| elif self.config.problem_type == "single_label_classification": |
| loss_fct = CrossEntropyLoss() |
| loss = loss_fct(logits.view(-1, self.num_labels), labels.view(-1)) |
| elif self.config.problem_type == "multi_label_classification": |
| loss_fct = BCEWithLogitsLoss() |
| loss = loss_fct(logits, labels) |
| if not return_dict: |
| output = (logits,) + transformer_outputs[1:] |
| return ((loss,) + output) if loss is not None else output |
|
|
| return SequenceClassifierOutputWithPast( |
| loss=loss, |
| logits=logits, |
| past_key_values=transformer_outputs.past_key_values, |
| hidden_states=transformer_outputs.hidden_states, |
| attentions=transformer_outputs.attentions, |
| ) |
|
|