import sys, os sys.path.insert(0, os.path.dirname(__file__)) import torch import torch.nn as nn import torch.nn.functional as F from typing import Optional from config import * class SelfCrossAttn(nn.Module): def __init__(self, ch, heads = 8, text_emb_dim = 1024, cross = False, group_size = unet_group_size): super().__init__() assert ch % heads == 0 self.heads = heads self.dim = ch // heads self.scale = self.dim ** -0.5 self.cross = cross self.norm = nn.GroupNorm(group_size, ch) self.qkv_latent = nn.Linear(ch, ch * 3, bias=True) # for self-attn if cross: self.q = nn.Linear(ch, ch) self.k_text = nn.Linear(text_emb_dim, ch) self.v_text = nn.Linear(text_emb_dim, ch) self.proj = nn.Linear(ch, ch) self.attn_drop = nn.Dropout(attn_dropout) self.proj_drop = nn.Dropout(attn_dropout) # nn.init.zeros_(self.proj.weight) # nn.init.zeros_(self.proj.bias) def forward(self, x: torch.Tensor, text: Optional[torch.Tensor] = None): # flatten spatial dims B, C, H, W = x.shape # (B, 16, 16, 16) N = H * W x_norm = self.norm(x) # x_norm = x_norm.view(B, C, N) # (B, 16, 16 x 16 = 256) x_flat = x_norm.flatten(2) # (B, C, N) = (B, 16, 16 x 16) x_flat = x_flat.transpose(1, 2) # (B, C, N) -> (B, N, C) if self.cross and text is not None: # Keys & values from text if text.dim() == 2: text = text[:, None, :] # (B, 1, D) q = self.q(x_flat) # (B, N, C) -> (B, N, C) k = self.k_text(text) # (B, T, D) -> (B, T, C) v = self.v_text(text) # (B, T, D) -> (B, T, C) # q = q.view(B, N, self.heads, self.dim).transpose(1, 2) # (B, N, H, C/H) -> (B, H, N, C/H) # k = k.view(B, -1, self.heads, self.dim).transpose(1, 2) # (B, H, T, C/H) # v = v.view(B, -1, self.heads, self.dim).transpose(1, 2) # (B, H, T, C/H) else: # Self-attention over latent qkv = self.qkv_latent(x_flat) # (B, N, C) -> (B, N, C x 3) q, k, v = qkv.chunk(3, dim=2) # (B, N, C) (B, N, C) (B, N, C) q = q.view(B, N, self.heads, self.dim).transpose(1, 2) # (B, N, C) -> (B, N, H, C/H) -> (B, H, N, C/H) k = k.view(B, -1, self.heads, self.dim).transpose(1, 2) # (B, N, C) -> (B, N, H, C/H) -> (B, H, N, C/H) v = v.view(B, -1, self.heads, self.dim).transpose(1, 2) # (B, N, C) -> (B, N, H, C/H) -> (B, H, N, C/H) attn_weights = (q @ k.transpose(2, 3)) # (B, H, N, C/H) @ (B, H, C/H, N or T) -> (B, H, N, N or T) attn_weights = attn_weights * self.scale # (B, H, N, N) -> (B, H, N, N or T) attn_weights = attn_weights - attn_weights.amax(dim=-1, keepdim=True) # Stability (B, H, N, N or T) attn_weights = F.softmax(attn_weights, dim=-1) # (B, H, N, N or T) attn_weights = self.attn_drop(attn_weights) # (B, H, N, N or T) out = attn_weights @ v # # (B, H, N, N or T) @ (B, H, N or T, C/H) -> (B, H, N, C/H) # out = F.scaled_dot_product_attention(q, k, v, dropout_p=0.0 if not self.training else 0.1) out = out.transpose(1, 2).contiguous().view(B, N, C) # (B, H, N, C/H) -> (B, N, H, C/H) -> (B, N, C) out = self.proj(out) # (B, N, C) -> (B, N, C) out = self.proj_drop(out) out = out.transpose(1, 2).contiguous().view(x.shape) # (B, C, N) -> (B, C, N) -> (B, C, H, W) return out + x # residual