| """ |
| Based on https://github.com/buoyancy99/diffusion-forcing/blob/main/algorithms/diffusion_forcing/models/attention.py |
| """ |
|
|
| from typing import Optional |
| from collections import namedtuple |
| import torch |
| from torch import nn |
| from torch.nn import functional as F |
| from einops import rearrange |
| from .rotary_embedding_torch import RotaryEmbedding, apply_rotary_emb |
|
|
|
|
| class TemporalAxialAttention(nn.Module): |
| def __init__( |
| self, |
| dim: int, |
| heads: int, |
| dim_head: int, |
| rotary_emb: RotaryEmbedding, |
| is_causal: bool = True, |
| ): |
| super().__init__() |
| self.inner_dim = dim_head * heads |
| self.heads = heads |
| self.head_dim = dim_head |
| self.inner_dim = dim_head * heads |
| self.to_qkv = nn.Linear(dim, self.inner_dim * 3, bias=False) |
| self.to_out = nn.Linear(self.inner_dim, dim) |
|
|
| self.rotary_emb = rotary_emb |
| self.is_causal = is_causal |
|
|
| def forward(self, x: torch.Tensor): |
| B, T, H, W, D = x.shape |
|
|
| q, k, v = self.to_qkv(x).chunk(3, dim=-1) |
|
|
| q = rearrange(q, "B T H W (h d) -> (B H W) h T d", h=self.heads) |
| k = rearrange(k, "B T H W (h d) -> (B H W) h T d", h=self.heads) |
| v = rearrange(v, "B T H W (h d) -> (B H W) h T d", h=self.heads) |
|
|
| q = self.rotary_emb.rotate_queries_or_keys(q, self.rotary_emb.freqs) |
| k = self.rotary_emb.rotate_queries_or_keys(k, self.rotary_emb.freqs) |
|
|
| q, k, v = map(lambda t: t.contiguous(), (q, k, v)) |
|
|
| x = F.scaled_dot_product_attention(query=q, key=k, value=v, is_causal=self.is_causal) |
|
|
| x = rearrange(x, "(B H W) h T d -> B T H W (h d)", B=B, H=H, W=W) |
| x = x.to(q.dtype) |
|
|
| |
| x = self.to_out(x) |
| return x |
|
|
|
|
| class SpatialAxialAttention(nn.Module): |
| def __init__( |
| self, |
| dim: int, |
| heads: int, |
| dim_head: int, |
| rotary_emb: RotaryEmbedding, |
| ): |
| super().__init__() |
| self.inner_dim = dim_head * heads |
| self.heads = heads |
| self.head_dim = dim_head |
| self.inner_dim = dim_head * heads |
| self.to_qkv = nn.Linear(dim, self.inner_dim * 3, bias=False) |
| self.to_out = nn.Linear(self.inner_dim, dim) |
|
|
| self.rotary_emb = rotary_emb |
|
|
| def forward(self, x: torch.Tensor): |
| B, T, H, W, D = x.shape |
|
|
| q, k, v = self.to_qkv(x).chunk(3, dim=-1) |
|
|
| q = rearrange(q, "B T H W (h d) -> (B T) h H W d", h=self.heads) |
| k = rearrange(k, "B T H W (h d) -> (B T) h H W d", h=self.heads) |
| v = rearrange(v, "B T H W (h d) -> (B T) h H W d", h=self.heads) |
|
|
| freqs = self.rotary_emb.get_axial_freqs(H, W) |
| q = apply_rotary_emb(freqs, q) |
| k = apply_rotary_emb(freqs, k) |
|
|
| |
| q = rearrange(q, "(B T) h H W d -> (B T) h (H W) d", B=B, T=T, h=self.heads) |
| k = rearrange(k, "(B T) h H W d -> (B T) h (H W) d", B=B, T=T, h=self.heads) |
| v = rearrange(v, "(B T) h H W d -> (B T) h (H W) d", B=B, T=T, h=self.heads) |
|
|
| x = F.scaled_dot_product_attention(query=q, key=k, value=v, is_causal=False) |
|
|
| x = rearrange(x, "(B T) h (H W) d -> B T H W (h d)", B=B, H=H, W=W) |
| x = x.to(q.dtype) |
|
|
| |
| x = self.to_out(x) |
| return x |