File size: 10,375 Bytes
772b344
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
# modules/hydra_layers.py

import math
import re
from typing import Any
import torch
from torch import Tensor
from torch.nn import (
    Module, ModuleList, Parameter, Buffer,
    Linear, LayerNorm, RMSNorm, Dropout, Flatten, Identity,
    init
)
from torch.nn.functional import pad, scaled_dot_product_attention, silu, gelu
from einops import rearrange

# --- GLU.PY ---
class GatedUnit(Module):
    def __init__(self, dim: int = -1) -> None:
        super().__init__()
        self.dim = dim
    def _activation(self, x: Tensor) -> Tensor:
        raise NotImplementedError
    def forward(self, x: Tensor) -> Tensor:
        f, g = x.chunk(2, dim=self.dim)
        return self._activation(f) * g

class SwiGLU(GatedUnit):
    def __init__(self, dim: int = -1) -> None:
        super().__init__(dim)
    def _activation(self, x: Tensor) -> Tensor:
        return silu(x)

# --- HYDRA_POOL.PY UTILS ---
class IndexedAdd(Module):
    def __init__(self, n_indices: int, dim: int, weight_shape: tuple[int, ...] | None = None, *, inplace: bool = False, device=None, dtype=None) -> None:
        super().__init__()
        self.dim = dim
        self.inplace = inplace
        self.index = Buffer(torch.empty(2, n_indices, device=device, dtype=torch.int32))
        self.weight = Parameter(torch.ones(*(sz if sz != -1 else n_indices for sz in weight_shape), device=device, dtype=dtype)) if weight_shape is not None else None

    def forward(self, dst: Tensor, src: Tensor) -> Tensor:
        src = src.index_select(self.dim, self.index[0])
        if self.weight is not None: src.mul_(self.weight)
        return dst.index_add_(self.dim, self.index[1], src) if self.inplace else dst.index_add(self.dim, self.index[1], src)

class BatchLinear(Module):
    def __init__(self, batch_shape: tuple[int, ...] | int, in_features: int, out_features: int, *, bias: bool = False, flatten: bool = False, bias_inplace: bool = True, device=None, dtype=None) -> None:
        super().__init__()
        if isinstance(batch_shape, int): batch_shape = (batch_shape,)
        self.flatten = -(len(batch_shape) + 1) if flatten else 0
        self.weight = Parameter(torch.empty(*batch_shape, in_features, out_features, device=device, dtype=dtype))
        bt = self.weight.flatten(end_dim=-3).mT
        for idx in range(bt.size(0)): init.kaiming_uniform_(bt[idx], a=math.sqrt(5))
        self.bias = Parameter(torch.zeros(*batch_shape, out_features, device=device, dtype=dtype)) if bias else None
        self.bias_inplace = bias_inplace

    def forward(self, x: Tensor) -> Tensor:
        x = torch.matmul(x.unsqueeze(-2), self.weight).squeeze(-2)
        if self.bias is not None:
            if self.bias_inplace: x.add_(self.bias)
            else: x = x + self.bias
        if self.flatten: x = x.flatten(self.flatten)
        return x

class Mean(Module):
    def __init__(self, dim: tuple[int, ...] | int = -1, *, keepdim: bool = False) -> None:
        super().__init__()
        self.dim = dim
        self.keepdim = keepdim
    def forward(self, x: Tensor) -> Tensor:
        return x.mean(self.dim, self.keepdim)

class _MidBlock(Module):
    def __init__(self, attn_dim: int, head_dim: int, n_classes: int, *, ff_ratio: float, ff_dropout: float, q_cls_inplace: bool = True, device=None, dtype=None) -> None:
        super().__init__()
        self.head_dim = head_dim
        self.q_cls_inplace = q_cls_inplace
        hidden_dim = int(attn_dim * ff_ratio)
        self.q_proj = Linear(attn_dim, attn_dim, bias=False, device=device, dtype=dtype)
        self.q_cls = Parameter(torch.zeros(n_classes, attn_dim, device=device, dtype=dtype))
        self.q_norm = RMSNorm(head_dim, eps=1e-5, elementwise_affine=False)
        self.attn_out = Linear(attn_dim, attn_dim, bias=False, device=device, dtype=dtype)
        self.ff_norm = LayerNorm(attn_dim, device=device, dtype=dtype)
        self.ff_in = Linear(attn_dim, hidden_dim * 2, bias=False, device=device, dtype=dtype)
        self.ff_act = SwiGLU()
        self.ff_drop = Dropout(ff_dropout)
        self.ff_out = Linear(hidden_dim, attn_dim, bias=False, device=device, dtype=dtype)

    def _forward_q(self, x: Tensor) -> Tensor:
        x = self.q_proj(x)
        if self.q_cls_inplace: x.add_(self.q_cls)
        else: x = x + self.q_cls
        x = self.q_norm(x)
        return rearrange(x, "... s (h e) -> ... h s e", e=self.head_dim)

    def _forward_attn(self, x: Tensor, k: Tensor, v: Tensor, attn_mask: Tensor | None) -> Tensor:
        a = scaled_dot_product_attention(self._forward_q(x), k, v, attn_mask=attn_mask)
        a = rearrange(a, "... h s e -> ... s (h e)")
        return x + self.attn_out(a)

    def _forward_ff(self, x: Tensor) -> Tensor:
        f = self.ff_out(self.ff_drop(self.ff_act(self.ff_in(self.ff_norm(x)))))
        return x + f

    def forward(self, x: Tensor, k: Tensor, v: Tensor, attn_mask: Tensor | None = None) -> Tensor:
        return self._forward_ff(self._forward_attn(x, k, v, attn_mask))

class HydraPool(Module):
    def __init__(self, attn_dim: int, head_dim: int, n_classes: int, *, mid_blocks: int = 0, roots: tuple[int, int, int] = (0, 0, 0), ff_ratio: float = 3.0, ff_dropout: float = 0.0, input_dim: int = -1, output_dim: int = 1, device=None, dtype=None) -> None:
        super().__init__()
        if input_dim < 0: input_dim = attn_dim
        self.n_classes = n_classes
        self.head_dim = head_dim
        self.output_dim = output_dim
        self._has_roots = False
        self._has_ff = False
        self._q_normed = None

        if roots != (0, 0, 0):
            self._has_roots = True
            n_roots, n_classroots, n_subclasses = roots
            self.cls = Parameter(torch.randn(attn_dim // head_dim, n_classes, head_dim, device=device, dtype=dtype))
            self.roots = Parameter(torch.randn(attn_dim // head_dim, n_roots, head_dim, device=device, dtype=dtype)) if n_roots > 0 else None
            self.clsroots = IndexedAdd(n_classroots, dim=-2, weight_shape=(attn_dim // head_dim, -1, 1), device=device, dtype=dtype) if n_classroots > 0 else None
            self.clscls = IndexedAdd(n_subclasses, dim=-2, weight_shape=(attn_dim // head_dim, -1, 1), inplace=True, device=device, dtype=dtype) if n_subclasses > 0 else None
            self.q = Buffer(torch.empty(attn_dim // head_dim, n_classes, head_dim, device=device, dtype=dtype))
        else:
            self.q = Parameter(torch.randn(attn_dim // head_dim, n_classes, head_dim, device=device, dtype=dtype))
            self._q_normed = False

        self.kv = Linear(input_dim, attn_dim * 2, bias=False, device=device, dtype=dtype)
        self.qk_norm = RMSNorm(head_dim, eps=1e-5, elementwise_affine=False)

        if ff_ratio > 0.0:
            self._has_ff = True
            hidden_dim = int(attn_dim * ff_ratio)
            self.ff_norm = LayerNorm(attn_dim, device=device, dtype=dtype)
            self.ff_in = Linear(attn_dim, hidden_dim * 2, bias=False, device=device, dtype=dtype)
            self.ff_act = SwiGLU()
            self.ff_drop = Dropout(ff_dropout)
            self.ff_out = Linear(hidden_dim, attn_dim, bias=False, device=device, dtype=dtype)

        self.mid_blocks = ModuleList(_MidBlock(attn_dim, head_dim, n_classes, ff_ratio=ff_ratio, ff_dropout=ff_dropout, device=device, dtype=dtype) for _ in range(mid_blocks))
        self.out_proj = BatchLinear(n_classes, attn_dim, output_dim * 2, device=device, dtype=dtype)
        self.out_act = SwiGLU()

    def create_head(self) -> Module:
        return Flatten(-2) if self.output_dim == 1 else Mean(-1)

    def _forward_q(self) -> Tensor:
        if self._q_normed is None:
            q = self.qk_norm(self.roots) if self.roots is not None else self.cls
            if self.clsroots is not None: q = self.clsroots(self.cls, q)
            if self.clscls is not None: q = self.clscls(q, q.detach())
            return self.qk_norm(q)
        elif self._q_normed is False: return self.qk_norm(self.q)
        else: return self.q

    def _forward_attn(self, x: Tensor, attn_mask: Tensor | None) -> tuple[Tensor, Tensor, Tensor]:
        q = self._forward_q().expand(*x.shape[:-2], -1, -1, -1)
        x = self.kv(x)
        k, v = rearrange(x, "... s (n h e) -> n ... h s e", n=2, e=self.head_dim).unbind(0)
        k = self.qk_norm(k)
        x = scaled_dot_product_attention(q, k, v, attn_mask=attn_mask)
        return rearrange(x, "... h s e -> ... s (h e)"), k, v

    def _forward_ff(self, x: Tensor) -> Tensor:
        if not self._has_ff: return x
        return x + self.ff_out(self.ff_drop(self.ff_act(self.ff_in(self.ff_norm(x)))))

    def forward(self, x: Tensor, attn_mask: Tensor | None = None) -> Tensor:
        x, k, v = self._forward_attn(x, attn_mask)
        x = self._forward_ff(x)
        for block in self.mid_blocks: x = block(x, k, v, attn_mask)
        return self.out_act(self.out_proj(x))

    @staticmethod
    def for_state(state_dict: dict[str, Any], prefix: str = "", *, ff_dropout: float = 0.0, device=None, dtype=None) -> "HydraPool":
        n_heads, n_classes, head_dim = state_dict[f"{prefix}q"].shape
        attn_dim = n_heads * head_dim
        roots_t, clsroots_t, clscls_t = state_dict.get(f"{prefix}roots"), state_dict.get(f"{prefix}clsroots.index"), state_dict.get(f"{prefix}clscls.index")
        roots = (roots_t.size(1) if roots_t is not None else 0, clsroots_t.size(1) if clsroots_t is not None else 0, clscls_t.size(1) if clscls_t is not None else 0)
        input_dim = state_dict[f"{prefix}kv.weight"].size(1)
        output_dim = state_dict[f"{prefix}out_proj.weight"].size(2) // 2
        ffout_t = state_dict.get(f"{prefix}ff_out.weight")
        hidden_dim = ffout_t.size(1) + 0.5 if ffout_t is not None else 0
        ff_ratio = hidden_dim / attn_dim
        pattern = re.compile(rf"^{re.escape(prefix)}mid_blocks\.([0-9]+)\.")
        mid_blocks = max([-1, *(int(match[1]) for key in state_dict if (match := pattern.match(key)) is not None)]) + 1
        return HydraPool(attn_dim, head_dim, n_classes, mid_blocks=mid_blocks, roots=roots, ff_ratio=ff_ratio, ff_dropout=ff_dropout, input_dim=input_dim, output_dim=output_dim, device=device, dtype=dtype)