| | from typing import Optional, Tuple |
| | from dataclasses import dataclass |
| | import torch |
| | import torch.nn as nn |
| |
|
| | from transformers.modeling_outputs import ( |
| | SequenceClassifierOutput, |
| | ) |
| |
|
| | from typing import Optional, Tuple |
| |
|
| | import torch |
| | import torch.utils.checkpoint |
| | from torch import nn |
| |
|
| | from dataclasses import dataclass |
| | from transformers.activations import ACT2FN, ACT2CLS |
| | from transformers.modeling_utils import PreTrainedModel |
| | from transformers.utils import logging |
| | from transformers.modeling_outputs import BaseModelOutputWithPast, BaseModelOutput, CausalLMOutputWithPast |
| | from transformers.modeling_attn_mask_utils import _prepare_4d_causal_attention_mask |
| | import xformers.ops as xops |
| |
|
| | from collections import OrderedDict |
| |
|
| | logger = logging.get_logger(__name__) |
| |
|
| | import torch |
| | import torch.nn as nn |
| | from einops import rearrange, einsum |
| | from transformers.pytorch_utils import Conv1D |
| |
|
| |
|
| | import torch |
| | from torch.amp import autocast |
| | from torch import nn, einsum, Tensor |
| |
|
| | from einops import rearrange, repeat |
| | from typing import Optional, Union |
| |
|
| | from .configuration_decodon import DeCodonConfig |
| |
|
| | logger = logging.get_logger(__name__) |
| |
|
| |
|
| | def rotate_half(x): |
| | x = rearrange(x, "... (d r) -> ... d r", r=2) |
| | x1, x2 = x.unbind(dim=-1) |
| | x = torch.stack((-x2, x1), dim=-1) |
| | return rearrange(x, "... d r -> ... (d r)") |
| |
|
| |
|
| | @autocast(device_type="cuda", enabled=False) |
| | def apply_rotary_emb(freqs, t, start_index=0, scale=1.0): |
| | """ |
| | Applies rotary embeddings to a tensor. |
| | |
| | Parameters |
| | ---------- |
| | freqs : Tensor |
| | The frequencies to apply to the tensor: (seq_len, dim) |
| | t : Tensor |
| | The tensor to apply the rotary embeddings to: (..., seq_len, n_heads, dim) |
| | start_index : int |
| | The starting index to apply the rotary embeddings. (default: 0) |
| | scale : float |
| | The scale to apply to the rotary embeddings. (default: 1.0) |
| | |
| | Returns |
| | ------- |
| | Tensor |
| | The tensor with the rotary embeddings applied.: (..., seq_len, n_heads, dim) |
| | |
| | """ |
| | |
| | |
| | |
| |
|
| | rot_dim = freqs.shape[-1] |
| | end_index = start_index + rot_dim |
| |
|
| | assert ( |
| | rot_dim <= t.shape[-1] |
| | ), f"feature dimension {t.shape[-1]} is not of sufficient size to rotate in all the positions {rot_dim}" |
| |
|
| | t_left, t, t_right = ( |
| | t[..., :start_index], |
| | t[..., start_index:end_index], |
| | t[..., end_index:], |
| | ) |
| | if isinstance(scale, float): |
| | scale = torch.tensor(scale, device=t.device, dtype=t.dtype) |
| |
|
| | t = (t * freqs.cos() * scale) + (rotate_half(t) * freqs.sin() * scale) |
| | return torch.cat((t_left, t, t_right), dim=-1) |
| |
|
| |
|
| | |
| | def apply_learned_rotations(rotations, t, start_index=0, freq_ranges=None): |
| | if freq_ranges is not None: |
| | rotations = einsum("..., f -> ... f", rotations, freq_ranges) |
| | rotations = rearrange(rotations, "... r f -> ... (r f)") |
| |
|
| | rotations = repeat(rotations, "... n -> ... (n r)", r=2) |
| | return apply_rotary_emb(rotations, t, start_index=start_index) |
| |
|
| |
|
| | """ |
| | Inspired from https://github.com/lucidrains/rotary-embedding-torch |
| | """ |
| |
|
| | class RotaryEmbedding(nn.Module): |
| | """ |
| | Rotary Embeddings Implemenetation inspired by https://github.com/lucidrains/rotary-embedding-torch. |
| | |
| | Rotary Positional Embeddings (RoPE) encode position information of tokens with a |
| | rotation matrix that naturally incorporates explicit relative position dependency. |
| | |
| | Parameters |
| | ---------- |
| | emb_dim : int |
| | Embedding dimension. Usually set to the dim of each head in the attention module. |
| | freqs : Optional[Tensor] |
| | Custom frequencies to apply to query/key tensors. (default: None) |
| | theta : float |
| | Base constant used for computing rotation angles. |
| | learned_freq : bool (default: False) |
| | Whether to learn the frequencies. |
| | use_xpos : bool (default: False) |
| | Whether to employ XPos technique for resolving length extrapolation issue. |
| | NOTE: This can only be enabled for autoregressive models like GPT. |
| | xpos_scale_base : int (default: 512) |
| | The base for the scale factor used in XPos technique. |
| | interpolate_factor : float (default: 1.0) |
| | Length interpolation factor for extending context length of the pretrained model. |
| | Final model's context length = pretrained_model_context_length * interpolate_factor. |
| | |
| | theta_rescale_factor : float (default: 1.0) |
| | The factor to rescale the theta. |
| | |
| | cache_if_possible : bool (default: True) |
| | Whether to cache the frequencies/scales if possible. |
| | |
| | """ |
| |
|
| | def __init__( |
| | self, |
| | emb_dim, |
| | freqs: Optional[Tensor] = None, |
| | theta=1e4, |
| | learned_freq=False, |
| | use_xpos=False, |
| | xpos_scale_base=512, |
| | interpolate_factor=1.0, |
| | theta_rescale_factor=1.0, |
| | cache_if_possible=True, |
| | ): |
| | super().__init__() |
| | |
| | |
| | |
| |
|
| | theta *= theta_rescale_factor ** (emb_dim / (emb_dim - 2)) |
| |
|
| | if freqs is None: |
| | freqs = 1.0 / ( |
| | theta |
| | ** (torch.arange(0, emb_dim, 2)[: (emb_dim // 2)].float() / emb_dim) |
| | ) |
| | |
| |
|
| | self.cache_if_possible = cache_if_possible |
| |
|
| | self.register_buffer("cached_freqs", None, persistent=False) |
| | self.register_buffer("cached_scales", None, persistent=False) |
| |
|
| | self.freqs = nn.Parameter(freqs, requires_grad=learned_freq) |
| |
|
| | self.learned_freq = learned_freq |
| |
|
| | |
| |
|
| | assert interpolate_factor >= 1.0 |
| | self.interpolate_factor = interpolate_factor |
| |
|
| | |
| | self.use_xpos = use_xpos |
| | if not use_xpos: |
| | self.register_buffer("scale", None, persistent=False) |
| | return |
| |
|
| | scale = (torch.arange(0, emb_dim, 2) + 0.4 * emb_dim) / (1.4 * emb_dim) |
| | self.scale_base = xpos_scale_base |
| | self.register_buffer("scale", scale, persistent=False) |
| |
|
| | @property |
| | def device(self): |
| | return self.freqs.device |
| |
|
| | def rotate_queries_or_keys(self, t, offset=0, freq_seq_len=None, scale=None): |
| | """ |
| | Parameters |
| | ---------- |
| | t : Tensor |
| | tensor to rotate: (batch_size, seq_len, num_heads, head_dim) |
| | """ |
| | seq_len = t.shape[1] |
| | assert ( |
| | not self.use_xpos or scale is not None |
| | ), "you must use `.rotate_queries_and_keys` method instead and pass in both queries and keys, for length extrapolatable rotary embeddings" |
| |
|
| | if freq_seq_len is not None: |
| | assert freq_seq_len >= seq_len |
| | seq_len = freq_seq_len |
| |
|
| | seq = ( |
| | torch.arange(seq_len, device=t.device, dtype=t.dtype) + offset |
| | ) / self.interpolate_factor |
| |
|
| | freqs = self.forward( |
| | seq, |
| | seq_len=seq_len, |
| | offset=offset, |
| | ).to(t.dtype) |
| |
|
| | freqs = rearrange(freqs, "n d -> n 1 d") |
| |
|
| | if scale is not None: |
| | scale = rearrange(scale, "n d -> n 1 d") |
| |
|
| | if scale is None: |
| | scale = torch.tensor(1.0, device=t.device, dtype=t.dtype) |
| |
|
| | return apply_rotary_emb(freqs, t, scale=scale) |
| |
|
| | def rotate_queries_and_keys(self, q, k): |
| | """ |
| | Parameters |
| | ---------- |
| | q : Tensor |
| | queries tensor: (batch_size, seq_len, num_heads, head_dim) |
| | k : Tensor |
| | keys tensor: (batch_size, seq_len, num_heads, head_dim) |
| | """ |
| | assert self.use_xpos |
| | seq_len = q.shape[-3] |
| |
|
| | seq = ( |
| | torch.arange(seq_len, device=q.device, dtype=q.dtype) |
| | ) / self.interpolate_factor |
| |
|
| | freqs = self.forward(seq, seq_len=seq_len) |
| | scale = self.get_scale(seq, seq_len=seq_len) |
| |
|
| | freqs = rearrange(freqs, "n d -> n 1 d") |
| | scale = rearrange(scale, "n d -> n 1 d") |
| |
|
| | rotated_q = apply_rotary_emb(freqs, q, scale=scale) |
| | rotated_k = apply_rotary_emb(freqs, k, scale=scale**-1) |
| |
|
| | rotated_q = rotated_q.type(q.dtype) |
| | rotated_k = rotated_k.type(k.dtype) |
| |
|
| | return rotated_q, rotated_k |
| |
|
| | def get_scale(self, t: Tensor, seq_len: Optional[int] = None, offset=0): |
| | assert self.use_xpos |
| |
|
| | should_cache = self.cache_if_possible and seq_len is not None |
| |
|
| | if ( |
| | should_cache |
| | and self.cached_scales is not None |
| | and (seq_len + offset) <= self.cached_scales.shape[0] |
| | ): |
| | return self.cached_scales[offset : (offset + seq_len)] |
| |
|
| | scale = 1.0 |
| | if self.use_xpos: |
| | power = (t - len(t) // 2) / self.scale_base |
| | scale = self.scale ** rearrange(power, "n -> n 1") |
| | scale = torch.cat((scale, scale), dim=-1) |
| |
|
| | if should_cache: |
| | self.register_buffer("cached_scales", scale, persistent=False) |
| |
|
| | return scale |
| |
|
| | def rotate_queries_with_cached_keys(self, q, k, offset=0): |
| | q_len, k_len = q.shape[1], k.shape[1] |
| | assert q_len <= k_len |
| |
|
| | rotated_q, rotated_k = self.rotate_queries_and_keys(q, k) |
| |
|
| | rotated_q = rotated_q[:, -1:, ...] |
| |
|
| | return rotated_q, rotated_k |
| |
|
| | seq = ( |
| | torch.arange(k_len, device=q.device, dtype=q.dtype) |
| | ) / self.interpolate_factor |
| |
|
| | if self.use_xpos: |
| | q_scale = self.get_scale(seq[-q_len:]).to(q.dtype) |
| | k_scale = self.get_scale(seq).to(k.dtype) |
| |
|
| | else: |
| | k_scale = 1.0 |
| | q_scale = 1.0 |
| |
|
| | rotated_q = self.rotate_queries_or_keys( |
| | q, scale=q_scale, offset=k_len - q_len + offset |
| | ) |
| | rotated_k = self.rotate_queries_or_keys(k, scale=k_scale**-1) |
| |
|
| | return rotated_q, rotated_k |
| |
|
| | @autocast(device_type="cuda", enabled=False) |
| | def forward(self, t: Tensor, seq_len=None, offset=0): |
| | should_cache = ( |
| | self.cache_if_possible and not self.learned_freq and seq_len is not None |
| | ) |
| |
|
| | if ( |
| | should_cache |
| | and self.cached_freqs is not None |
| | and (offset + seq_len) <= self.cached_freqs.shape[0] |
| | ): |
| | return self.cached_freqs[offset : (offset + seq_len)].detach() |
| |
|
| | freqs = self.freqs |
| |
|
| | freqs = einsum("..., f -> ... f", t, freqs) |
| | freqs = repeat(freqs, "... n -> ... (n r)", r=2) |
| |
|
| | if should_cache: |
| | self.register_buffer("cached_freqs", freqs.detach(), persistent=False) |
| |
|
| | return freqs |
| |
|
| |
|
| |
|
| | class MultiHeadedSelfAttention(nn.Module): |
| | """ |
| | Multi-Headed Self Attention module supported with Flash Attention and Rotary Embeddings. |
| | |
| | Parameters |
| | ---------- |
| | q_input_dim: int |
| | The input dimension of the query tensor. |
| | kv_input_dim: int |
| | The input dimension of the key and value tensors. |
| | qk_proj_dim: int |
| | The projected dimension of the query and key tensors. |
| | v_proj_dim: int |
| | The projected dimension of the value tensors. |
| | num_heads: int |
| | Number of attention heads. |
| | dropout: float |
| | Dropout rate to apply to the attention scores. |
| | projection_layer: str |
| | The type of projection layer to use. Either 'linear' or 'conv'. |
| | Basically both are linear projections, but 'conv' uses Conv1D layer as proposed in the original GPT2 paper. |
| | use_flash_attn: bool |
| | Whether to use Flash Attention or not. If True, Flash Attention will be used. |
| | NOTE: Flash Attention is required to be installed. |
| | use_rotary_emb: bool |
| | Whether to use Rotary Embeddings or not. |
| | rotary_theta: int |
| | The base for the geometric progression used to compute the rotation angles. |
| | rotary_use_xpos: bool |
| | Whether to use XPos technique for resolving length extrapolation issue. |
| | NOTE: This can only be enabled for autoregressive models like GPT. |
| | """ |
| |
|
| | def __init__( |
| | self, |
| | q_input_dim, |
| | kv_input_dim, |
| | qk_proj_dim, |
| | v_proj_dim, |
| | num_heads, |
| | dropout: float = 0.0, |
| | projection_layer: str = "linear", |
| | use_flash_attn: bool = True, |
| | use_rotary_emb: bool = False, |
| | rotary_theta: int = 1e4, |
| | rotary_use_xpos: bool = False, |
| | is_cross_attention: bool = False, |
| | **kwargs, |
| | ): |
| | super().__init__() |
| | assert ( |
| | qk_proj_dim % num_heads == 0 |
| | ), "qk_proj_dim must be divisible by num_heads" |
| | assert v_proj_dim % num_heads == 0, "v_proj_dim must be divisible by num_heads" |
| |
|
| | self.num_heads = num_heads |
| | self.dropout_rate = dropout |
| | self.projection_layer = projection_layer |
| | self.use_rotary_emb = use_rotary_emb |
| | self.is_cross_attention = is_cross_attention |
| |
|
| | if use_flash_attn and not is_cross_attention: |
| | try: |
| | from flash_attn import flash_attn_qkvpacked_func |
| |
|
| | self.use_flash_attn = True |
| | self.flashattn_fn = flash_attn_qkvpacked_func |
| | except ImportError: |
| | print("flash_attn not installed, reverting to default attention") |
| | self.use_flash_attn = False |
| | self.flashattn_fn = None |
| | else: |
| | self.use_flash_attn = False |
| | self.flashattn_fn = None |
| |
|
| | if self.projection_layer == "linear": |
| | self.query = nn.Linear(q_input_dim, qk_proj_dim) |
| | self.key = nn.Linear(kv_input_dim, qk_proj_dim) |
| | self.value = nn.Linear(kv_input_dim, v_proj_dim) |
| | elif self.projection_layer == "conv": |
| | self.query = Conv1D(qk_proj_dim, q_input_dim) |
| | self.key = Conv1D(qk_proj_dim, kv_input_dim) |
| | self.value = Conv1D(v_proj_dim, kv_input_dim) |
| | else: |
| | raise ValueError( |
| | f"projection_layer must be either 'linear' or 'conv', got {projection_layer}" |
| | ) |
| |
|
| | if self.use_rotary_emb: |
| | self.rotary_emb = RotaryEmbedding( |
| | emb_dim=qk_proj_dim // num_heads // 2, |
| | theta=rotary_theta, |
| | use_xpos=rotary_use_xpos, |
| | ) |
| |
|
| | self.dr_rate = dropout |
| | self.dropout = nn.Dropout(dropout) |
| |
|
| | def forward( |
| | self, |
| | x_q, |
| | x_kv, |
| | is_causal=False, |
| | attention_bias=None, |
| | attention_mask=None, |
| | output_attentions=False, |
| | query=None, |
| | key=None, |
| | value=None, |
| | use_cache=False, |
| | ): |
| | """ |
| | Applies a classical self attention operation. |
| | |
| | Parameters |
| | ---------- |
| | x_q: torch.Tensor |
| | The query tensor of shape (batch_size, query_seq_len, emb_dim) |
| | x_kv: torch.Tensor |
| | The key/value tensor of shape (batch_size, kv_seq_len, emb_dim) |
| | attention_bias: torch.Tensor |
| | The attention bias to apply to the attention scores. (default: None) |
| | attention_mask: torch.Tensor |
| | The attention mask to apply to the attention scores. Shape: (batch_size, q_len, kv_seq_len) |
| | """ |
| | assert (x_q is not None and x_kv is not None) or ( |
| | query is not None and key is not None and value is not None |
| | ), "Either x_q and x_kv or query, key and value must be provided" |
| |
|
| | past_memory_provided = ( |
| | query is not None and key is not None and value is not None |
| | ) |
| |
|
| | if query is None: |
| | q_len = x_q.size(1) |
| | k_len = x_kv.size(1) |
| |
|
| | query = self.query(x_q) |
| | key = self.key(x_kv) |
| | value = self.value(x_kv) |
| |
|
| | else: |
| | q_len = query.size(1) |
| | k_len = key.size(1) |
| |
|
| | if use_cache: |
| | cache = (key.clone(), value.clone(), query.clone()) |
| |
|
| | q = rearrange(query, "b q (h d) -> b q h d", h=self.num_heads) |
| | k = rearrange(key, "b k (h d) -> b k h d", h=self.num_heads) |
| | v = rearrange(value, "b v (h d) -> b v h d", h=self.num_heads) |
| |
|
| | if self.use_rotary_emb: |
| | if use_cache and past_memory_provided: |
| | q, k = self.rotary_emb.rotate_queries_with_cached_keys(q, k) |
| | if self.rotary_emb.use_xpos: |
| | q, k = self.rotary_emb.rotate_queries_and_keys(q, k) |
| | else: |
| | q = self.rotary_emb.rotate_queries_or_keys(q) |
| | k = self.rotary_emb.rotate_queries_or_keys(k) |
| |
|
| | if ( |
| | self.use_flash_attn |
| | and not use_cache |
| | and not output_attentions |
| | and attention_bias is None |
| | ): |
| | qkv = torch.stack([q, k, v], dim=2).to(torch.bfloat16) |
| | x = self.flashattn_fn( |
| | qkv=qkv, |
| | dropout_p=self.dropout_rate if self.training else 0.0, |
| | causal=is_causal, |
| | deterministic=False, |
| | return_attn_probs=False, |
| | ) |
| |
|
| | x = x.to(x_q.dtype) |
| | elif self.use_flash_attn and not output_attentions: |
| | attn_bias = xops.LowerTriangularMask() if is_causal else attention_bias |
| |
|
| | if attention_mask is not None: |
| | if attn_bias is None: |
| | attn_bias = attention_mask |
| | else: |
| | if isinstance(attn_bias, torch.Tensor): |
| | attn_bias = attn_bias + attention_mask |
| | else: |
| | attn_bias.add_bias(bias=attention_mask) |
| |
|
| | attn_bias = attn_bias.materialize( |
| | shape=(q_len, k_len), |
| | device=q.device, |
| | dtype=q.dtype, |
| | ) |
| | else: |
| | if isinstance(attn_bias, torch.Tensor) and len(attn_bias.shape) == 3: |
| | attn_bias = ( |
| | attn_bias.unsqueeze(1) |
| | .expand(-1, self.num_heads, -1, -1) |
| | .float() |
| | ) |
| | else: |
| | attn_bias = attn_bias.materialize( |
| | shape=(q_len, k_len), |
| | device=q.device, |
| | dtype=q.dtype, |
| | ) |
| |
|
| | if isinstance(attn_bias, xops.LowerTriangularMask): |
| | attn_bias = attn_bias.materialize( |
| | shape=(q_len, k_len), |
| | device=q.device, |
| | dtype=q.dtype, |
| | ) |
| |
|
| | |
| | |
| |
|
| | need_adjustment = False |
| | if attn_bias.shape[-2] % 8 != 0: |
| | nearest_multiple_q = 8 * (1 + attn_bias.shape[-2] // 8) |
| | need_adjustment = True |
| | else: |
| | nearest_multiple_q = attn_bias.shape[-2] |
| |
|
| | if attn_bias.shape[-1] % 8 != 0: |
| | nearest_multiple_k = 8 * (1 + attn_bias.shape[-1] // 8) |
| | need_adjustment = True |
| | else: |
| | nearest_multiple_k = attn_bias.shape[-1] |
| |
|
| | if need_adjustment: |
| | new_attn_bias = torch.zeros( |
| | attn_bias.shape[0], |
| | attn_bias.shape[1], |
| | nearest_multiple_q, |
| | nearest_multiple_k, |
| | ).to(attn_bias.device) |
| | new_attn_bias[:, :, : attn_bias.shape[-2], : attn_bias.shape[-1]] = ( |
| | attn_bias |
| | ) |
| |
|
| | x = xops.memory_efficient_attention( |
| | query=q, |
| | key=k, |
| | value=v, |
| | op=None, |
| | attn_bias=new_attn_bias[:, :, :q_len, :k_len], |
| | p=self.dr_rate, |
| | ) |
| | else: |
| | attn_bias = attn_bias.to(q.dtype) |
| | attn_bias = attn_bias.repeat(1, self.num_heads, 1, 1) |
| | x = xops.memory_efficient_attention( |
| | query=q, |
| | key=k, |
| | value=v, |
| | op=None, |
| | attn_bias=attn_bias, |
| | p=self.dr_rate, |
| | ) |
| | |
| | else: |
| | |
| | attention_scores = einsum(q, k, "b q h d, b k h d -> b h q k") |
| | attention_scores = attention_scores / (q.size(-1) ** 0.5) |
| |
|
| | if attention_bias is not None: |
| | attn_bias = attention_bias.unsqueeze(1).expand( |
| | -1, self.num_heads, -1, -1 |
| | ) |
| | |
| | |
| | |
| | |
| | else: |
| | attn_bias = None |
| |
|
| | if attention_mask is not None: |
| | if attn_bias is None: |
| | attn_bias = attention_mask |
| | else: |
| | attn_bias = attn_bias + attention_mask |
| |
|
| | attention_scores = attention_scores + attn_bias |
| |
|
| | attention_probs = attention_scores.softmax(dim=-1) |
| | attention_probs = self.dropout(attention_probs) |
| |
|
| | x = einsum(attention_probs, v, "b h q k, b k h d -> b q h d") |
| |
|
| | x = rearrange(x, "b q h d -> b q (h d)", h=self.num_heads) |
| |
|
| | if use_cache: |
| | if output_attentions: |
| | return x, attention_probs, cache |
| | else: |
| | return x, None, cache |
| | else: |
| | if output_attentions: |
| | return x, attention_probs |
| | else: |
| | return x, None |
| |
|
| | class DeCodonPreTrainedModel(PreTrainedModel): |
| | """ |
| | An abstract class to handle weights initialization and a simple interface for downloading and loading pretrained |
| | models. |
| | """ |
| | base_model_prefix = "decodon" |
| | supports_gradient_checkpointing = True |
| |
|
| | def _init_weights(self, module): |
| | """MAGNETO Initialize the weights""" |
| | if isinstance(module, nn.Linear): |
| | nn.init.xavier_normal_(module.weight, gain=self.config.gamma_init) |
| | if module.bias is not None: |
| | module.bias.data.zero_() |
| |
|
| | elif isinstance(module, nn.Embedding): |
| | module.weight.data.normal_(mean=0.0, std=self.config.initializer_range) |
| | if module.padding_idx is not None: |
| | module.weight.data[module.padding_idx].zero_() |
| |
|
| | elif isinstance(module, nn.LayerNorm): |
| | module.bias.data.zero_() |
| | module.weight.data.fill_(1.0) |
| |
|
| | def _set_gradient_checkpointing(self, module, value=False): |
| | if isinstance(module, DeCodonLayer): |
| | module.gradient_checkpointing = value |
| |
|
| |
|
| | class DeCodonEmbeddings(nn.Module): |
| | """ |
| | DeCodon Embeddings |
| | |
| | Word, position and token type embeddings for DeCodon. |
| | """ |
| |
|
| | def __init__(self, config): |
| | super().__init__() |
| | self.word_embeddings = nn.Embedding( |
| | config.vocab_size, config.hidden_size, padding_idx=config.pad_token_id |
| | ) |
| | self.position_embeddings = nn.Embedding( |
| | config.max_position_embeddings, config.hidden_size |
| | ) |
| | self.token_type_embeddings = nn.Embedding( |
| | config.type_vocab_size, config.hidden_size |
| | ) |
| |
|
| | self.ln = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps) |
| | self.dropout = nn.Dropout(config.hidden_dropout_prob) |
| |
|
| | self.position_embedding_type = getattr( |
| | config, "position_embedding_type", "absolute" |
| | ) |
| |
|
| | self.register_buffer( |
| | "position_ids", |
| | torch.arange(config.max_position_embeddings).expand((1, -1)), |
| | persistent=False, |
| | ) |
| |
|
| | self.register_buffer( |
| | "token_type_ids", |
| | torch.zeros(self.position_ids.size(), dtype=torch.long), |
| | persistent=False, |
| | ) |
| |
|
| | def forward( |
| | self, |
| | input_ids: Optional[torch.LongTensor] = None, |
| | token_type_ids: Optional[torch.LongTensor] = None, |
| | position_ids: Optional[torch.LongTensor] = None, |
| | inputs_embeds: Optional[torch.FloatTensor] = None, |
| | past_key_values_length: int = 0, |
| | ) -> torch.Tensor: |
| |
|
| | if input_ids is not None: |
| | input_shape = input_ids.size() |
| | else: |
| | input_shape = inputs_embeds.size()[:-1] |
| |
|
| | seq_length = input_shape[1] |
| |
|
| | if position_ids is None: |
| | position_ids = self.position_ids[ |
| | :, past_key_values_length : seq_length + past_key_values_length |
| | ] |
| |
|
| | |
| | |
| | |
| | if token_type_ids is None: |
| | if hasattr(self, "token_type_ids"): |
| | buffered_token_type_ids = self.token_type_ids[:, :seq_length] |
| | buffered_token_type_ids_expanded = buffered_token_type_ids.expand( |
| | input_shape[0], seq_length |
| | ) |
| | token_type_ids = buffered_token_type_ids_expanded |
| | else: |
| | token_type_ids = torch.zeros( |
| | input_shape, dtype=torch.long, device=self.position_ids.device |
| | ) |
| |
|
| | if inputs_embeds is None: |
| | inputs_embeds = self.word_embeddings(input_ids) |
| |
|
| | token_type_embeddings = self.token_type_embeddings(token_type_ids) |
| |
|
| | embeddings = inputs_embeds + token_type_embeddings |
| | if self.position_embedding_type == "absolute": |
| | position_embeddings = self.position_embeddings(position_ids) |
| | embeddings += position_embeddings |
| |
|
| | |
| | embeddings = self.dropout(embeddings) |
| |
|
| | return embeddings |
| |
|
| |
|
| | class DeCodonAttention(nn.Module): |
| | """ |
| | DeCodon Attention Layer |
| | |
| | This module supports self-attention and dilated attention with Rotary Positional Embeddings (RoPE). |
| | """ |
| |
|
| | def __init__(self, config): |
| | super().__init__() |
| |
|
| | self.pre_layer_norm = nn.LayerNorm( |
| | config.hidden_size, eps=config.layer_norm_eps |
| | ) |
| | self.post_attn_dense = nn.Linear(config.hidden_size, config.hidden_size) |
| | self.post_layer_norm = nn.LayerNorm( |
| | config.hidden_size, eps=config.layer_norm_eps |
| | ) |
| | self.dropout = nn.Dropout(config.hidden_dropout_prob) |
| |
|
| | self.self_attention = MultiHeadedSelfAttention( |
| | q_input_dim=config.hidden_size, |
| | kv_input_dim=config.hidden_size, |
| | qk_proj_dim=config.hidden_size, |
| | v_proj_dim=config.hidden_size, |
| | num_heads=config.num_attention_heads, |
| | dropout=config.attention_probs_dropout_prob, |
| | projection_layer="conv", |
| | use_flash_attn=config.use_flash_attn, |
| | use_rotary_emb=config.use_rotary_emb, |
| | rotary_theta=config.rotary_theta, |
| | rotary_use_xpos=True, |
| | ) |
| |
|
| | def forward( |
| | self, |
| | hidden_states: Optional[Tuple[torch.FloatTensor]], |
| | attention_mask: Optional[torch.FloatTensor] = None, |
| | output_attentions: Optional[bool] = False, |
| | past_key_values: Optional[Tuple[torch.FloatTensor]] = None, |
| | use_cache: Optional[bool] = False, |
| | ) -> Tuple[Union[torch.Tensor, Tuple[torch.Tensor]], ...]: |
| |
|
| | attn_input = self.pre_layer_norm(hidden_states) |
| |
|
| | if past_key_values is not None: |
| | query = self.self_attention.query(attn_input) |
| | key = self.self_attention.key(attn_input) |
| | value = self.self_attention.value(attn_input) |
| |
|
| | past_key, past_value, past_query = past_key_values |
| |
|
| | |
| | |
| | |
| |
|
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| |
|
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| |
|
| | |
| | |
| |
|
| | key = torch.cat( |
| | (past_key, key), dim=1 |
| | ) |
| | value = torch.cat( |
| | (past_value, value), dim=1 |
| | ) |
| | query = torch.cat((past_query, query), dim=1) |
| |
|
| | |
| | |
| |
|
| | attn_outputs = self.self_attention( |
| | x_q=None, |
| | x_kv=None, |
| | query=query, |
| | key=key, |
| | value=value, |
| | is_causal=True, |
| | attention_mask=attention_mask, |
| | output_attentions=output_attentions, |
| | use_cache=use_cache, |
| | attention_bias=None, |
| | ) |
| | else: |
| | attn_outputs = self.self_attention( |
| | x_q=attn_input, |
| | x_kv=attn_input, |
| | is_causal=True, |
| | attention_bias=None, |
| | attention_mask=attention_mask, |
| | output_attentions=output_attentions, |
| | use_cache=use_cache, |
| | ) |
| |
|
| | attn_output = attn_outputs[0] |
| | attn_output = self.post_layer_norm(attn_output) |
| | attn_output = self.post_attn_dense(attn_output) |
| | attn_output = self.dropout(attn_output) |
| | attn_output = hidden_states + attn_output |
| |
|
| | return (attn_output,) + attn_outputs[1:] |
| |
|
| |
|
| | class DeCodonFFN(nn.Module): |
| | """ |
| | DeCodon Position-wise Feed-Forward Network |
| | """ |
| |
|
| | def __init__(self, config): |
| | super().__init__() |
| | embed_dim = config.hidden_size |
| | self.pre_layer_norm = nn.LayerNorm( |
| | config.hidden_size, eps=config.layer_norm_eps |
| | ) |
| | self.intermediate_dense = Conv1D(config.intermediate_size, embed_dim) |
| | self.post_layer_norm = nn.LayerNorm( |
| | config.intermediate_size, eps=config.layer_norm_eps |
| | ) |
| | self.post_dense = Conv1D(embed_dim, config.intermediate_size) |
| | self.dropout = nn.Dropout(config.hidden_dropout_prob) |
| |
|
| | if isinstance(config.hidden_act, str): |
| | self.intermediate_act_fn = ACT2FN[config.hidden_act] |
| | else: |
| | self.intermediate_act_fn = config.hidden_act |
| |
|
| | def forward( |
| | self, hidden_states: Optional[Tuple[torch.FloatTensor]] |
| | ) -> torch.FloatTensor: |
| | hidden_states = self.pre_layer_norm(hidden_states) |
| | hidden_states = self.intermediate_dense(hidden_states) |
| | hidden_states = self.intermediate_act_fn(hidden_states) |
| | hidden_states = self.post_layer_norm(hidden_states) |
| | hidden_states = self.post_dense(hidden_states) |
| | hidden_states = self.dropout(hidden_states) |
| | return hidden_states |
| |
|
| |
|
| | class DeCodonLayer(nn.Module): |
| | """ |
| | DeCodon (Decoder) Layer consists of an attention layer and a position-wise feed-forward network. |
| | """ |
| |
|
| | def __init__(self, config): |
| | super().__init__() |
| | self.attention = DeCodonAttention(config) |
| | self.output = DeCodonFFN(config) |
| |
|
| | def forward( |
| | self, |
| | hidden_states: Optional[Tuple[torch.FloatTensor]], |
| | attention_mask: Optional[torch.FloatTensor] = None, |
| | output_attentions: Optional[bool] = False, |
| | past_key_values: Optional[Tuple[torch.FloatTensor]] = None, |
| | use_cache: Optional[bool] = False, |
| | ) -> Union[ |
| | Tuple[torch.Tensor], |
| | Optional[Tuple[torch.Tensor, Tuple[torch.FloatTensor, ...]]], |
| | ]: |
| | self_attention_outputs = self.attention( |
| | hidden_states, |
| | attention_mask, |
| | output_attentions=output_attentions, |
| | past_key_values=past_key_values, |
| | use_cache=use_cache, |
| | ) |
| | attention_output = self_attention_outputs[0] |
| |
|
| | outputs = self_attention_outputs[ |
| | 1: |
| | ] |
| |
|
| | layer_output = self.output(attention_output) |
| | outputs = (layer_output,) + outputs |
| |
|
| | return outputs |
| |
|
| |
|
| | class DeCodonStack(nn.Module): |
| | """ |
| | DeCodon Stack consists of multiple DeCodon layers. |
| | """ |
| |
|
| | def __init__(self, config): |
| | super().__init__() |
| | self.config = config |
| | self.blocks = nn.ModuleList( |
| | [DeCodonLayer(config) for _ in range(config.num_hidden_layers)] |
| | ) |
| | self.gradient_checkpointing = False |
| |
|
| | def forward( |
| | self, |
| | hidden_states: torch.Tensor, |
| | attention_mask: Optional[torch.FloatTensor] = None, |
| | past_key_values: Optional[Tuple[torch.FloatTensor]] = None, |
| | output_attentions: Optional[bool] = False, |
| | output_hidden_states: Optional[bool] = False, |
| | return_dict: Optional[bool] = True, |
| | use_cache: Optional[bool] = False, |
| | ) -> Union[Tuple[torch.Tensor], BaseModelOutput]: |
| |
|
| | if past_key_values is None: |
| | past_key_values = [None] * len(self.blocks) |
| | past_length = 0 |
| | else: |
| | past_length = past_key_values[0][0].size(-2) |
| |
|
| | all_hidden_states = () if output_hidden_states else None |
| | all_self_attentions = () if output_attentions else None |
| | presents = () if use_cache else None |
| | for i, (block, past_key_value) in enumerate(zip(self.blocks, past_key_values)): |
| | if output_hidden_states: |
| | all_hidden_states = all_hidden_states + (hidden_states,) |
| |
|
| | block_outputs = block( |
| | hidden_states=hidden_states, |
| | attention_mask=attention_mask, |
| | output_attentions=output_attentions, |
| | past_key_values=past_key_value, |
| | use_cache=use_cache, |
| | ) |
| |
|
| | hidden_states = block_outputs[0] |
| |
|
| | if use_cache: |
| | presents = presents + (block_outputs[2],) |
| |
|
| | if output_attentions: |
| | all_self_attentions = all_self_attentions + (block_outputs[1],) |
| |
|
| | if output_hidden_states: |
| | all_hidden_states = all_hidden_states + (hidden_states,) |
| |
|
| | if not return_dict: |
| | return tuple( |
| | v |
| | for v in [ |
| | hidden_states, |
| | presents, |
| | all_hidden_states, |
| | all_self_attentions, |
| | ] |
| | if v is not None |
| | ) |
| |
|
| | return BaseModelOutputWithPast( |
| | last_hidden_state=hidden_states, |
| | past_key_values=presents, |
| | hidden_states=all_hidden_states, |
| | attentions=all_self_attentions, |
| | ) |
| |
|
| |
|
| | class DeCodonModule(DeCodonPreTrainedModel): |
| | """ |
| | The DeCodon Module (Decoder only) without any task-specific head on top. |
| | """ |
| | |
| | def __init__(self, config): |
| | super().__init__(config) |
| |
|
| | self.embeddings = DeCodonEmbeddings(config) |
| | self.decoder = DeCodonStack(config) |
| | self.ln_f = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps) |
| |
|
| | self.gradient_checkpointing = False |
| |
|
| | |
| | self.post_init() |
| | |
| | def set_input_embeddings(self, new_embeddings): |
| | self.embeddings.word_embeddings = new_embeddings |
| |
|
| | def forward( |
| | self, |
| | input_ids: Optional[torch.LongTensor] = None, |
| | attention_mask: Optional[torch.FloatTensor] = None, |
| | token_type_ids: Optional[torch.LongTensor] = None, |
| | position_ids: Optional[torch.LongTensor] = None, |
| | inputs_embeds: Optional[torch.FloatTensor] = None, |
| | past_key_values: Optional[Tuple[torch.FloatTensor]] = None, |
| | output_attentions: Optional[bool] = None, |
| | output_hidden_states: Optional[bool] = None, |
| | return_dict: Optional[bool] = None, |
| | use_cache: Optional[bool] = False, |
| | ) -> Union[Tuple, BaseModelOutput]: |
| | output_attentions = ( |
| | output_attentions |
| | if output_attentions is not None |
| | else self.config.output_attentions |
| | ) |
| | output_hidden_states = ( |
| | output_hidden_states |
| | if output_hidden_states is not None |
| | else self.config.output_hidden_states |
| | ) |
| | return_dict = ( |
| | return_dict if return_dict is not None else self.config.use_return_dict |
| | ) |
| |
|
| | if input_ids is not None and inputs_embeds is not None: |
| | raise ValueError( |
| | "You cannot specify both input_ids and inputs_embeds at the same time" |
| | ) |
| | elif input_ids is not None: |
| | self.warn_if_padding_and_no_attention_mask(input_ids, attention_mask) |
| | input_shape = input_ids.size() |
| | elif inputs_embeds is not None: |
| | input_shape = inputs_embeds.size()[:-1] |
| | else: |
| | raise ValueError("You have to specify either input_ids or inputs_embeds") |
| |
|
| | if past_key_values is not None: |
| | past_length = past_key_values[0][0].size(-2) |
| | else: |
| | past_length = 0 |
| |
|
| | batch_size, seq_length = input_shape |
| | device = input_ids.device if input_ids is not None else inputs_embeds.device |
| |
|
| | if attention_mask is None: |
| | attention_mask = torch.ones(((batch_size, seq_length)), device=device) |
| |
|
| | if token_type_ids is None: |
| | if hasattr(self.embeddings, "token_type_ids"): |
| | buffered_token_type_ids = self.embeddings.token_type_ids[:, :seq_length] |
| | buffered_token_type_ids_expanded = buffered_token_type_ids.expand( |
| | batch_size, seq_length |
| | ) |
| | token_type_ids = buffered_token_type_ids_expanded |
| | else: |
| | token_type_ids = torch.zeros( |
| | input_shape, dtype=torch.long, device=device |
| | ) |
| |
|
| | |
| | |
| | |
| | |
| | |
| | embedding_output = self.embeddings( |
| | input_ids=input_ids, |
| | position_ids=position_ids, |
| | token_type_ids=token_type_ids, |
| | inputs_embeds=inputs_embeds, |
| | ) |
| |
|
| | extended_attention_mask = _prepare_4d_causal_attention_mask( |
| | attention_mask=attention_mask, |
| | input_shape=(batch_size, input_shape[-1]), |
| | inputs_embeds=embedding_output, |
| | past_key_values_length=past_length, |
| | ) |
| | |
| |
|
| | decoder_outputs = self.decoder( |
| | embedding_output, |
| | attention_mask=extended_attention_mask, |
| | output_attentions=output_attentions, |
| | output_hidden_states=output_hidden_states, |
| | past_key_values=past_key_values, |
| | return_dict=return_dict, |
| | use_cache=use_cache, |
| | ) |
| |
|
| | sequence_output = decoder_outputs[0] |
| |
|
| | if not return_dict: |
| | return (sequence_output,) + decoder_outputs[1:] |
| |
|
| | return BaseModelOutputWithPast( |
| | last_hidden_state=sequence_output, |
| | past_key_values=decoder_outputs.past_key_values, |
| | hidden_states=decoder_outputs.hidden_states, |
| | attentions=decoder_outputs.attentions, |
| | ) |
| |
|
| |
|
| | @dataclass |
| | class DeCodonForPreTrainingOutput(CausalLMOutputWithPast): |
| | """ |
| | Output type of [`BERTransForPreTraining`]. |
| | |
| | Args: |
| | loss (*optional*, returned when `labels` is provided, `torch.FloatTensor` of shape `(1,)`): |
| | Total loss as the sum of the masked language modeling loss and the next sequence prediction |
| | (classification) loss. |
| | logits (`torch.FloatTensor` of shape `(batch_size, sequence_length, config.vocab_size)`): |
| | Prediction scores of the language modeling head (scores for each vocabulary token before SoftMax). |
| | org_logits (`torch.FloatTensor` of shape `(batch_size, 1)`): |
| | Prediction scores for organism classification (scores for each organism label before SoftMax). |
| | hidden_states (`tuple(torch.FloatTensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`): |
| | Tuple of `torch.FloatTensor` (one for the output of the embeddings + one for the output of each layer) of |
| | shape `(batch_size, sequence_length, hidden_size)`. |
| | |
| | Hidden-states of the model at the output of each layer plus the initial embedding outputs. |
| | attentions (`tuple(torch.FloatTensor)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`): |
| | Tuple of `torch.FloatTensor` (one for each layer) of shape `(batch_size, num_heads, sequence_length, |
| | sequence_length)`. |
| | |
| | Attentions weights after the attention softmax, used to compute the weighted average in the self-attention |
| | heads. |
| | """ |
| |
|
| | loss: Optional[torch.FloatTensor] = None |
| | logits: torch.FloatTensor = None |
| | past_key_values: Optional[Tuple[torch.FloatTensor]] = None |
| | hidden_states: Optional[Tuple[torch.FloatTensor]] = None |
| | attentions: Optional[Tuple[torch.FloatTensor]] = None |
| |
|
| |
|
| | class DeCodon(DeCodonPreTrainedModel): |
| | config_class = DeCodonConfig |
| | _tied_weights_keys = [] |
| |
|
| | def __init__(self, config): |
| | super().__init__(config) |
| |
|
| | self.gpt = DeCodonModule(config) |
| |
|
| | |
| | if config.lm_type == "gpt": |
| | self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False) |
| | DeCodon._tied_weights_keys.append("lm_head.weight") |
| | else: |
| | self.lm_head = nn.Sequential( |
| | OrderedDict( |
| | [ |
| | ("dropout", nn.Dropout(config.hidden_dropout_prob)), |
| | ( |
| | "transform", |
| | nn.Linear(config.hidden_size, config.hidden_size), |
| | ), |
| | ("act", nn.ReLU()), |
| | ( |
| | "norm", |
| | nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps), |
| | ), |
| | ( |
| | "pred", |
| | nn.Linear( |
| | config.hidden_size, config.vocab_size, bias=False |
| | ), |
| | ), |
| | ] |
| | ) |
| | ) |
| | DeCodon._tied_weights_keys.append("lm_head.pred.weight") |
| |
|
| | |
| | self.post_init() |
| |
|
| | def get_input_embeddings(self): |
| | return self.gpt.embeddings.word_embeddings |
| |
|
| | def get_output_embeddings(self): |
| | return ( |
| | self.lm_head.pred.weight |
| | if isinstance(self.lm_head, nn.Sequential) |
| | else self.lm_head.weight if self.config.lm_type == "gpt" else None |
| | ) |
| |
|
| | def set_output_embeddings(self, new_embeddings): |
| | if isinstance(self.lm_head, nn.Sequential): |
| | self.lm_head.pred.weight = new_embeddings |
| | else: |
| | self.lm_head.weight = new_embeddings |
| |
|
| | def prepare_inputs_for_generation( |
| | self, input_ids, inputs_embeds=None, past_key_values=None, **kwargs |
| | ): |
| | token_type_ids = kwargs.get("token_type_ids", None) |
| | attention_mask = kwargs.get("attention_mask", None) |
| | position_ids = kwargs.get("position_ids", None) |
| | use_cache = kwargs.get("use_cache", True) |
| |
|
| | if past_key_values is not None and use_cache: |
| | past_length = past_key_values[0][0].shape[1] |
| |
|
| | if input_ids.shape[1] > past_length: |
| | remove_prefix_len = past_length |
| | else: |
| | remove_prefix_len = input_ids.shape[1] - 1 |
| |
|
| | input_ids = input_ids[:, remove_prefix_len:] |
| |
|
| | if token_type_ids is not None: |
| | token_type_ids = token_type_ids[:, remove_prefix_len:] |
| |
|
| | if attention_mask is not None and position_ids is None: |
| | |
| | position_ids = attention_mask.long().cumsum(-1) - 1 |
| | position_ids.masked_fill_(attention_mask == 0, 1) |
| | else: |
| | position_ids = None |
| |
|
| | if inputs_embeds is not None: |
| | model_inputs = {"inputs_embeds": inputs_embeds} |
| | else: |
| | model_inputs = {"input_ids": input_ids} |
| |
|
| | model_inputs.update( |
| | { |
| | "position_ids": position_ids, |
| | "attention_mask": attention_mask, |
| | "token_type_ids": token_type_ids, |
| | "past_key_values": past_key_values, |
| | "use_cache": kwargs.get("use_cache", True), |
| | } |
| | ) |
| |
|
| | return model_inputs |
| |
|
| | @staticmethod |
| | def _reorder_cache( |
| | past_key_values: Tuple[Tuple[torch.Tensor]], beam_idx: torch.Tensor |
| | ) -> Tuple[Tuple[torch.Tensor]]: |
| | """ |
| | This function is used to re-order the `past_key_values` cache if [`~PreTrainedModel.beam_search`] or |
| | [`~PreTrainedModel.beam_sample`] is called. This is required to match `past_key_values` with the correct |
| | beam_idx at every generation step. |
| | """ |
| | return tuple( |
| | tuple( |
| | past_state.index_select(0, beam_idx.to(past_state.device)) |
| | for past_state in layer_past |
| | ) |
| | for layer_past in past_key_values |
| | ) |
| |
|
| | def forward( |
| | self, |
| | input_ids: Optional[torch.Tensor] = None, |
| | attention_mask: Optional[torch.Tensor] = None, |
| | token_type_ids: Optional[torch.Tensor] = None, |
| | position_ids: Optional[torch.Tensor] = None, |
| | inputs_embeds: Optional[torch.Tensor] = None, |
| | labels: Optional[torch.Tensor] = None, |
| | organism: Optional[torch.Tensor] = None, |
| | past_key_values: Optional[Tuple[torch.Tensor]] = None, |
| | output_attentions: Optional[bool] = None, |
| | output_hidden_states: Optional[bool] = None, |
| | return_dict: Optional[bool] = None, |
| | use_cache: Optional[bool] = False, |
| | **kwargs, |
| | ) -> Union[Tuple[torch.Tensor], DeCodonForPreTrainingOutput]: |
| | r""" |
| | labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*): |
| | Labels for computing the masked language modeling loss. Indices should be in `[-100, 0, ..., |
| | config.vocab_size]` (see `input_ids` docstring) Tokens with indices set to `-100` are ignored (masked), |
| | the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]` |
| | organism (`torch.LongTensor` of shape `(batch_size,)`, *optional*): |
| | Organism labels |
| | kwargs (`Dict[str, any]`, optional, defaults to *{}*): |
| | Used to hide legacy arguments that have been deprecated. |
| | |
| | Returns: |
| | |
| | Example: |
| | |
| | ```python |
| | >>> from transformers import AutoTokenizer, BertForPreTraining |
| | >>> import torch |
| | |
| | >>> tokenizer = AutoTokenizer.from_pretrained("bertrans-base") |
| | >>> model = BERTransForPreTraining.from_pretrained("bertrans-base") |
| | |
| | >>> inputs = tokenizer("AAAAGGGGGGCCCCCCTTTTT", return_tensors="pt") |
| | >>> outputs = model(**inputs) |
| | |
| | >>> prediction_logits = outputs.prediction_logits |
| | >>> organism_logits = outputs.organism_logits |
| | >>> biotype_logits = outputs.biotype_logits |
| | ``` |
| | """ |
| | return_dict = ( |
| | return_dict if return_dict is not None else self.config.use_return_dict |
| | ) |
| |
|
| | if input_ids is not None: |
| | batch_size, sequence_length = input_ids.shape[:2] |
| | else: |
| | batch_size, sequence_length = inputs_embeds.shape[:2] |
| |
|
| | if self.config.pad_token_id is None: |
| | sequence_lengths = -1 |
| | else: |
| | if input_ids is not None: |
| | sequence_lengths = ( |
| | torch.eq(input_ids, self.config.pad_token_id).long().argmax(-1) - 1 |
| | ).to(input_ids.device) |
| | else: |
| | sequence_lengths = -1 |
| | logger.warning( |
| | f"{self.__class__.__name__} will not detect padding tokens in `inputs_embeds`. Results may be " |
| | "unexpected if using padding tokens in conjunction with `inputs_embeds.`" |
| | ) |
| |
|
| | gpt_outputs = self.gpt( |
| | input_ids, |
| | attention_mask=attention_mask, |
| | token_type_ids=token_type_ids, |
| | position_ids=position_ids, |
| | inputs_embeds=inputs_embeds, |
| | past_key_values=past_key_values, |
| | output_attentions=output_attentions, |
| | output_hidden_states=output_hidden_states, |
| | return_dict=return_dict, |
| | use_cache=use_cache, |
| | ) |
| |
|
| | hidden_states = gpt_outputs[0] |
| | lm_logits = self.lm_head( |
| | hidden_states |
| | ) |
| |
|
| | loss = None |
| | if labels is not None: |
| | |
| | labels = labels.to(lm_logits.device) |
| | |
| | shift_logits = lm_logits[..., :-1, :].contiguous() |
| | shift_labels = labels[..., 1:].contiguous() |
| | |
| | loss_fct = nn.CrossEntropyLoss() |
| | lm_loss = loss_fct( |
| | shift_logits.view(-1, shift_logits.size(-1)), shift_labels.view(-1) |
| | ) |
| | loss = lm_loss |
| | else: |
| | lm_loss = None |
| |
|
| | if not return_dict: |
| | output = (lm_logits,) + gpt_outputs[1:] |
| | return ((loss,) + output) if loss is not None else output |
| |
|
| | return DeCodonForPreTrainingOutput( |
| | loss=loss, |
| | logits=lm_logits, |
| | past_key_values=gpt_outputs.past_key_values, |
| | hidden_states=gpt_outputs.hidden_states, |
| | attentions=gpt_outputs.attentions, |
| | ) |
| | |
| | def freeze(self, layer_indices: Optional[list] = None): |
| | if layer_indices is None or len(layer_indices) == 0: |
| | for param in self.gpt.parameters(): |
| | param.requires_grad = False |
| | else: |
| | for param in self.gpt.embeddings.parameters(): |
| | param.requires_grad = False |
| |
|
| | if isinstance(layer_indices, int): |
| | layer_indices = [layer_indices] |
| |
|
| | layer_indices = [i % len(self.gpt.decoder.blocks) for i in layer_indices] |
| |
|
| | for i in range(len(self.gpt.decoder.blocks)): |
| | if i not in layer_indices: |
| | for param in self.gpt.decoder.blocks[i].parameters(): |
| | param.requires_grad = False |
| |
|
| |
|
| |
|
| | class DeCodonForSequenceTask(DeCodonPreTrainedModel): |
| | def __init__(self, config): |
| | super().__init__(config) |
| | self.config = config |
| |
|
| | self.gpt = DeCodonModule(config) |
| |
|
| | if config.cls_type.lower() == "cls": |
| | layer_indices = config.layer_indices |
| | layer_indices = ( |
| | [] |
| | if layer_indices is None |
| | else ( |
| | [layer_indices] if isinstance(layer_indices, int) else layer_indices |
| | ) |
| | ) |
| | layer_indices = [i % len(self.gpt.decoder.blocks) for i in layer_indices] |
| |
|
| | n_layers = len(layer_indices) |
| | self.layer_indices = layer_indices |
| | self.classifier = nn.Sequential( |
| | nn.LayerNorm(config.hidden_size * n_layers), |
| | nn.Linear(config.hidden_size * n_layers, config.hidden_size), |
| | ACT2CLS[config.cls_hidden_act](), |
| | nn.Dropout(config.cls_dropout_prob), |
| | nn.Linear( |
| | config.hidden_size, |
| | config.num_labels * config.num_tasks, |
| | ), |
| | ) |
| | else: |
| | raise ValueError(f"Invalid cls_type: {config.cls_type}.") |
| |
|
| | self.init_weights() |
| |
|
| | def freeze(self, layers_idx: Optional[list] = None): |
| | if layers_idx is None or len(layers_idx) == 0: |
| | for param in self.gpt.parameters(): |
| | param.requires_grad = False |
| | else: |
| | for param in self.gpt.embeddings.parameters(): |
| | param.requires_grad = False |
| |
|
| | if isinstance(layers_idx, int): |
| | layers_idx = [layers_idx] |
| |
|
| | layers_idx = [i % self.config.num_hidden_layers for i in layers_idx] |
| |
|
| | for i in range(self.config.num_hidden_layers): |
| | if i not in layers_idx: |
| | for param in self.gpt.decoder.blocks[i].parameters(): |
| | param.requires_grad = False |
| |
|
| | def forward( |
| | self, |
| | input_ids: Optional[torch.Tensor] = None, |
| | target: Optional[torch.Tensor] = None, |
| | attention_mask: Optional[torch.Tensor] = None, |
| | token_type_ids: Optional[torch.Tensor] = None, |
| | position_ids: Optional[torch.Tensor] = None, |
| | inputs_embeds: Optional[torch.Tensor] = None, |
| | output_attentions: Optional[bool] = None, |
| | output_hidden_states: Optional[bool] = None, |
| | return_dict: Optional[bool] = None, |
| | **kwargs, |
| | ): |
| | return_dict = ( |
| | return_dict if return_dict is not None else self.config.use_return_dict |
| | ) |
| |
|
| | if input_ids is not None: |
| | batch_size, sequence_length = input_ids.shape[:2] |
| | else: |
| | batch_size, sequence_length = inputs_embeds.shape[:2] |
| |
|
| | if self.config.pad_token_id is None: |
| | sequence_lengths = -1 |
| | else: |
| | if input_ids is not None: |
| | sequence_lengths = ( |
| | torch.eq(input_ids, self.config.pad_token_id).long().argmax(-1) - 1 |
| | ).to( |
| | input_ids.device |
| | ) |
| | else: |
| | sequence_lengths = -1 |
| | logger.warning( |
| | f"{self.__class__.__name__} will not detect padding tokens in `inputs_embeds`. Results may be " |
| | "unexpected if using padding tokens in conjunction with `inputs_embeds.`" |
| | ) |
| |
|
| | gpt_outputs = self.gpt( |
| | input_ids, |
| | attention_mask=attention_mask, |
| | token_type_ids=token_type_ids, |
| | position_ids=position_ids, |
| | inputs_embeds=inputs_embeds, |
| | output_attentions=output_attentions, |
| | output_hidden_states=True, |
| | return_dict=return_dict, |
| | ) |
| |
|
| | all_hidden_states = gpt_outputs.hidden_states |
| |
|
| | if self.config.cls_type.lower() not in ["crossattention", "ca", "cls"]: |
| | logits, _ = self.classifier(all_hidden_states, attention_mask) |
| | elif self.config.cls_type.lower() in ["crossattention", "ca"]: |
| | bs, seq_len = input_ids.shape |
| |
|
| | query_tasks = self.task_embeddings.weight |
| | query_tasks = query_tasks.unsqueeze(0).expand( |
| | bs, -1, -1 |
| | ) |
| |
|
| | cls_outputs = self.classifier( |
| | query_tasks, |
| | all_hidden_states, |
| | attention_mask, |
| | output_attentions=output_attentions, |
| | ) |
| |
|
| | logits, ca = cls_outputs |
| |
|
| | logits = logits.squeeze() |
| | elif self.config.cls_type.lower() == "cls": |
| | bs, seq_len = input_ids.shape |
| | |
| | pooled_hidden_states = [ |
| | h[torch.arange(bs, device=h.device), sequence_lengths - 1, :] |
| | for i, h in enumerate(all_hidden_states) |
| | if i in self.layer_indices |
| | ] |
| | pooled_output = torch.cat( |
| | pooled_hidden_states, dim=-1 |
| | ) |
| |
|
| | logits = self.classifier(pooled_output) |
| |
|
| | loss = None |
| | if target is not None: |
| | if self.config.problem_type == "regression": |
| | logits = logits.view(-1, self.config.num_labels * self.config.num_tasks) |
| | target = target.view(-1, self.config.num_labels * self.config.num_tasks) |
| |
|
| | mask = target != -500.0 |
| | |
| | if self.config.loss_fn == "mse": |
| | loss_fct = nn.MSELoss() |
| | loss = loss_fct(logits[mask], target[mask]) |
| | elif self.config.loss_fn == "mae": |
| | loss_fct = nn.L1Loss() |
| | loss = loss_fct(logits[mask], target[mask]) |
| | elif self.config.loss_fn == "huber": |
| | loss_fct = nn.SmoothL1Loss() |
| | loss = loss_fct(logits[mask], target[mask]) |
| | else: |
| | raise ValueError(f"Invalid loss_fn: {self.config.loss_fn}.") |
| | else: |
| | loss_fct = nn.CrossEntropyLoss() |
| |
|
| | logits = logits.view(-1, self.config.num_labels * self.config.num_tasks) |
| | target = target.view( |
| | -1, |
| | ) |
| |
|
| | loss = loss_fct(logits, target) |
| |
|
| | if not return_dict: |
| | output = (logits,) + gpt_outputs[2:] |
| | return ((loss,) + output) if loss is not None else output |
| |
|
| | if output_attentions: |
| | if ca is not None: |
| | attentions = gpt_outputs.attentions + [ca] |
| | else: |
| | attentions = gpt_outputs.attentions |
| | else: |
| | attentions = None |
| |
|
| | return SequenceClassifierOutput( |
| | loss=loss, |
| | logits=logits, |
| | hidden_states=pooled_output, |
| | attentions=attentions, |
| | ) |