File size: 4,849 Bytes
cf812a0 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 |
from torch import nn
import torch
from einops import rearrange
import torch.nn.functional as F
from ..attention import attention
class CausalConv1d(nn.Module):
def __init__(self, chan_in, chan_out, kernel_size=3, stride=1, dilation=1, pad_mode="replicate", **kwargs):
super().__init__()
self.pad_mode = pad_mode
padding = (kernel_size - 1, 0) # T
self.time_causal_padding = padding
self.conv = nn.Conv1d(chan_in, chan_out, kernel_size, stride=stride, dilation=dilation, **kwargs)
def forward(self, x):
x = F.pad(x, self.time_causal_padding, mode=self.pad_mode)
return self.conv(x)
class FaceEncoder(nn.Module):
def __init__(self, in_dim: int, out_dim: int, num_heads: int, dtype=None, device=None):
super().__init__()
self.dtype = dtype
self.device = device
self.num_heads = num_heads
self.conv1_local = CausalConv1d(in_dim, 1024 * num_heads, 3, stride=1)
self.norm1 = nn.LayerNorm(1024, elementwise_affine=False, eps=1e-6, device=device, dtype=dtype)
self.act = nn.SiLU()
self.conv2 = CausalConv1d(1024, 1024, 3, stride=2)
self.conv3 = CausalConv1d(1024, 1024, 3, stride=2)
self.out_proj = nn.Linear(1024, out_dim)
self.norm2 = nn.LayerNorm(1024, elementwise_affine=False, eps=1e-6, device=device, dtype=dtype)
self.norm3 = nn.LayerNorm(1024, elementwise_affine=False, eps=1e-6, device=device, dtype=dtype)
self.padding_tokens = nn.Parameter(torch.zeros(1, 1, 1, out_dim))
def forward(self, x):
x = rearrange(x, "b t c -> b c t")
b = x.shape[0]
x = self.conv1_local(x)
x = rearrange(x, "b (n c) t -> (b n) t c", n=self.num_heads)
x = self.norm1(x)
x = self.act(x)
x = rearrange(x, "b t c -> b c t")
x = self.conv2(x)
x = rearrange(x, "b c t -> b t c")
x = self.norm2(x)
x = self.act(x)
x = rearrange(x, "b t c -> b c t")
x = self.conv3(x)
x = rearrange(x, "b c t -> b t c")
x = self.norm3(x)
x = self.act(x)
x = self.out_proj(x)
x = rearrange(x, "(b n) t c -> b t n c", b=b)
padding = self.padding_tokens.repeat(b, x.shape[1], 1, 1)
return torch.cat([x, padding], dim=-2)
class RMSNorm(nn.Module):
def __init__(self, dim, elementwise_affine=True, eps=1e-6, device=None, dtype=None):
super().__init__()
self.eps = eps
if elementwise_affine:
self.weight = nn.Parameter(torch.ones(dim, device=device, dtype=dtype))
def _norm(self, x):
return x * torch.rsqrt(x.pow(2).mean(-1, keepdim=True) + self.eps)
def forward(self, x):
output = self._norm(x.float()).type_as(x)
if hasattr(self, "weight"):
output = output * self.weight
return output
class FaceBlock(nn.Module):
def __init__(self, feature_dim, num_heads, dtype=None, device=None):
super().__init__()
self.feature_dim = feature_dim
self.num_heads = num_heads
head_dim = feature_dim // num_heads
self.linear1_kv = nn.Linear(feature_dim, feature_dim * 2, device=device, dtype=dtype)
self.linear1_q = nn.Linear(feature_dim, feature_dim, device=device, dtype=dtype)
self.linear2 = nn.Linear(feature_dim, feature_dim, device=device, dtype=dtype)
self.q_norm = (RMSNorm(head_dim, elementwise_affine=True, eps=1e-6, device=device, dtype=dtype))
self.k_norm = (RMSNorm(head_dim, elementwise_affine=True, eps=1e-6, device=device, dtype=dtype))
self.pre_norm_feat = nn.LayerNorm(feature_dim, elementwise_affine=False, eps=1e-6, device=device, dtype=dtype)
self.pre_norm_motion = nn.LayerNorm(feature_dim, elementwise_affine=False, eps=1e-6, device=device, dtype=dtype)
def forward(self, x, motion_vec, motion_mask=None):
B, T, N, C = motion_vec.shape
x_motion = self.pre_norm_motion(motion_vec)
x_feat = self.pre_norm_feat(x)
kv = self.linear1_kv(x_motion)
q = self.linear1_q(x_feat)
k, v = rearrange(kv, "B L N (K H D) -> K B L N H D", K=2, H=self.num_heads)
q = rearrange(q, "B S (H D) -> B S H D", H=self.num_heads)
q = self.q_norm(q).to(v)
k = self.k_norm(k).to(v)
k = rearrange(k, "B L N H D -> (B L) N H D")
v = rearrange(v, "B L N H D -> (B L) N H D")
q = rearrange(q, "B (L S) H D -> (B L) S H D", L=T)
attn = attention(q, k, v)
attn = attn.reshape(attn.shape[0], attn.shape[1], -1)
attn = rearrange(attn, "(B L) S C -> B (L S) C", L=T)
output = self.linear2(attn)
if motion_mask is not None:
output = output * rearrange(motion_mask, "B T H W -> B (T H W)").unsqueeze(-1)
return output |