Spaces:
Runtime error
Runtime error
| # SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. | |
| # SPDX-License-Identifier: MIT | |
| # | |
| # Permission is hereby granted, free of charge, to any person obtaining a | |
| # copy of this software and associated documentation files (the "Software"), | |
| # to deal in the Software without restriction, including without limitation | |
| # the rights to use, copy, modify, merge, publish, distribute, sublicense, | |
| # and/or sell copies of the Software, and to permit persons to whom the | |
| # Software is furnished to do so, subject to the following conditions: | |
| # | |
| # The above copyright notice and this permission notice shall be included in | |
| # all copies or substantial portions of the Software. | |
| # | |
| # THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR | |
| # IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, | |
| # FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL | |
| # THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER | |
| # LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING | |
| # FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER | |
| # DEALINGS IN THE SOFTWARE. | |
| # Copyright (c) Kyutai, all rights reserved. | |
| # This source code is licensed under the license found in the | |
| # LICENSE file in the root directory of this source tree. | |
| import torch | |
| from torch import nn | |
| from torch.nn import functional as F | |
| from ..utils.compile import torch_compile_lazy | |
| def gating_forward_kernel( | |
| weight_in: torch.Tensor, weight_out: torch.Tensor, activation, x: torch.Tensor | |
| ): | |
| x = F.linear(x, weight_in) | |
| B, T, _ = x.shape | |
| x = x.view(B, T, 2, -1) | |
| x = activation(x[..., 0, :]) * x[..., 1, :] | |
| x = F.linear(x, weight_out) | |
| return x | |
| class ActivationGating(nn.Module): | |
| """ | |
| Gating FFN layer, using the given activation. | |
| Args: | |
| dim (int): dimension of the input and output of the transformer. | |
| activation (any callable Tensor to Tensor): activation function to use. | |
| **factory_kwargs: other kwargs passed to the linear layer, in particular device and dtype. | |
| """ | |
| _fsdp_final = True | |
| def __init__(self, dim: int, dim_feedforward: int, activation, **factory_kwargs): | |
| super().__init__() | |
| # We should have 8 d^2 param, instead we will have | |
| # 2 * h * d + h * d = 3 h * d = 8 d^2 | |
| # so h = 8 d / 3 but following Hervé's advice we use 21 / 8 as an approx. | |
| if dim_feedforward == 4 * dim: | |
| hidden = (21 * dim) // 8 | |
| else: | |
| hidden = (2 * dim_feedforward) // 3 | |
| self.linear_in = nn.Linear(dim, 2 * hidden, bias=False, **factory_kwargs) | |
| self.linear_out = nn.Linear(hidden, dim, bias=False, **factory_kwargs) | |
| self.activation = activation | |
| def forward(self, x: torch.Tensor): | |
| return gating_forward_kernel( | |
| self.linear_in.weight, self.linear_out.weight, self.activation, x | |
| ) | |
| def _get_activation(name: str): | |
| if name in ["sigmoid", "tanh", "relu"]: | |
| return getattr(torch, name) | |
| elif name in ["leaky_relu", "elu", "gelu", "silu", "mish", "softsign"]: | |
| return getattr(torch.nn.functional, name) | |
| elif name == "identity": | |
| return torch.nn.Identity() | |
| else: | |
| raise ValueError(f"Unknown activation {name}") | |
| def _make_gating( | |
| name: str, dim: int, dim_feedforward: int, **factory_kwargs | |
| ) -> nn.Module: | |
| return ActivationGating( | |
| dim, dim_feedforward, _get_activation(name), **factory_kwargs | |
| ) | |
| def make_gating( | |
| name: str, dim: int, dim_feedforward: int, **factory_kwargs | |
| ) -> nn.Module: | |
| gating = _make_gating(name, dim, dim_feedforward, **factory_kwargs) | |
| max_params = 2 * dim * dim_feedforward | |
| params = sum(p.numel() for p in gating.parameters()) | |
| assert ( | |
| params <= max_params | |
| ), f"{name} gating has {params} params, max is {max_params}" | |
| return gating | |