ultrapro-tagger / modules /hydra_layers.py
Hydragee's picture
Upload folder using huggingface_hub
772b344 verified
# modules/hydra_layers.py
import math
import re
from typing import Any
import torch
from torch import Tensor
from torch.nn import (
Module, ModuleList, Parameter, Buffer,
Linear, LayerNorm, RMSNorm, Dropout, Flatten, Identity,
init
)
from torch.nn.functional import pad, scaled_dot_product_attention, silu, gelu
from einops import rearrange
# --- GLU.PY ---
class GatedUnit(Module):
def __init__(self, dim: int = -1) -> None:
super().__init__()
self.dim = dim
def _activation(self, x: Tensor) -> Tensor:
raise NotImplementedError
def forward(self, x: Tensor) -> Tensor:
f, g = x.chunk(2, dim=self.dim)
return self._activation(f) * g
class SwiGLU(GatedUnit):
def __init__(self, dim: int = -1) -> None:
super().__init__(dim)
def _activation(self, x: Tensor) -> Tensor:
return silu(x)
# --- HYDRA_POOL.PY UTILS ---
class IndexedAdd(Module):
def __init__(self, n_indices: int, dim: int, weight_shape: tuple[int, ...] | None = None, *, inplace: bool = False, device=None, dtype=None) -> None:
super().__init__()
self.dim = dim
self.inplace = inplace
self.index = Buffer(torch.empty(2, n_indices, device=device, dtype=torch.int32))
self.weight = Parameter(torch.ones(*(sz if sz != -1 else n_indices for sz in weight_shape), device=device, dtype=dtype)) if weight_shape is not None else None
def forward(self, dst: Tensor, src: Tensor) -> Tensor:
src = src.index_select(self.dim, self.index[0])
if self.weight is not None: src.mul_(self.weight)
return dst.index_add_(self.dim, self.index[1], src) if self.inplace else dst.index_add(self.dim, self.index[1], src)
class BatchLinear(Module):
def __init__(self, batch_shape: tuple[int, ...] | int, in_features: int, out_features: int, *, bias: bool = False, flatten: bool = False, bias_inplace: bool = True, device=None, dtype=None) -> None:
super().__init__()
if isinstance(batch_shape, int): batch_shape = (batch_shape,)
self.flatten = -(len(batch_shape) + 1) if flatten else 0
self.weight = Parameter(torch.empty(*batch_shape, in_features, out_features, device=device, dtype=dtype))
bt = self.weight.flatten(end_dim=-3).mT
for idx in range(bt.size(0)): init.kaiming_uniform_(bt[idx], a=math.sqrt(5))
self.bias = Parameter(torch.zeros(*batch_shape, out_features, device=device, dtype=dtype)) if bias else None
self.bias_inplace = bias_inplace
def forward(self, x: Tensor) -> Tensor:
x = torch.matmul(x.unsqueeze(-2), self.weight).squeeze(-2)
if self.bias is not None:
if self.bias_inplace: x.add_(self.bias)
else: x = x + self.bias
if self.flatten: x = x.flatten(self.flatten)
return x
class Mean(Module):
def __init__(self, dim: tuple[int, ...] | int = -1, *, keepdim: bool = False) -> None:
super().__init__()
self.dim = dim
self.keepdim = keepdim
def forward(self, x: Tensor) -> Tensor:
return x.mean(self.dim, self.keepdim)
class _MidBlock(Module):
def __init__(self, attn_dim: int, head_dim: int, n_classes: int, *, ff_ratio: float, ff_dropout: float, q_cls_inplace: bool = True, device=None, dtype=None) -> None:
super().__init__()
self.head_dim = head_dim
self.q_cls_inplace = q_cls_inplace
hidden_dim = int(attn_dim * ff_ratio)
self.q_proj = Linear(attn_dim, attn_dim, bias=False, device=device, dtype=dtype)
self.q_cls = Parameter(torch.zeros(n_classes, attn_dim, device=device, dtype=dtype))
self.q_norm = RMSNorm(head_dim, eps=1e-5, elementwise_affine=False)
self.attn_out = Linear(attn_dim, attn_dim, bias=False, device=device, dtype=dtype)
self.ff_norm = LayerNorm(attn_dim, device=device, dtype=dtype)
self.ff_in = Linear(attn_dim, hidden_dim * 2, bias=False, device=device, dtype=dtype)
self.ff_act = SwiGLU()
self.ff_drop = Dropout(ff_dropout)
self.ff_out = Linear(hidden_dim, attn_dim, bias=False, device=device, dtype=dtype)
def _forward_q(self, x: Tensor) -> Tensor:
x = self.q_proj(x)
if self.q_cls_inplace: x.add_(self.q_cls)
else: x = x + self.q_cls
x = self.q_norm(x)
return rearrange(x, "... s (h e) -> ... h s e", e=self.head_dim)
def _forward_attn(self, x: Tensor, k: Tensor, v: Tensor, attn_mask: Tensor | None) -> Tensor:
a = scaled_dot_product_attention(self._forward_q(x), k, v, attn_mask=attn_mask)
a = rearrange(a, "... h s e -> ... s (h e)")
return x + self.attn_out(a)
def _forward_ff(self, x: Tensor) -> Tensor:
f = self.ff_out(self.ff_drop(self.ff_act(self.ff_in(self.ff_norm(x)))))
return x + f
def forward(self, x: Tensor, k: Tensor, v: Tensor, attn_mask: Tensor | None = None) -> Tensor:
return self._forward_ff(self._forward_attn(x, k, v, attn_mask))
class HydraPool(Module):
def __init__(self, attn_dim: int, head_dim: int, n_classes: int, *, mid_blocks: int = 0, roots: tuple[int, int, int] = (0, 0, 0), ff_ratio: float = 3.0, ff_dropout: float = 0.0, input_dim: int = -1, output_dim: int = 1, device=None, dtype=None) -> None:
super().__init__()
if input_dim < 0: input_dim = attn_dim
self.n_classes = n_classes
self.head_dim = head_dim
self.output_dim = output_dim
self._has_roots = False
self._has_ff = False
self._q_normed = None
if roots != (0, 0, 0):
self._has_roots = True
n_roots, n_classroots, n_subclasses = roots
self.cls = Parameter(torch.randn(attn_dim // head_dim, n_classes, head_dim, device=device, dtype=dtype))
self.roots = Parameter(torch.randn(attn_dim // head_dim, n_roots, head_dim, device=device, dtype=dtype)) if n_roots > 0 else None
self.clsroots = IndexedAdd(n_classroots, dim=-2, weight_shape=(attn_dim // head_dim, -1, 1), device=device, dtype=dtype) if n_classroots > 0 else None
self.clscls = IndexedAdd(n_subclasses, dim=-2, weight_shape=(attn_dim // head_dim, -1, 1), inplace=True, device=device, dtype=dtype) if n_subclasses > 0 else None
self.q = Buffer(torch.empty(attn_dim // head_dim, n_classes, head_dim, device=device, dtype=dtype))
else:
self.q = Parameter(torch.randn(attn_dim // head_dim, n_classes, head_dim, device=device, dtype=dtype))
self._q_normed = False
self.kv = Linear(input_dim, attn_dim * 2, bias=False, device=device, dtype=dtype)
self.qk_norm = RMSNorm(head_dim, eps=1e-5, elementwise_affine=False)
if ff_ratio > 0.0:
self._has_ff = True
hidden_dim = int(attn_dim * ff_ratio)
self.ff_norm = LayerNorm(attn_dim, device=device, dtype=dtype)
self.ff_in = Linear(attn_dim, hidden_dim * 2, bias=False, device=device, dtype=dtype)
self.ff_act = SwiGLU()
self.ff_drop = Dropout(ff_dropout)
self.ff_out = Linear(hidden_dim, attn_dim, bias=False, device=device, dtype=dtype)
self.mid_blocks = ModuleList(_MidBlock(attn_dim, head_dim, n_classes, ff_ratio=ff_ratio, ff_dropout=ff_dropout, device=device, dtype=dtype) for _ in range(mid_blocks))
self.out_proj = BatchLinear(n_classes, attn_dim, output_dim * 2, device=device, dtype=dtype)
self.out_act = SwiGLU()
def create_head(self) -> Module:
return Flatten(-2) if self.output_dim == 1 else Mean(-1)
def _forward_q(self) -> Tensor:
if self._q_normed is None:
q = self.qk_norm(self.roots) if self.roots is not None else self.cls
if self.clsroots is not None: q = self.clsroots(self.cls, q)
if self.clscls is not None: q = self.clscls(q, q.detach())
return self.qk_norm(q)
elif self._q_normed is False: return self.qk_norm(self.q)
else: return self.q
def _forward_attn(self, x: Tensor, attn_mask: Tensor | None) -> tuple[Tensor, Tensor, Tensor]:
q = self._forward_q().expand(*x.shape[:-2], -1, -1, -1)
x = self.kv(x)
k, v = rearrange(x, "... s (n h e) -> n ... h s e", n=2, e=self.head_dim).unbind(0)
k = self.qk_norm(k)
x = scaled_dot_product_attention(q, k, v, attn_mask=attn_mask)
return rearrange(x, "... h s e -> ... s (h e)"), k, v
def _forward_ff(self, x: Tensor) -> Tensor:
if not self._has_ff: return x
return x + self.ff_out(self.ff_drop(self.ff_act(self.ff_in(self.ff_norm(x)))))
def forward(self, x: Tensor, attn_mask: Tensor | None = None) -> Tensor:
x, k, v = self._forward_attn(x, attn_mask)
x = self._forward_ff(x)
for block in self.mid_blocks: x = block(x, k, v, attn_mask)
return self.out_act(self.out_proj(x))
@staticmethod
def for_state(state_dict: dict[str, Any], prefix: str = "", *, ff_dropout: float = 0.0, device=None, dtype=None) -> "HydraPool":
n_heads, n_classes, head_dim = state_dict[f"{prefix}q"].shape
attn_dim = n_heads * head_dim
roots_t, clsroots_t, clscls_t = state_dict.get(f"{prefix}roots"), state_dict.get(f"{prefix}clsroots.index"), state_dict.get(f"{prefix}clscls.index")
roots = (roots_t.size(1) if roots_t is not None else 0, clsroots_t.size(1) if clsroots_t is not None else 0, clscls_t.size(1) if clscls_t is not None else 0)
input_dim = state_dict[f"{prefix}kv.weight"].size(1)
output_dim = state_dict[f"{prefix}out_proj.weight"].size(2) // 2
ffout_t = state_dict.get(f"{prefix}ff_out.weight")
hidden_dim = ffout_t.size(1) + 0.5 if ffout_t is not None else 0
ff_ratio = hidden_dim / attn_dim
pattern = re.compile(rf"^{re.escape(prefix)}mid_blocks\.([0-9]+)\.")
mid_blocks = max([-1, *(int(match[1]) for key in state_dict if (match := pattern.match(key)) is not None)]) + 1
return HydraPool(attn_dim, head_dim, n_classes, mid_blocks=mid_blocks, roots=roots, ff_ratio=ff_ratio, ff_dropout=ff_dropout, input_dim=input_dim, output_dim=output_dim, device=device, dtype=dtype)