Tiger14n's picture
Upload folder using huggingface_hub
49235ad verified
# coding=utf-8
# Copyright 2018 Mesh TensorFlow authors, T5 Authors and HuggingFace Inc. team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
# Modified by nanoT5 authors
# https://github.com/PiotrNawrot/nanoT5/blob/main/nanoT5/utils/t5_model.py
import copy
import math
from typing import Optional
from dataclasses import dataclass
import torch
from torch import nn
from torch.nn import CrossEntropyLoss
from transformers.modeling_utils import ModuleUtilsMixin
from transformers.modeling_outputs import ModelOutput
from transformers.models.t5.configuration_t5 import T5Config
from transformers.models.t5.modeling_t5 import (
T5LayerNorm,
T5DenseGatedActDense,
)
from .spectrogram import MelSpectrogram
@dataclass
class EncoderOutput(ModelOutput):
hidden_states: torch.FloatTensor = None
attention_mask: torch.FloatTensor = None
@dataclass
class Seq2SeqLMOutput(ModelOutput):
loss: torch.FloatTensor = None
logits: torch.FloatTensor = None
encoder_outputs: EncoderOutput = None
class T5LayerFF(nn.Module):
def __init__(self, config: T5Config):
super().__init__()
assert config.is_gated_act
self.DenseReluDense = T5DenseGatedActDense(config)
self.layer_norm = T5LayerNorm(config.d_model, eps=config.layer_norm_epsilon)
self.dropout = nn.Dropout(config.dropout_rate)
def forward(self, hidden_states):
forwarded_states = self.layer_norm(hidden_states).type_as(hidden_states)
forwarded_states = self.DenseReluDense(forwarded_states)
hidden_states = hidden_states + self.dropout(forwarded_states)
return hidden_states
class T5Attention(nn.Module):
def __init__(self, config: T5Config, has_relative_attention_bias=False):
super().__init__()
self.is_decoder = config.is_decoder
self.has_relative_attention_bias = has_relative_attention_bias
self.relative_attention_num_buckets = config.relative_attention_num_buckets
self.relative_attention_max_distance = config.relative_attention_max_distance
self.d_model = config.d_model
self.key_value_proj_dim = config.d_kv
self.n_heads = config.num_heads
self.dropout = config.dropout_rate
self.inner_dim = self.n_heads * self.key_value_proj_dim
self.q = nn.Linear(self.d_model, self.inner_dim, bias=False)
self.k = nn.Linear(self.d_model, self.inner_dim, bias=False)
self.v = nn.Linear(self.d_model, self.inner_dim, bias=False)
self.o = nn.Linear(self.inner_dim, self.d_model, bias=False)
if self.has_relative_attention_bias:
self.relative_attention_bias = nn.Embedding(
self.relative_attention_num_buckets, self.n_heads
)
@staticmethod
def _relative_position_bucket(
relative_position, bidirectional=True, num_buckets=32, max_distance=128
):
"""
Adapted from Mesh Tensorflow:
https://github.com/tensorflow/mesh/blob/0cb87fe07da627bf0b7e60475d59f95ed6b5be3d/mesh_tensorflow/transformer/transformer_layers.py#L593
Translate relative position to a bucket number for relative attention. The relative position is defined as
memory_position - query_position, i.e. the distance in tokens from the attending position to the attended-to
position. If bidirectional=False, then positive relative positions are invalid. We use smaller buckets for
small absolute relative_position and larger buckets for larger absolute relative_positions. All relative
positions >=max_distance map to the same bucket. All relative positions <=-max_distance map to the same bucket.
This should allow for more graceful generalization to longer sequences than the model has been trained on
Args:
relative_position: an int32 Tensor
bidirectional: a boolean - whether the attention is bidirectional
num_buckets: an integer
max_distance: an integer
Returns:
a Tensor with the same shape as relative_position, containing int32 values in the range [0, num_buckets)
"""
relative_buckets = 0
if bidirectional:
num_buckets //= 2
relative_buckets += (relative_position > 0).to(torch.long) * num_buckets
relative_position = torch.abs(relative_position)
else:
relative_position = -torch.min(
relative_position, torch.zeros_like(relative_position)
)
# now relative_position is in the range [0, inf)
# half of the buckets are for exact increments in positions
max_exact = num_buckets // 2
is_small = relative_position < max_exact
# The other half of the buckets are for logarithmically bigger bins in positions up to max_distance
relative_position_if_large = max_exact + (
torch.log(relative_position.float() / max_exact)
/ math.log(max_distance / max_exact)
* (num_buckets - max_exact)
).to(torch.long)
relative_position_if_large = torch.min(
relative_position_if_large,
torch.full_like(relative_position_if_large, num_buckets - 1),
)
relative_buckets += torch.where(
is_small, relative_position, relative_position_if_large
)
return relative_buckets
def compute_bias(self, query_length, key_length, device=None):
"""Compute binned relative position bias"""
if device is None:
device = self.relative_attention_bias.weight.device
context_position = torch.arange(query_length, dtype=torch.long, device=device)[
:, None
]
memory_position = torch.arange(key_length, dtype=torch.long, device=device)[
None, :
]
relative_position = (
memory_position - context_position
) # shape (query_length, key_length)
relative_position_bucket = self._relative_position_bucket(
relative_position, # shape (query_length, key_length)
bidirectional=(not self.is_decoder),
num_buckets=self.relative_attention_num_buckets,
max_distance=self.relative_attention_max_distance,
)
values = self.relative_attention_bias(
relative_position_bucket
) # shape (query_length, key_length, num_heads)
values = values.permute([2, 0, 1]).unsqueeze(
0
) # shape (1, num_heads, query_length, key_length)
return values
def forward(
self,
hidden_states,
mask=None,
key_value_states=None,
position_bias=None,
):
"""
Self-attention (if key_value_states is None) or attention over source sentence (provided by key_value_states).
"""
# Input is (batch_size, seq_length, dim)
# Mask is (batch_size, key_length) (non-causal) or (batch_size, key_length, key_length)
batch_size, seq_length = hidden_states.shape[:2]
real_seq_length = seq_length
key_length = (
real_seq_length if key_value_states is None else key_value_states.shape[1]
)
def shape(states):
"""projection"""
return states.view(
batch_size, -1, self.n_heads, self.key_value_proj_dim
).transpose(1, 2)
def unshape(states):
"""reshape"""
return (
states.transpose(1, 2).contiguous().view(batch_size, -1, self.inner_dim)
)
query_states = self.q(hidden_states)
if key_value_states is None:
key_states, value_states = self.k(hidden_states), self.v(hidden_states)
else:
key_states, value_states = self.k(key_value_states), self.v(
key_value_states
)
query_states, key_states, value_states = (
shape(query_states),
shape(key_states),
shape(value_states),
)
scores = torch.matmul(
query_states, key_states.transpose(3, 2)
) # equivalent of torch.einsum("bnqd,bnkd->bnqk", query_states, key_states), compatible with onnx op>9
if position_bias is None:
if not self.has_relative_attention_bias:
position_bias = torch.zeros(
(1, self.n_heads, real_seq_length, key_length),
device=scores.device,
dtype=scores.dtype,
)
else:
position_bias = self.compute_bias(
real_seq_length, key_length, device=scores.device
)
if mask is not None:
# Masking happens here, masked elements in the mask have the value of -inf
position_bias = (
position_bias + mask
) # (batch_size, n_heads, seq_length, key_length)
position_bias_masked = position_bias
scores += position_bias_masked
attn_weights = nn.functional.softmax(scores.float(), dim=-1).type_as(
scores
) # (batch_size, n_heads, seq_length, key_length)
attn_weights = nn.functional.dropout(
attn_weights, p=self.dropout, training=self.training
) # (batch_size, n_heads, seq_length, key_length)
attn_output = unshape(
torch.matmul(attn_weights, value_states)
) # (batch_size, seq_length, dim)
attn_output = self.o(attn_output)
return (attn_output, position_bias)
class T5LayerSelfAttention(nn.Module):
def __init__(self, config, has_relative_attention_bias=False):
super().__init__()
self.SelfAttention = T5Attention(
config, has_relative_attention_bias=has_relative_attention_bias
)
self.layer_norm = T5LayerNorm(config.d_model, eps=config.layer_norm_epsilon)
self.dropout = nn.Dropout(config.dropout_rate)
def forward(
self,
hidden_states,
attention_mask=None,
position_bias=None,
):
normed_hidden_states = self.layer_norm(hidden_states).type_as(hidden_states)
attention_output = self.SelfAttention(
normed_hidden_states,
mask=attention_mask,
position_bias=position_bias,
)
hidden_states = hidden_states + self.dropout(attention_output[0])
outputs = (hidden_states,) + attention_output[1:]
return outputs
class T5LayerCrossAttention(nn.Module):
def __init__(self, config):
super().__init__()
self.EncDecAttention = T5Attention(config, has_relative_attention_bias=False)
self.layer_norm = T5LayerNorm(config.d_model, eps=config.layer_norm_epsilon)
self.dropout = nn.Dropout(config.dropout_rate)
def forward(
self,
hidden_states,
key_value_states,
attention_mask=None,
position_bias=None,
):
normed_hidden_states = self.layer_norm(hidden_states)
attention_output = self.EncDecAttention(
normed_hidden_states,
mask=attention_mask,
key_value_states=key_value_states,
position_bias=position_bias,
)
layer_output = hidden_states + self.dropout(attention_output[0])
outputs = (layer_output,) + attention_output[1:]
return outputs
class T5Block(nn.Module):
def __init__(self, config, has_relative_attention_bias=False):
super().__init__()
self.is_decoder = config.is_decoder
self.layer = nn.ModuleList()
self.layer.append(
T5LayerSelfAttention(
config, has_relative_attention_bias=has_relative_attention_bias
)
)
if self.is_decoder:
self.layer.append(T5LayerCrossAttention(config))
self.layer.append(T5LayerFF(config))
def forward(
self,
hidden_states,
attention_mask=None,
position_bias=None,
encoder_hidden_states=None,
encoder_attention_mask=None,
encoder_decoder_position_bias=None,
):
self_attention_outputs = self.layer[0](
hidden_states,
attention_mask=attention_mask,
position_bias=position_bias,
)
hidden_states = self_attention_outputs[0]
attention_outputs = self_attention_outputs[1:] # Relative position weights
if self.is_decoder and encoder_hidden_states is not None:
cross_attention_outputs = self.layer[1](
hidden_states,
key_value_states=encoder_hidden_states,
attention_mask=encoder_attention_mask,
position_bias=encoder_decoder_position_bias,
)
hidden_states = cross_attention_outputs[0]
# Keep relative position weights
attention_outputs = attention_outputs + cross_attention_outputs[1:]
# Apply Feed Forward layer
hidden_states = self.layer[-1](hidden_states)
outputs = (hidden_states,)
outputs = outputs + attention_outputs
return outputs # hidden-states, (self-attention position bias), (cross-attention position bias)
class T5Stack(nn.Module, ModuleUtilsMixin):
def __init__(self, config, embed_tokens):
super().__init__()
assert embed_tokens is not None
self.config = config
self.embed_tokens = embed_tokens
self.is_decoder = config.is_decoder
self.block = nn.ModuleList(
[
T5Block(config, has_relative_attention_bias=bool(i == 0))
for i in range(config.num_layers)
]
)
self.final_layer_norm = T5LayerNorm(
config.d_model, eps=config.layer_norm_epsilon
)
self.dropout = nn.Dropout(config.dropout_rate)
def forward(
self,
input_ids=None,
attention_mask=None,
encoder_hidden_states=None,
encoder_attention_mask=None,
inputs_embeds: Optional[torch.FloatTensor] = None,
) -> EncoderOutput:
if inputs_embeds is None:
inputs_embeds = self.embed_tokens(input_ids)
input_shape = inputs_embeds.size()
batch_size = input_shape[0]
seq_length = input_shape[1]
input_shape = (batch_size, seq_length)
if hasattr(self.config, "is_bf16") and self.config.is_bf16:
inputs_embeds = inputs_embeds.to(torch.bfloat16)
# Masking
if attention_mask is None:
attention_mask = torch.ones(
batch_size, seq_length, device=inputs_embeds.device
)
if (
self.is_decoder
and encoder_attention_mask is None
and encoder_hidden_states is not None
):
encoder_seq_length = encoder_hidden_states.shape[1]
encoder_attention_mask = torch.ones(
batch_size,
encoder_seq_length,
device=inputs_embeds.device,
dtype=torch.long,
)
# We can provide a self-attention mask of dimensions [batch_size, from_seq_length, to_seq_length]
# ourselves in which case we just need to make it broadcastable to all heads.
extended_attention_mask = self.get_extended_attention_mask(
attention_mask, input_shape
)
# If a 2D or 3D attention mask is provided for the cross-attention
# we need to make broadcastable to [batch_size, num_heads, seq_length, seq_length]
if self.is_decoder and encoder_hidden_states is not None:
(
encoder_batch_size,
encoder_sequence_length,
_,
) = encoder_hidden_states.size()
encoder_hidden_shape = (encoder_batch_size, encoder_sequence_length)
if encoder_attention_mask is None:
encoder_attention_mask = torch.ones(
encoder_hidden_shape, device=inputs_embeds.device
)
encoder_extended_attention_mask = self.invert_attention_mask(
encoder_attention_mask
)
else:
encoder_extended_attention_mask = None
position_bias = None
encoder_decoder_position_bias = None
hidden_states = self.dropout(inputs_embeds)
for _, layer_module in enumerate(self.block):
layer_outputs = layer_module(
hidden_states,
attention_mask=extended_attention_mask,
position_bias=position_bias,
encoder_hidden_states=encoder_hidden_states,
encoder_attention_mask=encoder_extended_attention_mask,
encoder_decoder_position_bias=encoder_decoder_position_bias,
)
hidden_states = layer_outputs[0]
# We share the position biases between the layers - the first layer store them
position_bias = layer_outputs[1]
if self.is_decoder and encoder_hidden_states is not None:
encoder_decoder_position_bias = layer_outputs[2]
hidden_states = self.final_layer_norm(hidden_states).type_as(hidden_states)
hidden_states = self.dropout(hidden_states)
return EncoderOutput(
hidden_states=hidden_states,
attention_mask=attention_mask,
)
class T5(nn.Module):
def __init__(self, config: T5Config):
super().__init__()
config.is_encoder_decoder = False
assert not config.tie_word_embeddings
self.config = config
self.model_dim = config.d_model
self.spectrogram = MelSpectrogram(
config.sample_rate, config.n_fft, config.n_mels, config.hop_length
)
self.encoder_embedder = nn.Linear(config.n_mels, config.d_model)
self.decoder_embedder = nn.Embedding(config.vocab_size, config.d_model)
encoder_config = copy.deepcopy(config)
encoder_config.is_decoder = False
self.encoder = T5Stack(encoder_config, self.encoder_embedder)
decoder_config = copy.deepcopy(config)
decoder_config.is_decoder = True
decoder_config.num_layers = config.num_decoder_layers
self.decoder = T5Stack(decoder_config, self.decoder_embedder)
self.lm_head = nn.Linear(config.d_model, config.vocab_size, bias=False)
self.generation_config = None
self.apply(self._init_weights)
def generate(
self,
frames: Optional[torch.LongTensor] = None,
attention_mask: Optional[torch.FloatTensor] = None,
max_length=None,
**kwargs,
) -> torch.LongTensor:
"""
frames: B x L_encoder x mel_bins, float32
attention_mask: B x L_encoder, int64
1 for tokens to attend to, 0 for tokens to ignore
Generation:
Starts with [SOS], ends with [EOS], padding is [PAD] (see Tokenizer)
"""
B, _ = frames.size()
SOS_TOKEN_ID = self.config.decoder_start_token_id
PAD_TOKEN_ID = self.config.pad_token_id
EOS_TOKEN_ID = self.config.eos_token_id
labels = torch.ones(B, 1, dtype=torch.long, device=frames.device) * SOS_TOKEN_ID
encoder_outputs = None
for _ in range(max_length):
out = self.forward(
frames=frames,
attention_mask=attention_mask,
decoder_input_ids=labels,
encoder_outputs=encoder_outputs,
)
encoder_outputs = out.encoder_outputs
top_labels = out.logits[:, -1].argmax(-1).unsqueeze(-1)
labels = torch.cat([labels, top_labels], dim=-1)
if (labels == EOS_TOKEN_ID).sum(-1).clamp(min=0, max=1).sum().item() == B:
break
labels[:, -1] = EOS_TOKEN_ID
# Mask out the padding, i.e., all positions after the first 1 with 0
B, L = labels.size()
mask = torch.arange(L, device=labels.device).unsqueeze(0) <= (
labels == EOS_TOKEN_ID
).long().argmax(-1).unsqueeze(-1)
labels = labels.masked_fill(~mask, PAD_TOKEN_ID)
return labels
def forward(
self,
frames: Optional[torch.LongTensor] = None,
attention_mask: Optional[torch.FloatTensor] = None,
decoder_input_ids: Optional[torch.LongTensor] = None,
decoder_attention_mask: Optional[torch.BoolTensor] = None,
tokens: Optional[torch.LongTensor] = None,
encoder_outputs=None,
inputs_embeds: Optional[torch.FloatTensor] = None,
) -> Seq2SeqLMOutput:
"""
frames: B x L_encoder x mel_bins, float32
attention_mask: B x L_encoder, int64
1 for tokens to attend to, 0 for tokens to ignore
tokens: B x L_decoder, int64
"""
if encoder_outputs is None:
encoder_outputs = self.encoder(
frames,
attention_mask=attention_mask,
inputs_embeds=inputs_embeds,
)
hidden_states = encoder_outputs.hidden_states
if tokens is not None and decoder_input_ids is None:
decoder_input_ids = self._shift_right(tokens)
decoder_outputs = self.decoder(
input_ids=decoder_input_ids,
attention_mask=decoder_attention_mask,
encoder_hidden_states=hidden_states,
encoder_attention_mask=attention_mask,
)
sequence_output = decoder_outputs[0]
lm_logits = self.lm_head(sequence_output)
loss = None
if tokens is not None:
loss_fct = CrossEntropyLoss(ignore_index=-100)
loss = loss_fct(lm_logits.view(-1, lm_logits.size(-1)), tokens.view(-1))
return Seq2SeqLMOutput(
loss=loss,
logits=lm_logits,
encoder_outputs=encoder_outputs,
)
def _init_weights(self, module):
factor = (
self.config.initializer_factor
) # Used for testing weights initialization
if isinstance(module, T5LayerNorm):
module.weight.data.fill_(factor * 1.0)
elif isinstance(module, (T5)):
module.decoder_embedder.weight.data.normal_(mean=0.0, std=factor * 1.0)
if hasattr(module, "lm_head") and not self.config.tie_word_embeddings:
module.lm_head.weight.data.normal_(mean=0.0, std=factor * 1.0)
elif isinstance(module, T5DenseGatedActDense):
d_ff, d_model = module.wi_0.weight.data.size()
module.wi_0.weight.data.normal_(mean=0.0, std=factor * ((d_model) ** -0.5))
module.wi_1.weight.data.normal_(mean=0.0, std=factor * ((d_model) ** -0.5))
module.wo.weight.data.normal_(mean=0.0, std=factor * ((d_ff) ** -0.5))
elif isinstance(module, T5Attention):
d_model = self.config.d_model
key_value_proj_dim = self.config.d_kv
n_heads = self.config.num_heads
module.q.weight.data.normal_(
mean=0.0, std=factor * ((d_model * key_value_proj_dim) ** -0.5)
)
module.k.weight.data.normal_(mean=0.0, std=factor * (d_model**-0.5))
module.v.weight.data.normal_(mean=0.0, std=factor * (d_model**-0.5))
module.o.weight.data.normal_(
mean=0.0, std=factor * ((n_heads * key_value_proj_dim) ** -0.5)
)
if hasattr(module, "relative_attention_bias"):
module.relative_attention_bias.weight.data.normal_(
mean=0.0, std=factor * ((d_model) ** -0.5)
)
def _shift_right(self, input_ids):
SOS_TOKEN_ID = self.config.decoder_start_token_id
PAD_TOKEN_ID = self.config.pad_token_id
assert SOS_TOKEN_ID is not None and PAD_TOKEN_ID is not None
shifted_input_ids = input_ids.new_zeros(input_ids.shape)
shifted_input_ids[..., 1:] = input_ids[..., :-1].clone()
shifted_input_ids[..., 0] = SOS_TOKEN_ID
# replace possible -100 values in labels by `pad_token_id`
shifted_input_ids.masked_fill_(shifted_input_ids == -100, PAD_TOKEN_ID)
return shifted_input_ids