from typing import Optional, Tuple from dataclasses import dataclass import torch import torch.nn as nn from transformers.modeling_outputs import ( SequenceClassifierOutput, ) from typing import Optional, Tuple import torch import torch.utils.checkpoint from torch import nn from dataclasses import dataclass from transformers.activations import ACT2FN, ACT2CLS from transformers.modeling_utils import PreTrainedModel from transformers.utils import logging from transformers.modeling_outputs import BaseModelOutputWithPast, BaseModelOutput, CausalLMOutputWithPast from transformers.modeling_attn_mask_utils import _prepare_4d_causal_attention_mask import xformers.ops as xops from collections import OrderedDict logger = logging.get_logger(__name__) import torch import torch.nn as nn from einops import rearrange, einsum from transformers.pytorch_utils import Conv1D import torch from torch.amp import autocast from torch import nn, einsum, Tensor from einops import rearrange, repeat from typing import Optional, Union from .configuration_decodon import DeCodonConfig logger = logging.get_logger(__name__) def rotate_half(x): x = rearrange(x, "... (d r) -> ... d r", r=2) x1, x2 = x.unbind(dim=-1) x = torch.stack((-x2, x1), dim=-1) return rearrange(x, "... d r -> ... (d r)") @autocast(device_type="cuda", enabled=False) def apply_rotary_emb(freqs, t, start_index=0, scale=1.0): """ Applies rotary embeddings to a tensor. Parameters ---------- freqs : Tensor The frequencies to apply to the tensor: (seq_len, dim) t : Tensor The tensor to apply the rotary embeddings to: (..., seq_len, n_heads, dim) start_index : int The starting index to apply the rotary embeddings. (default: 0) scale : float The scale to apply to the rotary embeddings. (default: 1.0) Returns ------- Tensor The tensor with the rotary embeddings applied.: (..., seq_len, n_heads, dim) """ # if t.ndim == 3: # seq_len = t.shape[seq_dim] # freqs = freqs[-seq_len:].to(t) rot_dim = freqs.shape[-1] end_index = start_index + rot_dim assert ( rot_dim <= t.shape[-1] ), f"feature dimension {t.shape[-1]} is not of sufficient size to rotate in all the positions {rot_dim}" t_left, t, t_right = ( t[..., :start_index], t[..., start_index:end_index], t[..., end_index:], ) if isinstance(scale, float): scale = torch.tensor(scale, device=t.device, dtype=t.dtype) t = (t * freqs.cos() * scale) + (rotate_half(t) * freqs.sin() * scale) return torch.cat((t_left, t, t_right), dim=-1) # learned rotation helpers def apply_learned_rotations(rotations, t, start_index=0, freq_ranges=None): if freq_ranges is not None: rotations = einsum("..., f -> ... f", rotations, freq_ranges) rotations = rearrange(rotations, "... r f -> ... (r f)") rotations = repeat(rotations, "... n -> ... (n r)", r=2) return apply_rotary_emb(rotations, t, start_index=start_index) """ Inspired from https://github.com/lucidrains/rotary-embedding-torch """ class RotaryEmbedding(nn.Module): """ Rotary Embeddings Implemenetation inspired by https://github.com/lucidrains/rotary-embedding-torch. Rotary Positional Embeddings (RoPE) encode position information of tokens with a rotation matrix that naturally incorporates explicit relative position dependency. Parameters ---------- emb_dim : int Embedding dimension. Usually set to the dim of each head in the attention module. freqs : Optional[Tensor] Custom frequencies to apply to query/key tensors. (default: None) theta : float Base constant used for computing rotation angles. learned_freq : bool (default: False) Whether to learn the frequencies. use_xpos : bool (default: False) Whether to employ XPos technique for resolving length extrapolation issue. NOTE: This can only be enabled for autoregressive models like GPT. xpos_scale_base : int (default: 512) The base for the scale factor used in XPos technique. interpolate_factor : float (default: 1.0) Length interpolation factor for extending context length of the pretrained model. Final model's context length = pretrained_model_context_length * interpolate_factor. theta_rescale_factor : float (default: 1.0) The factor to rescale the theta. cache_if_possible : bool (default: True) Whether to cache the frequencies/scales if possible. """ def __init__( self, emb_dim, freqs: Optional[Tensor] = None, theta=1e4, learned_freq=False, use_xpos=False, xpos_scale_base=512, interpolate_factor=1.0, theta_rescale_factor=1.0, cache_if_possible=True, ): super().__init__() # proposed by reddit user bloc97, to rescale rotary embeddings to longer sequence length without fine-tuning # has some connection to NTK literature # https://www.reddit.com/r/LocalLLaMA/comments/14lz7j5/ntkaware_scaled_rope_allows_llama_models_to_have/ theta *= theta_rescale_factor ** (emb_dim / (emb_dim - 2)) if freqs is None: freqs = 1.0 / ( theta ** (torch.arange(0, emb_dim, 2)[: (emb_dim // 2)].float() / emb_dim) ) # freqs = torch.ones(num_freqs).float() self.cache_if_possible = cache_if_possible self.register_buffer("cached_freqs", None, persistent=False) self.register_buffer("cached_scales", None, persistent=False) self.freqs = nn.Parameter(freqs, requires_grad=learned_freq) self.learned_freq = learned_freq # interpolation factors assert interpolate_factor >= 1.0 self.interpolate_factor = interpolate_factor # xpos self.use_xpos = use_xpos if not use_xpos: self.register_buffer("scale", None, persistent=False) return scale = (torch.arange(0, emb_dim, 2) + 0.4 * emb_dim) / (1.4 * emb_dim) self.scale_base = xpos_scale_base self.register_buffer("scale", scale, persistent=False) @property def device(self): return self.freqs.device def rotate_queries_or_keys(self, t, offset=0, freq_seq_len=None, scale=None): """ Parameters ---------- t : Tensor tensor to rotate: (batch_size, seq_len, num_heads, head_dim) """ seq_len = t.shape[1] assert ( not self.use_xpos or scale is not None ), "you must use `.rotate_queries_and_keys` method instead and pass in both queries and keys, for length extrapolatable rotary embeddings" if freq_seq_len is not None: assert freq_seq_len >= seq_len seq_len = freq_seq_len seq = ( torch.arange(seq_len, device=t.device, dtype=t.dtype) + offset ) / self.interpolate_factor freqs = self.forward( seq, seq_len=seq_len, offset=offset, ).to(t.dtype) freqs = rearrange(freqs, "n d -> n 1 d") if scale is not None: scale = rearrange(scale, "n d -> n 1 d") if scale is None: scale = torch.tensor(1.0, device=t.device, dtype=t.dtype) return apply_rotary_emb(freqs, t, scale=scale) def rotate_queries_and_keys(self, q, k): """ Parameters ---------- q : Tensor queries tensor: (batch_size, seq_len, num_heads, head_dim) k : Tensor keys tensor: (batch_size, seq_len, num_heads, head_dim) """ assert self.use_xpos seq_len = q.shape[-3] seq = ( torch.arange(seq_len, device=q.device, dtype=q.dtype) ) / self.interpolate_factor freqs = self.forward(seq, seq_len=seq_len) scale = self.get_scale(seq, seq_len=seq_len) freqs = rearrange(freqs, "n d -> n 1 d") scale = rearrange(scale, "n d -> n 1 d") rotated_q = apply_rotary_emb(freqs, q, scale=scale) rotated_k = apply_rotary_emb(freqs, k, scale=scale**-1) rotated_q = rotated_q.type(q.dtype) rotated_k = rotated_k.type(k.dtype) return rotated_q, rotated_k def get_scale(self, t: Tensor, seq_len: Optional[int] = None, offset=0): assert self.use_xpos should_cache = self.cache_if_possible and seq_len is not None if ( should_cache and self.cached_scales is not None and (seq_len + offset) <= self.cached_scales.shape[0] ): return self.cached_scales[offset : (offset + seq_len)] scale = 1.0 if self.use_xpos: power = (t - len(t) // 2) / self.scale_base scale = self.scale ** rearrange(power, "n -> n 1") scale = torch.cat((scale, scale), dim=-1) if should_cache: self.register_buffer("cached_scales", scale, persistent=False) return scale def rotate_queries_with_cached_keys(self, q, k, offset=0): q_len, k_len = q.shape[1], k.shape[1] assert q_len <= k_len rotated_q, rotated_k = self.rotate_queries_and_keys(q, k) rotated_q = rotated_q[:, -1:, ...] return rotated_q, rotated_k seq = ( torch.arange(k_len, device=q.device, dtype=q.dtype) ) / self.interpolate_factor if self.use_xpos: q_scale = self.get_scale(seq[-q_len:]).to(q.dtype) k_scale = self.get_scale(seq).to(k.dtype) else: k_scale = 1.0 q_scale = 1.0 rotated_q = self.rotate_queries_or_keys( q, scale=q_scale, offset=k_len - q_len + offset ) rotated_k = self.rotate_queries_or_keys(k, scale=k_scale**-1) return rotated_q, rotated_k @autocast(device_type="cuda", enabled=False) def forward(self, t: Tensor, seq_len=None, offset=0): should_cache = ( self.cache_if_possible and not self.learned_freq and seq_len is not None ) if ( should_cache and self.cached_freqs is not None and (offset + seq_len) <= self.cached_freqs.shape[0] ): return self.cached_freqs[offset : (offset + seq_len)].detach() freqs = self.freqs freqs = einsum("..., f -> ... f", t, freqs) freqs = repeat(freqs, "... n -> ... (n r)", r=2) if should_cache: self.register_buffer("cached_freqs", freqs.detach(), persistent=False) return freqs class MultiHeadedSelfAttention(nn.Module): """ Multi-Headed Self Attention module supported with Flash Attention and Rotary Embeddings. Parameters ---------- q_input_dim: int The input dimension of the query tensor. kv_input_dim: int The input dimension of the key and value tensors. qk_proj_dim: int The projected dimension of the query and key tensors. v_proj_dim: int The projected dimension of the value tensors. num_heads: int Number of attention heads. dropout: float Dropout rate to apply to the attention scores. projection_layer: str The type of projection layer to use. Either 'linear' or 'conv'. Basically both are linear projections, but 'conv' uses Conv1D layer as proposed in the original GPT2 paper. use_flash_attn: bool Whether to use Flash Attention or not. If True, Flash Attention will be used. NOTE: Flash Attention is required to be installed. use_rotary_emb: bool Whether to use Rotary Embeddings or not. rotary_theta: int The base for the geometric progression used to compute the rotation angles. rotary_use_xpos: bool Whether to use XPos technique for resolving length extrapolation issue. NOTE: This can only be enabled for autoregressive models like GPT. """ def __init__( self, q_input_dim, kv_input_dim, qk_proj_dim, v_proj_dim, num_heads, dropout: float = 0.0, projection_layer: str = "linear", use_flash_attn: bool = True, use_rotary_emb: bool = False, rotary_theta: int = 1e4, rotary_use_xpos: bool = False, is_cross_attention: bool = False, **kwargs, ): super().__init__() assert ( qk_proj_dim % num_heads == 0 ), "qk_proj_dim must be divisible by num_heads" assert v_proj_dim % num_heads == 0, "v_proj_dim must be divisible by num_heads" self.num_heads = num_heads self.dropout_rate = dropout self.projection_layer = projection_layer self.use_rotary_emb = use_rotary_emb self.is_cross_attention = is_cross_attention if use_flash_attn and not is_cross_attention: try: from flash_attn import flash_attn_qkvpacked_func self.use_flash_attn = True self.flashattn_fn = flash_attn_qkvpacked_func except ImportError: print("flash_attn not installed, reverting to default attention") self.use_flash_attn = False self.flashattn_fn = None else: self.use_flash_attn = False self.flashattn_fn = None if self.projection_layer == "linear": self.query = nn.Linear(q_input_dim, qk_proj_dim) self.key = nn.Linear(kv_input_dim, qk_proj_dim) self.value = nn.Linear(kv_input_dim, v_proj_dim) elif self.projection_layer == "conv": self.query = Conv1D(qk_proj_dim, q_input_dim) self.key = Conv1D(qk_proj_dim, kv_input_dim) self.value = Conv1D(v_proj_dim, kv_input_dim) else: raise ValueError( f"projection_layer must be either 'linear' or 'conv', got {projection_layer}" ) if self.use_rotary_emb: self.rotary_emb = RotaryEmbedding( emb_dim=qk_proj_dim // num_heads // 2, theta=rotary_theta, use_xpos=rotary_use_xpos, ) self.dr_rate = dropout self.dropout = nn.Dropout(dropout) def forward( self, x_q, x_kv, is_causal=False, attention_bias=None, attention_mask=None, output_attentions=False, query=None, key=None, value=None, use_cache=False, ): """ Applies a classical self attention operation. Parameters ---------- x_q: torch.Tensor The query tensor of shape (batch_size, query_seq_len, emb_dim) x_kv: torch.Tensor The key/value tensor of shape (batch_size, kv_seq_len, emb_dim) attention_bias: torch.Tensor The attention bias to apply to the attention scores. (default: None) attention_mask: torch.Tensor The attention mask to apply to the attention scores. Shape: (batch_size, q_len, kv_seq_len) """ assert (x_q is not None and x_kv is not None) or ( query is not None and key is not None and value is not None ), "Either x_q and x_kv or query, key and value must be provided" past_memory_provided = ( query is not None and key is not None and value is not None ) if query is None: q_len = x_q.size(1) k_len = x_kv.size(1) query = self.query(x_q) key = self.key(x_kv) value = self.value(x_kv) else: q_len = query.size(1) k_len = key.size(1) if use_cache: cache = (key.clone(), value.clone(), query.clone()) q = rearrange(query, "b q (h d) -> b q h d", h=self.num_heads) k = rearrange(key, "b k (h d) -> b k h d", h=self.num_heads) v = rearrange(value, "b v (h d) -> b v h d", h=self.num_heads) if self.use_rotary_emb: if use_cache and past_memory_provided: q, k = self.rotary_emb.rotate_queries_with_cached_keys(q, k) if self.rotary_emb.use_xpos: q, k = self.rotary_emb.rotate_queries_and_keys(q, k) else: q = self.rotary_emb.rotate_queries_or_keys(q) k = self.rotary_emb.rotate_queries_or_keys(k) if ( self.use_flash_attn and not use_cache and not output_attentions and attention_bias is None ): qkv = torch.stack([q, k, v], dim=2).to(torch.bfloat16) x = self.flashattn_fn( qkv=qkv, dropout_p=self.dropout_rate if self.training else 0.0, causal=is_causal, deterministic=False, return_attn_probs=False, ) x = x.to(x_q.dtype) elif self.use_flash_attn and not output_attentions: attn_bias = xops.LowerTriangularMask() if is_causal else attention_bias if attention_mask is not None: if attn_bias is None: attn_bias = attention_mask else: if isinstance(attn_bias, torch.Tensor): attn_bias = attn_bias + attention_mask else: attn_bias.add_bias(bias=attention_mask) attn_bias = attn_bias.materialize( shape=(q_len, k_len), device=q.device, dtype=q.dtype, ) else: if isinstance(attn_bias, torch.Tensor) and len(attn_bias.shape) == 3: attn_bias = ( attn_bias.unsqueeze(1) .expand(-1, self.num_heads, -1, -1) .float() ) # (batch_size, num_heads, q_len, k_len) else: attn_bias = attn_bias.materialize( shape=(q_len, k_len), device=q.device, dtype=q.dtype, ) if isinstance(attn_bias, xops.LowerTriangularMask): attn_bias = attn_bias.materialize( shape=(q_len, k_len), device=q.device, dtype=q.dtype, ) # print(attention_mask.shape, attn_bias.shape) # print(attn_bias[0, 0, 0, :]) need_adjustment = False if attn_bias.shape[-2] % 8 != 0: nearest_multiple_q = 8 * (1 + attn_bias.shape[-2] // 8) need_adjustment = True else: nearest_multiple_q = attn_bias.shape[-2] if attn_bias.shape[-1] % 8 != 0: nearest_multiple_k = 8 * (1 + attn_bias.shape[-1] // 8) need_adjustment = True else: nearest_multiple_k = attn_bias.shape[-1] if need_adjustment: new_attn_bias = torch.zeros( attn_bias.shape[0], attn_bias.shape[1], nearest_multiple_q, nearest_multiple_k, ).to(attn_bias.device) new_attn_bias[:, :, : attn_bias.shape[-2], : attn_bias.shape[-1]] = ( attn_bias ) x = xops.memory_efficient_attention( query=q, key=k, value=v, op=None, attn_bias=new_attn_bias[:, :, :q_len, :k_len], p=self.dr_rate, ) else: attn_bias = attn_bias.to(q.dtype) attn_bias = attn_bias.repeat(1, self.num_heads, 1, 1) x = xops.memory_efficient_attention( query=q, key=k, value=v, op=None, attn_bias=attn_bias, p=self.dr_rate, ) # x: (batch_size, query_seq_len, n_head, head_dim) else: # if output_attentions: attention_scores = einsum(q, k, "b q h d, b k h d -> b h q k") attention_scores = attention_scores / (q.size(-1) ** 0.5) if attention_bias is not None: attn_bias = attention_bias.unsqueeze(1).expand( -1, self.num_heads, -1, -1 ) # elif is_causal: # attn_bias = xops.LowerTriangularMask().materialize( # shape=attention_scores.shape, device=attention_scores.device # ) else: attn_bias = None if attention_mask is not None: if attn_bias is None: attn_bias = attention_mask else: attn_bias = attn_bias + attention_mask attention_scores = attention_scores + attn_bias attention_probs = attention_scores.softmax(dim=-1) attention_probs = self.dropout(attention_probs) x = einsum(attention_probs, v, "b h q k, b k h d -> b q h d") x = rearrange(x, "b q h d -> b q (h d)", h=self.num_heads) if use_cache: if output_attentions: return x, attention_probs, cache else: return x, None, cache else: if output_attentions: return x, attention_probs else: return x, None class DeCodonPreTrainedModel(PreTrainedModel): """ An abstract class to handle weights initialization and a simple interface for downloading and loading pretrained models. """ base_model_prefix = "decodon" supports_gradient_checkpointing = True def _init_weights(self, module): """MAGNETO Initialize the weights""" if isinstance(module, nn.Linear): nn.init.xavier_normal_(module.weight, gain=self.config.gamma_init) if module.bias is not None: module.bias.data.zero_() elif isinstance(module, nn.Embedding): module.weight.data.normal_(mean=0.0, std=self.config.initializer_range) if module.padding_idx is not None: module.weight.data[module.padding_idx].zero_() elif isinstance(module, nn.LayerNorm): module.bias.data.zero_() module.weight.data.fill_(1.0) def _set_gradient_checkpointing(self, module, value=False): if isinstance(module, DeCodonLayer): module.gradient_checkpointing = value class DeCodonEmbeddings(nn.Module): """ DeCodon Embeddings Word, position and token type embeddings for DeCodon. """ def __init__(self, config): super().__init__() self.word_embeddings = nn.Embedding( config.vocab_size, config.hidden_size, padding_idx=config.pad_token_id ) self.position_embeddings = nn.Embedding( config.max_position_embeddings, config.hidden_size ) self.token_type_embeddings = nn.Embedding( config.type_vocab_size, config.hidden_size ) self.ln = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps) self.dropout = nn.Dropout(config.hidden_dropout_prob) self.position_embedding_type = getattr( config, "position_embedding_type", "absolute" ) self.register_buffer( "position_ids", torch.arange(config.max_position_embeddings).expand((1, -1)), persistent=False, ) self.register_buffer( "token_type_ids", torch.zeros(self.position_ids.size(), dtype=torch.long), persistent=False, ) def forward( self, input_ids: Optional[torch.LongTensor] = None, token_type_ids: Optional[torch.LongTensor] = None, position_ids: Optional[torch.LongTensor] = None, inputs_embeds: Optional[torch.FloatTensor] = None, past_key_values_length: int = 0, ) -> torch.Tensor: if input_ids is not None: input_shape = input_ids.size() else: input_shape = inputs_embeds.size()[:-1] seq_length = input_shape[1] if position_ids is None: position_ids = self.position_ids[ :, past_key_values_length : seq_length + past_key_values_length ] # Setting the token_type_ids to the registered buffer in constructor where it is all zeros, which usually occurs # when its auto-generated, registered buffer helps users when tracing the model without passing token_type_ids, solves # issue #5664 if token_type_ids is None: if hasattr(self, "token_type_ids"): buffered_token_type_ids = self.token_type_ids[:, :seq_length] buffered_token_type_ids_expanded = buffered_token_type_ids.expand( input_shape[0], seq_length ) token_type_ids = buffered_token_type_ids_expanded else: token_type_ids = torch.zeros( input_shape, dtype=torch.long, device=self.position_ids.device ) if inputs_embeds is None: inputs_embeds = self.word_embeddings(input_ids) token_type_embeddings = self.token_type_embeddings(token_type_ids) embeddings = inputs_embeds + token_type_embeddings if self.position_embedding_type == "absolute": position_embeddings = self.position_embeddings(position_ids) embeddings += position_embeddings # embeddings = self.ln(embeddings) embeddings = self.dropout(embeddings) return embeddings class DeCodonAttention(nn.Module): """ DeCodon Attention Layer This module supports self-attention and dilated attention with Rotary Positional Embeddings (RoPE). """ def __init__(self, config): super().__init__() self.pre_layer_norm = nn.LayerNorm( config.hidden_size, eps=config.layer_norm_eps ) self.post_attn_dense = nn.Linear(config.hidden_size, config.hidden_size) self.post_layer_norm = nn.LayerNorm( config.hidden_size, eps=config.layer_norm_eps ) self.dropout = nn.Dropout(config.hidden_dropout_prob) self.self_attention = MultiHeadedSelfAttention( q_input_dim=config.hidden_size, kv_input_dim=config.hidden_size, qk_proj_dim=config.hidden_size, v_proj_dim=config.hidden_size, num_heads=config.num_attention_heads, dropout=config.attention_probs_dropout_prob, projection_layer="conv", use_flash_attn=config.use_flash_attn, use_rotary_emb=config.use_rotary_emb, rotary_theta=config.rotary_theta, rotary_use_xpos=True, ) def forward( self, hidden_states: Optional[Tuple[torch.FloatTensor]], attention_mask: Optional[torch.FloatTensor] = None, output_attentions: Optional[bool] = False, past_key_values: Optional[Tuple[torch.FloatTensor]] = None, use_cache: Optional[bool] = False, ) -> Tuple[Union[torch.Tensor, Tuple[torch.Tensor]], ...]: attn_input = self.pre_layer_norm(hidden_states) if past_key_values is not None: query = self.self_attention.query(attn_input) key = self.self_attention.key(attn_input) value = self.self_attention.value(attn_input) past_key, past_value, past_query = past_key_values # past_new_query = query[:, :-1, :] # past_new_key = key[:, :-1, :] # past_new_value = value[:, :-1, :] # print( # (past_new_query[0] != past_query[0]).sum(), # past_new_query.size(), # past_new_query[past_new_query != past_query].cpu().numpy(), # past_query[past_new_query != past_query].cpu().numpy(), # past_query.sum().item(), # ) # print( # (past_new_key[0] == past_key[0]).sum(), # past_new_key.size(), # # past_new_key[0, 0, :1024], # # past_key[0, 0, :1024], # past_new_key[past_new_key != past_key].cpu().numpy(), # past_key[past_new_key != past_key].cpu().numpy(), # past_key.sum().item(), # ) # print( # (past_new_value[0] == past_value[0]).sum(), # past_new_value.size(), # # past_new_value[0, 0, :1024], # # past_value[0, 0, :1024], # past_new_value[past_new_value != past_value].cpu().numpy(), # past_value[past_new_value != past_value].cpu().numpy(), # past_value.sum().item(), # ) # print(query.shape, key.shape, value.shape) # print(past_query.shape, past_key.shape, past_value.shape) key = torch.cat( (past_key, key), dim=1 ) # (batch_size, seq_len, hidden_size) value = torch.cat( (past_value, value), dim=1 ) # (batch_size, seq_len, hidden_size) query = torch.cat((past_query, query), dim=1) # print(query.shape, key.shape, value.shape) # print() attn_outputs = self.self_attention( x_q=None, x_kv=None, query=query, key=key, value=value, is_causal=True, attention_mask=attention_mask, output_attentions=output_attentions, use_cache=use_cache, attention_bias=None, ) else: attn_outputs = self.self_attention( x_q=attn_input, x_kv=attn_input, is_causal=True, attention_bias=None, attention_mask=attention_mask, output_attentions=output_attentions, use_cache=use_cache, ) attn_output = attn_outputs[0] attn_output = self.post_layer_norm(attn_output) attn_output = self.post_attn_dense(attn_output) attn_output = self.dropout(attn_output) attn_output = hidden_states + attn_output return (attn_output,) + attn_outputs[1:] class DeCodonFFN(nn.Module): """ DeCodon Position-wise Feed-Forward Network """ def __init__(self, config): super().__init__() embed_dim = config.hidden_size self.pre_layer_norm = nn.LayerNorm( config.hidden_size, eps=config.layer_norm_eps ) self.intermediate_dense = Conv1D(config.intermediate_size, embed_dim) self.post_layer_norm = nn.LayerNorm( config.intermediate_size, eps=config.layer_norm_eps ) self.post_dense = Conv1D(embed_dim, config.intermediate_size) self.dropout = nn.Dropout(config.hidden_dropout_prob) if isinstance(config.hidden_act, str): self.intermediate_act_fn = ACT2FN[config.hidden_act] else: self.intermediate_act_fn = config.hidden_act def forward( self, hidden_states: Optional[Tuple[torch.FloatTensor]] ) -> torch.FloatTensor: hidden_states = self.pre_layer_norm(hidden_states) hidden_states = self.intermediate_dense(hidden_states) hidden_states = self.intermediate_act_fn(hidden_states) hidden_states = self.post_layer_norm(hidden_states) hidden_states = self.post_dense(hidden_states) hidden_states = self.dropout(hidden_states) return hidden_states class DeCodonLayer(nn.Module): """ DeCodon (Decoder) Layer consists of an attention layer and a position-wise feed-forward network. """ def __init__(self, config): super().__init__() self.attention = DeCodonAttention(config) self.output = DeCodonFFN(config) def forward( self, hidden_states: Optional[Tuple[torch.FloatTensor]], attention_mask: Optional[torch.FloatTensor] = None, output_attentions: Optional[bool] = False, past_key_values: Optional[Tuple[torch.FloatTensor]] = None, use_cache: Optional[bool] = False, ) -> Union[ Tuple[torch.Tensor], Optional[Tuple[torch.Tensor, Tuple[torch.FloatTensor, ...]]], ]: self_attention_outputs = self.attention( hidden_states, attention_mask, output_attentions=output_attentions, past_key_values=past_key_values, use_cache=use_cache, ) attention_output = self_attention_outputs[0] outputs = self_attention_outputs[ 1: ] # add self attentions if we output attention weights layer_output = self.output(attention_output) outputs = (layer_output,) + outputs return outputs class DeCodonStack(nn.Module): """ DeCodon Stack consists of multiple DeCodon layers. """ def __init__(self, config): super().__init__() self.config = config self.blocks = nn.ModuleList( [DeCodonLayer(config) for _ in range(config.num_hidden_layers)] ) self.gradient_checkpointing = False def forward( self, hidden_states: torch.Tensor, attention_mask: Optional[torch.FloatTensor] = None, past_key_values: Optional[Tuple[torch.FloatTensor]] = None, output_attentions: Optional[bool] = False, output_hidden_states: Optional[bool] = False, return_dict: Optional[bool] = True, use_cache: Optional[bool] = False, ) -> Union[Tuple[torch.Tensor], BaseModelOutput]: if past_key_values is None: past_key_values = [None] * len(self.blocks) past_length = 0 else: past_length = past_key_values[0][0].size(-2) all_hidden_states = () if output_hidden_states else None all_self_attentions = () if output_attentions else None presents = () if use_cache else None for i, (block, past_key_value) in enumerate(zip(self.blocks, past_key_values)): if output_hidden_states: all_hidden_states = all_hidden_states + (hidden_states,) block_outputs = block( hidden_states=hidden_states, attention_mask=attention_mask, output_attentions=output_attentions, past_key_values=past_key_value, use_cache=use_cache, ) hidden_states = block_outputs[0] if use_cache: presents = presents + (block_outputs[2],) if output_attentions: all_self_attentions = all_self_attentions + (block_outputs[1],) if output_hidden_states: all_hidden_states = all_hidden_states + (hidden_states,) if not return_dict: return tuple( v for v in [ hidden_states, presents, all_hidden_states, all_self_attentions, ] if v is not None ) return BaseModelOutputWithPast( last_hidden_state=hidden_states, past_key_values=presents, hidden_states=all_hidden_states, attentions=all_self_attentions, ) class DeCodonModule(DeCodonPreTrainedModel): """ The DeCodon Module (Decoder only) without any task-specific head on top. """ def __init__(self, config): super().__init__(config) self.embeddings = DeCodonEmbeddings(config) self.decoder = DeCodonStack(config) self.ln_f = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps) self.gradient_checkpointing = False # Initialize weights and apply final processing self.post_init() def set_input_embeddings(self, new_embeddings): self.embeddings.word_embeddings = new_embeddings def forward( self, input_ids: Optional[torch.LongTensor] = None, attention_mask: Optional[torch.FloatTensor] = None, token_type_ids: Optional[torch.LongTensor] = None, position_ids: Optional[torch.LongTensor] = None, inputs_embeds: Optional[torch.FloatTensor] = None, past_key_values: Optional[Tuple[torch.FloatTensor]] = None, output_attentions: Optional[bool] = None, output_hidden_states: Optional[bool] = None, return_dict: Optional[bool] = None, use_cache: Optional[bool] = False, ) -> Union[Tuple, BaseModelOutput]: output_attentions = ( output_attentions if output_attentions is not None else self.config.output_attentions ) output_hidden_states = ( output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states ) return_dict = ( return_dict if return_dict is not None else self.config.use_return_dict ) if input_ids is not None and inputs_embeds is not None: raise ValueError( "You cannot specify both input_ids and inputs_embeds at the same time" ) elif input_ids is not None: self.warn_if_padding_and_no_attention_mask(input_ids, attention_mask) input_shape = input_ids.size() elif inputs_embeds is not None: input_shape = inputs_embeds.size()[:-1] else: raise ValueError("You have to specify either input_ids or inputs_embeds") if past_key_values is not None: past_length = past_key_values[0][0].size(-2) else: past_length = 0 batch_size, seq_length = input_shape device = input_ids.device if input_ids is not None else inputs_embeds.device if attention_mask is None: attention_mask = torch.ones(((batch_size, seq_length)), device=device) if token_type_ids is None: if hasattr(self.embeddings, "token_type_ids"): buffered_token_type_ids = self.embeddings.token_type_ids[:, :seq_length] buffered_token_type_ids_expanded = buffered_token_type_ids.expand( batch_size, seq_length ) token_type_ids = buffered_token_type_ids_expanded else: token_type_ids = torch.zeros( input_shape, dtype=torch.long, device=device ) # We can provide a self-attention mask of dimensions [batch_size, from_seq_length, to_seq_length] # ourselves in which case we just need to make it broadcastable to all heads. # extended_attention_mask: torch.Tensor = self.get_extended_attention_mask( # attention_mask, input_shape # ) embedding_output = self.embeddings( input_ids=input_ids, position_ids=position_ids, token_type_ids=token_type_ids, inputs_embeds=inputs_embeds, ) extended_attention_mask = _prepare_4d_causal_attention_mask( attention_mask=attention_mask, input_shape=(batch_size, input_shape[-1]), inputs_embeds=embedding_output, past_key_values_length=past_length, ) # extended_attention_mask = attention_mask decoder_outputs = self.decoder( embedding_output, attention_mask=extended_attention_mask, output_attentions=output_attentions, output_hidden_states=output_hidden_states, past_key_values=past_key_values, return_dict=return_dict, use_cache=use_cache, ) sequence_output = decoder_outputs[0] if not return_dict: return (sequence_output,) + decoder_outputs[1:] return BaseModelOutputWithPast( last_hidden_state=sequence_output, past_key_values=decoder_outputs.past_key_values, hidden_states=decoder_outputs.hidden_states, attentions=decoder_outputs.attentions, ) @dataclass class DeCodonForPreTrainingOutput(CausalLMOutputWithPast): """ Output type of [`BERTransForPreTraining`]. Args: loss (*optional*, returned when `labels` is provided, `torch.FloatTensor` of shape `(1,)`): Total loss as the sum of the masked language modeling loss and the next sequence prediction (classification) loss. logits (`torch.FloatTensor` of shape `(batch_size, sequence_length, config.vocab_size)`): Prediction scores of the language modeling head (scores for each vocabulary token before SoftMax). org_logits (`torch.FloatTensor` of shape `(batch_size, 1)`): Prediction scores for organism classification (scores for each organism label before SoftMax). hidden_states (`tuple(torch.FloatTensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`): Tuple of `torch.FloatTensor` (one for the output of the embeddings + one for the output of each layer) of shape `(batch_size, sequence_length, hidden_size)`. Hidden-states of the model at the output of each layer plus the initial embedding outputs. attentions (`tuple(torch.FloatTensor)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`): Tuple of `torch.FloatTensor` (one for each layer) of shape `(batch_size, num_heads, sequence_length, sequence_length)`. Attentions weights after the attention softmax, used to compute the weighted average in the self-attention heads. """ loss: Optional[torch.FloatTensor] = None logits: torch.FloatTensor = None past_key_values: Optional[Tuple[torch.FloatTensor]] = None hidden_states: Optional[Tuple[torch.FloatTensor]] = None attentions: Optional[Tuple[torch.FloatTensor]] = None class DeCodon(DeCodonPreTrainedModel): config_class = DeCodonConfig _tied_weights_keys = [] def __init__(self, config): super().__init__(config) self.gpt = DeCodonModule(config) # causal language modeling head if config.lm_type == "gpt": self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False) DeCodon._tied_weights_keys.append("lm_head.weight") else: self.lm_head = nn.Sequential( OrderedDict( [ ("dropout", nn.Dropout(config.hidden_dropout_prob)), ( "transform", nn.Linear(config.hidden_size, config.hidden_size), ), ("act", nn.ReLU()), ( "norm", nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps), ), ( "pred", nn.Linear( config.hidden_size, config.vocab_size, bias=False ), ), ] ) ) DeCodon._tied_weights_keys.append("lm_head.pred.weight") # Initialize weights and apply final processing self.post_init() def get_input_embeddings(self): return self.gpt.embeddings.word_embeddings def get_output_embeddings(self): return ( self.lm_head.pred.weight if isinstance(self.lm_head, nn.Sequential) else self.lm_head.weight if self.config.lm_type == "gpt" else None ) def set_output_embeddings(self, new_embeddings): if isinstance(self.lm_head, nn.Sequential): self.lm_head.pred.weight = new_embeddings else: self.lm_head.weight = new_embeddings def prepare_inputs_for_generation( self, input_ids, inputs_embeds=None, past_key_values=None, **kwargs ): token_type_ids = kwargs.get("token_type_ids", None) attention_mask = kwargs.get("attention_mask", None) position_ids = kwargs.get("position_ids", None) use_cache = kwargs.get("use_cache", True) if past_key_values is not None and use_cache: past_length = past_key_values[0][0].shape[1] if input_ids.shape[1] > past_length: remove_prefix_len = past_length else: remove_prefix_len = input_ids.shape[1] - 1 input_ids = input_ids[:, remove_prefix_len:] if token_type_ids is not None: token_type_ids = token_type_ids[:, remove_prefix_len:] if attention_mask is not None and position_ids is None: # create position_ids on the fly for batch generation position_ids = attention_mask.long().cumsum(-1) - 1 position_ids.masked_fill_(attention_mask == 0, 1) else: position_ids = None if inputs_embeds is not None: model_inputs = {"inputs_embeds": inputs_embeds} else: model_inputs = {"input_ids": input_ids} model_inputs.update( { "position_ids": position_ids, "attention_mask": attention_mask, "token_type_ids": token_type_ids, "past_key_values": past_key_values, "use_cache": kwargs.get("use_cache", True), } ) return model_inputs @staticmethod def _reorder_cache( past_key_values: Tuple[Tuple[torch.Tensor]], beam_idx: torch.Tensor ) -> Tuple[Tuple[torch.Tensor]]: """ This function is used to re-order the `past_key_values` cache if [`~PreTrainedModel.beam_search`] or [`~PreTrainedModel.beam_sample`] is called. This is required to match `past_key_values` with the correct beam_idx at every generation step. """ return tuple( tuple( past_state.index_select(0, beam_idx.to(past_state.device)) for past_state in layer_past ) for layer_past in past_key_values ) def forward( self, input_ids: Optional[torch.Tensor] = None, attention_mask: Optional[torch.Tensor] = None, token_type_ids: Optional[torch.Tensor] = None, position_ids: Optional[torch.Tensor] = None, inputs_embeds: Optional[torch.Tensor] = None, labels: Optional[torch.Tensor] = None, organism: Optional[torch.Tensor] = None, past_key_values: Optional[Tuple[torch.Tensor]] = None, output_attentions: Optional[bool] = None, output_hidden_states: Optional[bool] = None, return_dict: Optional[bool] = None, use_cache: Optional[bool] = False, **kwargs, ) -> Union[Tuple[torch.Tensor], DeCodonForPreTrainingOutput]: r""" labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*): Labels for computing the masked language modeling loss. Indices should be in `[-100, 0, ..., config.vocab_size]` (see `input_ids` docstring) Tokens with indices set to `-100` are ignored (masked), the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]` organism (`torch.LongTensor` of shape `(batch_size,)`, *optional*): Organism labels kwargs (`Dict[str, any]`, optional, defaults to *{}*): Used to hide legacy arguments that have been deprecated. Returns: Example: ```python >>> from transformers import AutoTokenizer, BertForPreTraining >>> import torch >>> tokenizer = AutoTokenizer.from_pretrained("bertrans-base") >>> model = BERTransForPreTraining.from_pretrained("bertrans-base") >>> inputs = tokenizer("AAAAGGGGGGCCCCCCTTTTT", return_tensors="pt") >>> outputs = model(**inputs) >>> prediction_logits = outputs.prediction_logits >>> organism_logits = outputs.organism_logits >>> biotype_logits = outputs.biotype_logits ``` """ return_dict = ( return_dict if return_dict is not None else self.config.use_return_dict ) if input_ids is not None: batch_size, sequence_length = input_ids.shape[:2] else: batch_size, sequence_length = inputs_embeds.shape[:2] if self.config.pad_token_id is None: sequence_lengths = -1 else: if input_ids is not None: sequence_lengths = ( torch.eq(input_ids, self.config.pad_token_id).long().argmax(-1) - 1 ).to(input_ids.device) else: sequence_lengths = -1 logger.warning( f"{self.__class__.__name__} will not detect padding tokens in `inputs_embeds`. Results may be " "unexpected if using padding tokens in conjunction with `inputs_embeds.`" ) gpt_outputs = self.gpt( input_ids, attention_mask=attention_mask, token_type_ids=token_type_ids, position_ids=position_ids, inputs_embeds=inputs_embeds, past_key_values=past_key_values, output_attentions=output_attentions, output_hidden_states=output_hidden_states, return_dict=return_dict, use_cache=use_cache, ) hidden_states = gpt_outputs[0] # (batch_size, sequence_length, hidden_size) lm_logits = self.lm_head( hidden_states ) # (batch_size, sequence_length, vocab_size) loss = None if labels is not None: # move labels to correct device to enable model parallelism labels = labels.to(lm_logits.device) # Shift so that tokens < n predict n shift_logits = lm_logits[..., :-1, :].contiguous() shift_labels = labels[..., 1:].contiguous() # Flatten the tokens loss_fct = nn.CrossEntropyLoss() lm_loss = loss_fct( shift_logits.view(-1, shift_logits.size(-1)), shift_labels.view(-1) ) loss = lm_loss else: lm_loss = None if not return_dict: output = (lm_logits,) + gpt_outputs[1:] return ((loss,) + output) if loss is not None else output return DeCodonForPreTrainingOutput( loss=loss, logits=lm_logits, past_key_values=gpt_outputs.past_key_values, hidden_states=gpt_outputs.hidden_states, attentions=gpt_outputs.attentions, ) def freeze(self, layer_indices: Optional[list] = None): if layer_indices is None or len(layer_indices) == 0: for param in self.gpt.parameters(): param.requires_grad = False else: for param in self.gpt.embeddings.parameters(): param.requires_grad = False if isinstance(layer_indices, int): layer_indices = [layer_indices] layer_indices = [i % len(self.gpt.decoder.blocks) for i in layer_indices] for i in range(len(self.gpt.decoder.blocks)): if i not in layer_indices: for param in self.gpt.decoder.blocks[i].parameters(): param.requires_grad = False class DeCodonForSequenceTask(DeCodonPreTrainedModel): def __init__(self, config): super().__init__(config) self.config = config self.gpt = DeCodonModule(config) if config.cls_type.lower() == "cls": layer_indices = config.layer_indices layer_indices = ( [] if layer_indices is None else ( [layer_indices] if isinstance(layer_indices, int) else layer_indices ) ) layer_indices = [i % len(self.gpt.decoder.blocks) for i in layer_indices] n_layers = len(layer_indices) self.layer_indices = layer_indices self.classifier = nn.Sequential( nn.LayerNorm(config.hidden_size * n_layers), nn.Linear(config.hidden_size * n_layers, config.hidden_size), ACT2CLS[config.cls_hidden_act](), nn.Dropout(config.cls_dropout_prob), nn.Linear( config.hidden_size, config.num_labels * config.num_tasks, ), ) else: raise ValueError(f"Invalid cls_type: {config.cls_type}.") self.init_weights() def freeze(self, layers_idx: Optional[list] = None): if layers_idx is None or len(layers_idx) == 0: for param in self.gpt.parameters(): param.requires_grad = False else: for param in self.gpt.embeddings.parameters(): param.requires_grad = False if isinstance(layers_idx, int): layers_idx = [layers_idx] layers_idx = [i % self.config.num_hidden_layers for i in layers_idx] for i in range(self.config.num_hidden_layers): if i not in layers_idx: for param in self.gpt.decoder.blocks[i].parameters(): param.requires_grad = False def forward( self, input_ids: Optional[torch.Tensor] = None, target: Optional[torch.Tensor] = None, attention_mask: Optional[torch.Tensor] = None, token_type_ids: Optional[torch.Tensor] = None, position_ids: Optional[torch.Tensor] = None, inputs_embeds: Optional[torch.Tensor] = None, output_attentions: Optional[bool] = None, output_hidden_states: Optional[bool] = None, return_dict: Optional[bool] = None, **kwargs, ): return_dict = ( return_dict if return_dict is not None else self.config.use_return_dict ) if input_ids is not None: batch_size, sequence_length = input_ids.shape[:2] else: batch_size, sequence_length = inputs_embeds.shape[:2] if self.config.pad_token_id is None: sequence_lengths = -1 else: if input_ids is not None: sequence_lengths = ( torch.eq(input_ids, self.config.pad_token_id).long().argmax(-1) - 1 ).to( input_ids.device ) # (batch_size,) else: sequence_lengths = -1 logger.warning( f"{self.__class__.__name__} will not detect padding tokens in `inputs_embeds`. Results may be " "unexpected if using padding tokens in conjunction with `inputs_embeds.`" ) gpt_outputs = self.gpt( input_ids, attention_mask=attention_mask, token_type_ids=token_type_ids, position_ids=position_ids, inputs_embeds=inputs_embeds, output_attentions=output_attentions, output_hidden_states=True, return_dict=return_dict, ) all_hidden_states = gpt_outputs.hidden_states if self.config.cls_type.lower() not in ["crossattention", "ca", "cls"]: logits, _ = self.classifier(all_hidden_states, attention_mask) elif self.config.cls_type.lower() in ["crossattention", "ca"]: bs, seq_len = input_ids.shape query_tasks = self.task_embeddings.weight # (num_tasks, hidden_size) query_tasks = query_tasks.unsqueeze(0).expand( bs, -1, -1 ) # (batch_size, num_tasks, hidden_size) cls_outputs = self.classifier( query_tasks, all_hidden_states, attention_mask, output_attentions=output_attentions, ) # (batch_size, num_tasks, num_labels) logits, ca = cls_outputs logits = logits.squeeze() elif self.config.cls_type.lower() == "cls": bs, seq_len = input_ids.shape # here we select latest token's hidden states as pooled output pooled_hidden_states = [ h[torch.arange(bs, device=h.device), sequence_lengths - 1, :] for i, h in enumerate(all_hidden_states) if i in self.layer_indices ] pooled_output = torch.cat( pooled_hidden_states, dim=-1 ) # (batch_size, hidden_size * n_layers) logits = self.classifier(pooled_output) loss = None if target is not None: if self.config.problem_type == "regression": logits = logits.view(-1, self.config.num_labels * self.config.num_tasks) target = target.view(-1, self.config.num_labels * self.config.num_tasks) mask = target != -500.0 if self.config.loss_fn == "mse": loss_fct = nn.MSELoss() loss = loss_fct(logits[mask], target[mask]) elif self.config.loss_fn == "mae": loss_fct = nn.L1Loss() loss = loss_fct(logits[mask], target[mask]) elif self.config.loss_fn == "huber": loss_fct = nn.SmoothL1Loss() loss = loss_fct(logits[mask], target[mask]) else: raise ValueError(f"Invalid loss_fn: {self.config.loss_fn}.") else: loss_fct = nn.CrossEntropyLoss() logits = logits.view(-1, self.config.num_labels * self.config.num_tasks) target = target.view( -1, ) loss = loss_fct(logits, target) if not return_dict: output = (logits,) + gpt_outputs[2:] return ((loss,) + output) if loss is not None else output if output_attentions: if ca is not None: attentions = gpt_outputs.attentions + [ca] else: attentions = gpt_outputs.attentions else: attentions = None return SequenceClassifierOutput( loss=loss, logits=logits, hidden_states=pooled_output, attentions=attentions, )