ArthurY's picture
update source
c3d0544
# SPDX-FileCopyrightText: Copyright (c) 2023 - 2025 NVIDIA CORPORATION & AFFILIATES.
# SPDX-FileCopyrightText: All rights reserved.
# SPDX-License-Identifier: Apache-2.0
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import copy
from typing import Callable, Optional, Union
import torch
from torch import Tensor
from torch.nn import Dropout, LayerNorm, Linear, Module, ModuleList, MultiheadAttention
from torch.nn import functional as F
class TransformerDecoder(Module):
r"""TransformerDecoder is a stack of N decoder layers
Parameters
----------:
decoder_layer: torch.nn.Module
Layer used for the doceder
num_layers: int
Number of sub-decoder-layers in the decoder.
norm: str
Layer normalization component.
"""
__constants__ = ["norm"]
def __init__(self, decoder_layer, num_layers, norm=None):
super().__init__()
torch._C._log_api_usage_once(f"torch.nn.modules.{self.__class__.__name__}")
self.layers = _get_clones(decoder_layer, num_layers)
self.num_layers = num_layers
self.norm = norm
def forward(
self,
tgt: Tensor,
tgt_mask: Optional[Tensor] = None,
tgt_key_padding_mask: Optional[Tensor] = None,
tgt_is_causal: Optional[bool] = None,
) -> Tensor:
"""Pass the inputs (and mask) through the decoder layer in turn."""
output = tgt
tgt_is_causal = True
for mod in self.layers:
output = mod(
output,
tgt_mask=tgt_mask,
tgt_key_padding_mask=tgt_key_padding_mask,
tgt_is_causal=tgt_is_causal,
)
if self.norm is not None:
output = self.norm(output)
return output
class DecoderOnlyLayer(Module):
r"""
Parameters
----------
d_model: int
Number of expected features in the input.
nhead: int
Number of heads in the multiheadattention models.
dim_feedforward: int
Dimension of the feedforward network model, by default 2048.
dropout: float
The dropout value, by default 0.1.
activation: str
The activation function of the intermediate layer, by default 'relu'.
layer_norm_eps: float
The eps value in layer normalization components, by default 1e-5.
batch_first: Bool
If ``True``, then the input and output tensors are provided
as (batch, seq, feature), by default ``False`` (seq, batch, feature).
norm_first: Bool
If ``True``, layer norm is done prior to self attention, multihead
attention and feedforward operations, respectively. Otherwise it's done after,
by default ``False`` (after).
bias: If set to ``False``, ``Linear`` and ``LayerNorm`` layers will not learn an additive
bias. Default: ``True``.
"""
__constants__ = ["norm_first"]
def __init__(
self,
d_model: int,
nhead: int,
dim_feedforward: int = 2048,
dropout: float = 0.1,
activation: Union[str, Callable[[Tensor], Tensor]] = F.relu,
layer_norm_eps: float = 1e-5,
batch_first: bool = False,
norm_first: bool = False,
bias: bool = True,
device=None,
dtype=None,
) -> None:
factory_kwargs = {"device": device, "dtype": dtype}
super().__init__()
self.self_attn = MultiheadAttention(
d_model,
nhead,
dropout=dropout,
batch_first=batch_first,
bias=bias,
**factory_kwargs,
)
self.multihead_attn = MultiheadAttention(
d_model,
nhead,
dropout=dropout,
batch_first=batch_first,
bias=bias,
**factory_kwargs,
)
# Implementation of Feedforward model
self.linear1 = Linear(d_model, dim_feedforward, bias=bias, **factory_kwargs)
self.dropout = Dropout(dropout)
self.linear2 = Linear(dim_feedforward, d_model, bias=bias, **factory_kwargs)
self.norm_first = norm_first
self.norm1 = LayerNorm(d_model, eps=layer_norm_eps, bias=bias, **factory_kwargs)
self.norm2 = LayerNorm(d_model, eps=layer_norm_eps, bias=bias, **factory_kwargs)
self.norm3 = LayerNorm(d_model, eps=layer_norm_eps, bias=bias, **factory_kwargs)
self.dropout1 = Dropout(dropout)
self.dropout2 = Dropout(dropout)
self.dropout3 = Dropout(dropout)
# Legacy string support for activation function.
if isinstance(activation, str):
self.activation = _get_activation_fn(activation)
else:
self.activation = activation
def __setstate__(self, state):
if "activation" not in state:
state["activation"] = F.relu
super().__setstate__(state)
def forward(
self,
tgt: Tensor,
tgt_mask: Optional[Tensor] = None,
tgt_key_padding_mask: Optional[Tensor] = None,
tgt_is_causal: bool = False,
) -> Tensor:
r"""Pass the inputs (and mask) through the decoder layer."""
# see Fig. 1 of https://arxiv.org/pdf/2002.04745v1.pdf
x = tgt
if self.norm_first:
x = x + self._sa_block(
self.norm1(x), tgt_mask, tgt_key_padding_mask, tgt_is_causal
)
x = x + self._mha_block(
self.norm2(x), tgt_mask, tgt_key_padding_mask, tgt_is_causal
)
x = x + self._ff_block(self.norm3(x))
else:
x = self.norm1(
x + self._sa_block(x, tgt_mask, tgt_key_padding_mask, tgt_is_causal)
)
x = self.norm2(
x + self._mha_block(x, tgt_mask, tgt_key_padding_mask, tgt_is_causal)
)
x = self.norm3(x + self._ff_block(x))
return x
# self-attention block
def _sa_block(
self,
x: Tensor,
attn_mask: Optional[Tensor],
key_padding_mask: Optional[Tensor],
is_causal: bool = False,
) -> Tensor:
x = self.self_attn(
x,
x,
x,
attn_mask=attn_mask,
key_padding_mask=key_padding_mask,
is_causal=is_causal,
need_weights=False,
)[0]
return self.dropout1(x)
# multihead attention block
def _mha_block(
self,
x: Tensor,
attn_mask: Optional[Tensor],
key_padding_mask: Optional[Tensor],
is_causal: bool = False,
) -> Tensor:
x = self.multihead_attn(
x,
x,
x,
attn_mask=attn_mask,
key_padding_mask=key_padding_mask,
is_causal=is_causal,
need_weights=False,
)[0]
return self.dropout2(x)
# feed forward block
def _ff_block(self, x: Tensor) -> Tensor:
x = self.linear2(self.dropout(self.activation(self.linear1(x))))
return self.dropout3(x)
def _get_clones(module, N):
# FIXME: copy.deepcopy() is not defined on nn.module
return ModuleList([copy.deepcopy(module) for i in range(N)])
def _get_activation_fn(activation: str) -> Callable[[Tensor], Tensor]:
if activation == "relu":
return F.relu
elif activation == "gelu":
return F.gelu
raise RuntimeError(f"activation should be relu/gelu, not {activation}")
def _detect_is_causal_mask(
mask: Optional[Tensor],
is_causal: Optional[bool] = None,
size: Optional[int] = None,
) -> bool:
"""Return whether the given attention mask is causal."""
# Prevent type refinement
make_causal = is_causal is True
if is_causal is None and mask is not None:
sz = size if size is not None else mask.size(-2)
# ruff: noqa: F821
causal_comparison = _generate_square_subsequent_mask(
sz, device=mask.device, dtype=mask.dtype
)
# Do not use `torch.equal` so we handle batched masks by
# broadcasting the comparison.
if mask.size() == causal_comparison.size():
make_causal = bool((mask == causal_comparison).all())
else:
make_causal = False
return make_causal
def _get_seq_len(src: Tensor, batch_first: bool) -> Optional[int]:
if src.is_nested:
return None
else:
src_size = src.size()
if len(src_size) == 2:
# unbatched: S, E
return src_size[0]
else:
# batched: B, S, E if batch_first else S, B, E
seq_len_pos = 1 if batch_first else 0
return src_size[seq_len_pos]