""" Based on https://github.com/buoyancy99/diffusion-forcing/blob/main/algorithms/diffusion_forcing/models/attention.py """ from typing import Optional from collections import namedtuple import torch from torch import nn from torch.nn import functional as F from einops import rearrange from .rotary_embedding_torch import RotaryEmbedding, apply_rotary_emb class TemporalAxialAttention(nn.Module): def __init__( self, dim: int, heads: int, dim_head: int, rotary_emb: RotaryEmbedding, is_causal: bool = True, ): super().__init__() self.inner_dim = dim_head * heads self.heads = heads self.head_dim = dim_head self.inner_dim = dim_head * heads self.to_qkv = nn.Linear(dim, self.inner_dim * 3, bias=False) self.to_out = nn.Linear(self.inner_dim, dim) self.rotary_emb = rotary_emb self.is_causal = is_causal def forward(self, x: torch.Tensor): B, T, H, W, D = x.shape q, k, v = self.to_qkv(x).chunk(3, dim=-1) q = rearrange(q, "B T H W (h d) -> (B H W) h T d", h=self.heads) k = rearrange(k, "B T H W (h d) -> (B H W) h T d", h=self.heads) v = rearrange(v, "B T H W (h d) -> (B H W) h T d", h=self.heads) q = self.rotary_emb.rotate_queries_or_keys(q, self.rotary_emb.freqs) k = self.rotary_emb.rotate_queries_or_keys(k, self.rotary_emb.freqs) q, k, v = map(lambda t: t.contiguous(), (q, k, v)) x = F.scaled_dot_product_attention(query=q, key=k, value=v, is_causal=self.is_causal) x = rearrange(x, "(B H W) h T d -> B T H W (h d)", B=B, H=H, W=W) x = x.to(q.dtype) # linear proj x = self.to_out(x) return x class SpatialAxialAttention(nn.Module): def __init__( self, dim: int, heads: int, dim_head: int, rotary_emb: RotaryEmbedding, ): super().__init__() self.inner_dim = dim_head * heads self.heads = heads self.head_dim = dim_head self.inner_dim = dim_head * heads self.to_qkv = nn.Linear(dim, self.inner_dim * 3, bias=False) self.to_out = nn.Linear(self.inner_dim, dim) self.rotary_emb = rotary_emb def forward(self, x: torch.Tensor): B, T, H, W, D = x.shape q, k, v = self.to_qkv(x).chunk(3, dim=-1) q = rearrange(q, "B T H W (h d) -> (B T) h H W d", h=self.heads) k = rearrange(k, "B T H W (h d) -> (B T) h H W d", h=self.heads) v = rearrange(v, "B T H W (h d) -> (B T) h H W d", h=self.heads) freqs = self.rotary_emb.get_axial_freqs(H, W) q = apply_rotary_emb(freqs, q) k = apply_rotary_emb(freqs, k) # prepare for attn q = rearrange(q, "(B T) h H W d -> (B T) h (H W) d", B=B, T=T, h=self.heads) k = rearrange(k, "(B T) h H W d -> (B T) h (H W) d", B=B, T=T, h=self.heads) v = rearrange(v, "(B T) h H W d -> (B T) h (H W) d", B=B, T=T, h=self.heads) x = F.scaled_dot_product_attention(query=q, key=k, value=v, is_causal=False) x = rearrange(x, "(B T) h (H W) d -> B T H W (h d)", B=B, H=H, W=W) x = x.to(q.dtype) # linear proj x = self.to_out(x) return x