| |
|
| | |
| | |
| | |
| | |
| | |
| |
|
| | from typing import Callable, Optional |
| |
|
| | from torch import Tensor, nn |
| | import torch.nn.functional as F |
| |
|
| |
|
| | class SwiGLUFFN(nn.Module): |
| | def __init__( |
| | self, |
| | in_features: int, |
| | hidden_features: Optional[int] = None, |
| | out_features: Optional[int] = None, |
| | act_layer: Callable[..., nn.Module] = None, |
| | drop: float = 0.0, |
| | bias: bool = True, |
| | ) -> None: |
| | super().__init__() |
| | out_features = out_features or in_features |
| | hidden_features = hidden_features or in_features |
| | self.w12 = nn.Linear(in_features, 2 * hidden_features, bias=bias) |
| | self.w3 = nn.Linear(hidden_features, out_features, bias=bias) |
| |
|
| | def forward(self, x: Tensor) -> Tensor: |
| | x12 = self.w12(x) |
| | x1, x2 = x12.chunk(2, dim=-1) |
| | hidden = F.silu(x1) * x2 |
| | return self.w3(hidden) |
| |
|
| |
|
| | try: |
| | from xformers.ops import SwiGLU |
| |
|
| | XFORMERS_AVAILABLE = True |
| | except ImportError: |
| | SwiGLU = SwiGLUFFN |
| | XFORMERS_AVAILABLE = False |
| |
|
| |
|
| | class SwiGLUFFNFused(SwiGLU): |
| | def __init__( |
| | self, |
| | in_features: int, |
| | hidden_features: Optional[int] = None, |
| | out_features: Optional[int] = None, |
| | act_layer: Callable[..., nn.Module] = None, |
| | drop: float = 0.0, |
| | bias: bool = True, |
| | ) -> None: |
| | out_features = out_features or in_features |
| | hidden_features = hidden_features or in_features |
| | hidden_features = (int(hidden_features * 2 / 3) + 7) // 8 * 8 |
| | super().__init__( |
| | in_features=in_features, |
| | hidden_features=hidden_features, |
| | out_features=out_features, |
| | bias=bias, |
| | ) |
| |
|