# modules/hydra_layers.py import math import re from typing import Any import torch from torch import Tensor from torch.nn import ( Module, ModuleList, Parameter, Buffer, Linear, LayerNorm, RMSNorm, Dropout, Flatten, Identity, init ) from torch.nn.functional import pad, scaled_dot_product_attention, silu, gelu from einops import rearrange # --- GLU.PY --- class GatedUnit(Module): def __init__(self, dim: int = -1) -> None: super().__init__() self.dim = dim def _activation(self, x: Tensor) -> Tensor: raise NotImplementedError def forward(self, x: Tensor) -> Tensor: f, g = x.chunk(2, dim=self.dim) return self._activation(f) * g class SwiGLU(GatedUnit): def __init__(self, dim: int = -1) -> None: super().__init__(dim) def _activation(self, x: Tensor) -> Tensor: return silu(x) # --- HYDRA_POOL.PY UTILS --- class IndexedAdd(Module): def __init__(self, n_indices: int, dim: int, weight_shape: tuple[int, ...] | None = None, *, inplace: bool = False, device=None, dtype=None) -> None: super().__init__() self.dim = dim self.inplace = inplace self.index = Buffer(torch.empty(2, n_indices, device=device, dtype=torch.int32)) self.weight = Parameter(torch.ones(*(sz if sz != -1 else n_indices for sz in weight_shape), device=device, dtype=dtype)) if weight_shape is not None else None def forward(self, dst: Tensor, src: Tensor) -> Tensor: src = src.index_select(self.dim, self.index[0]) if self.weight is not None: src.mul_(self.weight) return dst.index_add_(self.dim, self.index[1], src) if self.inplace else dst.index_add(self.dim, self.index[1], src) class BatchLinear(Module): def __init__(self, batch_shape: tuple[int, ...] | int, in_features: int, out_features: int, *, bias: bool = False, flatten: bool = False, bias_inplace: bool = True, device=None, dtype=None) -> None: super().__init__() if isinstance(batch_shape, int): batch_shape = (batch_shape,) self.flatten = -(len(batch_shape) + 1) if flatten else 0 self.weight = Parameter(torch.empty(*batch_shape, in_features, out_features, device=device, dtype=dtype)) bt = self.weight.flatten(end_dim=-3).mT for idx in range(bt.size(0)): init.kaiming_uniform_(bt[idx], a=math.sqrt(5)) self.bias = Parameter(torch.zeros(*batch_shape, out_features, device=device, dtype=dtype)) if bias else None self.bias_inplace = bias_inplace def forward(self, x: Tensor) -> Tensor: x = torch.matmul(x.unsqueeze(-2), self.weight).squeeze(-2) if self.bias is not None: if self.bias_inplace: x.add_(self.bias) else: x = x + self.bias if self.flatten: x = x.flatten(self.flatten) return x class Mean(Module): def __init__(self, dim: tuple[int, ...] | int = -1, *, keepdim: bool = False) -> None: super().__init__() self.dim = dim self.keepdim = keepdim def forward(self, x: Tensor) -> Tensor: return x.mean(self.dim, self.keepdim) class _MidBlock(Module): def __init__(self, attn_dim: int, head_dim: int, n_classes: int, *, ff_ratio: float, ff_dropout: float, q_cls_inplace: bool = True, device=None, dtype=None) -> None: super().__init__() self.head_dim = head_dim self.q_cls_inplace = q_cls_inplace hidden_dim = int(attn_dim * ff_ratio) self.q_proj = Linear(attn_dim, attn_dim, bias=False, device=device, dtype=dtype) self.q_cls = Parameter(torch.zeros(n_classes, attn_dim, device=device, dtype=dtype)) self.q_norm = RMSNorm(head_dim, eps=1e-5, elementwise_affine=False) self.attn_out = Linear(attn_dim, attn_dim, bias=False, device=device, dtype=dtype) self.ff_norm = LayerNorm(attn_dim, device=device, dtype=dtype) self.ff_in = Linear(attn_dim, hidden_dim * 2, bias=False, device=device, dtype=dtype) self.ff_act = SwiGLU() self.ff_drop = Dropout(ff_dropout) self.ff_out = Linear(hidden_dim, attn_dim, bias=False, device=device, dtype=dtype) def _forward_q(self, x: Tensor) -> Tensor: x = self.q_proj(x) if self.q_cls_inplace: x.add_(self.q_cls) else: x = x + self.q_cls x = self.q_norm(x) return rearrange(x, "... s (h e) -> ... h s e", e=self.head_dim) def _forward_attn(self, x: Tensor, k: Tensor, v: Tensor, attn_mask: Tensor | None) -> Tensor: a = scaled_dot_product_attention(self._forward_q(x), k, v, attn_mask=attn_mask) a = rearrange(a, "... h s e -> ... s (h e)") return x + self.attn_out(a) def _forward_ff(self, x: Tensor) -> Tensor: f = self.ff_out(self.ff_drop(self.ff_act(self.ff_in(self.ff_norm(x))))) return x + f def forward(self, x: Tensor, k: Tensor, v: Tensor, attn_mask: Tensor | None = None) -> Tensor: return self._forward_ff(self._forward_attn(x, k, v, attn_mask)) class HydraPool(Module): def __init__(self, attn_dim: int, head_dim: int, n_classes: int, *, mid_blocks: int = 0, roots: tuple[int, int, int] = (0, 0, 0), ff_ratio: float = 3.0, ff_dropout: float = 0.0, input_dim: int = -1, output_dim: int = 1, device=None, dtype=None) -> None: super().__init__() if input_dim < 0: input_dim = attn_dim self.n_classes = n_classes self.head_dim = head_dim self.output_dim = output_dim self._has_roots = False self._has_ff = False self._q_normed = None if roots != (0, 0, 0): self._has_roots = True n_roots, n_classroots, n_subclasses = roots self.cls = Parameter(torch.randn(attn_dim // head_dim, n_classes, head_dim, device=device, dtype=dtype)) self.roots = Parameter(torch.randn(attn_dim // head_dim, n_roots, head_dim, device=device, dtype=dtype)) if n_roots > 0 else None self.clsroots = IndexedAdd(n_classroots, dim=-2, weight_shape=(attn_dim // head_dim, -1, 1), device=device, dtype=dtype) if n_classroots > 0 else None self.clscls = IndexedAdd(n_subclasses, dim=-2, weight_shape=(attn_dim // head_dim, -1, 1), inplace=True, device=device, dtype=dtype) if n_subclasses > 0 else None self.q = Buffer(torch.empty(attn_dim // head_dim, n_classes, head_dim, device=device, dtype=dtype)) else: self.q = Parameter(torch.randn(attn_dim // head_dim, n_classes, head_dim, device=device, dtype=dtype)) self._q_normed = False self.kv = Linear(input_dim, attn_dim * 2, bias=False, device=device, dtype=dtype) self.qk_norm = RMSNorm(head_dim, eps=1e-5, elementwise_affine=False) if ff_ratio > 0.0: self._has_ff = True hidden_dim = int(attn_dim * ff_ratio) self.ff_norm = LayerNorm(attn_dim, device=device, dtype=dtype) self.ff_in = Linear(attn_dim, hidden_dim * 2, bias=False, device=device, dtype=dtype) self.ff_act = SwiGLU() self.ff_drop = Dropout(ff_dropout) self.ff_out = Linear(hidden_dim, attn_dim, bias=False, device=device, dtype=dtype) self.mid_blocks = ModuleList(_MidBlock(attn_dim, head_dim, n_classes, ff_ratio=ff_ratio, ff_dropout=ff_dropout, device=device, dtype=dtype) for _ in range(mid_blocks)) self.out_proj = BatchLinear(n_classes, attn_dim, output_dim * 2, device=device, dtype=dtype) self.out_act = SwiGLU() def create_head(self) -> Module: return Flatten(-2) if self.output_dim == 1 else Mean(-1) def _forward_q(self) -> Tensor: if self._q_normed is None: q = self.qk_norm(self.roots) if self.roots is not None else self.cls if self.clsroots is not None: q = self.clsroots(self.cls, q) if self.clscls is not None: q = self.clscls(q, q.detach()) return self.qk_norm(q) elif self._q_normed is False: return self.qk_norm(self.q) else: return self.q def _forward_attn(self, x: Tensor, attn_mask: Tensor | None) -> tuple[Tensor, Tensor, Tensor]: q = self._forward_q().expand(*x.shape[:-2], -1, -1, -1) x = self.kv(x) k, v = rearrange(x, "... s (n h e) -> n ... h s e", n=2, e=self.head_dim).unbind(0) k = self.qk_norm(k) x = scaled_dot_product_attention(q, k, v, attn_mask=attn_mask) return rearrange(x, "... h s e -> ... s (h e)"), k, v def _forward_ff(self, x: Tensor) -> Tensor: if not self._has_ff: return x return x + self.ff_out(self.ff_drop(self.ff_act(self.ff_in(self.ff_norm(x))))) def forward(self, x: Tensor, attn_mask: Tensor | None = None) -> Tensor: x, k, v = self._forward_attn(x, attn_mask) x = self._forward_ff(x) for block in self.mid_blocks: x = block(x, k, v, attn_mask) return self.out_act(self.out_proj(x)) @staticmethod def for_state(state_dict: dict[str, Any], prefix: str = "", *, ff_dropout: float = 0.0, device=None, dtype=None) -> "HydraPool": n_heads, n_classes, head_dim = state_dict[f"{prefix}q"].shape attn_dim = n_heads * head_dim roots_t, clsroots_t, clscls_t = state_dict.get(f"{prefix}roots"), state_dict.get(f"{prefix}clsroots.index"), state_dict.get(f"{prefix}clscls.index") roots = (roots_t.size(1) if roots_t is not None else 0, clsroots_t.size(1) if clsroots_t is not None else 0, clscls_t.size(1) if clscls_t is not None else 0) input_dim = state_dict[f"{prefix}kv.weight"].size(1) output_dim = state_dict[f"{prefix}out_proj.weight"].size(2) // 2 ffout_t = state_dict.get(f"{prefix}ff_out.weight") hidden_dim = ffout_t.size(1) + 0.5 if ffout_t is not None else 0 ff_ratio = hidden_dim / attn_dim pattern = re.compile(rf"^{re.escape(prefix)}mid_blocks\.([0-9]+)\.") mid_blocks = max([-1, *(int(match[1]) for key in state_dict if (match := pattern.match(key)) is not None)]) + 1 return HydraPool(attn_dim, head_dim, n_classes, mid_blocks=mid_blocks, roots=roots, ff_ratio=ff_ratio, ff_dropout=ff_dropout, input_dim=input_dim, output_dim=output_dim, device=device, dtype=dtype)