flickr8k-backend / core /attention.py
Rohan3's picture
deploy backend
4aabce3
import sys, os
sys.path.insert(0, os.path.dirname(__file__))
import torch
import torch.nn as nn
import torch.nn.functional as F
from typing import Optional
from config import *
class SelfCrossAttn(nn.Module):
def __init__(self, ch, heads = 8, text_emb_dim = 1024, cross = False, group_size = unet_group_size):
super().__init__()
assert ch % heads == 0
self.heads = heads
self.dim = ch // heads
self.scale = self.dim ** -0.5
self.cross = cross
self.norm = nn.GroupNorm(group_size, ch)
self.qkv_latent = nn.Linear(ch, ch * 3, bias=True) # for self-attn
if cross:
self.q = nn.Linear(ch, ch)
self.k_text = nn.Linear(text_emb_dim, ch)
self.v_text = nn.Linear(text_emb_dim, ch)
self.proj = nn.Linear(ch, ch)
self.attn_drop = nn.Dropout(attn_dropout)
self.proj_drop = nn.Dropout(attn_dropout)
# nn.init.zeros_(self.proj.weight)
# nn.init.zeros_(self.proj.bias)
def forward(self, x: torch.Tensor, text: Optional[torch.Tensor] = None):
# flatten spatial dims
B, C, H, W = x.shape # (B, 16, 16, 16)
N = H * W
x_norm = self.norm(x)
# x_norm = x_norm.view(B, C, N) # (B, 16, 16 x 16 = 256)
x_flat = x_norm.flatten(2) # (B, C, N) = (B, 16, 16 x 16)
x_flat = x_flat.transpose(1, 2) # (B, C, N) -> (B, N, C)
if self.cross and text is not None: # Keys & values from text
if text.dim() == 2: text = text[:, None, :] # (B, 1, D)
q = self.q(x_flat) # (B, N, C) -> (B, N, C)
k = self.k_text(text) # (B, T, D) -> (B, T, C)
v = self.v_text(text) # (B, T, D) -> (B, T, C)
# q = q.view(B, N, self.heads, self.dim).transpose(1, 2) # (B, N, H, C/H) -> (B, H, N, C/H)
# k = k.view(B, -1, self.heads, self.dim).transpose(1, 2) # (B, H, T, C/H)
# v = v.view(B, -1, self.heads, self.dim).transpose(1, 2) # (B, H, T, C/H)
else: # Self-attention over latent
qkv = self.qkv_latent(x_flat) # (B, N, C) -> (B, N, C x 3)
q, k, v = qkv.chunk(3, dim=2) # (B, N, C) (B, N, C) (B, N, C)
q = q.view(B, N, self.heads, self.dim).transpose(1, 2) # (B, N, C) -> (B, N, H, C/H) -> (B, H, N, C/H)
k = k.view(B, -1, self.heads, self.dim).transpose(1, 2) # (B, N, C) -> (B, N, H, C/H) -> (B, H, N, C/H)
v = v.view(B, -1, self.heads, self.dim).transpose(1, 2) # (B, N, C) -> (B, N, H, C/H) -> (B, H, N, C/H)
attn_weights = (q @ k.transpose(2, 3)) # (B, H, N, C/H) @ (B, H, C/H, N or T) -> (B, H, N, N or T)
attn_weights = attn_weights * self.scale # (B, H, N, N) -> (B, H, N, N or T)
attn_weights = attn_weights - attn_weights.amax(dim=-1, keepdim=True) # Stability (B, H, N, N or T)
attn_weights = F.softmax(attn_weights, dim=-1) # (B, H, N, N or T)
attn_weights = self.attn_drop(attn_weights) # (B, H, N, N or T)
out = attn_weights @ v # # (B, H, N, N or T) @ (B, H, N or T, C/H) -> (B, H, N, C/H)
# out = F.scaled_dot_product_attention(q, k, v, dropout_p=0.0 if not self.training else 0.1)
out = out.transpose(1, 2).contiguous().view(B, N, C) # (B, H, N, C/H) -> (B, N, H, C/H) -> (B, N, C)
out = self.proj(out) # (B, N, C) -> (B, N, C)
out = self.proj_drop(out)
out = out.transpose(1, 2).contiguous().view(x.shape) # (B, C, N) -> (B, C, N) -> (B, C, H, W)
return out + x # residual