# SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. # SPDX-License-Identifier: MIT # # Permission is hereby granted, free of charge, to any person obtaining a # copy of this software and associated documentation files (the "Software"), # to deal in the Software without restriction, including without limitation # the rights to use, copy, modify, merge, publish, distribute, sublicense, # and/or sell copies of the Software, and to permit persons to whom the # Software is furnished to do so, subject to the following conditions: # # The above copyright notice and this permission notice shall be included in # all copies or substantial portions of the Software. # # THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR # IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, # FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL # THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER # LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING # FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER # DEALINGS IN THE SOFTWARE. # Copyright (c) Kyutai, all rights reserved. # This source code is licensed under the license found in the # LICENSE file in the root directory of this source tree. import torch from torch import nn from torch.nn import functional as F from ..utils.compile import torch_compile_lazy @torch_compile_lazy def gating_forward_kernel( weight_in: torch.Tensor, weight_out: torch.Tensor, activation, x: torch.Tensor ): x = F.linear(x, weight_in) B, T, _ = x.shape x = x.view(B, T, 2, -1) x = activation(x[..., 0, :]) * x[..., 1, :] x = F.linear(x, weight_out) return x class ActivationGating(nn.Module): """ Gating FFN layer, using the given activation. Args: dim (int): dimension of the input and output of the transformer. activation (any callable Tensor to Tensor): activation function to use. **factory_kwargs: other kwargs passed to the linear layer, in particular device and dtype. """ _fsdp_final = True def __init__(self, dim: int, dim_feedforward: int, activation, **factory_kwargs): super().__init__() # We should have 8 d^2 param, instead we will have # 2 * h * d + h * d = 3 h * d = 8 d^2 # so h = 8 d / 3 but following Hervé's advice we use 21 / 8 as an approx. if dim_feedforward == 4 * dim: hidden = (21 * dim) // 8 else: hidden = (2 * dim_feedforward) // 3 self.linear_in = nn.Linear(dim, 2 * hidden, bias=False, **factory_kwargs) self.linear_out = nn.Linear(hidden, dim, bias=False, **factory_kwargs) self.activation = activation def forward(self, x: torch.Tensor): return gating_forward_kernel( self.linear_in.weight, self.linear_out.weight, self.activation, x ) def _get_activation(name: str): if name in ["sigmoid", "tanh", "relu"]: return getattr(torch, name) elif name in ["leaky_relu", "elu", "gelu", "silu", "mish", "softsign"]: return getattr(torch.nn.functional, name) elif name == "identity": return torch.nn.Identity() else: raise ValueError(f"Unknown activation {name}") def _make_gating( name: str, dim: int, dim_feedforward: int, **factory_kwargs ) -> nn.Module: return ActivationGating( dim, dim_feedforward, _get_activation(name), **factory_kwargs ) def make_gating( name: str, dim: int, dim_feedforward: int, **factory_kwargs ) -> nn.Module: gating = _make_gating(name, dim, dim_feedforward, **factory_kwargs) max_params = 2 * dim * dim_feedforward params = sum(p.numel() for p in gating.parameters()) assert ( params <= max_params ), f"{name} gating has {params} params, max is {max_params}" return gating