Spaces:
Sleeping
Sleeping
| # SPDX-FileCopyrightText: Copyright (c) 2023 - 2025 NVIDIA CORPORATION & AFFILIATES. | |
| # SPDX-FileCopyrightText: All rights reserved. | |
| # SPDX-License-Identifier: Apache-2.0 | |
| # | |
| # Licensed under the Apache License, Version 2.0 (the "License"); | |
| # you may not use this file except in compliance with the License. | |
| # You may obtain a copy of the License at | |
| # | |
| # http://www.apache.org/licenses/LICENSE-2.0 | |
| # | |
| # Unless required by applicable law or agreed to in writing, software | |
| # distributed under the License is distributed on an "AS IS" BASIS, | |
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |
| # See the License for the specific language governing permissions and | |
| # limitations under the License. | |
| # ruff: noqa: F722 | |
| from typing import List | |
| import torch.nn as nn | |
| from jaxtyping import Float | |
| from torch import Tensor | |
| class LinearBlock(nn.Module): | |
| """Simple linear block with ReLU and dropout | |
| Parameters | |
| ---------- | |
| in_channels : int | |
| Number of input channels | |
| out_channels : int | |
| Number of output channels | |
| activation : type[nn.Module] | |
| Activation function, default nn.GELU | |
| """ | |
| def __init__( | |
| self, | |
| in_channels: int, | |
| out_channels: int, | |
| activation: type[nn.Module] = nn.GELU, | |
| ): | |
| super().__init__() | |
| self.block = nn.Sequential( | |
| nn.Linear(in_channels, out_channels, bias=False), | |
| nn.LayerNorm(out_channels), | |
| activation(), | |
| ) | |
| def forward(self, x: Float[Tensor, "... C1"]) -> Float[Tensor, "... C2"]: | |
| return self.block(x) | |
| class ResidualLinearBlock(nn.Module): | |
| """MLPBlock.""" | |
| def __init__( | |
| self, | |
| in_channels: int, | |
| out_channels: int, | |
| hidden_channels: int = None, | |
| activation: type[nn.Module] = nn.GELU, | |
| ): | |
| super().__init__() | |
| if hidden_channels is None: | |
| hidden_channels = in_channels | |
| self.blocks = nn.Sequential( | |
| nn.Linear(in_channels, hidden_channels), | |
| nn.LayerNorm(hidden_channels), | |
| activation(), | |
| nn.Linear(hidden_channels, out_channels), | |
| nn.LayerNorm(out_channels), | |
| ) | |
| self.shortcut = ( | |
| nn.Identity() | |
| if in_channels == out_channels | |
| else nn.Linear(in_channels, out_channels) | |
| ) | |
| self.activation = activation() | |
| def forward(self, x): | |
| out = self.blocks(x) | |
| # add skip connection | |
| out = self.activation(out + self.shortcut(x)) | |
| return out | |
| class MLP(nn.Module): | |
| """Multi-layer perceptron | |
| Parameters | |
| ---------- | |
| in_channels : int | |
| Number of input channels | |
| out_channels : int | |
| Number of output channels | |
| hidden_channels : int | |
| Number of inernal channels in the MLP. | |
| use_residual : bool, optional | |
| Whether to use residual connections, default False. | |
| activation : type[nn.Module] | |
| Activation function, default nn.GELU | |
| """ | |
| def __init__( | |
| self, | |
| in_channels: int, | |
| out_channels: int, | |
| hidden_channels: List[int], | |
| use_residual: bool = False, | |
| activation: type[nn.Module] = nn.GELU, | |
| ): | |
| """ | |
| :param channels: list of channels | |
| :param dropout: dropout rate | |
| """ | |
| super().__init__() | |
| self.layers = nn.ModuleList() | |
| channels = [in_channels] + hidden_channels + [out_channels] | |
| for i in range(len(channels) - 1): | |
| if use_residual and i < len(channels) - 2: | |
| self.layers.append( | |
| ResidualLinearBlock( | |
| channels[i], | |
| channels[i + 1], | |
| activation=activation, | |
| ) | |
| ) | |
| else: | |
| self.layers.append( | |
| LinearBlock(channels[i], channels[i + 1], activation=activation) | |
| ) | |
| def forward(self, x: Float[Tensor, "... C1"]) -> Float[Tensor, "... C2"]: | |
| """ | |
| Forward pass | |
| """ | |
| for layer in self.layers: | |
| x = layer(x) | |
| return x | |
| class MLPBlock(nn.Module): | |
| """MLPBlock.""" | |
| def __init__( | |
| self, | |
| in_channels: int, | |
| hidden_channels: int = None, | |
| out_channels: int = None, | |
| activation: type[nn.Module] = nn.GELU, | |
| ): | |
| super().__init__() | |
| if hidden_channels is None: | |
| hidden_channels = in_channels | |
| if out_channels is None: | |
| out_channels = in_channels | |
| self.in_channels = in_channels | |
| self.fc1 = nn.Linear(in_channels, hidden_channels) | |
| self.norm1 = nn.LayerNorm(hidden_channels) | |
| self.fc2 = nn.Linear(hidden_channels, out_channels) | |
| self.norm2 = nn.LayerNorm(out_channels) | |
| self.shortcut = nn.Linear(in_channels, out_channels) | |
| self.activation = activation() | |
| def forward(self, x): | |
| out = self.activation(self.norm1(self.fc1(x))) | |
| out = self.norm2(self.fc2(out)) | |
| # add skip connection | |
| out = self.activation(out + self.shortcut(x)) | |
| return out | |