Spaces:
Sleeping
Sleeping
| import sys, os | |
| sys.path.insert(0, os.path.dirname(__file__)) | |
| import torch | |
| import torch.nn as nn | |
| import torch.nn.functional as F | |
| from typing import Optional | |
| from config import * | |
| class SelfCrossAttn(nn.Module): | |
| def __init__(self, ch, heads = 8, text_emb_dim = 1024, cross = False, group_size = unet_group_size): | |
| super().__init__() | |
| assert ch % heads == 0 | |
| self.heads = heads | |
| self.dim = ch // heads | |
| self.scale = self.dim ** -0.5 | |
| self.cross = cross | |
| self.norm = nn.GroupNorm(group_size, ch) | |
| self.qkv_latent = nn.Linear(ch, ch * 3, bias=True) # for self-attn | |
| if cross: | |
| self.q = nn.Linear(ch, ch) | |
| self.k_text = nn.Linear(text_emb_dim, ch) | |
| self.v_text = nn.Linear(text_emb_dim, ch) | |
| self.proj = nn.Linear(ch, ch) | |
| self.attn_drop = nn.Dropout(attn_dropout) | |
| self.proj_drop = nn.Dropout(attn_dropout) | |
| # nn.init.zeros_(self.proj.weight) | |
| # nn.init.zeros_(self.proj.bias) | |
| def forward(self, x: torch.Tensor, text: Optional[torch.Tensor] = None): | |
| # flatten spatial dims | |
| B, C, H, W = x.shape # (B, 16, 16, 16) | |
| N = H * W | |
| x_norm = self.norm(x) | |
| # x_norm = x_norm.view(B, C, N) # (B, 16, 16 x 16 = 256) | |
| x_flat = x_norm.flatten(2) # (B, C, N) = (B, 16, 16 x 16) | |
| x_flat = x_flat.transpose(1, 2) # (B, C, N) -> (B, N, C) | |
| if self.cross and text is not None: # Keys & values from text | |
| if text.dim() == 2: text = text[:, None, :] # (B, 1, D) | |
| q = self.q(x_flat) # (B, N, C) -> (B, N, C) | |
| k = self.k_text(text) # (B, T, D) -> (B, T, C) | |
| v = self.v_text(text) # (B, T, D) -> (B, T, C) | |
| # q = q.view(B, N, self.heads, self.dim).transpose(1, 2) # (B, N, H, C/H) -> (B, H, N, C/H) | |
| # k = k.view(B, -1, self.heads, self.dim).transpose(1, 2) # (B, H, T, C/H) | |
| # v = v.view(B, -1, self.heads, self.dim).transpose(1, 2) # (B, H, T, C/H) | |
| else: # Self-attention over latent | |
| qkv = self.qkv_latent(x_flat) # (B, N, C) -> (B, N, C x 3) | |
| q, k, v = qkv.chunk(3, dim=2) # (B, N, C) (B, N, C) (B, N, C) | |
| q = q.view(B, N, self.heads, self.dim).transpose(1, 2) # (B, N, C) -> (B, N, H, C/H) -> (B, H, N, C/H) | |
| k = k.view(B, -1, self.heads, self.dim).transpose(1, 2) # (B, N, C) -> (B, N, H, C/H) -> (B, H, N, C/H) | |
| v = v.view(B, -1, self.heads, self.dim).transpose(1, 2) # (B, N, C) -> (B, N, H, C/H) -> (B, H, N, C/H) | |
| attn_weights = (q @ k.transpose(2, 3)) # (B, H, N, C/H) @ (B, H, C/H, N or T) -> (B, H, N, N or T) | |
| attn_weights = attn_weights * self.scale # (B, H, N, N) -> (B, H, N, N or T) | |
| attn_weights = attn_weights - attn_weights.amax(dim=-1, keepdim=True) # Stability (B, H, N, N or T) | |
| attn_weights = F.softmax(attn_weights, dim=-1) # (B, H, N, N or T) | |
| attn_weights = self.attn_drop(attn_weights) # (B, H, N, N or T) | |
| out = attn_weights @ v # # (B, H, N, N or T) @ (B, H, N or T, C/H) -> (B, H, N, C/H) | |
| # out = F.scaled_dot_product_attention(q, k, v, dropout_p=0.0 if not self.training else 0.1) | |
| out = out.transpose(1, 2).contiguous().view(B, N, C) # (B, H, N, C/H) -> (B, N, H, C/H) -> (B, N, C) | |
| out = self.proj(out) # (B, N, C) -> (B, N, C) | |
| out = self.proj_drop(out) | |
| out = out.transpose(1, 2).contiguous().view(x.shape) # (B, C, N) -> (B, C, N) -> (B, C, H, W) | |
| return out + x # residual |