dinac_ae / dit /mlp.py
data-archetype's picture
Upload DINAC-AE export package
1b703d5
"""Small MLP factory for DINAC-AE DiT blocks."""
from __future__ import annotations
from collections.abc import Callable
from typing import Protocol, cast
import torch.nn.functional as F
from torch import Tensor, nn
from dit.mlp_types import MLPType
class Resettable(Protocol):
"""Typing protocol for modules with ``reset_parameters``."""
def reset_parameters(self) -> None:
"""Reset module parameters."""
def reset_module_parameters(module: nn.Module) -> None:
"""Reset a module that exposes ``reset_parameters``."""
cast(Resettable, module).reset_parameters()
class SimpleActivationMLP(nn.Module):
"""Feedforward MLP: ``down(activation(up(x)))``."""
in_features: int
hidden_features: int
activation: Callable[[Tensor], Tensor]
activation_name: str
up: nn.Linear
down: nn.Linear
def __init__(
self,
in_features: int,
hidden_features: int,
*,
activation: Callable[[Tensor], Tensor],
activation_name: str,
bias_up: bool,
bias_down: bool,
) -> None:
super().__init__()
self.in_features = int(in_features)
self.hidden_features = int(hidden_features)
self.activation = activation
self.activation_name = str(activation_name)
self.up = nn.Linear(self.in_features, self.hidden_features, bias=bias_up)
self.down = nn.Linear(self.hidden_features, self.in_features, bias=bias_down)
self.reset_parameters()
def reset_parameters(self) -> None:
"""Reset linear projections."""
nn.init.xavier_uniform_(self.up.weight)
if self.up.bias is not None:
nn.init.zeros_(self.up.bias)
nn.init.xavier_uniform_(self.down.weight)
if self.down.bias is not None:
nn.init.zeros_(self.down.bias)
def forward(self, x: Tensor) -> Tensor: # type: ignore[override]
"""Apply the MLP."""
return self.down(self.activation(self.up(x)))
def build_dit_mlp(
*,
mlp_type: MLPType,
in_features: int,
hidden_budget: int,
activation_config: object | None = None,
block_index: int = 0,
bias_up: bool = False,
bias_down: bool = False,
) -> nn.Module:
"""Build the exported MLP variant."""
_ = activation_config, block_index
match mlp_type:
case MLPType.GELU:
return SimpleActivationMLP(
in_features=int(in_features),
hidden_features=int(hidden_budget),
activation=F.gelu,
activation_name="gelu",
bias_up=bool(bias_up),
bias_down=bool(bias_down),
)
case MLPType.SILU:
return SimpleActivationMLP(
in_features=int(in_features),
hidden_features=int(hidden_budget),
activation=F.silu,
activation_name="silu",
bias_up=bool(bias_up),
bias_down=bool(bias_down),
)
case MLPType.RELU:
return SimpleActivationMLP(
in_features=int(in_features),
hidden_features=int(hidden_budget),
activation=F.relu,
activation_name="relu",
bias_up=bool(bias_up),
bias_down=bool(bias_down),
)
case _ as unreachable:
raise ValueError(f"Unsupported exported MLP type: {unreachable}")
__all__ = ["SimpleActivationMLP", "build_dit_mlp", "reset_module_parameters"]