# Adapted from LaM-SLidE # https://github.com/ml-jku/LaM-SLidE/blob/main/src/modules/torch_modules.py import math from functools import wraps import torch import torch.nn as nn import torch.nn.functional as F from einops import rearrange, repeat from torch import Tensor from torch.utils.checkpoint import checkpoint def exists(val): return val is not None def default(val, d): return val if exists(val) else d def cache_fn(f): cache = None @wraps(f) def cached_fn(*args, _cache=True, **kwargs): if not _cache: return f(*args, **kwargs) nonlocal cache if cache is not None: return cache cache = f(*args, **kwargs) return cache return cached_fn class GELU(nn.Module): @staticmethod def gelu(x): """Implementation of the gelu activation function. For information: OpenAI GPT's gelu is slightly different (and gives slightly different results): 0.5 * x * (1 + torch.tanh(math.sqrt(2 / math.pi) * (x + 0.044715 * torch.pow(x, 3)))) """ return x * 0.5 * (1.0 + torch.erf(x / math.sqrt(2.0))) def forward(self, x): return self.gelu(x) class Sin(nn.Module): def forward(self, input: Tensor) -> Tensor: return torch.sin(input) def dropout_seq(seq, mask, dropout): b, n, *_, device = *seq.shape, seq.device logits = torch.randn(b, n, device=device) if exists(mask): logits = logits.masked_fill(~mask, -torch.finfo(logits.dtype).max) keep_prob = 1.0 - dropout num_keep = max(1, int(keep_prob * n)) keep_indices = logits.topk(num_keep, dim=1).indices batch_indices = torch.arange(b, device=device) batch_indices = rearrange(batch_indices, "b -> b 1") seq = seq[batch_indices, keep_indices] if exists(mask): seq_counts = mask.sum(dim=-1) seq_keep_counts = torch.ceil(seq_counts * keep_prob).int() keep_mask = torch.arange(num_keep, device=device) < rearrange(seq_keep_counts, "b -> b 1") mask = mask[batch_indices, keep_indices] & keep_mask return seq, mask class RMSNorm(torch.nn.Module): def __init__(self, dim: int): super().__init__() self.scale = nn.Parameter(torch.ones(dim)) def forward(self, x: Tensor): x_dtype = x.dtype x = x.float() rrms = torch.rsqrt(torch.mean(x**2, dim=-1, keepdim=True) + 1e-6) return (x * rrms).to(dtype=x_dtype) * self.scale class QKNorm(torch.nn.Module): def __init__(self, dim: int): super().__init__() self.query_norm = RMSNorm(dim) self.key_norm = RMSNorm(dim) def forward(self, q: Tensor, k: Tensor, v: torch.Tensor) -> tuple[Tensor, Tensor]: q = self.query_norm(q) k = self.key_norm(k) return q.to(v), k.to(v) class PreNorm(nn.Module): def __init__(self, dim, fn, context_dim=None): super().__init__() self.fn = fn self.norm = nn.LayerNorm(dim) self.norm_context = nn.LayerNorm(context_dim) if exists(context_dim) else None def forward(self, x, context=None, mask=None): x = self.norm(x) if exists(self.norm_context): context = self.norm_context(context) return self.fn(x, context, mask) return self.fn(x) class FeedForward(nn.Module): def __init__( self, dim: int, depth: int = 1, act: nn.Module = nn.GELU, input_dim: int = None, output_dim: int = None, ): super().__init__() input_dim = default(input_dim, dim) output_dim = default(output_dim, dim) layers = [nn.Sequential(nn.Linear(input_dim, dim), act())] layers = layers + [nn.Sequential(nn.Linear(dim, dim), act()) for _ in range(1, depth)] layers.append(nn.Linear(dim, output_dim)) self.net = nn.Sequential(*layers) def forward(self, x): return self.net(x) class Attention(nn.Module): def __init__( self, query_dim, context_dim=None, heads=8, dim_head=64, scale=None, qk_norm=False ): super().__init__() inner_dim = dim_head * heads context_dim = default(context_dim, query_dim) self.scale = default(scale, dim_head**-0.5) self.heads = heads self.to_q = nn.Linear(query_dim, inner_dim, bias=False) self.to_kv = nn.Linear(context_dim, inner_dim * 2, bias=False) self.to_out = nn.Linear(inner_dim, query_dim) self.norm = QKNorm(dim_head) if qk_norm else lambda q, k, v: (q, k) self.reset_parameters() def reset_parameters(self): nn.init.xavier_uniform_(self.to_q.weight, gain=1 / math.sqrt(2)) nn.init.xavier_uniform_(self.to_kv.weight, gain=1 / math.sqrt(2)) nn.init.xavier_uniform_(self.to_out.weight) if self.to_out.bias is not None: nn.init.constant_(self.to_out.bias, 0) def forward(self, x, context=None, mask=None): h = self.heads q = self.to_q(x) context = default(context, x) k, v = self.to_kv(context).chunk(2, dim=-1) q, k, v = map(lambda t: rearrange(t, "b n (h d) -> (b h) n d", h=h), (q, k, v)) q, k = self.norm(q, k, v) if mask is not None: mask = repeat(mask, "b j -> (b h) () j", h=h) out = F.scaled_dot_product_attention(q, k, v, attn_mask=mask, scale=self.scale) out = rearrange(out, "(b h) n d -> b n (h d)", h=h) return self.to_out(out) class CrossAttentionBlock(nn.Module): def __init__( self, dim, context_dim=None, heads: int = 4, dim_head: int = 64, act=nn.GELU, scale=None, qk_norm=False, ): super().__init__() self.attn = PreNorm( dim, Attention( query_dim=dim, context_dim=context_dim, heads=heads, dim_head=dim_head, scale=scale, qk_norm=qk_norm, ), context_dim=context_dim, ) self.ff = PreNorm(dim, FeedForward(dim, act=act)) def forward(self, x, context=None, mask=None): x = self.attn(x, context=context, mask=mask) + x x = self.ff(x) + x return x class SelfAttention(nn.Module): def __init__(self, dim, heads, dim_head, scale=None, qk_norm=False): super().__init__() inner_dim = dim_head * heads self.scale = default(scale, dim_head**-0.5) self.heads = heads self.to_qkv = nn.Linear(dim, inner_dim * 3, bias=False) self.to_out = nn.Linear(inner_dim, dim) self.norm = QKNorm(dim_head) if qk_norm else lambda q, k, _: (q, k) self.reset_parameters() def reset_parameters(self): nn.init.xavier_uniform_(self.to_qkv.weight, gain=1 / math.sqrt(2)) nn.init.xavier_uniform_(self.to_out.weight) if self.to_out.bias is not None: nn.init.constant_(self.to_out.bias, 0) def forward(self, x, mask=None): h = self.heads q, k, v = self.to_qkv(x).chunk(3, dim=-1) q, k, v = map(lambda t: rearrange(t, "b n (h d) -> (b h) n d", h=h), (q, k, v)) q, k = self.norm(q, k, v) if mask is not None: mask = repeat(mask, "b j -> (b h) () j", h=h) out = F.scaled_dot_product_attention(q, k, v, attn_mask=mask, scale=self.scale) out = rearrange(out, "(b h) n d -> b n (h d)", h=h) return self.to_out(out) class SelfAttentionBlock(nn.Module): def __init__( self, dim: int, heads: int, dim_head: int = 64, act=nn.GELU, scale=None, qk_norm=False, ): super().__init__() self.attn = PreNorm(dim, SelfAttention(dim, heads, dim_head, scale, qk_norm)) self.ff = PreNorm(dim, FeedForward(dim, act=act)) def forward(self, x, mask=None): x = self.attn(x, mask=mask) + x x = self.ff(x) + x return x class StackedRandomGenerator: def __init__(self, device, seeds): super().__init__() self.generators = [ torch.Generator(device).manual_seed(int(seed) % (1 << 32)) for seed in seeds ] def randn(self, size, **kwargs): assert size[0] == len(self.generators) return torch.stack( [torch.randn(size[1:], generator=gen, **kwargs) for gen in self.generators] ) def randn_like(self, input): return self.randn(input.shape, dtype=input.dtype, layout=input.layout, device=input.device) def randint(self, *args, size, **kwargs): assert size[0] == len(self.generators) return torch.stack( [ torch.randint(*args, size=size[1:], generator=gen, **kwargs) for gen in self.generators ] ) def grad_checkpoint(func, args, checkpointing=False): if checkpointing: return checkpoint(func, *args, use_reentrant=False) else: return func(*args)