# SPDX-FileCopyrightText: Copyright (c) 2023 - 2025 NVIDIA CORPORATION & AFFILIATES. # SPDX-FileCopyrightText: All rights reserved. # SPDX-License-Identifier: Apache-2.0 # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. import copy from typing import Callable, Optional, Union import torch from torch import Tensor from torch.nn import Dropout, LayerNorm, Linear, Module, ModuleList, MultiheadAttention from torch.nn import functional as F class TransformerDecoder(Module): r"""TransformerDecoder is a stack of N decoder layers Parameters ----------: decoder_layer: torch.nn.Module Layer used for the doceder num_layers: int Number of sub-decoder-layers in the decoder. norm: str Layer normalization component. """ __constants__ = ["norm"] def __init__(self, decoder_layer, num_layers, norm=None): super().__init__() torch._C._log_api_usage_once(f"torch.nn.modules.{self.__class__.__name__}") self.layers = _get_clones(decoder_layer, num_layers) self.num_layers = num_layers self.norm = norm def forward( self, tgt: Tensor, tgt_mask: Optional[Tensor] = None, tgt_key_padding_mask: Optional[Tensor] = None, tgt_is_causal: Optional[bool] = None, ) -> Tensor: """Pass the inputs (and mask) through the decoder layer in turn.""" output = tgt tgt_is_causal = True for mod in self.layers: output = mod( output, tgt_mask=tgt_mask, tgt_key_padding_mask=tgt_key_padding_mask, tgt_is_causal=tgt_is_causal, ) if self.norm is not None: output = self.norm(output) return output class DecoderOnlyLayer(Module): r""" Parameters ---------- d_model: int Number of expected features in the input. nhead: int Number of heads in the multiheadattention models. dim_feedforward: int Dimension of the feedforward network model, by default 2048. dropout: float The dropout value, by default 0.1. activation: str The activation function of the intermediate layer, by default 'relu'. layer_norm_eps: float The eps value in layer normalization components, by default 1e-5. batch_first: Bool If ``True``, then the input and output tensors are provided as (batch, seq, feature), by default ``False`` (seq, batch, feature). norm_first: Bool If ``True``, layer norm is done prior to self attention, multihead attention and feedforward operations, respectively. Otherwise it's done after, by default ``False`` (after). bias: If set to ``False``, ``Linear`` and ``LayerNorm`` layers will not learn an additive bias. Default: ``True``. """ __constants__ = ["norm_first"] def __init__( self, d_model: int, nhead: int, dim_feedforward: int = 2048, dropout: float = 0.1, activation: Union[str, Callable[[Tensor], Tensor]] = F.relu, layer_norm_eps: float = 1e-5, batch_first: bool = False, norm_first: bool = False, bias: bool = True, device=None, dtype=None, ) -> None: factory_kwargs = {"device": device, "dtype": dtype} super().__init__() self.self_attn = MultiheadAttention( d_model, nhead, dropout=dropout, batch_first=batch_first, bias=bias, **factory_kwargs, ) self.multihead_attn = MultiheadAttention( d_model, nhead, dropout=dropout, batch_first=batch_first, bias=bias, **factory_kwargs, ) # Implementation of Feedforward model self.linear1 = Linear(d_model, dim_feedforward, bias=bias, **factory_kwargs) self.dropout = Dropout(dropout) self.linear2 = Linear(dim_feedforward, d_model, bias=bias, **factory_kwargs) self.norm_first = norm_first self.norm1 = LayerNorm(d_model, eps=layer_norm_eps, bias=bias, **factory_kwargs) self.norm2 = LayerNorm(d_model, eps=layer_norm_eps, bias=bias, **factory_kwargs) self.norm3 = LayerNorm(d_model, eps=layer_norm_eps, bias=bias, **factory_kwargs) self.dropout1 = Dropout(dropout) self.dropout2 = Dropout(dropout) self.dropout3 = Dropout(dropout) # Legacy string support for activation function. if isinstance(activation, str): self.activation = _get_activation_fn(activation) else: self.activation = activation def __setstate__(self, state): if "activation" not in state: state["activation"] = F.relu super().__setstate__(state) def forward( self, tgt: Tensor, tgt_mask: Optional[Tensor] = None, tgt_key_padding_mask: Optional[Tensor] = None, tgt_is_causal: bool = False, ) -> Tensor: r"""Pass the inputs (and mask) through the decoder layer.""" # see Fig. 1 of https://arxiv.org/pdf/2002.04745v1.pdf x = tgt if self.norm_first: x = x + self._sa_block( self.norm1(x), tgt_mask, tgt_key_padding_mask, tgt_is_causal ) x = x + self._mha_block( self.norm2(x), tgt_mask, tgt_key_padding_mask, tgt_is_causal ) x = x + self._ff_block(self.norm3(x)) else: x = self.norm1( x + self._sa_block(x, tgt_mask, tgt_key_padding_mask, tgt_is_causal) ) x = self.norm2( x + self._mha_block(x, tgt_mask, tgt_key_padding_mask, tgt_is_causal) ) x = self.norm3(x + self._ff_block(x)) return x # self-attention block def _sa_block( self, x: Tensor, attn_mask: Optional[Tensor], key_padding_mask: Optional[Tensor], is_causal: bool = False, ) -> Tensor: x = self.self_attn( x, x, x, attn_mask=attn_mask, key_padding_mask=key_padding_mask, is_causal=is_causal, need_weights=False, )[0] return self.dropout1(x) # multihead attention block def _mha_block( self, x: Tensor, attn_mask: Optional[Tensor], key_padding_mask: Optional[Tensor], is_causal: bool = False, ) -> Tensor: x = self.multihead_attn( x, x, x, attn_mask=attn_mask, key_padding_mask=key_padding_mask, is_causal=is_causal, need_weights=False, )[0] return self.dropout2(x) # feed forward block def _ff_block(self, x: Tensor) -> Tensor: x = self.linear2(self.dropout(self.activation(self.linear1(x)))) return self.dropout3(x) def _get_clones(module, N): # FIXME: copy.deepcopy() is not defined on nn.module return ModuleList([copy.deepcopy(module) for i in range(N)]) def _get_activation_fn(activation: str) -> Callable[[Tensor], Tensor]: if activation == "relu": return F.relu elif activation == "gelu": return F.gelu raise RuntimeError(f"activation should be relu/gelu, not {activation}") def _detect_is_causal_mask( mask: Optional[Tensor], is_causal: Optional[bool] = None, size: Optional[int] = None, ) -> bool: """Return whether the given attention mask is causal.""" # Prevent type refinement make_causal = is_causal is True if is_causal is None and mask is not None: sz = size if size is not None else mask.size(-2) # ruff: noqa: F821 causal_comparison = _generate_square_subsequent_mask( sz, device=mask.device, dtype=mask.dtype ) # Do not use `torch.equal` so we handle batched masks by # broadcasting the comparison. if mask.size() == causal_comparison.size(): make_causal = bool((mask == causal_comparison).all()) else: make_causal = False return make_causal def _get_seq_len(src: Tensor, batch_first: bool) -> Optional[int]: if src.is_nested: return None else: src_size = src.size() if len(src_size) == 2: # unbatched: S, E return src_size[0] else: # batched: B, S, E if batch_first else S, B, E seq_len_pos = 1 if batch_first else 0 return src_size[seq_len_pos]