Spaces:
Sleeping
Sleeping
| # SPDX-FileCopyrightText: Copyright (c) 2023 - 2025 NVIDIA CORPORATION & AFFILIATES. | |
| # SPDX-FileCopyrightText: All rights reserved. | |
| # SPDX-License-Identifier: Apache-2.0 | |
| # | |
| # Licensed under the Apache License, Version 2.0 (the "License"); | |
| # you may not use this file except in compliance with the License. | |
| # You may obtain a copy of the License at | |
| # | |
| # http://www.apache.org/licenses/LICENSE-2.0 | |
| # | |
| # Unless required by applicable law or agreed to in writing, software | |
| # distributed under the License is distributed on an "AS IS" BASIS, | |
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |
| # See the License for the specific language governing permissions and | |
| # limitations under the License. | |
| import copy | |
| from typing import Callable, Optional, Union | |
| import torch | |
| from torch import Tensor | |
| from torch.nn import Dropout, LayerNorm, Linear, Module, ModuleList, MultiheadAttention | |
| from torch.nn import functional as F | |
| class TransformerDecoder(Module): | |
| r"""TransformerDecoder is a stack of N decoder layers | |
| Parameters | |
| ----------: | |
| decoder_layer: torch.nn.Module | |
| Layer used for the doceder | |
| num_layers: int | |
| Number of sub-decoder-layers in the decoder. | |
| norm: str | |
| Layer normalization component. | |
| """ | |
| __constants__ = ["norm"] | |
| def __init__(self, decoder_layer, num_layers, norm=None): | |
| super().__init__() | |
| torch._C._log_api_usage_once(f"torch.nn.modules.{self.__class__.__name__}") | |
| self.layers = _get_clones(decoder_layer, num_layers) | |
| self.num_layers = num_layers | |
| self.norm = norm | |
| def forward( | |
| self, | |
| tgt: Tensor, | |
| tgt_mask: Optional[Tensor] = None, | |
| tgt_key_padding_mask: Optional[Tensor] = None, | |
| tgt_is_causal: Optional[bool] = None, | |
| ) -> Tensor: | |
| """Pass the inputs (and mask) through the decoder layer in turn.""" | |
| output = tgt | |
| tgt_is_causal = True | |
| for mod in self.layers: | |
| output = mod( | |
| output, | |
| tgt_mask=tgt_mask, | |
| tgt_key_padding_mask=tgt_key_padding_mask, | |
| tgt_is_causal=tgt_is_causal, | |
| ) | |
| if self.norm is not None: | |
| output = self.norm(output) | |
| return output | |
| class DecoderOnlyLayer(Module): | |
| r""" | |
| Parameters | |
| ---------- | |
| d_model: int | |
| Number of expected features in the input. | |
| nhead: int | |
| Number of heads in the multiheadattention models. | |
| dim_feedforward: int | |
| Dimension of the feedforward network model, by default 2048. | |
| dropout: float | |
| The dropout value, by default 0.1. | |
| activation: str | |
| The activation function of the intermediate layer, by default 'relu'. | |
| layer_norm_eps: float | |
| The eps value in layer normalization components, by default 1e-5. | |
| batch_first: Bool | |
| If ``True``, then the input and output tensors are provided | |
| as (batch, seq, feature), by default ``False`` (seq, batch, feature). | |
| norm_first: Bool | |
| If ``True``, layer norm is done prior to self attention, multihead | |
| attention and feedforward operations, respectively. Otherwise it's done after, | |
| by default ``False`` (after). | |
| bias: If set to ``False``, ``Linear`` and ``LayerNorm`` layers will not learn an additive | |
| bias. Default: ``True``. | |
| """ | |
| __constants__ = ["norm_first"] | |
| def __init__( | |
| self, | |
| d_model: int, | |
| nhead: int, | |
| dim_feedforward: int = 2048, | |
| dropout: float = 0.1, | |
| activation: Union[str, Callable[[Tensor], Tensor]] = F.relu, | |
| layer_norm_eps: float = 1e-5, | |
| batch_first: bool = False, | |
| norm_first: bool = False, | |
| bias: bool = True, | |
| device=None, | |
| dtype=None, | |
| ) -> None: | |
| factory_kwargs = {"device": device, "dtype": dtype} | |
| super().__init__() | |
| self.self_attn = MultiheadAttention( | |
| d_model, | |
| nhead, | |
| dropout=dropout, | |
| batch_first=batch_first, | |
| bias=bias, | |
| **factory_kwargs, | |
| ) | |
| self.multihead_attn = MultiheadAttention( | |
| d_model, | |
| nhead, | |
| dropout=dropout, | |
| batch_first=batch_first, | |
| bias=bias, | |
| **factory_kwargs, | |
| ) | |
| # Implementation of Feedforward model | |
| self.linear1 = Linear(d_model, dim_feedforward, bias=bias, **factory_kwargs) | |
| self.dropout = Dropout(dropout) | |
| self.linear2 = Linear(dim_feedforward, d_model, bias=bias, **factory_kwargs) | |
| self.norm_first = norm_first | |
| self.norm1 = LayerNorm(d_model, eps=layer_norm_eps, bias=bias, **factory_kwargs) | |
| self.norm2 = LayerNorm(d_model, eps=layer_norm_eps, bias=bias, **factory_kwargs) | |
| self.norm3 = LayerNorm(d_model, eps=layer_norm_eps, bias=bias, **factory_kwargs) | |
| self.dropout1 = Dropout(dropout) | |
| self.dropout2 = Dropout(dropout) | |
| self.dropout3 = Dropout(dropout) | |
| # Legacy string support for activation function. | |
| if isinstance(activation, str): | |
| self.activation = _get_activation_fn(activation) | |
| else: | |
| self.activation = activation | |
| def __setstate__(self, state): | |
| if "activation" not in state: | |
| state["activation"] = F.relu | |
| super().__setstate__(state) | |
| def forward( | |
| self, | |
| tgt: Tensor, | |
| tgt_mask: Optional[Tensor] = None, | |
| tgt_key_padding_mask: Optional[Tensor] = None, | |
| tgt_is_causal: bool = False, | |
| ) -> Tensor: | |
| r"""Pass the inputs (and mask) through the decoder layer.""" | |
| # see Fig. 1 of https://arxiv.org/pdf/2002.04745v1.pdf | |
| x = tgt | |
| if self.norm_first: | |
| x = x + self._sa_block( | |
| self.norm1(x), tgt_mask, tgt_key_padding_mask, tgt_is_causal | |
| ) | |
| x = x + self._mha_block( | |
| self.norm2(x), tgt_mask, tgt_key_padding_mask, tgt_is_causal | |
| ) | |
| x = x + self._ff_block(self.norm3(x)) | |
| else: | |
| x = self.norm1( | |
| x + self._sa_block(x, tgt_mask, tgt_key_padding_mask, tgt_is_causal) | |
| ) | |
| x = self.norm2( | |
| x + self._mha_block(x, tgt_mask, tgt_key_padding_mask, tgt_is_causal) | |
| ) | |
| x = self.norm3(x + self._ff_block(x)) | |
| return x | |
| # self-attention block | |
| def _sa_block( | |
| self, | |
| x: Tensor, | |
| attn_mask: Optional[Tensor], | |
| key_padding_mask: Optional[Tensor], | |
| is_causal: bool = False, | |
| ) -> Tensor: | |
| x = self.self_attn( | |
| x, | |
| x, | |
| x, | |
| attn_mask=attn_mask, | |
| key_padding_mask=key_padding_mask, | |
| is_causal=is_causal, | |
| need_weights=False, | |
| )[0] | |
| return self.dropout1(x) | |
| # multihead attention block | |
| def _mha_block( | |
| self, | |
| x: Tensor, | |
| attn_mask: Optional[Tensor], | |
| key_padding_mask: Optional[Tensor], | |
| is_causal: bool = False, | |
| ) -> Tensor: | |
| x = self.multihead_attn( | |
| x, | |
| x, | |
| x, | |
| attn_mask=attn_mask, | |
| key_padding_mask=key_padding_mask, | |
| is_causal=is_causal, | |
| need_weights=False, | |
| )[0] | |
| return self.dropout2(x) | |
| # feed forward block | |
| def _ff_block(self, x: Tensor) -> Tensor: | |
| x = self.linear2(self.dropout(self.activation(self.linear1(x)))) | |
| return self.dropout3(x) | |
| def _get_clones(module, N): | |
| # FIXME: copy.deepcopy() is not defined on nn.module | |
| return ModuleList([copy.deepcopy(module) for i in range(N)]) | |
| def _get_activation_fn(activation: str) -> Callable[[Tensor], Tensor]: | |
| if activation == "relu": | |
| return F.relu | |
| elif activation == "gelu": | |
| return F.gelu | |
| raise RuntimeError(f"activation should be relu/gelu, not {activation}") | |
| def _detect_is_causal_mask( | |
| mask: Optional[Tensor], | |
| is_causal: Optional[bool] = None, | |
| size: Optional[int] = None, | |
| ) -> bool: | |
| """Return whether the given attention mask is causal.""" | |
| # Prevent type refinement | |
| make_causal = is_causal is True | |
| if is_causal is None and mask is not None: | |
| sz = size if size is not None else mask.size(-2) | |
| # ruff: noqa: F821 | |
| causal_comparison = _generate_square_subsequent_mask( | |
| sz, device=mask.device, dtype=mask.dtype | |
| ) | |
| # Do not use `torch.equal` so we handle batched masks by | |
| # broadcasting the comparison. | |
| if mask.size() == causal_comparison.size(): | |
| make_causal = bool((mask == causal_comparison).all()) | |
| else: | |
| make_causal = False | |
| return make_causal | |
| def _get_seq_len(src: Tensor, batch_first: bool) -> Optional[int]: | |
| if src.is_nested: | |
| return None | |
| else: | |
| src_size = src.size() | |
| if len(src_size) == 2: | |
| # unbatched: S, E | |
| return src_size[0] | |
| else: | |
| # batched: B, S, E if batch_first else S, B, E | |
| seq_len_pos = 1 if batch_first else 0 | |
| return src_size[seq_len_pos] | |