|
|
from typing import Optional, Tuple |
|
|
from dataclasses import dataclass |
|
|
import torch |
|
|
import torch.nn as nn |
|
|
|
|
|
from transformers.modeling_outputs import ( |
|
|
SequenceClassifierOutput, |
|
|
) |
|
|
|
|
|
from typing import Optional, Tuple |
|
|
|
|
|
import torch |
|
|
import torch.utils.checkpoint |
|
|
from torch import nn |
|
|
|
|
|
from dataclasses import dataclass |
|
|
from transformers.activations import ACT2FN, ACT2CLS |
|
|
from transformers.modeling_utils import PreTrainedModel |
|
|
from transformers.utils import logging |
|
|
from transformers.modeling_outputs import BaseModelOutputWithPast, BaseModelOutput, CausalLMOutputWithPast |
|
|
from transformers.modeling_attn_mask_utils import _prepare_4d_causal_attention_mask |
|
|
import xformers.ops as xops |
|
|
|
|
|
from collections import OrderedDict |
|
|
|
|
|
logger = logging.get_logger(__name__) |
|
|
|
|
|
import torch |
|
|
import torch.nn as nn |
|
|
from einops import rearrange, einsum |
|
|
from transformers.pytorch_utils import Conv1D |
|
|
|
|
|
|
|
|
import torch |
|
|
from torch.amp import autocast |
|
|
from torch import nn, einsum, Tensor |
|
|
|
|
|
from einops import rearrange, repeat |
|
|
from typing import Optional, Union |
|
|
|
|
|
from .configuration_decodon import DeCodonConfig |
|
|
|
|
|
logger = logging.get_logger(__name__) |
|
|
|
|
|
|
|
|
def rotate_half(x): |
|
|
x = rearrange(x, "... (d r) -> ... d r", r=2) |
|
|
x1, x2 = x.unbind(dim=-1) |
|
|
x = torch.stack((-x2, x1), dim=-1) |
|
|
return rearrange(x, "... d r -> ... (d r)") |
|
|
|
|
|
|
|
|
@autocast(device_type="cuda", enabled=False) |
|
|
def apply_rotary_emb(freqs, t, start_index=0, scale=1.0): |
|
|
""" |
|
|
Applies rotary embeddings to a tensor. |
|
|
|
|
|
Parameters |
|
|
---------- |
|
|
freqs : Tensor |
|
|
The frequencies to apply to the tensor: (seq_len, dim) |
|
|
t : Tensor |
|
|
The tensor to apply the rotary embeddings to: (..., seq_len, n_heads, dim) |
|
|
start_index : int |
|
|
The starting index to apply the rotary embeddings. (default: 0) |
|
|
scale : float |
|
|
The scale to apply to the rotary embeddings. (default: 1.0) |
|
|
|
|
|
Returns |
|
|
------- |
|
|
Tensor |
|
|
The tensor with the rotary embeddings applied.: (..., seq_len, n_heads, dim) |
|
|
|
|
|
""" |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
rot_dim = freqs.shape[-1] |
|
|
end_index = start_index + rot_dim |
|
|
|
|
|
assert ( |
|
|
rot_dim <= t.shape[-1] |
|
|
), f"feature dimension {t.shape[-1]} is not of sufficient size to rotate in all the positions {rot_dim}" |
|
|
|
|
|
t_left, t, t_right = ( |
|
|
t[..., :start_index], |
|
|
t[..., start_index:end_index], |
|
|
t[..., end_index:], |
|
|
) |
|
|
if isinstance(scale, float): |
|
|
scale = torch.tensor(scale, device=t.device, dtype=t.dtype) |
|
|
|
|
|
t = (t * freqs.cos() * scale) + (rotate_half(t) * freqs.sin() * scale) |
|
|
return torch.cat((t_left, t, t_right), dim=-1) |
|
|
|
|
|
|
|
|
|
|
|
def apply_learned_rotations(rotations, t, start_index=0, freq_ranges=None): |
|
|
if freq_ranges is not None: |
|
|
rotations = einsum("..., f -> ... f", rotations, freq_ranges) |
|
|
rotations = rearrange(rotations, "... r f -> ... (r f)") |
|
|
|
|
|
rotations = repeat(rotations, "... n -> ... (n r)", r=2) |
|
|
return apply_rotary_emb(rotations, t, start_index=start_index) |
|
|
|
|
|
|
|
|
""" |
|
|
Inspired from https://github.com/lucidrains/rotary-embedding-torch |
|
|
""" |
|
|
|
|
|
class RotaryEmbedding(nn.Module): |
|
|
""" |
|
|
Rotary Embeddings Implemenetation inspired by https://github.com/lucidrains/rotary-embedding-torch. |
|
|
|
|
|
Rotary Positional Embeddings (RoPE) encode position information of tokens with a |
|
|
rotation matrix that naturally incorporates explicit relative position dependency. |
|
|
|
|
|
Parameters |
|
|
---------- |
|
|
emb_dim : int |
|
|
Embedding dimension. Usually set to the dim of each head in the attention module. |
|
|
freqs : Optional[Tensor] |
|
|
Custom frequencies to apply to query/key tensors. (default: None) |
|
|
theta : float |
|
|
Base constant used for computing rotation angles. |
|
|
learned_freq : bool (default: False) |
|
|
Whether to learn the frequencies. |
|
|
use_xpos : bool (default: False) |
|
|
Whether to employ XPos technique for resolving length extrapolation issue. |
|
|
NOTE: This can only be enabled for autoregressive models like GPT. |
|
|
xpos_scale_base : int (default: 512) |
|
|
The base for the scale factor used in XPos technique. |
|
|
interpolate_factor : float (default: 1.0) |
|
|
Length interpolation factor for extending context length of the pretrained model. |
|
|
Final model's context length = pretrained_model_context_length * interpolate_factor. |
|
|
|
|
|
theta_rescale_factor : float (default: 1.0) |
|
|
The factor to rescale the theta. |
|
|
|
|
|
cache_if_possible : bool (default: True) |
|
|
Whether to cache the frequencies/scales if possible. |
|
|
|
|
|
""" |
|
|
|
|
|
def __init__( |
|
|
self, |
|
|
emb_dim, |
|
|
freqs: Optional[Tensor] = None, |
|
|
theta=1e4, |
|
|
learned_freq=False, |
|
|
use_xpos=False, |
|
|
xpos_scale_base=512, |
|
|
interpolate_factor=1.0, |
|
|
theta_rescale_factor=1.0, |
|
|
cache_if_possible=True, |
|
|
): |
|
|
super().__init__() |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
theta *= theta_rescale_factor ** (emb_dim / (emb_dim - 2)) |
|
|
|
|
|
if freqs is None: |
|
|
freqs = 1.0 / ( |
|
|
theta |
|
|
** (torch.arange(0, emb_dim, 2)[: (emb_dim // 2)].float() / emb_dim) |
|
|
) |
|
|
|
|
|
|
|
|
self.cache_if_possible = cache_if_possible |
|
|
|
|
|
self.register_buffer("cached_freqs", None, persistent=False) |
|
|
self.register_buffer("cached_scales", None, persistent=False) |
|
|
|
|
|
self.freqs = nn.Parameter(freqs, requires_grad=learned_freq) |
|
|
|
|
|
self.learned_freq = learned_freq |
|
|
|
|
|
|
|
|
|
|
|
assert interpolate_factor >= 1.0 |
|
|
self.interpolate_factor = interpolate_factor |
|
|
|
|
|
|
|
|
self.use_xpos = use_xpos |
|
|
if not use_xpos: |
|
|
self.register_buffer("scale", None, persistent=False) |
|
|
return |
|
|
|
|
|
scale = (torch.arange(0, emb_dim, 2) + 0.4 * emb_dim) / (1.4 * emb_dim) |
|
|
self.scale_base = xpos_scale_base |
|
|
self.register_buffer("scale", scale, persistent=False) |
|
|
|
|
|
@property |
|
|
def device(self): |
|
|
return self.freqs.device |
|
|
|
|
|
def rotate_queries_or_keys(self, t, offset=0, freq_seq_len=None, scale=None): |
|
|
""" |
|
|
Parameters |
|
|
---------- |
|
|
t : Tensor |
|
|
tensor to rotate: (batch_size, seq_len, num_heads, head_dim) |
|
|
""" |
|
|
seq_len = t.shape[1] |
|
|
assert ( |
|
|
not self.use_xpos or scale is not None |
|
|
), "you must use `.rotate_queries_and_keys` method instead and pass in both queries and keys, for length extrapolatable rotary embeddings" |
|
|
|
|
|
if freq_seq_len is not None: |
|
|
assert freq_seq_len >= seq_len |
|
|
seq_len = freq_seq_len |
|
|
|
|
|
seq = ( |
|
|
torch.arange(seq_len, device=t.device, dtype=t.dtype) + offset |
|
|
) / self.interpolate_factor |
|
|
|
|
|
freqs = self.forward( |
|
|
seq, |
|
|
seq_len=seq_len, |
|
|
offset=offset, |
|
|
).to(t.dtype) |
|
|
|
|
|
freqs = rearrange(freqs, "n d -> n 1 d") |
|
|
|
|
|
if scale is not None: |
|
|
scale = rearrange(scale, "n d -> n 1 d") |
|
|
|
|
|
if scale is None: |
|
|
scale = torch.tensor(1.0, device=t.device, dtype=t.dtype) |
|
|
|
|
|
return apply_rotary_emb(freqs, t, scale=scale) |
|
|
|
|
|
def rotate_queries_and_keys(self, q, k): |
|
|
""" |
|
|
Parameters |
|
|
---------- |
|
|
q : Tensor |
|
|
queries tensor: (batch_size, seq_len, num_heads, head_dim) |
|
|
k : Tensor |
|
|
keys tensor: (batch_size, seq_len, num_heads, head_dim) |
|
|
""" |
|
|
assert self.use_xpos |
|
|
seq_len = q.shape[-3] |
|
|
|
|
|
seq = ( |
|
|
torch.arange(seq_len, device=q.device, dtype=q.dtype) |
|
|
) / self.interpolate_factor |
|
|
|
|
|
freqs = self.forward(seq, seq_len=seq_len) |
|
|
scale = self.get_scale(seq, seq_len=seq_len) |
|
|
|
|
|
freqs = rearrange(freqs, "n d -> n 1 d") |
|
|
scale = rearrange(scale, "n d -> n 1 d") |
|
|
|
|
|
rotated_q = apply_rotary_emb(freqs, q, scale=scale) |
|
|
rotated_k = apply_rotary_emb(freqs, k, scale=scale**-1) |
|
|
|
|
|
rotated_q = rotated_q.type(q.dtype) |
|
|
rotated_k = rotated_k.type(k.dtype) |
|
|
|
|
|
return rotated_q, rotated_k |
|
|
|
|
|
def get_scale(self, t: Tensor, seq_len: Optional[int] = None, offset=0): |
|
|
assert self.use_xpos |
|
|
|
|
|
should_cache = self.cache_if_possible and seq_len is not None |
|
|
|
|
|
if ( |
|
|
should_cache |
|
|
and self.cached_scales is not None |
|
|
and (seq_len + offset) <= self.cached_scales.shape[0] |
|
|
): |
|
|
return self.cached_scales[offset : (offset + seq_len)] |
|
|
|
|
|
scale = 1.0 |
|
|
if self.use_xpos: |
|
|
power = (t - len(t) // 2) / self.scale_base |
|
|
scale = self.scale ** rearrange(power, "n -> n 1") |
|
|
scale = torch.cat((scale, scale), dim=-1) |
|
|
|
|
|
if should_cache: |
|
|
self.register_buffer("cached_scales", scale, persistent=False) |
|
|
|
|
|
return scale |
|
|
|
|
|
def rotate_queries_with_cached_keys(self, q, k, offset=0): |
|
|
q_len, k_len = q.shape[1], k.shape[1] |
|
|
assert q_len <= k_len |
|
|
|
|
|
rotated_q, rotated_k = self.rotate_queries_and_keys(q, k) |
|
|
|
|
|
rotated_q = rotated_q[:, -1:, ...] |
|
|
|
|
|
return rotated_q, rotated_k |
|
|
|
|
|
seq = ( |
|
|
torch.arange(k_len, device=q.device, dtype=q.dtype) |
|
|
) / self.interpolate_factor |
|
|
|
|
|
if self.use_xpos: |
|
|
q_scale = self.get_scale(seq[-q_len:]).to(q.dtype) |
|
|
k_scale = self.get_scale(seq).to(k.dtype) |
|
|
|
|
|
else: |
|
|
k_scale = 1.0 |
|
|
q_scale = 1.0 |
|
|
|
|
|
rotated_q = self.rotate_queries_or_keys( |
|
|
q, scale=q_scale, offset=k_len - q_len + offset |
|
|
) |
|
|
rotated_k = self.rotate_queries_or_keys(k, scale=k_scale**-1) |
|
|
|
|
|
return rotated_q, rotated_k |
|
|
|
|
|
@autocast(device_type="cuda", enabled=False) |
|
|
def forward(self, t: Tensor, seq_len=None, offset=0): |
|
|
should_cache = ( |
|
|
self.cache_if_possible and not self.learned_freq and seq_len is not None |
|
|
) |
|
|
|
|
|
if ( |
|
|
should_cache |
|
|
and self.cached_freqs is not None |
|
|
and (offset + seq_len) <= self.cached_freqs.shape[0] |
|
|
): |
|
|
return self.cached_freqs[offset : (offset + seq_len)].detach() |
|
|
|
|
|
freqs = self.freqs |
|
|
|
|
|
freqs = einsum("..., f -> ... f", t, freqs) |
|
|
freqs = repeat(freqs, "... n -> ... (n r)", r=2) |
|
|
|
|
|
if should_cache: |
|
|
self.register_buffer("cached_freqs", freqs.detach(), persistent=False) |
|
|
|
|
|
return freqs |
|
|
|
|
|
|
|
|
|
|
|
class MultiHeadedSelfAttention(nn.Module): |
|
|
""" |
|
|
Multi-Headed Self Attention module supported with Flash Attention and Rotary Embeddings. |
|
|
|
|
|
Parameters |
|
|
---------- |
|
|
q_input_dim: int |
|
|
The input dimension of the query tensor. |
|
|
kv_input_dim: int |
|
|
The input dimension of the key and value tensors. |
|
|
qk_proj_dim: int |
|
|
The projected dimension of the query and key tensors. |
|
|
v_proj_dim: int |
|
|
The projected dimension of the value tensors. |
|
|
num_heads: int |
|
|
Number of attention heads. |
|
|
dropout: float |
|
|
Dropout rate to apply to the attention scores. |
|
|
projection_layer: str |
|
|
The type of projection layer to use. Either 'linear' or 'conv'. |
|
|
Basically both are linear projections, but 'conv' uses Conv1D layer as proposed in the original GPT2 paper. |
|
|
use_flash_attn: bool |
|
|
Whether to use Flash Attention or not. If True, Flash Attention will be used. |
|
|
NOTE: Flash Attention is required to be installed. |
|
|
use_rotary_emb: bool |
|
|
Whether to use Rotary Embeddings or not. |
|
|
rotary_theta: int |
|
|
The base for the geometric progression used to compute the rotation angles. |
|
|
rotary_use_xpos: bool |
|
|
Whether to use XPos technique for resolving length extrapolation issue. |
|
|
NOTE: This can only be enabled for autoregressive models like GPT. |
|
|
""" |
|
|
|
|
|
def __init__( |
|
|
self, |
|
|
q_input_dim, |
|
|
kv_input_dim, |
|
|
qk_proj_dim, |
|
|
v_proj_dim, |
|
|
num_heads, |
|
|
dropout: float = 0.0, |
|
|
projection_layer: str = "linear", |
|
|
use_flash_attn: bool = True, |
|
|
use_rotary_emb: bool = False, |
|
|
rotary_theta: int = 1e4, |
|
|
rotary_use_xpos: bool = False, |
|
|
is_cross_attention: bool = False, |
|
|
**kwargs, |
|
|
): |
|
|
super().__init__() |
|
|
assert ( |
|
|
qk_proj_dim % num_heads == 0 |
|
|
), "qk_proj_dim must be divisible by num_heads" |
|
|
assert v_proj_dim % num_heads == 0, "v_proj_dim must be divisible by num_heads" |
|
|
|
|
|
self.num_heads = num_heads |
|
|
self.dropout_rate = dropout |
|
|
self.projection_layer = projection_layer |
|
|
self.use_rotary_emb = use_rotary_emb |
|
|
self.is_cross_attention = is_cross_attention |
|
|
|
|
|
if use_flash_attn and not is_cross_attention: |
|
|
try: |
|
|
from flash_attn import flash_attn_qkvpacked_func |
|
|
|
|
|
self.use_flash_attn = True |
|
|
self.flashattn_fn = flash_attn_qkvpacked_func |
|
|
except ImportError: |
|
|
print("flash_attn not installed, reverting to default attention") |
|
|
self.use_flash_attn = False |
|
|
self.flashattn_fn = None |
|
|
else: |
|
|
self.use_flash_attn = False |
|
|
self.flashattn_fn = None |
|
|
|
|
|
if self.projection_layer == "linear": |
|
|
self.query = nn.Linear(q_input_dim, qk_proj_dim) |
|
|
self.key = nn.Linear(kv_input_dim, qk_proj_dim) |
|
|
self.value = nn.Linear(kv_input_dim, v_proj_dim) |
|
|
elif self.projection_layer == "conv": |
|
|
self.query = Conv1D(qk_proj_dim, q_input_dim) |
|
|
self.key = Conv1D(qk_proj_dim, kv_input_dim) |
|
|
self.value = Conv1D(v_proj_dim, kv_input_dim) |
|
|
else: |
|
|
raise ValueError( |
|
|
f"projection_layer must be either 'linear' or 'conv', got {projection_layer}" |
|
|
) |
|
|
|
|
|
if self.use_rotary_emb: |
|
|
self.rotary_emb = RotaryEmbedding( |
|
|
emb_dim=qk_proj_dim // num_heads // 2, |
|
|
theta=rotary_theta, |
|
|
use_xpos=rotary_use_xpos, |
|
|
) |
|
|
|
|
|
self.dr_rate = dropout |
|
|
self.dropout = nn.Dropout(dropout) |
|
|
|
|
|
def forward( |
|
|
self, |
|
|
x_q, |
|
|
x_kv, |
|
|
is_causal=False, |
|
|
attention_bias=None, |
|
|
attention_mask=None, |
|
|
output_attentions=False, |
|
|
query=None, |
|
|
key=None, |
|
|
value=None, |
|
|
use_cache=False, |
|
|
): |
|
|
""" |
|
|
Applies a classical self attention operation. |
|
|
|
|
|
Parameters |
|
|
---------- |
|
|
x_q: torch.Tensor |
|
|
The query tensor of shape (batch_size, query_seq_len, emb_dim) |
|
|
x_kv: torch.Tensor |
|
|
The key/value tensor of shape (batch_size, kv_seq_len, emb_dim) |
|
|
attention_bias: torch.Tensor |
|
|
The attention bias to apply to the attention scores. (default: None) |
|
|
attention_mask: torch.Tensor |
|
|
The attention mask to apply to the attention scores. Shape: (batch_size, q_len, kv_seq_len) |
|
|
""" |
|
|
assert (x_q is not None and x_kv is not None) or ( |
|
|
query is not None and key is not None and value is not None |
|
|
), "Either x_q and x_kv or query, key and value must be provided" |
|
|
|
|
|
past_memory_provided = ( |
|
|
query is not None and key is not None and value is not None |
|
|
) |
|
|
|
|
|
if query is None: |
|
|
q_len = x_q.size(1) |
|
|
k_len = x_kv.size(1) |
|
|
|
|
|
query = self.query(x_q) |
|
|
key = self.key(x_kv) |
|
|
value = self.value(x_kv) |
|
|
|
|
|
else: |
|
|
q_len = query.size(1) |
|
|
k_len = key.size(1) |
|
|
|
|
|
if use_cache: |
|
|
cache = (key.clone(), value.clone(), query.clone()) |
|
|
|
|
|
q = rearrange(query, "b q (h d) -> b q h d", h=self.num_heads) |
|
|
k = rearrange(key, "b k (h d) -> b k h d", h=self.num_heads) |
|
|
v = rearrange(value, "b v (h d) -> b v h d", h=self.num_heads) |
|
|
|
|
|
if self.use_rotary_emb: |
|
|
if use_cache and past_memory_provided: |
|
|
q, k = self.rotary_emb.rotate_queries_with_cached_keys(q, k) |
|
|
if self.rotary_emb.use_xpos: |
|
|
q, k = self.rotary_emb.rotate_queries_and_keys(q, k) |
|
|
else: |
|
|
q = self.rotary_emb.rotate_queries_or_keys(q) |
|
|
k = self.rotary_emb.rotate_queries_or_keys(k) |
|
|
|
|
|
if ( |
|
|
self.use_flash_attn |
|
|
and not use_cache |
|
|
and not output_attentions |
|
|
and attention_bias is None |
|
|
): |
|
|
qkv = torch.stack([q, k, v], dim=2).to(torch.bfloat16) |
|
|
x = self.flashattn_fn( |
|
|
qkv=qkv, |
|
|
dropout_p=self.dropout_rate if self.training else 0.0, |
|
|
causal=is_causal, |
|
|
deterministic=False, |
|
|
return_attn_probs=False, |
|
|
) |
|
|
|
|
|
x = x.to(x_q.dtype) |
|
|
elif self.use_flash_attn and not output_attentions: |
|
|
attn_bias = xops.LowerTriangularMask() if is_causal else attention_bias |
|
|
|
|
|
if attention_mask is not None: |
|
|
if attn_bias is None: |
|
|
attn_bias = attention_mask |
|
|
else: |
|
|
if isinstance(attn_bias, torch.Tensor): |
|
|
attn_bias = attn_bias + attention_mask |
|
|
else: |
|
|
attn_bias.add_bias(bias=attention_mask) |
|
|
|
|
|
attn_bias = attn_bias.materialize( |
|
|
shape=(q_len, k_len), |
|
|
device=q.device, |
|
|
dtype=q.dtype, |
|
|
) |
|
|
else: |
|
|
if isinstance(attn_bias, torch.Tensor) and len(attn_bias.shape) == 3: |
|
|
attn_bias = ( |
|
|
attn_bias.unsqueeze(1) |
|
|
.expand(-1, self.num_heads, -1, -1) |
|
|
.float() |
|
|
) |
|
|
else: |
|
|
attn_bias = attn_bias.materialize( |
|
|
shape=(q_len, k_len), |
|
|
device=q.device, |
|
|
dtype=q.dtype, |
|
|
) |
|
|
|
|
|
if isinstance(attn_bias, xops.LowerTriangularMask): |
|
|
attn_bias = attn_bias.materialize( |
|
|
shape=(q_len, k_len), |
|
|
device=q.device, |
|
|
dtype=q.dtype, |
|
|
) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
need_adjustment = False |
|
|
if attn_bias.shape[-2] % 8 != 0: |
|
|
nearest_multiple_q = 8 * (1 + attn_bias.shape[-2] // 8) |
|
|
need_adjustment = True |
|
|
else: |
|
|
nearest_multiple_q = attn_bias.shape[-2] |
|
|
|
|
|
if attn_bias.shape[-1] % 8 != 0: |
|
|
nearest_multiple_k = 8 * (1 + attn_bias.shape[-1] // 8) |
|
|
need_adjustment = True |
|
|
else: |
|
|
nearest_multiple_k = attn_bias.shape[-1] |
|
|
|
|
|
if need_adjustment: |
|
|
new_attn_bias = torch.zeros( |
|
|
attn_bias.shape[0], |
|
|
attn_bias.shape[1], |
|
|
nearest_multiple_q, |
|
|
nearest_multiple_k, |
|
|
).to(attn_bias.device) |
|
|
new_attn_bias[:, :, : attn_bias.shape[-2], : attn_bias.shape[-1]] = ( |
|
|
attn_bias |
|
|
) |
|
|
|
|
|
x = xops.memory_efficient_attention( |
|
|
query=q, |
|
|
key=k, |
|
|
value=v, |
|
|
op=None, |
|
|
attn_bias=new_attn_bias[:, :, :q_len, :k_len], |
|
|
p=self.dr_rate, |
|
|
) |
|
|
else: |
|
|
attn_bias = attn_bias.to(q.dtype) |
|
|
attn_bias = attn_bias.repeat(1, self.num_heads, 1, 1) |
|
|
x = xops.memory_efficient_attention( |
|
|
query=q, |
|
|
key=k, |
|
|
value=v, |
|
|
op=None, |
|
|
attn_bias=attn_bias, |
|
|
p=self.dr_rate, |
|
|
) |
|
|
|
|
|
else: |
|
|
|
|
|
attention_scores = einsum(q, k, "b q h d, b k h d -> b h q k") |
|
|
attention_scores = attention_scores / (q.size(-1) ** 0.5) |
|
|
|
|
|
if attention_bias is not None: |
|
|
attn_bias = attention_bias.unsqueeze(1).expand( |
|
|
-1, self.num_heads, -1, -1 |
|
|
) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
else: |
|
|
attn_bias = None |
|
|
|
|
|
if attention_mask is not None: |
|
|
if attn_bias is None: |
|
|
attn_bias = attention_mask |
|
|
else: |
|
|
attn_bias = attn_bias + attention_mask |
|
|
|
|
|
attention_scores = attention_scores + attn_bias |
|
|
|
|
|
attention_probs = attention_scores.softmax(dim=-1) |
|
|
attention_probs = self.dropout(attention_probs) |
|
|
|
|
|
x = einsum(attention_probs, v, "b h q k, b k h d -> b q h d") |
|
|
|
|
|
x = rearrange(x, "b q h d -> b q (h d)", h=self.num_heads) |
|
|
|
|
|
if use_cache: |
|
|
if output_attentions: |
|
|
return x, attention_probs, cache |
|
|
else: |
|
|
return x, None, cache |
|
|
else: |
|
|
if output_attentions: |
|
|
return x, attention_probs |
|
|
else: |
|
|
return x, None |
|
|
|
|
|
class DeCodonPreTrainedModel(PreTrainedModel): |
|
|
""" |
|
|
An abstract class to handle weights initialization and a simple interface for downloading and loading pretrained |
|
|
models. |
|
|
""" |
|
|
base_model_prefix = "decodon" |
|
|
supports_gradient_checkpointing = True |
|
|
|
|
|
def _init_weights(self, module): |
|
|
"""MAGNETO Initialize the weights""" |
|
|
if isinstance(module, nn.Linear): |
|
|
nn.init.xavier_normal_(module.weight, gain=self.config.gamma_init) |
|
|
if module.bias is not None: |
|
|
module.bias.data.zero_() |
|
|
|
|
|
elif isinstance(module, nn.Embedding): |
|
|
module.weight.data.normal_(mean=0.0, std=self.config.initializer_range) |
|
|
if module.padding_idx is not None: |
|
|
module.weight.data[module.padding_idx].zero_() |
|
|
|
|
|
elif isinstance(module, nn.LayerNorm): |
|
|
module.bias.data.zero_() |
|
|
module.weight.data.fill_(1.0) |
|
|
|
|
|
def _set_gradient_checkpointing(self, module, value=False): |
|
|
if isinstance(module, DeCodonLayer): |
|
|
module.gradient_checkpointing = value |
|
|
|
|
|
|
|
|
class DeCodonEmbeddings(nn.Module): |
|
|
""" |
|
|
DeCodon Embeddings |
|
|
|
|
|
Word, position and token type embeddings for DeCodon. |
|
|
""" |
|
|
|
|
|
def __init__(self, config): |
|
|
super().__init__() |
|
|
self.word_embeddings = nn.Embedding( |
|
|
config.vocab_size, config.hidden_size, padding_idx=config.pad_token_id |
|
|
) |
|
|
self.position_embeddings = nn.Embedding( |
|
|
config.max_position_embeddings, config.hidden_size |
|
|
) |
|
|
self.token_type_embeddings = nn.Embedding( |
|
|
config.type_vocab_size, config.hidden_size |
|
|
) |
|
|
|
|
|
self.ln = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps) |
|
|
self.dropout = nn.Dropout(config.hidden_dropout_prob) |
|
|
|
|
|
self.position_embedding_type = getattr( |
|
|
config, "position_embedding_type", "absolute" |
|
|
) |
|
|
|
|
|
self.register_buffer( |
|
|
"position_ids", |
|
|
torch.arange(config.max_position_embeddings).expand((1, -1)), |
|
|
persistent=False, |
|
|
) |
|
|
|
|
|
self.register_buffer( |
|
|
"token_type_ids", |
|
|
torch.zeros(self.position_ids.size(), dtype=torch.long), |
|
|
persistent=False, |
|
|
) |
|
|
|
|
|
def forward( |
|
|
self, |
|
|
input_ids: Optional[torch.LongTensor] = None, |
|
|
token_type_ids: Optional[torch.LongTensor] = None, |
|
|
position_ids: Optional[torch.LongTensor] = None, |
|
|
inputs_embeds: Optional[torch.FloatTensor] = None, |
|
|
past_key_values_length: int = 0, |
|
|
) -> torch.Tensor: |
|
|
|
|
|
if input_ids is not None: |
|
|
input_shape = input_ids.size() |
|
|
else: |
|
|
input_shape = inputs_embeds.size()[:-1] |
|
|
|
|
|
seq_length = input_shape[1] |
|
|
|
|
|
if position_ids is None: |
|
|
position_ids = self.position_ids[ |
|
|
:, past_key_values_length : seq_length + past_key_values_length |
|
|
] |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
if token_type_ids is None: |
|
|
if hasattr(self, "token_type_ids"): |
|
|
buffered_token_type_ids = self.token_type_ids[:, :seq_length] |
|
|
buffered_token_type_ids_expanded = buffered_token_type_ids.expand( |
|
|
input_shape[0], seq_length |
|
|
) |
|
|
token_type_ids = buffered_token_type_ids_expanded |
|
|
else: |
|
|
token_type_ids = torch.zeros( |
|
|
input_shape, dtype=torch.long, device=self.position_ids.device |
|
|
) |
|
|
|
|
|
if inputs_embeds is None: |
|
|
inputs_embeds = self.word_embeddings(input_ids) |
|
|
|
|
|
token_type_embeddings = self.token_type_embeddings(token_type_ids) |
|
|
|
|
|
embeddings = inputs_embeds + token_type_embeddings |
|
|
if self.position_embedding_type == "absolute": |
|
|
position_embeddings = self.position_embeddings(position_ids) |
|
|
embeddings += position_embeddings |
|
|
|
|
|
|
|
|
embeddings = self.dropout(embeddings) |
|
|
|
|
|
return embeddings |
|
|
|
|
|
|
|
|
class DeCodonAttention(nn.Module): |
|
|
""" |
|
|
DeCodon Attention Layer |
|
|
|
|
|
This module supports self-attention and dilated attention with Rotary Positional Embeddings (RoPE). |
|
|
""" |
|
|
|
|
|
def __init__(self, config): |
|
|
super().__init__() |
|
|
|
|
|
self.pre_layer_norm = nn.LayerNorm( |
|
|
config.hidden_size, eps=config.layer_norm_eps |
|
|
) |
|
|
self.post_attn_dense = nn.Linear(config.hidden_size, config.hidden_size) |
|
|
self.post_layer_norm = nn.LayerNorm( |
|
|
config.hidden_size, eps=config.layer_norm_eps |
|
|
) |
|
|
self.dropout = nn.Dropout(config.hidden_dropout_prob) |
|
|
|
|
|
self.self_attention = MultiHeadedSelfAttention( |
|
|
q_input_dim=config.hidden_size, |
|
|
kv_input_dim=config.hidden_size, |
|
|
qk_proj_dim=config.hidden_size, |
|
|
v_proj_dim=config.hidden_size, |
|
|
num_heads=config.num_attention_heads, |
|
|
dropout=config.attention_probs_dropout_prob, |
|
|
projection_layer="conv", |
|
|
use_flash_attn=config.use_flash_attn, |
|
|
use_rotary_emb=config.use_rotary_emb, |
|
|
rotary_theta=config.rotary_theta, |
|
|
rotary_use_xpos=True, |
|
|
) |
|
|
|
|
|
def forward( |
|
|
self, |
|
|
hidden_states: Optional[Tuple[torch.FloatTensor]], |
|
|
attention_mask: Optional[torch.FloatTensor] = None, |
|
|
output_attentions: Optional[bool] = False, |
|
|
past_key_values: Optional[Tuple[torch.FloatTensor]] = None, |
|
|
use_cache: Optional[bool] = False, |
|
|
) -> Tuple[Union[torch.Tensor, Tuple[torch.Tensor]], ...]: |
|
|
|
|
|
attn_input = self.pre_layer_norm(hidden_states) |
|
|
|
|
|
if past_key_values is not None: |
|
|
query = self.self_attention.query(attn_input) |
|
|
key = self.self_attention.key(attn_input) |
|
|
value = self.self_attention.value(attn_input) |
|
|
|
|
|
past_key, past_value, past_query = past_key_values |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
key = torch.cat( |
|
|
(past_key, key), dim=1 |
|
|
) |
|
|
value = torch.cat( |
|
|
(past_value, value), dim=1 |
|
|
) |
|
|
query = torch.cat((past_query, query), dim=1) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
attn_outputs = self.self_attention( |
|
|
x_q=None, |
|
|
x_kv=None, |
|
|
query=query, |
|
|
key=key, |
|
|
value=value, |
|
|
is_causal=True, |
|
|
attention_mask=attention_mask, |
|
|
output_attentions=output_attentions, |
|
|
use_cache=use_cache, |
|
|
attention_bias=None, |
|
|
) |
|
|
else: |
|
|
attn_outputs = self.self_attention( |
|
|
x_q=attn_input, |
|
|
x_kv=attn_input, |
|
|
is_causal=True, |
|
|
attention_bias=None, |
|
|
attention_mask=attention_mask, |
|
|
output_attentions=output_attentions, |
|
|
use_cache=use_cache, |
|
|
) |
|
|
|
|
|
attn_output = attn_outputs[0] |
|
|
attn_output = self.post_layer_norm(attn_output) |
|
|
attn_output = self.post_attn_dense(attn_output) |
|
|
attn_output = self.dropout(attn_output) |
|
|
attn_output = hidden_states + attn_output |
|
|
|
|
|
return (attn_output,) + attn_outputs[1:] |
|
|
|
|
|
|
|
|
class DeCodonFFN(nn.Module): |
|
|
""" |
|
|
DeCodon Position-wise Feed-Forward Network |
|
|
""" |
|
|
|
|
|
def __init__(self, config): |
|
|
super().__init__() |
|
|
embed_dim = config.hidden_size |
|
|
self.pre_layer_norm = nn.LayerNorm( |
|
|
config.hidden_size, eps=config.layer_norm_eps |
|
|
) |
|
|
self.intermediate_dense = Conv1D(config.intermediate_size, embed_dim) |
|
|
self.post_layer_norm = nn.LayerNorm( |
|
|
config.intermediate_size, eps=config.layer_norm_eps |
|
|
) |
|
|
self.post_dense = Conv1D(embed_dim, config.intermediate_size) |
|
|
self.dropout = nn.Dropout(config.hidden_dropout_prob) |
|
|
|
|
|
if isinstance(config.hidden_act, str): |
|
|
self.intermediate_act_fn = ACT2FN[config.hidden_act] |
|
|
else: |
|
|
self.intermediate_act_fn = config.hidden_act |
|
|
|
|
|
def forward( |
|
|
self, hidden_states: Optional[Tuple[torch.FloatTensor]] |
|
|
) -> torch.FloatTensor: |
|
|
hidden_states = self.pre_layer_norm(hidden_states) |
|
|
hidden_states = self.intermediate_dense(hidden_states) |
|
|
hidden_states = self.intermediate_act_fn(hidden_states) |
|
|
hidden_states = self.post_layer_norm(hidden_states) |
|
|
hidden_states = self.post_dense(hidden_states) |
|
|
hidden_states = self.dropout(hidden_states) |
|
|
return hidden_states |
|
|
|
|
|
|
|
|
class DeCodonLayer(nn.Module): |
|
|
""" |
|
|
DeCodon (Decoder) Layer consists of an attention layer and a position-wise feed-forward network. |
|
|
""" |
|
|
|
|
|
def __init__(self, config): |
|
|
super().__init__() |
|
|
self.attention = DeCodonAttention(config) |
|
|
self.output = DeCodonFFN(config) |
|
|
|
|
|
def forward( |
|
|
self, |
|
|
hidden_states: Optional[Tuple[torch.FloatTensor]], |
|
|
attention_mask: Optional[torch.FloatTensor] = None, |
|
|
output_attentions: Optional[bool] = False, |
|
|
past_key_values: Optional[Tuple[torch.FloatTensor]] = None, |
|
|
use_cache: Optional[bool] = False, |
|
|
) -> Union[ |
|
|
Tuple[torch.Tensor], |
|
|
Optional[Tuple[torch.Tensor, Tuple[torch.FloatTensor, ...]]], |
|
|
]: |
|
|
self_attention_outputs = self.attention( |
|
|
hidden_states, |
|
|
attention_mask, |
|
|
output_attentions=output_attentions, |
|
|
past_key_values=past_key_values, |
|
|
use_cache=use_cache, |
|
|
) |
|
|
attention_output = self_attention_outputs[0] |
|
|
|
|
|
outputs = self_attention_outputs[ |
|
|
1: |
|
|
] |
|
|
|
|
|
layer_output = self.output(attention_output) |
|
|
outputs = (layer_output,) + outputs |
|
|
|
|
|
return outputs |
|
|
|
|
|
|
|
|
class DeCodonStack(nn.Module): |
|
|
""" |
|
|
DeCodon Stack consists of multiple DeCodon layers. |
|
|
""" |
|
|
|
|
|
def __init__(self, config): |
|
|
super().__init__() |
|
|
self.config = config |
|
|
self.blocks = nn.ModuleList( |
|
|
[DeCodonLayer(config) for _ in range(config.num_hidden_layers)] |
|
|
) |
|
|
self.gradient_checkpointing = False |
|
|
|
|
|
def forward( |
|
|
self, |
|
|
hidden_states: torch.Tensor, |
|
|
attention_mask: Optional[torch.FloatTensor] = None, |
|
|
past_key_values: Optional[Tuple[torch.FloatTensor]] = None, |
|
|
output_attentions: Optional[bool] = False, |
|
|
output_hidden_states: Optional[bool] = False, |
|
|
return_dict: Optional[bool] = True, |
|
|
use_cache: Optional[bool] = False, |
|
|
) -> Union[Tuple[torch.Tensor], BaseModelOutput]: |
|
|
|
|
|
if past_key_values is None: |
|
|
past_key_values = [None] * len(self.blocks) |
|
|
past_length = 0 |
|
|
else: |
|
|
past_length = past_key_values[0][0].size(-2) |
|
|
|
|
|
all_hidden_states = () if output_hidden_states else None |
|
|
all_self_attentions = () if output_attentions else None |
|
|
presents = () if use_cache else None |
|
|
for i, (block, past_key_value) in enumerate(zip(self.blocks, past_key_values)): |
|
|
if output_hidden_states: |
|
|
all_hidden_states = all_hidden_states + (hidden_states,) |
|
|
|
|
|
block_outputs = block( |
|
|
hidden_states=hidden_states, |
|
|
attention_mask=attention_mask, |
|
|
output_attentions=output_attentions, |
|
|
past_key_values=past_key_value, |
|
|
use_cache=use_cache, |
|
|
) |
|
|
|
|
|
hidden_states = block_outputs[0] |
|
|
|
|
|
if use_cache: |
|
|
presents = presents + (block_outputs[2],) |
|
|
|
|
|
if output_attentions: |
|
|
all_self_attentions = all_self_attentions + (block_outputs[1],) |
|
|
|
|
|
if output_hidden_states: |
|
|
all_hidden_states = all_hidden_states + (hidden_states,) |
|
|
|
|
|
if not return_dict: |
|
|
return tuple( |
|
|
v |
|
|
for v in [ |
|
|
hidden_states, |
|
|
presents, |
|
|
all_hidden_states, |
|
|
all_self_attentions, |
|
|
] |
|
|
if v is not None |
|
|
) |
|
|
|
|
|
return BaseModelOutputWithPast( |
|
|
last_hidden_state=hidden_states, |
|
|
past_key_values=presents, |
|
|
hidden_states=all_hidden_states, |
|
|
attentions=all_self_attentions, |
|
|
) |
|
|
|
|
|
|
|
|
class DeCodonModule(DeCodonPreTrainedModel): |
|
|
""" |
|
|
The DeCodon Module (Decoder only) without any task-specific head on top. |
|
|
""" |
|
|
|
|
|
def __init__(self, config): |
|
|
super().__init__(config) |
|
|
|
|
|
self.embeddings = DeCodonEmbeddings(config) |
|
|
self.decoder = DeCodonStack(config) |
|
|
self.ln_f = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps) |
|
|
|
|
|
self.gradient_checkpointing = False |
|
|
|
|
|
|
|
|
self.post_init() |
|
|
|
|
|
def set_input_embeddings(self, new_embeddings): |
|
|
self.embeddings.word_embeddings = new_embeddings |
|
|
|
|
|
def forward( |
|
|
self, |
|
|
input_ids: Optional[torch.LongTensor] = None, |
|
|
attention_mask: Optional[torch.FloatTensor] = None, |
|
|
token_type_ids: Optional[torch.LongTensor] = None, |
|
|
position_ids: Optional[torch.LongTensor] = None, |
|
|
inputs_embeds: Optional[torch.FloatTensor] = None, |
|
|
past_key_values: Optional[Tuple[torch.FloatTensor]] = None, |
|
|
output_attentions: Optional[bool] = None, |
|
|
output_hidden_states: Optional[bool] = None, |
|
|
return_dict: Optional[bool] = None, |
|
|
use_cache: Optional[bool] = False, |
|
|
) -> Union[Tuple, BaseModelOutput]: |
|
|
output_attentions = ( |
|
|
output_attentions |
|
|
if output_attentions is not None |
|
|
else self.config.output_attentions |
|
|
) |
|
|
output_hidden_states = ( |
|
|
output_hidden_states |
|
|
if output_hidden_states is not None |
|
|
else self.config.output_hidden_states |
|
|
) |
|
|
return_dict = ( |
|
|
return_dict if return_dict is not None else self.config.use_return_dict |
|
|
) |
|
|
|
|
|
if input_ids is not None and inputs_embeds is not None: |
|
|
raise ValueError( |
|
|
"You cannot specify both input_ids and inputs_embeds at the same time" |
|
|
) |
|
|
elif input_ids is not None: |
|
|
self.warn_if_padding_and_no_attention_mask(input_ids, attention_mask) |
|
|
input_shape = input_ids.size() |
|
|
elif inputs_embeds is not None: |
|
|
input_shape = inputs_embeds.size()[:-1] |
|
|
else: |
|
|
raise ValueError("You have to specify either input_ids or inputs_embeds") |
|
|
|
|
|
if past_key_values is not None: |
|
|
past_length = past_key_values[0][0].size(-2) |
|
|
else: |
|
|
past_length = 0 |
|
|
|
|
|
batch_size, seq_length = input_shape |
|
|
device = input_ids.device if input_ids is not None else inputs_embeds.device |
|
|
|
|
|
if attention_mask is None: |
|
|
attention_mask = torch.ones(((batch_size, seq_length)), device=device) |
|
|
|
|
|
if token_type_ids is None: |
|
|
if hasattr(self.embeddings, "token_type_ids"): |
|
|
buffered_token_type_ids = self.embeddings.token_type_ids[:, :seq_length] |
|
|
buffered_token_type_ids_expanded = buffered_token_type_ids.expand( |
|
|
batch_size, seq_length |
|
|
) |
|
|
token_type_ids = buffered_token_type_ids_expanded |
|
|
else: |
|
|
token_type_ids = torch.zeros( |
|
|
input_shape, dtype=torch.long, device=device |
|
|
) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
embedding_output = self.embeddings( |
|
|
input_ids=input_ids, |
|
|
position_ids=position_ids, |
|
|
token_type_ids=token_type_ids, |
|
|
inputs_embeds=inputs_embeds, |
|
|
) |
|
|
|
|
|
extended_attention_mask = _prepare_4d_causal_attention_mask( |
|
|
attention_mask=attention_mask, |
|
|
input_shape=(batch_size, input_shape[-1]), |
|
|
inputs_embeds=embedding_output, |
|
|
past_key_values_length=past_length, |
|
|
) |
|
|
|
|
|
|
|
|
decoder_outputs = self.decoder( |
|
|
embedding_output, |
|
|
attention_mask=extended_attention_mask, |
|
|
output_attentions=output_attentions, |
|
|
output_hidden_states=output_hidden_states, |
|
|
past_key_values=past_key_values, |
|
|
return_dict=return_dict, |
|
|
use_cache=use_cache, |
|
|
) |
|
|
|
|
|
sequence_output = decoder_outputs[0] |
|
|
|
|
|
if not return_dict: |
|
|
return (sequence_output,) + decoder_outputs[1:] |
|
|
|
|
|
return BaseModelOutputWithPast( |
|
|
last_hidden_state=sequence_output, |
|
|
past_key_values=decoder_outputs.past_key_values, |
|
|
hidden_states=decoder_outputs.hidden_states, |
|
|
attentions=decoder_outputs.attentions, |
|
|
) |
|
|
|
|
|
|
|
|
@dataclass |
|
|
class DeCodonForPreTrainingOutput(CausalLMOutputWithPast): |
|
|
""" |
|
|
Output type of [`BERTransForPreTraining`]. |
|
|
|
|
|
Args: |
|
|
loss (*optional*, returned when `labels` is provided, `torch.FloatTensor` of shape `(1,)`): |
|
|
Total loss as the sum of the masked language modeling loss and the next sequence prediction |
|
|
(classification) loss. |
|
|
logits (`torch.FloatTensor` of shape `(batch_size, sequence_length, config.vocab_size)`): |
|
|
Prediction scores of the language modeling head (scores for each vocabulary token before SoftMax). |
|
|
org_logits (`torch.FloatTensor` of shape `(batch_size, 1)`): |
|
|
Prediction scores for organism classification (scores for each organism label before SoftMax). |
|
|
hidden_states (`tuple(torch.FloatTensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`): |
|
|
Tuple of `torch.FloatTensor` (one for the output of the embeddings + one for the output of each layer) of |
|
|
shape `(batch_size, sequence_length, hidden_size)`. |
|
|
|
|
|
Hidden-states of the model at the output of each layer plus the initial embedding outputs. |
|
|
attentions (`tuple(torch.FloatTensor)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`): |
|
|
Tuple of `torch.FloatTensor` (one for each layer) of shape `(batch_size, num_heads, sequence_length, |
|
|
sequence_length)`. |
|
|
|
|
|
Attentions weights after the attention softmax, used to compute the weighted average in the self-attention |
|
|
heads. |
|
|
""" |
|
|
|
|
|
loss: Optional[torch.FloatTensor] = None |
|
|
logits: torch.FloatTensor = None |
|
|
past_key_values: Optional[Tuple[torch.FloatTensor]] = None |
|
|
hidden_states: Optional[Tuple[torch.FloatTensor]] = None |
|
|
attentions: Optional[Tuple[torch.FloatTensor]] = None |
|
|
|
|
|
|
|
|
class DeCodon(DeCodonPreTrainedModel): |
|
|
config_class = DeCodonConfig |
|
|
_tied_weights_keys = [] |
|
|
|
|
|
def __init__(self, config): |
|
|
super().__init__(config) |
|
|
|
|
|
self.gpt = DeCodonModule(config) |
|
|
|
|
|
|
|
|
if config.lm_type == "gpt": |
|
|
self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False) |
|
|
DeCodon._tied_weights_keys.append("lm_head.weight") |
|
|
else: |
|
|
self.lm_head = nn.Sequential( |
|
|
OrderedDict( |
|
|
[ |
|
|
("dropout", nn.Dropout(config.hidden_dropout_prob)), |
|
|
( |
|
|
"transform", |
|
|
nn.Linear(config.hidden_size, config.hidden_size), |
|
|
), |
|
|
("act", nn.ReLU()), |
|
|
( |
|
|
"norm", |
|
|
nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps), |
|
|
), |
|
|
( |
|
|
"pred", |
|
|
nn.Linear( |
|
|
config.hidden_size, config.vocab_size, bias=False |
|
|
), |
|
|
), |
|
|
] |
|
|
) |
|
|
) |
|
|
DeCodon._tied_weights_keys.append("lm_head.pred.weight") |
|
|
|
|
|
|
|
|
self.post_init() |
|
|
|
|
|
def get_input_embeddings(self): |
|
|
return self.gpt.embeddings.word_embeddings |
|
|
|
|
|
def get_output_embeddings(self): |
|
|
return ( |
|
|
self.lm_head.pred.weight |
|
|
if isinstance(self.lm_head, nn.Sequential) |
|
|
else self.lm_head.weight if self.config.lm_type == "gpt" else None |
|
|
) |
|
|
|
|
|
def set_output_embeddings(self, new_embeddings): |
|
|
if isinstance(self.lm_head, nn.Sequential): |
|
|
self.lm_head.pred.weight = new_embeddings |
|
|
else: |
|
|
self.lm_head.weight = new_embeddings |
|
|
|
|
|
def prepare_inputs_for_generation( |
|
|
self, input_ids, inputs_embeds=None, past_key_values=None, **kwargs |
|
|
): |
|
|
token_type_ids = kwargs.get("token_type_ids", None) |
|
|
attention_mask = kwargs.get("attention_mask", None) |
|
|
position_ids = kwargs.get("position_ids", None) |
|
|
use_cache = kwargs.get("use_cache", True) |
|
|
|
|
|
if past_key_values is not None and use_cache: |
|
|
past_length = past_key_values[0][0].shape[1] |
|
|
|
|
|
if input_ids.shape[1] > past_length: |
|
|
remove_prefix_len = past_length |
|
|
else: |
|
|
remove_prefix_len = input_ids.shape[1] - 1 |
|
|
|
|
|
input_ids = input_ids[:, remove_prefix_len:] |
|
|
|
|
|
if token_type_ids is not None: |
|
|
token_type_ids = token_type_ids[:, remove_prefix_len:] |
|
|
|
|
|
if attention_mask is not None and position_ids is None: |
|
|
|
|
|
position_ids = attention_mask.long().cumsum(-1) - 1 |
|
|
position_ids.masked_fill_(attention_mask == 0, 1) |
|
|
else: |
|
|
position_ids = None |
|
|
|
|
|
if inputs_embeds is not None: |
|
|
model_inputs = {"inputs_embeds": inputs_embeds} |
|
|
else: |
|
|
model_inputs = {"input_ids": input_ids} |
|
|
|
|
|
model_inputs.update( |
|
|
{ |
|
|
"position_ids": position_ids, |
|
|
"attention_mask": attention_mask, |
|
|
"token_type_ids": token_type_ids, |
|
|
"past_key_values": past_key_values, |
|
|
"use_cache": kwargs.get("use_cache", True), |
|
|
} |
|
|
) |
|
|
|
|
|
return model_inputs |
|
|
|
|
|
@staticmethod |
|
|
def _reorder_cache( |
|
|
past_key_values: Tuple[Tuple[torch.Tensor]], beam_idx: torch.Tensor |
|
|
) -> Tuple[Tuple[torch.Tensor]]: |
|
|
""" |
|
|
This function is used to re-order the `past_key_values` cache if [`~PreTrainedModel.beam_search`] or |
|
|
[`~PreTrainedModel.beam_sample`] is called. This is required to match `past_key_values` with the correct |
|
|
beam_idx at every generation step. |
|
|
""" |
|
|
return tuple( |
|
|
tuple( |
|
|
past_state.index_select(0, beam_idx.to(past_state.device)) |
|
|
for past_state in layer_past |
|
|
) |
|
|
for layer_past in past_key_values |
|
|
) |
|
|
|
|
|
def forward( |
|
|
self, |
|
|
input_ids: Optional[torch.Tensor] = None, |
|
|
attention_mask: Optional[torch.Tensor] = None, |
|
|
token_type_ids: Optional[torch.Tensor] = None, |
|
|
position_ids: Optional[torch.Tensor] = None, |
|
|
inputs_embeds: Optional[torch.Tensor] = None, |
|
|
labels: Optional[torch.Tensor] = None, |
|
|
organism: Optional[torch.Tensor] = None, |
|
|
past_key_values: Optional[Tuple[torch.Tensor]] = None, |
|
|
output_attentions: Optional[bool] = None, |
|
|
output_hidden_states: Optional[bool] = None, |
|
|
return_dict: Optional[bool] = None, |
|
|
use_cache: Optional[bool] = False, |
|
|
**kwargs, |
|
|
) -> Union[Tuple[torch.Tensor], DeCodonForPreTrainingOutput]: |
|
|
r""" |
|
|
labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*): |
|
|
Labels for computing the masked language modeling loss. Indices should be in `[-100, 0, ..., |
|
|
config.vocab_size]` (see `input_ids` docstring) Tokens with indices set to `-100` are ignored (masked), |
|
|
the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]` |
|
|
organism (`torch.LongTensor` of shape `(batch_size,)`, *optional*): |
|
|
Organism labels |
|
|
kwargs (`Dict[str, any]`, optional, defaults to *{}*): |
|
|
Used to hide legacy arguments that have been deprecated. |
|
|
|
|
|
Returns: |
|
|
|
|
|
Example: |
|
|
|
|
|
```python |
|
|
>>> from transformers import AutoTokenizer, BertForPreTraining |
|
|
>>> import torch |
|
|
|
|
|
>>> tokenizer = AutoTokenizer.from_pretrained("bertrans-base") |
|
|
>>> model = BERTransForPreTraining.from_pretrained("bertrans-base") |
|
|
|
|
|
>>> inputs = tokenizer("AAAAGGGGGGCCCCCCTTTTT", return_tensors="pt") |
|
|
>>> outputs = model(**inputs) |
|
|
|
|
|
>>> prediction_logits = outputs.prediction_logits |
|
|
>>> organism_logits = outputs.organism_logits |
|
|
>>> biotype_logits = outputs.biotype_logits |
|
|
``` |
|
|
""" |
|
|
return_dict = ( |
|
|
return_dict if return_dict is not None else self.config.use_return_dict |
|
|
) |
|
|
|
|
|
if input_ids is not None: |
|
|
batch_size, sequence_length = input_ids.shape[:2] |
|
|
else: |
|
|
batch_size, sequence_length = inputs_embeds.shape[:2] |
|
|
|
|
|
if self.config.pad_token_id is None: |
|
|
sequence_lengths = -1 |
|
|
else: |
|
|
if input_ids is not None: |
|
|
sequence_lengths = ( |
|
|
torch.eq(input_ids, self.config.pad_token_id).long().argmax(-1) - 1 |
|
|
).to(input_ids.device) |
|
|
else: |
|
|
sequence_lengths = -1 |
|
|
logger.warning( |
|
|
f"{self.__class__.__name__} will not detect padding tokens in `inputs_embeds`. Results may be " |
|
|
"unexpected if using padding tokens in conjunction with `inputs_embeds.`" |
|
|
) |
|
|
|
|
|
gpt_outputs = self.gpt( |
|
|
input_ids, |
|
|
attention_mask=attention_mask, |
|
|
token_type_ids=token_type_ids, |
|
|
position_ids=position_ids, |
|
|
inputs_embeds=inputs_embeds, |
|
|
past_key_values=past_key_values, |
|
|
output_attentions=output_attentions, |
|
|
output_hidden_states=output_hidden_states, |
|
|
return_dict=return_dict, |
|
|
use_cache=use_cache, |
|
|
) |
|
|
|
|
|
hidden_states = gpt_outputs[0] |
|
|
lm_logits = self.lm_head( |
|
|
hidden_states |
|
|
) |
|
|
|
|
|
loss = None |
|
|
if labels is not None: |
|
|
|
|
|
labels = labels.to(lm_logits.device) |
|
|
|
|
|
shift_logits = lm_logits[..., :-1, :].contiguous() |
|
|
shift_labels = labels[..., 1:].contiguous() |
|
|
|
|
|
loss_fct = nn.CrossEntropyLoss() |
|
|
lm_loss = loss_fct( |
|
|
shift_logits.view(-1, shift_logits.size(-1)), shift_labels.view(-1) |
|
|
) |
|
|
loss = lm_loss |
|
|
else: |
|
|
lm_loss = None |
|
|
|
|
|
if not return_dict: |
|
|
output = (lm_logits,) + gpt_outputs[1:] |
|
|
return ((loss,) + output) if loss is not None else output |
|
|
|
|
|
return DeCodonForPreTrainingOutput( |
|
|
loss=loss, |
|
|
logits=lm_logits, |
|
|
past_key_values=gpt_outputs.past_key_values, |
|
|
hidden_states=gpt_outputs.hidden_states, |
|
|
attentions=gpt_outputs.attentions, |
|
|
) |
|
|
|
|
|
def freeze(self, layer_indices: Optional[list] = None): |
|
|
if layer_indices is None or len(layer_indices) == 0: |
|
|
for param in self.gpt.parameters(): |
|
|
param.requires_grad = False |
|
|
else: |
|
|
for param in self.gpt.embeddings.parameters(): |
|
|
param.requires_grad = False |
|
|
|
|
|
if isinstance(layer_indices, int): |
|
|
layer_indices = [layer_indices] |
|
|
|
|
|
layer_indices = [i % len(self.gpt.decoder.blocks) for i in layer_indices] |
|
|
|
|
|
for i in range(len(self.gpt.decoder.blocks)): |
|
|
if i not in layer_indices: |
|
|
for param in self.gpt.decoder.blocks[i].parameters(): |
|
|
param.requires_grad = False |
|
|
|
|
|
|
|
|
|
|
|
class DeCodonForSequenceTask(DeCodonPreTrainedModel): |
|
|
def __init__(self, config): |
|
|
super().__init__(config) |
|
|
self.config = config |
|
|
|
|
|
self.gpt = DeCodonModule(config) |
|
|
|
|
|
if config.cls_type.lower() == "cls": |
|
|
layer_indices = config.layer_indices |
|
|
layer_indices = ( |
|
|
[] |
|
|
if layer_indices is None |
|
|
else ( |
|
|
[layer_indices] if isinstance(layer_indices, int) else layer_indices |
|
|
) |
|
|
) |
|
|
layer_indices = [i % len(self.gpt.decoder.blocks) for i in layer_indices] |
|
|
|
|
|
n_layers = len(layer_indices) |
|
|
self.layer_indices = layer_indices |
|
|
self.classifier = nn.Sequential( |
|
|
nn.LayerNorm(config.hidden_size * n_layers), |
|
|
nn.Linear(config.hidden_size * n_layers, config.hidden_size), |
|
|
ACT2CLS[config.cls_hidden_act](), |
|
|
nn.Dropout(config.cls_dropout_prob), |
|
|
nn.Linear( |
|
|
config.hidden_size, |
|
|
config.num_labels * config.num_tasks, |
|
|
), |
|
|
) |
|
|
else: |
|
|
raise ValueError(f"Invalid cls_type: {config.cls_type}.") |
|
|
|
|
|
self.init_weights() |
|
|
|
|
|
def freeze(self, layers_idx: Optional[list] = None): |
|
|
if layers_idx is None or len(layers_idx) == 0: |
|
|
for param in self.gpt.parameters(): |
|
|
param.requires_grad = False |
|
|
else: |
|
|
for param in self.gpt.embeddings.parameters(): |
|
|
param.requires_grad = False |
|
|
|
|
|
if isinstance(layers_idx, int): |
|
|
layers_idx = [layers_idx] |
|
|
|
|
|
layers_idx = [i % self.config.num_hidden_layers for i in layers_idx] |
|
|
|
|
|
for i in range(self.config.num_hidden_layers): |
|
|
if i not in layers_idx: |
|
|
for param in self.gpt.decoder.blocks[i].parameters(): |
|
|
param.requires_grad = False |
|
|
|
|
|
def forward( |
|
|
self, |
|
|
input_ids: Optional[torch.Tensor] = None, |
|
|
target: Optional[torch.Tensor] = None, |
|
|
attention_mask: Optional[torch.Tensor] = None, |
|
|
token_type_ids: Optional[torch.Tensor] = None, |
|
|
position_ids: Optional[torch.Tensor] = None, |
|
|
inputs_embeds: Optional[torch.Tensor] = None, |
|
|
output_attentions: Optional[bool] = None, |
|
|
output_hidden_states: Optional[bool] = None, |
|
|
return_dict: Optional[bool] = None, |
|
|
**kwargs, |
|
|
): |
|
|
return_dict = ( |
|
|
return_dict if return_dict is not None else self.config.use_return_dict |
|
|
) |
|
|
|
|
|
if input_ids is not None: |
|
|
batch_size, sequence_length = input_ids.shape[:2] |
|
|
else: |
|
|
batch_size, sequence_length = inputs_embeds.shape[:2] |
|
|
|
|
|
if self.config.pad_token_id is None: |
|
|
sequence_lengths = -1 |
|
|
else: |
|
|
if input_ids is not None: |
|
|
sequence_lengths = ( |
|
|
torch.eq(input_ids, self.config.pad_token_id).long().argmax(-1) - 1 |
|
|
).to( |
|
|
input_ids.device |
|
|
) |
|
|
else: |
|
|
sequence_lengths = -1 |
|
|
logger.warning( |
|
|
f"{self.__class__.__name__} will not detect padding tokens in `inputs_embeds`. Results may be " |
|
|
"unexpected if using padding tokens in conjunction with `inputs_embeds.`" |
|
|
) |
|
|
|
|
|
gpt_outputs = self.gpt( |
|
|
input_ids, |
|
|
attention_mask=attention_mask, |
|
|
token_type_ids=token_type_ids, |
|
|
position_ids=position_ids, |
|
|
inputs_embeds=inputs_embeds, |
|
|
output_attentions=output_attentions, |
|
|
output_hidden_states=True, |
|
|
return_dict=return_dict, |
|
|
) |
|
|
|
|
|
all_hidden_states = gpt_outputs.hidden_states |
|
|
|
|
|
if self.config.cls_type.lower() not in ["crossattention", "ca", "cls"]: |
|
|
logits, _ = self.classifier(all_hidden_states, attention_mask) |
|
|
elif self.config.cls_type.lower() in ["crossattention", "ca"]: |
|
|
bs, seq_len = input_ids.shape |
|
|
|
|
|
query_tasks = self.task_embeddings.weight |
|
|
query_tasks = query_tasks.unsqueeze(0).expand( |
|
|
bs, -1, -1 |
|
|
) |
|
|
|
|
|
cls_outputs = self.classifier( |
|
|
query_tasks, |
|
|
all_hidden_states, |
|
|
attention_mask, |
|
|
output_attentions=output_attentions, |
|
|
) |
|
|
|
|
|
logits, ca = cls_outputs |
|
|
|
|
|
logits = logits.squeeze() |
|
|
elif self.config.cls_type.lower() == "cls": |
|
|
bs, seq_len = input_ids.shape |
|
|
|
|
|
pooled_hidden_states = [ |
|
|
h[torch.arange(bs, device=h.device), sequence_lengths - 1, :] |
|
|
for i, h in enumerate(all_hidden_states) |
|
|
if i in self.layer_indices |
|
|
] |
|
|
pooled_output = torch.cat( |
|
|
pooled_hidden_states, dim=-1 |
|
|
) |
|
|
|
|
|
logits = self.classifier(pooled_output) |
|
|
|
|
|
loss = None |
|
|
if target is not None: |
|
|
if self.config.problem_type == "regression": |
|
|
logits = logits.view(-1, self.config.num_labels * self.config.num_tasks) |
|
|
target = target.view(-1, self.config.num_labels * self.config.num_tasks) |
|
|
|
|
|
mask = target != -500.0 |
|
|
|
|
|
if self.config.loss_fn == "mse": |
|
|
loss_fct = nn.MSELoss() |
|
|
loss = loss_fct(logits[mask], target[mask]) |
|
|
elif self.config.loss_fn == "mae": |
|
|
loss_fct = nn.L1Loss() |
|
|
loss = loss_fct(logits[mask], target[mask]) |
|
|
elif self.config.loss_fn == "huber": |
|
|
loss_fct = nn.SmoothL1Loss() |
|
|
loss = loss_fct(logits[mask], target[mask]) |
|
|
else: |
|
|
raise ValueError(f"Invalid loss_fn: {self.config.loss_fn}.") |
|
|
else: |
|
|
loss_fct = nn.CrossEntropyLoss() |
|
|
|
|
|
logits = logits.view(-1, self.config.num_labels * self.config.num_tasks) |
|
|
target = target.view( |
|
|
-1, |
|
|
) |
|
|
|
|
|
loss = loss_fct(logits, target) |
|
|
|
|
|
if not return_dict: |
|
|
output = (logits,) + gpt_outputs[2:] |
|
|
return ((loss,) + output) if loss is not None else output |
|
|
|
|
|
if output_attentions: |
|
|
if ca is not None: |
|
|
attentions = gpt_outputs.attentions + [ca] |
|
|
else: |
|
|
attentions = gpt_outputs.attentions |
|
|
else: |
|
|
attentions = None |
|
|
|
|
|
return SequenceClassifierOutput( |
|
|
loss=loss, |
|
|
logits=logits, |
|
|
hidden_states=pooled_output, |
|
|
attentions=attentions, |
|
|
) |