File size: 3,532 Bytes
1b703d5
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
"""Small MLP factory for DINAC-AE DiT blocks."""

from __future__ import annotations

from collections.abc import Callable
from typing import Protocol, cast

import torch.nn.functional as F
from torch import Tensor, nn

from dit.mlp_types import MLPType


class Resettable(Protocol):
    """Typing protocol for modules with ``reset_parameters``."""

    def reset_parameters(self) -> None:
        """Reset module parameters."""


def reset_module_parameters(module: nn.Module) -> None:
    """Reset a module that exposes ``reset_parameters``."""

    cast(Resettable, module).reset_parameters()


class SimpleActivationMLP(nn.Module):
    """Feedforward MLP: ``down(activation(up(x)))``."""

    in_features: int
    hidden_features: int
    activation: Callable[[Tensor], Tensor]
    activation_name: str
    up: nn.Linear
    down: nn.Linear

    def __init__(
        self,
        in_features: int,
        hidden_features: int,
        *,
        activation: Callable[[Tensor], Tensor],
        activation_name: str,
        bias_up: bool,
        bias_down: bool,
    ) -> None:
        super().__init__()
        self.in_features = int(in_features)
        self.hidden_features = int(hidden_features)
        self.activation = activation
        self.activation_name = str(activation_name)
        self.up = nn.Linear(self.in_features, self.hidden_features, bias=bias_up)
        self.down = nn.Linear(self.hidden_features, self.in_features, bias=bias_down)
        self.reset_parameters()

    def reset_parameters(self) -> None:
        """Reset linear projections."""

        nn.init.xavier_uniform_(self.up.weight)
        if self.up.bias is not None:
            nn.init.zeros_(self.up.bias)
        nn.init.xavier_uniform_(self.down.weight)
        if self.down.bias is not None:
            nn.init.zeros_(self.down.bias)

    def forward(self, x: Tensor) -> Tensor:  # type: ignore[override]
        """Apply the MLP."""

        return self.down(self.activation(self.up(x)))


def build_dit_mlp(
    *,
    mlp_type: MLPType,
    in_features: int,
    hidden_budget: int,
    activation_config: object | None = None,
    block_index: int = 0,
    bias_up: bool = False,
    bias_down: bool = False,
) -> nn.Module:
    """Build the exported MLP variant."""

    _ = activation_config, block_index
    match mlp_type:
        case MLPType.GELU:
            return SimpleActivationMLP(
                in_features=int(in_features),
                hidden_features=int(hidden_budget),
                activation=F.gelu,
                activation_name="gelu",
                bias_up=bool(bias_up),
                bias_down=bool(bias_down),
            )
        case MLPType.SILU:
            return SimpleActivationMLP(
                in_features=int(in_features),
                hidden_features=int(hidden_budget),
                activation=F.silu,
                activation_name="silu",
                bias_up=bool(bias_up),
                bias_down=bool(bias_down),
            )
        case MLPType.RELU:
            return SimpleActivationMLP(
                in_features=int(in_features),
                hidden_features=int(hidden_budget),
                activation=F.relu,
                activation_name="relu",
                bias_up=bool(bias_up),
                bias_down=bool(bias_down),
            )
        case _ as unreachable:
            raise ValueError(f"Unsupported exported MLP type: {unreachable}")


__all__ = ["SimpleActivationMLP", "build_dit_mlp", "reset_module_parameters"]