STAR / fairseq /models /speech_dlm /modules /speech_dlm_decoder_layer.py
Yixuan Li
add fairseq folder
85ba398
# Copyright (c) Facebook, Inc. and its affiliates.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
from typing import Dict, List, Tuple, Optional
import torch
import torch.nn as nn
from fairseq import utils
from fairseq.modules import LayerNorm, MultiheadAttention
from fairseq.modules.fairseq_dropout import FairseqDropout
from fairseq.modules.quant_noise import quant_noise
from torch import Tensor
class CrossChannelTransformerDecoderLayer(nn.Module):
"""Cross-Attention Transformer Decoder Layer block as described
in the paper: https://arxiv.org/pdf/2203.16502.pdf
Composed of a Multi-head Self Attention block followed by a
Multi-head Cross-Attention block which attends to the self-attention
outputs of the other channels. The weights of the attention blocks
in all channels are shared.
Args:
args (argparse.Namespace): parsed command-line arguments
no_encoder_attn (bool, optional): whether to attend to encoder outputs
(default: False).
"""
def __init__(
self, args, no_encoder_attn=False, add_bias_kv=False, add_zero_attn=False
):
super().__init__()
self.embed_dim = args.decoder_embed_dim
self.dropout_module = FairseqDropout(
args.dropout, module_name=self.__class__.__name__
)
self.quant_noise = getattr(args, "quant_noise_pq", 0)
self.quant_noise_block_size = getattr(args, "quant_noise_pq_block_size", 8)
# This cross_self_attention is used for encoder-decoder systems,
# It's not the cross-channel attention (defined below as cross_channel_attn)
self.cross_self_attention = getattr(args, "cross_self_attention", False)
self.self_attn = self.build_self_attention(
self.embed_dim,
args,
add_bias_kv=add_bias_kv,
add_zero_attn=add_zero_attn,
)
self.cross_channel_attn = self.build_cross_channel_attention(
self.embed_dim,
args,
add_bias_kv=add_bias_kv,
add_zero_attn=add_zero_attn,
)
self.activation_fn = utils.get_activation_fn(
activation=str(args.activation_fn)
if getattr(args, "activation_fn", None) is not None
else "relu"
)
activation_dropout_p = getattr(args, "activation_dropout", 0) or 0
if activation_dropout_p == 0:
# for backwards compatibility with models that use args.relu_dropout
activation_dropout_p = getattr(args, "relu_dropout", 0) or 0
self.activation_dropout_module = FairseqDropout(
float(activation_dropout_p), module_name=self.__class__.__name__
)
self.normalize_before = args.decoder_normalize_before
# use layerNorm rather than FusedLayerNorm for exporting.
# char_inputs can be used to determint this.
# TODO remove this once we update apex with the fix
export = getattr(args, "char_inputs", False)
self.self_attn_layer_norm = LayerNorm(self.embed_dim, export=export)
self.cross_channel_attn_layer_norm = LayerNorm(self.embed_dim, export=export)
if no_encoder_attn:
self.encoder_attn = None
self.encoder_attn_layer_norm = None
else:
self.encoder_attn = self.build_encoder_attention(self.embed_dim, args)
self.encoder_attn_layer_norm = LayerNorm(self.embed_dim, export=export)
self.fc1 = self.build_fc1(
self.embed_dim,
args.decoder_ffn_embed_dim,
self.quant_noise,
self.quant_noise_block_size,
)
self.fc2 = self.build_fc2(
args.decoder_ffn_embed_dim,
self.embed_dim,
self.quant_noise,
self.quant_noise_block_size,
)
self.final_layer_norm = LayerNorm(self.embed_dim, export=export)
self.need_attn = True
self.onnx_trace = False
def build_fc1(self, input_dim, output_dim, q_noise, qn_block_size):
return quant_noise(nn.Linear(input_dim, output_dim), q_noise, qn_block_size)
def build_fc2(self, input_dim, output_dim, q_noise, qn_block_size):
return quant_noise(nn.Linear(input_dim, output_dim), q_noise, qn_block_size)
def build_self_attention(
self, embed_dim, args, add_bias_kv=False, add_zero_attn=False
):
return MultiheadAttention(
embed_dim,
args.decoder_attention_heads,
dropout=args.attention_dropout,
add_bias_kv=add_bias_kv,
add_zero_attn=add_zero_attn,
self_attention=not getattr(args, "cross_self_attention", False),
q_noise=self.quant_noise,
qn_block_size=self.quant_noise_block_size,
)
def build_cross_channel_attention(
self, embed_dim, args, add_bias_kv=False, add_zero_attn=False
):
return MultiheadAttention(
embed_dim,
args.decoder_attention_heads,
dropout=args.attention_dropout,
add_bias_kv=add_bias_kv,
add_zero_attn=add_zero_attn,
self_attention=False,
q_noise=self.quant_noise,
qn_block_size=self.quant_noise_block_size,
)
def build_encoder_attention(self, embed_dim, args):
return MultiheadAttention(
embed_dim,
args.decoder_attention_heads,
kdim=getattr(args, "encoder_embed_dim", None),
vdim=getattr(args, "encoder_embed_dim", None),
dropout=args.attention_dropout,
encoder_decoder_attention=True,
q_noise=self.quant_noise,
qn_block_size=self.quant_noise_block_size,
)
def prepare_for_onnx_export_(self):
self.onnx_trace = True
def residual_connection(self, x, residual):
return residual + x
def forward(
self,
x_list_tensor: List[torch.Tensor],
encoder_out: Optional[torch.Tensor] = None,
encoder_padding_mask: Optional[torch.Tensor] = None,
incremental_state: Optional[
List[Dict[str, Dict[str, Optional[Tensor]]]]
] = None,
prev_self_attn_state: Optional[List[Tuple[torch.Tensor, torch.Tensor]]] = None,
prev_attn_state: Optional[List[torch.Tensor]] = None,
self_attn_mask: Optional[torch.Tensor] = None,
self_attn_padding_mask: Optional[torch.Tensor] = None,
need_attn: bool = False,
need_head_weights: bool = False,
):
"""
Args:
x_list_tensor (List[Tensor]): list of input tensors in different channels,
each tensor is of shape `(seq_len, batch, embed_dim)`
encoder_padding_mask (ByteTensor, optional): binary
ByteTensor of shape `(batch, src_len)` where padding
elements are indicated by ``1``.
incremental_state (optional): list of incremental_state dictionaries over
different channels (sequence generation mode)
prev_self_attn_state (List[Tuple[Tensor, Tensor]], optional): list of tuples
(self_attn_state, cross_channel_attn_state) over different channels
need_attn (bool, optional): return attention weights
need_head_weights (bool, optional): return attention weights
for each head (default: return average over heads).
Returns:
list of encoded output of shape `(seq_len, batch, embed_dim)`
"""
n_channels = len(x_list_tensor)
if need_head_weights:
need_attn = True
# incremental_state is a list of dictionaries over different channels
if incremental_state is not None:
assert isinstance(incremental_state, list)
assert len(incremental_state) == n_channels
# prev_self_attn_state is a list of tuples (self_attn_state, cross_channel_attn_state) over different channels
if prev_self_attn_state is not None:
assert isinstance(prev_self_attn_state, list)
assert len(prev_self_attn_state) == n_channels
for prev_self_attn_state_channel in prev_self_attn_state:
assert isinstance(prev_self_attn_state_channel, tuple)
assert len(prev_self_attn_state_channel) == 2
# Backup for other channels & cross channel attention
self_attn_mask_orin = self_attn_mask
self_attn_padding_mask_orin = self_attn_padding_mask
x_list = []
attn_list = []
for i, x in enumerate(x_list_tensor):
residual = x
if self.normalize_before:
x = self.self_attn_layer_norm(x)
if prev_self_attn_state is not None:
prev_key, prev_value = prev_self_attn_state[i][0][:2]
saved_state: Dict[str, Optional[Tensor]] = {
"prev_key": prev_key,
"prev_value": prev_value,
}
if len(prev_self_attn_state[i][0]) >= 3:
saved_state["prev_key_padding_mask"] = prev_self_attn_state[i][0][2]
assert incremental_state is not None
self.self_attn._set_input_buffer(incremental_state[i], saved_state)
_self_attn_input_buffer = self.self_attn._get_input_buffer(
incremental_state[i] if incremental_state is not None else None
)
if self.cross_self_attention and not (
incremental_state is not None
and _self_attn_input_buffer is not None
and "prev_key" in _self_attn_input_buffer
):
if self_attn_mask_orin is not None:
assert encoder_out is not None
self_attn_mask = torch.cat(
(
x.new_zeros(x.size(0), encoder_out.size(0)),
self_attn_mask_orin,
),
dim=1,
)
if self_attn_padding_mask_orin is not None:
if encoder_padding_mask is None:
assert encoder_out is not None
encoder_padding_mask = self_attn_padding_mask_orin.new_zeros(
encoder_out.size(1), encoder_out.size(0)
)
self_attn_padding_mask = torch.cat(
(encoder_padding_mask, self_attn_padding_mask_orin), dim=1
)
assert encoder_out is not None
y = torch.cat((encoder_out, x), dim=0)
else:
y = x
x, attn = self.self_attn(
query=x,
key=y,
value=y,
key_padding_mask=self_attn_padding_mask,
incremental_state=incremental_state[i]
if incremental_state is not None
else None,
need_weights=False,
attn_mask=self_attn_mask,
)
x = self.dropout_module(x)
x = self.residual_connection(x, residual)
if not self.normalize_before:
x = self.self_attn_layer_norm(x)
if self.encoder_attn is not None and encoder_out is not None:
residual = x
if self.normalize_before:
x = self.encoder_attn_layer_norm(x)
if prev_attn_state is not None:
prev_key, prev_value = prev_attn_state[:2]
saved_state: Dict[str, Optional[Tensor]] = {
"prev_key": prev_key,
"prev_value": prev_value,
}
if len(prev_attn_state) >= 3:
saved_state["prev_key_padding_mask"] = prev_attn_state[2]
assert incremental_state is not None
self.encoder_attn._set_input_buffer(
incremental_state[i], saved_state
)
x, attn = self.encoder_attn(
query=x,
key=encoder_out,
value=encoder_out,
key_padding_mask=encoder_padding_mask,
incremental_state=incremental_state[i]
if incremental_state is not None
else None,
static_kv=True,
need_weights=need_attn or (not self.training and self.need_attn),
need_head_weights=need_head_weights,
)
x = self.dropout_module(x)
x = self.residual_connection(x, residual)
if not self.normalize_before:
x = self.encoder_attn_layer_norm(x)
x_list.append(x)
attn_list.append(attn)
# Store attentions & new x(s) (bc the old x(s) are used in other channels)
x_list_new = []
# Here comes the cross channel attention
for i, x in enumerate(x_list):
residual = x
if self.normalize_before:
x = self.cross_channel_attn_layer_norm(x)
if prev_self_attn_state is not None:
prev_key, prev_value = prev_self_attn_state[i][1][:2]
saved_state: Dict[str, Optional[Tensor]] = {
"prev_key": prev_key,
"prev_value": prev_value,
}
if len(prev_self_attn_state[i][1]) >= 3:
saved_state["prev_key_padding_mask"] = prev_self_attn_state[i][1][2]
assert incremental_state is not None
self.cross_channel_attn._set_input_buffer(
incremental_state[i], saved_state
)
# The cross attention is computed with the concatenation of attentions from other channels
if len(x_list) > 1:
x_other = torch.cat(
[x_list[(i + j) % len(x_list)] for j in range(1, len(x_list))],
dim=0,
)
else:
# Self-attention when having only one channel
x_other = x_list[i]
x, attn = self.cross_channel_attn(
query=x,
key=x_other,
value=x_other,
key_padding_mask=self_attn_padding_mask_orin,
incremental_state=incremental_state[i]
if incremental_state is not None
else None,
need_weights=False,
attn_mask=self_attn_mask_orin,
)
x = self.dropout_module(x)
x = self.residual_connection(x, residual)
if not self.normalize_before:
x = self.cross_channel_attn_layer_norm(x)
x_list_new.append(x)
x_list = x_list_new
for i, x in enumerate(x_list):
residual = x
if self.normalize_before:
x = self.final_layer_norm(x)
x = self.activation_fn(self.fc1(x))
x = self.activation_dropout_module(x)
x = self.fc2(x)
x = self.dropout_module(x)
x = self.residual_connection(x, residual)
if not self.normalize_before:
x = self.final_layer_norm(x)
x_list[i] = x
# Trick for the checkpoint activation
x_list_tensor = torch.stack(x_list)
if self.onnx_trace and incremental_state is not None:
self_and_cross_attn_state_list = []
for i in range(n_channels):
self_and_cross_attn_state = []
for self_attn_module in [self.self_attn, self.cross_channel_attn]:
saved_state = self_attn_module._get_input_buffer(
incremental_state[i]
)
assert saved_state is not None
if self_attn_padding_mask is not None:
self_attn_module_state = [
saved_state["prev_key"],
saved_state["prev_value"],
saved_state["prev_key_padding_mask"],
]
else:
self_attn_module_state = [
saved_state["prev_key"],
saved_state["prev_value"],
]
self_and_cross_attn_state.append(self_attn_module_state)
self_and_cross_attn_state_list.append(tuple(self_and_cross_attn_state))
return x_list_tensor, attn_list, self_and_cross_attn_state_list
return x_list_tensor, attn_list, None
def make_generation_fast_(self, need_attn: bool = False, **kwargs):
self.need_attn = need_attn
# Rewrite fairseq.modules.TransformerDecoderLayer
# to be compatible with checkpoint_activations
# (avoid forwarding model multiple times)
class StandardTransformerDecoderLayer(nn.Module):
"""Rewrite fairseq.modules.TransformerDecoderLayer to avoid forwarding
model multiple times and be compatible with checkpoint_activations.
The input is expected to be a list of tensors from different channels,
each is forwarded to the same model (shared attention weights).
In the original paper each operation (multi-head attention, encoder
attention or FFN) is postprocessed with: `dropout -> add residual ->
layernorm`. In the tensor2tensor code they suggest that learning is more
robust when preprocessing each layer with layernorm and postprocessing with:
`dropout -> add residual`. We default to the approach in the paper, but the
tensor2tensor approach can be enabled by setting
*args.decoder_normalize_before* to ``True``.
Args:
args (argparse.Namespace): parsed command-line arguments
no_encoder_attn (bool, optional): whether to attend to encoder outputs
(default: False).
"""
def __init__(
self, args, no_encoder_attn=False, add_bias_kv=False, add_zero_attn=False
):
super().__init__()
self.embed_dim = args.decoder_embed_dim
self.dropout_module = FairseqDropout(
args.dropout, module_name=self.__class__.__name__
)
self.quant_noise = getattr(args, "quant_noise_pq", 0)
self.quant_noise_block_size = getattr(args, "quant_noise_pq_block_size", 8)
self.cross_self_attention = getattr(args, "cross_self_attention", False)
self.self_attn = self.build_self_attention(
self.embed_dim,
args,
add_bias_kv=add_bias_kv,
add_zero_attn=add_zero_attn,
)
self.activation_fn = utils.get_activation_fn(
activation=str(args.activation_fn)
if getattr(args, "activation_fn", None) is not None
else "relu"
)
activation_dropout_p = getattr(args, "activation_dropout", 0) or 0
if activation_dropout_p == 0:
# for backwards compatibility with models that use args.relu_dropout
activation_dropout_p = getattr(args, "relu_dropout", 0) or 0
self.activation_dropout_module = FairseqDropout(
float(activation_dropout_p), module_name=self.__class__.__name__
)
self.normalize_before = args.decoder_normalize_before
# use layerNorm rather than FusedLayerNorm for exporting.
# char_inputs can be used to determint this.
# TODO remove this once we update apex with the fix
export = getattr(args, "char_inputs", False)
self.self_attn_layer_norm = LayerNorm(self.embed_dim, export=export)
if no_encoder_attn:
self.encoder_attn = None
self.encoder_attn_layer_norm = None
else:
self.encoder_attn = self.build_encoder_attention(self.embed_dim, args)
self.encoder_attn_layer_norm = LayerNorm(self.embed_dim, export=export)
self.fc1 = self.build_fc1(
self.embed_dim,
args.decoder_ffn_embed_dim,
self.quant_noise,
self.quant_noise_block_size,
)
self.fc2 = self.build_fc2(
args.decoder_ffn_embed_dim,
self.embed_dim,
self.quant_noise,
self.quant_noise_block_size,
)
self.final_layer_norm = LayerNorm(self.embed_dim, export=export)
self.need_attn = True
self.onnx_trace = False
def build_fc1(self, input_dim, output_dim, q_noise, qn_block_size):
return quant_noise(nn.Linear(input_dim, output_dim), q_noise, qn_block_size)
def build_fc2(self, input_dim, output_dim, q_noise, qn_block_size):
return quant_noise(nn.Linear(input_dim, output_dim), q_noise, qn_block_size)
def build_self_attention(
self, embed_dim, args, add_bias_kv=False, add_zero_attn=False
):
return MultiheadAttention(
embed_dim,
args.decoder_attention_heads,
dropout=args.attention_dropout,
add_bias_kv=add_bias_kv,
add_zero_attn=add_zero_attn,
self_attention=not getattr(args, "cross_self_attention", False),
q_noise=self.quant_noise,
qn_block_size=self.quant_noise_block_size,
)
def build_encoder_attention(self, embed_dim, args):
return MultiheadAttention(
embed_dim,
args.decoder_attention_heads,
kdim=getattr(args, "encoder_embed_dim", None),
vdim=getattr(args, "encoder_embed_dim", None),
dropout=args.attention_dropout,
encoder_decoder_attention=True,
q_noise=self.quant_noise,
qn_block_size=self.quant_noise_block_size,
)
def prepare_for_onnx_export_(self):
self.onnx_trace = True
def residual_connection(self, x, residual):
return residual + x
def forward(
self,
x_list_tensor: List[torch.Tensor],
encoder_out: Optional[torch.Tensor] = None,
encoder_padding_mask: Optional[torch.Tensor] = None,
incremental_state: Optional[
List[Dict[str, Dict[str, Optional[Tensor]]]]
] = None,
prev_self_attn_state: Optional[List[Tuple[torch.Tensor, torch.Tensor]]] = None,
prev_attn_state: Optional[List[torch.Tensor]] = None,
self_attn_mask: Optional[torch.Tensor] = None,
self_attn_padding_mask: Optional[torch.Tensor] = None,
need_attn: bool = False,
need_head_weights: bool = False,
):
"""
Args:
x_list_tensor (List[Tensor]): list of input tensors in different channels,
each tensor is of shape `(seq_len, batch, embed_dim)`
encoder_padding_mask (ByteTensor, optional): binary
ByteTensor of shape `(batch, src_len)` where padding
elements are indicated by ``1``.
incremental_state (optional): list of incremental_state dictionaries over
different channels (sequence generation mode)
prev_self_attn_state (List[Tuple[Tensor, Tensor]], optional): list of tuples
(self_attn_state, cross_channel_attn_state) over different channels
need_attn (bool, optional): return attention weights
need_head_weights (bool, optional): return attention weights
for each head (default: return average over heads).
Returns:
list of encoded output of shape `(seq_len, batch, embed_dim)`
"""
n_channels = len(x_list_tensor)
if need_head_weights:
need_attn = True
# incremental_state is a list of dictionaries over different channels
if incremental_state is not None:
assert isinstance(incremental_state, list)
assert len(incremental_state) == n_channels
# prev_self_attn_state is a list of self_attn_state over different channels
if prev_self_attn_state is not None:
assert isinstance(prev_self_attn_state, list)
assert len(prev_self_attn_state) == n_channels
x_list = []
attn_list = []
for i, x in enumerate(x_list_tensor):
residual = x
if self.normalize_before:
x = self.self_attn_layer_norm(x)
if prev_self_attn_state is not None:
prev_key, prev_value = prev_self_attn_state[i][:2]
saved_state: Dict[str, Optional[Tensor]] = {
"prev_key": prev_key,
"prev_value": prev_value,
}
if len(prev_self_attn_state[i]) >= 3:
saved_state["prev_key_padding_mask"] = prev_self_attn_state[2]
assert incremental_state is not None
self.self_attn._set_input_buffer(incremental_state[i], saved_state)
_self_attn_input_buffer = self.self_attn._get_input_buffer(
incremental_state
)
if self.cross_self_attention and not (
incremental_state is not None
and _self_attn_input_buffer is not None
and "prev_key" in _self_attn_input_buffer
):
if self_attn_mask is not None:
assert encoder_out is not None
self_attn_mask = torch.cat(
(x.new_zeros(x.size(0), encoder_out.size(0)), self_attn_mask),
dim=1,
)
if self_attn_padding_mask is not None:
if encoder_padding_mask is None:
assert encoder_out is not None
encoder_padding_mask = self_attn_padding_mask.new_zeros(
encoder_out.size(1), encoder_out.size(0)
)
self_attn_padding_mask = torch.cat(
(encoder_padding_mask, self_attn_padding_mask), dim=1
)
assert encoder_out is not None
y = torch.cat((encoder_out, x), dim=0)
else:
y = x
x, attn = self.self_attn(
query=x,
key=y,
value=y,
key_padding_mask=self_attn_padding_mask,
incremental_state=incremental_state[i]
if incremental_state is not None
else None,
need_weights=False,
attn_mask=self_attn_mask,
)
x = self.dropout_module(x)
x = self.residual_connection(x, residual)
if not self.normalize_before:
x = self.self_attn_layer_norm(x)
if self.encoder_attn is not None and encoder_out is not None:
residual = x
if self.normalize_before:
x = self.encoder_attn_layer_norm(x)
if prev_attn_state is not None:
prev_key, prev_value = prev_attn_state[:2]
saved_state: Dict[str, Optional[Tensor]] = {
"prev_key": prev_key,
"prev_value": prev_value,
}
if len(prev_attn_state) >= 3:
saved_state["prev_key_padding_mask"] = prev_attn_state[2]
assert incremental_state is not None
self.encoder_attn._set_input_buffer(incremental_state, saved_state)
x, attn = self.encoder_attn(
query=x,
key=encoder_out,
value=encoder_out,
key_padding_mask=encoder_padding_mask,
incremental_state=incremental_state[i]
if incremental_state is not None
else None,
static_kv=True,
need_weights=need_attn or (not self.training and self.need_attn),
need_head_weights=need_head_weights,
)
x = self.dropout_module(x)
x = self.residual_connection(x, residual)
if not self.normalize_before:
x = self.encoder_attn_layer_norm(x)
residual = x
if self.normalize_before:
x = self.final_layer_norm(x)
x = self.activation_fn(self.fc1(x))
x = self.activation_dropout_module(x)
x = self.fc2(x)
x = self.dropout_module(x)
x = self.residual_connection(x, residual)
if not self.normalize_before:
x = self.final_layer_norm(x)
x_list.append(x)
attn_list.append(attn)
# Trick for the checkpoint activation
x_list_tensor = torch.stack(x_list)
if self.onnx_trace and incremental_state is not None:
self_attn_state_list = []
for i in range(n_channels):
saved_state = self.self_attn._get_input_buffer(incremental_state[i])
assert saved_state is not None
if self_attn_padding_mask is not None:
self_attn_state = [
saved_state["prev_key"],
saved_state["prev_value"],
saved_state["prev_key_padding_mask"],
]
else:
self_attn_state = [
saved_state["prev_key"],
saved_state["prev_value"],
]
self_attn_state_list.append(self_attn_state)
return x_list_tensor, attn_list, self_attn_state_list
return x_list_tensor, attn_list, None
def make_generation_fast_(self, need_attn: bool = False, **kwargs):
self.need_attn = need_attn