File size: 22,203 Bytes
85ba398 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 270 271 272 273 274 275 276 277 278 279 280 281 282 283 284 285 286 287 288 289 290 291 292 293 294 295 296 297 298 299 300 301 302 303 304 305 306 307 308 309 310 311 312 313 314 315 316 317 318 319 320 321 322 323 324 325 326 327 328 329 330 331 332 333 334 335 336 337 338 339 340 341 342 343 344 345 346 347 348 349 350 351 352 353 354 355 356 357 358 359 360 361 362 363 364 365 366 367 368 369 370 371 372 373 374 375 376 377 378 379 380 381 382 383 384 385 386 387 388 389 390 391 392 393 394 395 396 397 398 399 400 401 402 403 404 405 406 407 408 409 410 411 412 413 414 415 416 417 418 419 420 421 422 423 424 425 426 427 428 429 430 431 432 433 434 435 436 437 438 439 440 441 442 443 444 445 446 447 448 449 450 451 452 453 454 455 456 457 458 459 460 461 462 463 464 465 466 467 468 469 470 471 472 473 474 475 476 477 478 479 480 481 482 483 484 485 486 487 488 489 490 491 492 493 494 495 496 497 498 499 500 501 502 503 504 505 506 507 508 509 510 511 512 513 514 515 516 517 518 519 520 521 522 523 524 525 526 527 528 529 530 531 532 533 534 535 536 537 538 539 540 541 542 543 544 545 546 547 548 549 550 551 552 553 554 555 556 557 558 559 560 561 562 563 564 565 566 567 568 569 570 571 572 573 |
# Copyright (c) Facebook, Inc. and its affiliates.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
import math
from typing import Any, Dict, List, Optional, Tuple
import torch
import torch.nn as nn
from fairseq import utils
from fairseq.models import FairseqIncrementalDecoder
from fairseq.modules import (
FairseqDropout,
LayerDropModuleList,
LayerNorm,
PositionalEmbedding,
)
from .speech_dlm_decoder_layer import (
CrossChannelTransformerDecoderLayer,
StandardTransformerDecoderLayer,
)
from fairseq.modules.checkpoint_activations import checkpoint_wrapper
from fairseq.modules.quant_noise import quant_noise as apply_quant_noise_
from torch import Tensor
class CrossChannelTransformerDecoder(FairseqIncrementalDecoder):
"""
Cross-channel Transformer Decoder Block for parallel spoken dialogue units
as described in the paper: https://arxiv.org/pdf/2203.16502.pdf;
consisting of *args.decoder_layers* layers. Each layer is a
:class:`StandardTransformerDecoderLayer` or
:class:`CrossChannelTransformerDecoderLayer`.
Args:
args (argparse.Namespace): parsed command-line arguments
dictionary (~fairseq.data.Dictionary): decoding dictionary
embed_tokens (torch.nn.Embedding): output embedding
channels (list): list of channel names (string)
no_encoder_attn (bool, optional): whether to attend to encoder outputs
(default: False).
"""
def __init__(self, args, dictionary, embed_tokens, channels, no_encoder_attn=False):
self.args = args
super().__init__(dictionary)
self.register_buffer("version", torch.Tensor([3]))
self._future_mask = torch.empty(0)
self.dropout_module = FairseqDropout(
args.dropout, module_name=self.__class__.__name__
)
self.decoder_layerdrop = args.decoder_layerdrop
self.share_input_output_embed = args.share_decoder_input_output_embed
self.channels = channels
input_embed_dim = embed_tokens.embedding_dim
embed_dim = args.decoder_embed_dim
self.embed_dim = embed_dim
self.output_embed_dim = args.decoder_output_dim
self.padding_idx = embed_tokens.padding_idx
self.max_target_positions = args.max_target_positions
self.embed_tokens = embed_tokens
self.embed_scale = 1.0 if args.no_scale_embedding else math.sqrt(embed_dim)
if args.quant_noise_pq > 0:
self.quant_noise = apply_quant_noise_(
nn.Linear(embed_dim, embed_dim, bias=False),
args.quant_noise_pq,
args.quant_noise_pq_block_size,
)
else:
self.quant_noise = None
self.project_in_dim = (
nn.Linear(input_embed_dim, embed_dim, bias=False)
if embed_dim != input_embed_dim
else None
)
self.embed_positions = (
PositionalEmbedding(
self.max_target_positions,
embed_dim,
self.padding_idx,
learned=args.decoder_learned_pos,
)
if not args.no_token_positional_embeddings
else None
)
if getattr(args, "layernorm_embedding", False):
self.layernorm_embedding = LayerNorm(embed_dim)
else:
self.layernorm_embedding = None
self.cross_self_attention = getattr(args, "cross_self_attention", False)
assert 0 <= args.decoder_cross_layers <= args.decoder_layers, (
"The number of cross-channel attention decoder layers must be non-negative"
f"and not exceeds the number of decoder layers (found {args.decoder_cross_layers})"
)
if self.decoder_layerdrop > 0.0:
self.layers = LayerDropModuleList(p=self.decoder_layerdrop)
else:
self.layers = nn.ModuleList([])
self.layers.extend(
[
self.build_decoder_layer(args, no_encoder_attn)
if i < args.decoder_layers - args.decoder_cross_layers
else self.build_cross_decoder_layer(args, no_encoder_attn)
for i in range(args.decoder_layers)
]
)
self.num_layers = len(self.layers)
self.non_cross_layers = args.decoder_layers - args.decoder_cross_layers
if args.decoder_normalize_before and not getattr(
args, "no_decoder_final_norm", False
):
self.layer_norm = LayerNorm(embed_dim)
else:
self.layer_norm = None
self.project_out_dim = (
nn.Linear(embed_dim, self.output_embed_dim, bias=False)
if embed_dim != self.output_embed_dim
else None
)
self.output_projection = None
self.is_cross_prediction = bool(
float(args.main_and_cross_weights.split(",")[1]) != 0
)
self.n_output_projections = (
1 if not self.is_cross_prediction else len(self.channels)
)
if self.share_input_output_embed:
# Output projection is a list of projections
# where the first proj is for the main-channel,
# then roll in a cicular way.
# For example: if the main channel has index i
# the second proj is for channel i+1 (mod N_channels), etc.
self.output_projection = nn.ModuleList(
[
nn.Linear(
embed_tokens.weight.shape[1], # embed_dim
embed_tokens.weight.shape[0], # n_dictionaries
bias=False,
)
for _ in range(self.n_output_projections)
]
)
# Only share the main-channel projection
self.output_projection[0].weight = embed_tokens.weight
for i in range(1, self.n_output_projections):
nn.init.normal_(
self.output_projection[i].weight,
mean=0,
std=embed_tokens.weight.shape[1] ** -0.5,
)
else:
self.output_projection = nn.ModuleList(
[
nn.Linear(self.output_embed_dim, len(dictionary), bias=False)
for _ in range(self.n_output_projections)
]
)
for i in range(self.n_output_projections):
nn.init.normal_(
self.output_projection[i].weight,
mean=0,
std=self.output_embed_dim**-0.5,
)
self.output_duration_prediction = (
None
if str(args.duration_prediction).lower() == "false"
else nn.ModuleList(
[
nn.Linear(self.output_embed_dim, 1)
for _ in range(self.n_output_projections)
]
)
)
def build_decoder_layer(self, args, no_encoder_attn=False):
layer = StandardTransformerDecoderLayer(args, no_encoder_attn)
if getattr(args, "checkpoint_activations", False):
offload_to_cpu = getattr(args, "offload_activations", False)
layer = checkpoint_wrapper(layer, offload_to_cpu=offload_to_cpu)
return layer
def build_cross_decoder_layer(self, args, no_encoder_attn=False):
layer = CrossChannelTransformerDecoderLayer(args, no_encoder_attn)
if getattr(args, "checkpoint_activations", False):
offload_to_cpu = getattr(args, "offload_activations", False)
layer = checkpoint_wrapper(layer, offload_to_cpu=offload_to_cpu)
return layer
def forward(
self,
prev_output_tokens: Dict[str, Tensor],
encoder_out: Optional[Dict[str, List[Tensor]]] = None,
incremental_state: Optional[
List[Dict[str, Dict[str, Optional[Tensor]]]]
] = None,
features_only: bool = False,
full_context_alignment: bool = False,
alignment_layer: Optional[int] = None,
alignment_heads: Optional[int] = None,
src_lengths: Optional[Any] = None,
# return_all_hiddens: bool = False,
):
"""
Args:
prev_output_tokens (dict[str, LongTensor]): previous decoder outputs,
dictionary over all channels with the values being the tensors
of shape `(batch, tgt_len)`, for teacher forcing
encoder_out (optional): output from the encoder, used for
encoder-side attention
incremental_state (dict): list of dictionaries used for storing state
during :ref:`Incremental decoding`
features_only (bool, optional): only return features without
applying output layer (default: False).
full_context_alignment (bool, optional): don't apply
auto-regressive mask to self-attention (default: False).
Returns:
tuple:
- the decoder's output, dict over channels of tensors
of shape `(batch, tgt_len, vocab)`
- a dictionary with any model-specific outputs
"""
x, extra = self.extract_features(
prev_output_tokens,
encoder_out=encoder_out,
incremental_state=incremental_state,
full_context_alignment=full_context_alignment,
alignment_layer=alignment_layer,
alignment_heads=alignment_heads,
)
if not features_only:
x = self.output_layer(x)
return x, extra
def extract_features(
self,
prev_output_tokens: Dict[str, Tensor],
encoder_out: Optional[Dict[str, List[Tensor]]],
incremental_state: Optional[
List[Dict[str, Dict[str, Optional[Tensor]]]]
] = None,
full_context_alignment: bool = False,
alignment_layer: Optional[int] = None,
alignment_heads: Optional[int] = None,
):
return self.extract_features_scriptable(
prev_output_tokens,
encoder_out,
incremental_state,
full_context_alignment,
alignment_layer,
alignment_heads,
)
"""
A scriptable subclass of this class has an extract_features method and calls
super().extract_features, but super() is not supported in torchscript. A copy of
this function is made to be used in the subclass instead.
"""
def extract_features_scriptable(
self,
prev_output_tokens: Dict[str, Tensor],
encoder_out: Optional[Dict[str, List[Tensor]]],
incremental_state: Optional[
List[Dict[str, Dict[str, Optional[Tensor]]]]
] = None,
full_context_alignment: bool = False,
alignment_layer: Optional[int] = None,
alignment_heads: Optional[int] = None,
):
"""
The core function of *forward* but only return features.
The input (prev_output_tokens) is a dictionary over all channels,
expected to have the following form:
{
'channel1' : Tensor((batch x tgt_len)),
'channel2' : Tensor((batch x tgt_len)),
}
Args:
full_context_alignment (bool, optional): don't apply
auto-regressive mask to self-attention (default: False).
alignment_layer (int, optional): return mean alignment over
heads at this layer (default: last layer).
alignment_heads (int, optional): only average alignment over
this many heads (default: all heads).
Returns:
tuple:
- the decoder's features, dict over channels of tensors
of shape `(batch, tgt_len, embed_dim)`
- a dictionary with any model-specific outputs
"""
if alignment_layer is None:
alignment_layer = self.num_layers - 1
x_list = []
for i, channel in enumerate(self.channels):
# embed positions
positions = None
if self.embed_positions is not None:
positions = self.embed_positions(
prev_output_tokens[channel],
incremental_state=incremental_state[i]
if incremental_state is not None
else None,
)
if incremental_state is not None:
prev_output_tokens[channel] = prev_output_tokens[channel][:, -1:]
if positions is not None:
positions = positions[:, -1:]
# embed tokens and positions
x = self.embed_tokens(prev_output_tokens[channel])
if self.project_in_dim is not None:
x = self.project_in_dim(x)
x = self.embed_scale * x
if self.quant_noise is not None:
x = self.quant_noise(x)
if positions is not None:
x += positions
if self.layernorm_embedding is not None:
x = self.layernorm_embedding(x)
x = self.dropout_module(x)
# B x T x C -> T x B x C
x = x.transpose(0, 1)
x_list.append(x)
self_attn_padding_mask: Optional[Tensor] = None
if (
self.cross_self_attention
or prev_output_tokens[self.channels[0]].eq(self.padding_idx).any()
):
self_attn_padding_mask = prev_output_tokens[self.channels[0]].eq(
self.padding_idx
)
# decoder layers
attn: Optional[Dict[Tensor]] = None
inner_states: List[Optional[Dict[str, Tensor]]] = [
{channel: x_list[i] for i, channel in enumerate(self.channels)}
]
for idx, layer in enumerate(self.layers):
if incremental_state is None and not full_context_alignment:
self_attn_mask = self.buffered_future_mask(x_list[0])
else:
self_attn_mask = None
# need to change to tensor for the checkpoint activation to work
if isinstance(x_list, list):
x_list = torch.stack(x_list)
x_list, layer_attn_list, _ = layer(
x_list,
encoder_out["encoder_out"][0]
if (encoder_out is not None and len(encoder_out["encoder_out"]) > 0)
else None,
encoder_out["encoder_padding_mask"][0]
if (
encoder_out is not None
and len(encoder_out["encoder_padding_mask"]) > 0
)
else None,
incremental_state,
self_attn_mask=self_attn_mask,
self_attn_padding_mask=self_attn_padding_mask,
need_attn=bool((idx == alignment_layer)),
need_head_weights=bool((idx == alignment_layer)),
)
inner_states.append(
{channel: x_list[i] for i, channel in enumerate(self.channels)}
)
if idx == alignment_layer and all(
layer_attn is not None for layer_attn in layer_attn_list
):
attn = {
channel: layer_attn_list[i].float().to(x_list[0])
for i, channel in enumerate(self.channels)
}
# change back from tensor to list
if not isinstance(x_list, list):
x_list = list(torch.unbind(x_list))
if attn is not None:
for channel in attn:
if alignment_heads is not None:
attn[channel] = attn[channel][:alignment_heads]
# average probabilities over heads
attn[channel] = attn[channel].mean(dim=0)
for i, x in enumerate(x_list):
if self.layer_norm is not None:
x = self.layer_norm(x)
# T x B x C -> B x T x C
x = x.transpose(0, 1)
if self.project_out_dim is not None:
x = self.project_out_dim(x)
x_list[i] = x
x = {channel: x_list[i] for i, channel in enumerate(self.channels)}
return x, {"attn": [attn], "inner_states": inner_states}
def output_layer(self, features):
"""Project features to the vocabulary size.
Return a dictionary of the form:
{
'input-channel': {
'predicted-channel': token prediction tensor of shape `(batch, tgt_len, vocab)`,
}
}
if duration_prediction is enabled
{
'input-channel': {
'predicted-channel': {
'pred_token': token prediction tensor of shape `(batch, tgt_len, vocab)`,
'pred_duration': duration prediction tensor
}
}
}
"""
# project back to size of vocabulary
if self.output_duration_prediction is None:
if self.is_cross_prediction:
return {
channel: {
pred_channel: self.output_projection[j - i](features[channel])
for j, pred_channel in enumerate(self.channels)
}
for i, channel in enumerate(self.channels)
}
else:
return {
channel: {channel: self.output_projection[0](features[channel])}
for i, channel in enumerate(self.channels)
}
else:
if self.is_cross_prediction:
return {
channel: {
pred_channel: {
"pred_token": self.output_projection[j - i](
features[channel]
),
"pred_duration": self.output_duration_prediction[j - i](
features[channel]
),
}
for j, pred_channel in enumerate(self.channels)
}
for i, channel in enumerate(self.channels)
}
else:
return {
channel: {
channel: {
"pred_token": self.output_projection[0](features[channel]),
"pred_duration": self.output_duration_prediction[0](
features[channel]
),
}
}
for i, channel in enumerate(self.channels)
}
def max_positions(self):
"""Maximum output length supported by the decoder."""
if self.embed_positions is None:
return self.max_target_positions
return min(self.max_target_positions, self.embed_positions.max_positions)
def buffered_future_mask(self, tensor):
dim = tensor.size(0)
# self._future_mask.device != tensor.device is not working in TorchScript. This is a workaround.
if (
self._future_mask.size(0) == 0
or (not self._future_mask.device == tensor.device)
or self._future_mask.size(0) < dim
):
self._future_mask = torch.triu(
utils.fill_with_neg_inf(torch.zeros([dim, dim])), 1
)
self._future_mask = self._future_mask.to(tensor)
return self._future_mask[:dim, :dim]
def get_normalized_probs_scriptable(
self,
net_output: Tuple[Tensor, Optional[Dict[str, List[Optional[Tensor]]]]],
log_probs: bool,
sample: Optional[Dict[str, Tensor]] = None,
):
"""Get normalized probabilities (or log probs) from a net's output."""
logits_dict = net_output[0]
out_dict = {}
for channel in logits_dict:
out_dict[channel] = {}
for pred_channel in logits_dict[channel]:
if isinstance(logits_dict[channel][pred_channel], dict):
pred_token_logits = logits_dict[channel][pred_channel]["pred_token"]
else:
pred_token_logits = logits_dict[channel][pred_channel]
if log_probs:
out = utils.log_softmax(
pred_token_logits, dim=-1, onnx_trace=self.onnx_trace
)
else:
out = utils.softmax(
pred_token_logits, dim=-1, onnx_trace=self.onnx_trace
)
if isinstance(logits_dict[channel][pred_channel], dict):
out_dict[channel][pred_channel] = {
"pred_token": out,
"pred_duration": logits_dict[channel][pred_channel][
"pred_duration"
].float(),
} # move to float32 to avoid inf loss
else:
out_dict[channel][pred_channel] = out
return out_dict
def reorder_incremental_state_scripting(
self,
incremental_state: List[Dict[str, Dict[str, Optional[Tensor]]]],
new_order: Tensor,
):
"""Main entry point for reordering the incremental state.
Due to limitations in TorchScript, we call this function in
:class:`fairseq.sequence_generator.SequenceGenerator` instead of
calling :func:`reorder_incremental_state` directly.
"""
for module in self.modules():
if hasattr(module, "reorder_incremental_state"):
for i, incremental_state_channel in enumerate(incremental_state):
result = module.reorder_incremental_state(
incremental_state_channel, new_order
)
if result is not None:
incremental_state[i] = result
|