artalk-youtube / app /modules /bitwise_vae.py
Ammunity's picture
phase5: upload ARTalk app/ (model code)
0ce25d3 verified
#!/usr/bin/env python
# Copyright (c) Xuangeng Chu (xg.chu@outlook.com)
import os
import math
import torch
import torchvision
import torch.nn as nn
import torch.nn.functional as F
from einops.layers.torch import Rearrange
# from .flame_model import FLAMEModel, RenderMesh
from .data_stats import TFHP_MEAN, TFHP_STD, ALLTALKEMICA_MEAN, ALLTALKEMICA_STD
class BITWISE_VAE(nn.Module):
def __init__(self, model_cfg=None, **kwargs):
super().__init__()
self.motion_dim = 106
self.code_dim = model_cfg['V_CODE_DIM']
self.patch_nums = model_cfg['V_PATCH_NUMS']
self.encoder = TransformerEncoder(
inp_dim=self.motion_dim, hidden_dim=model_cfg['T_HIDDEN_DIM'], code_dim=self.code_dim, depth=model_cfg['T_DEPTH'], n_heads=model_cfg['T_NUM_HEADS']
)
self.decoder = TransformerDecoder(
code_dim=self.code_dim, hidden_dim=model_cfg['T_HIDDEN_DIM'], out_dim=self.motion_dim, depth=model_cfg['T_DEPTH'], n_heads=model_cfg['T_NUM_HEADS']
)
self.quantize = MultiScaleBSQ(codebook_dim=self.code_dim, scale_schedule=self.patch_nums)
attn_mask = self.build_attn_mask(self.patch_nums[-1])
self.register_buffer('attn_mask', attn_mask)
# absolute position and level embedding
enc_pos_embed = torch.empty(1, self.patch_nums[-1]*2, self.motion_dim) # 1, L, C
nn.init.trunc_normal_(enc_pos_embed, mean=0, std=math.sqrt(1 / self.motion_dim / 3))
self.enc_pos_embed = nn.Parameter(enc_pos_embed)
dec_pos_embed = torch.empty(1, self.patch_nums[-1]*2, self.code_dim) # 1, L, C
nn.init.trunc_normal_(dec_pos_embed, mean=0, std=math.sqrt(1 / self.code_dim / 3))
self.dec_pos_embed = nn.Parameter(dec_pos_embed)
# stat & render
self.register_buffer("motion_mean", torch.tensor(ALLTALKEMICA_MEAN).float())
self.register_buffer("motion_std", torch.tensor(ALLTALKEMICA_STD).float())
def get_flame_verts(self, flame_model, shape_params, motion_params, with_global=False):
exp_code, pose_code = motion_params[..., :100], motion_params[..., 100:]
if not with_global:
pose_code = torch.cat([torch.zeros_like(pose_code[..., :3]), pose_code[..., 3:]], dim=-1)
if shape_params.dim() == 2:
verts = flame_model(shape_params=shape_params, expression_params=exp_code, pose_params=pose_code)
elif shape_params.dim() == 3:
verts = []
for bidx in range(shape_params.shape[0]):
this_verts = flame_model(shape_params=shape_params[bidx], expression_params=exp_code[bidx], pose_params=pose_code[bidx])
verts.append(this_verts)
verts = torch.stack(verts, dim=0)
else:
raise ValueError("Invalid shape of shape_params: {}".format(shape_params.shape))
return verts
def norm_with_stats(self, motion_code):
normed_motion_code = (motion_code - self.motion_mean) / self.motion_std
return normed_motion_code
def unnorm_with_stats(self, motion_code):
unnormed_motion_code = motion_code * self.motion_std + self.motion_mean
return unnormed_motion_code
@torch.no_grad()
def build_attn_mask(self, patch_nums):
zero_attn_bias_block = torch.zeros(patch_nums, patch_nums)
minf_attn_bias_block = torch.ones(patch_nums, patch_nums) * (-torch.inf)
attn_mask = torch.cat([
torch.cat([zero_attn_bias_block, minf_attn_bias_block], dim=-1),
torch.cat([zero_attn_bias_block, zero_attn_bias_block], dim=-1)
], dim=0
)
return attn_mask[None, None]
@torch.no_grad()
def quant_to_vqidx(self, prev_motion, this_motion=None):
seq_len = self.patch_nums[-1]
if this_motion is not None:
all_motion = torch.cat([prev_motion, this_motion], dim=1)
enc_in = self.norm_with_stats(all_motion)
enc_out = self.encoder(enc_in+self.enc_pos_embed, attn_mask=self.attn_mask)
prev_enc_out, this_enc_out = enc_out[:, :seq_len], enc_out[:, seq_len:]
_, prev_code_idx, _ = self.quantize(prev_enc_out)
_, this_code_idx, _ = self.quantize(this_enc_out)
else:
enc_in = self.norm_with_stats(prev_motion)
enc_out = self.encoder(enc_in+self.enc_pos_embed[:, :seq_len], attn_mask=self.attn_mask[:, :, :seq_len, :seq_len])
_, prev_code_idx, _ = self.quantize(enc_out)
this_code_idx = None
return prev_code_idx, this_code_idx
@torch.no_grad()
def flip_quant_to_vqidx(self, prev_motion, this_motion, flip_ratio):
seq_len = self.patch_nums[-1]
all_motion = torch.cat([prev_motion, this_motion], dim=1)
enc_in = self.norm_with_stats(all_motion)
enc_out = self.encoder(enc_in+self.enc_pos_embed, attn_mask=self.attn_mask)
this_enc_out = enc_out[:, seq_len:]
_, this_code_idx = self.quantize.flip_quant_to_vqidx(this_enc_out, flip_ratio)
return this_code_idx
@torch.no_grad()
def vqidx_to_motion(self, prev_code_idx, this_code_idx):
seq_len = self.patch_nums[-1]
prev_vq_out = self.quantize.vqidx_to_feat(prev_code_idx, multi_scale=False)
this_vq_out = self.quantize.vqidx_to_feat(this_code_idx, multi_scale=False)
vq_out = torch.cat([prev_vq_out, this_vq_out], dim=1)
dec_out = self.decoder(vq_out+self.dec_pos_embed, attn_mask=self.attn_mask)
motion_code = self.unnorm_with_stats(dec_out)
return motion_code[:, :seq_len], motion_code[:, seq_len:]
# for training of var model
@torch.no_grad()
def vqidx_to_ms_vqfeat(self, code_idx):
vqfeat = self.quantize.vqidx_to_feat(code_idx, multi_scale=True)
return vqfeat
# for inference of var model
@torch.no_grad()
def vqidx_to_ar_vqfeat(self, pidx, code_idx):
next_ar_vqfeat = self.quantize.vqidx_to_ar_vqfeat(pidx, code_idx)
return next_ar_vqfeat
class TransformerEncoder(nn.Module):
def __init__(self, inp_dim, hidden_dim, code_dim, depth=6, n_heads=8):
super().__init__()
self.inp_mapping = nn.Sequential(
nn.Linear(inp_dim, hidden_dim),
nn.LeakyReLU(0.2, True)
)
self.code_mapping = nn.Linear(hidden_dim, code_dim)
# transformer
blocks = []
for i in range(depth):
blocks += [
SimpleSelfAttention(hidden_dim, n_heads=n_heads),
torch.nn.Sequential(
nn.Linear(hidden_dim, int(1.5 * hidden_dim)),
nn.GELU(approximate='tanh'),
nn.Linear(int(1.5 * hidden_dim), hidden_dim)
)
]
self.encoder_transformer = nn.ModuleList(blocks)
def forward(self, inp_BLC, attn_mask=None):
feat = self.inp_mapping(inp_BLC)
for block in self.encoder_transformer:
if isinstance(block, SimpleSelfAttention):
feat = feat + block(feat, attn_mask)
else:
feat = feat + block(feat)
out = self.code_mapping(feat)
return out
class TransformerDecoder(nn.Module):
def __init__(self, code_dim, hidden_dim, out_dim, depth=6, n_heads=8):
super().__init__()
self.inp_mapping = nn.Sequential(
nn.Linear(code_dim, hidden_dim),
nn.LeakyReLU(0.2, True)
)
self.out_mapping = nn.Linear(hidden_dim, out_dim)
nn.init.xavier_uniform_(self.out_mapping.weight, gain=0.05)
nn.init.constant_(self.out_mapping.bias, 0)
# transformer
blocks = []
for i in range(depth):
blocks += [
SimpleSelfAttention(hidden_dim, n_heads=n_heads),
torch.nn.Sequential(
nn.Linear(hidden_dim, int(1.5 * hidden_dim)),
nn.GELU(approximate='tanh'),
nn.Linear(int(1.5 * hidden_dim), hidden_dim)
)
]
self.decoder_transformer = nn.ModuleList(blocks)
def forward(self, inp_BLC, attn_mask=None):
feat = self.inp_mapping(inp_BLC)
for block in self.decoder_transformer:
if isinstance(block, SimpleSelfAttention):
feat = feat + block(feat, attn_mask)
else:
feat = feat + block(feat)
out = self.out_mapping(feat)
return out
class SimpleSelfAttention(nn.Module):
def __init__(self, hidden_dim, n_heads=8):
super().__init__()
self.n_heads = n_heads
self.scale = int(hidden_dim)**(-0.5)
self.rearrange_qkv = Rearrange("b n (qkv h d) -> qkv b h n d", qkv=3, h=self.n_heads)
self.rearrange_out = Rearrange("b h n d -> b n (h d)")
self.norm = nn.LayerNorm(hidden_dim, eps=1e-5)
self.to_qkv = nn.Linear(hidden_dim, hidden_dim * 3, bias=False)
self.to_out = nn.Linear(hidden_dim, hidden_dim)
def forward(self, x, attn_mask=None):
B, L, C = x.shape
qkv = self.to_qkv(self.norm(x)) # [B, L, C]
q, k, v = self.rearrange_qkv(qkv).unbind(0) # [B, L, C] -> [B, H, L, c]
# compute attention
out = torch.nn.functional.scaled_dot_product_attention(
query=q, key=k, value=v, scale=self.scale, attn_mask=attn_mask, dropout_p=0.0
)
out = self.rearrange_out(out)
out = self.to_out(out)
return out
class MultiScaleBSQ(nn.Module):
def __init__(self, codebook_dim=32, scale_schedule=None):
super().__init__()
# codebook size -> 2 ** codebook_dim
self.codebook_dim = codebook_dim
self.scale_lvls = len(scale_schedule)
self.scale_schedule = scale_schedule
self.bsq_quant = BSQ(codebook_dim=codebook_dim)
def forward(self, f_BTC):
B, T, C = f_BTC.size()
quantized_out, residual = 0., f_BTC
all_losses, all_bit_indices = [], []
for lvl_idx, pt in enumerate(self.scale_schedule):
interpolate_residual = F.interpolate(residual.permute(0, 2, 1), size=(pt), mode='area').permute(0, 2, 1).contiguous() if pt != T else residual
quantized, bit_indices, loss = self.bsq_quant(interpolate_residual)
quantized = F.interpolate(quantized.permute(0, 2, 1), size=(T), mode='linear').permute(0, 2, 1).contiguous() if pt != T else quantized
residual = residual - quantized.detach() # remove_residual_detach = False
quantized_out = quantized_out + quantized
all_bit_indices.append(bit_indices)
all_losses.append(loss)
# stack all losses and indices
all_losses = torch.stack(all_losses, dim=-1)
all_bit_indices = torch.cat(all_bit_indices, dim=1)
return quantized_out, all_bit_indices, all_losses
@torch.no_grad()
def flip_quant_to_vqidx(self, f_BTC, flip_ratio):
B, T, C = f_BTC.size()
quantized_out, residual = 0., f_BTC
all_bit_indices = []
for lvl_idx, pt in enumerate(self.scale_schedule):
interpolate_residual = F.interpolate(residual.permute(0, 2, 1), size=(pt), mode='area').permute(0, 2, 1).contiguous() if pt != T else residual
quantized, bit_indices, _ = self.bsq_quant(interpolate_residual)
mask_flip = torch.rand(bit_indices.shape).to(bit_indices.device) < flip_ratio
pred_bit_indices = bit_indices.clone()
# if lvl_idx < self.scale_lvls-1:
pred_bit_indices[mask_flip] = 1 - pred_bit_indices[mask_flip]
quantized = (pred_bit_indices.float() * 2 - 1.0) / (self.codebook_dim ** 0.5)
quantized = F.interpolate(quantized.permute(0, 2, 1), size=(T), mode='linear').permute(0, 2, 1).contiguous() if pt != T else quantized
residual = residual - quantized.detach() # remove_residual_detach = False
quantized_out = quantized_out + quantized
all_bit_indices.append(pred_bit_indices)
all_bit_indices = torch.cat(all_bit_indices, dim=1)
return quantized_out, all_bit_indices
@torch.no_grad()
def vqidx_to_feat(self, bit_indices, multi_scale=False):
B, T, C = bit_indices.shape[0], self.scale_schedule[-1], self.codebook_dim
ori_h_BTC = (bit_indices.float() * 2 - 1.0) / (self.codebook_dim ** 0.5)
pn_start, pn_next = 0, self.scale_schedule[0]
if multi_scale:
ori_h_BCT = ori_h_BTC.permute(0, 2, 1).contiguous()
f_hat = bit_indices.new_zeros(B, C, T, dtype=torch.float32)
next_scales = []
for pidx in range(self.scale_lvls-1):
h_BCT = F.interpolate(ori_h_BCT[..., pn_start:pn_next], size=(T), mode='linear')
f_hat.add_(h_BCT)
pn_start = pn_next
pn_next = pn_next + self.scale_schedule[pidx+1]
next_scales.append(F.interpolate(f_hat, size=(self.scale_schedule[pidx+1]), mode='area'))
return torch.cat(next_scales, dim=-1).permute(0, 2, 1).contiguous()
else:
f_hat = bit_indices.new_zeros(B, T, C, dtype=torch.float32)
for pidx in range(self.scale_lvls-1):
h_BCT = F.interpolate(ori_h_BTC[:, pn_start:pn_next].permute(0, 2, 1).contiguous(), size=(T), mode='linear')
f_hat.add_(h_BCT.permute(0, 2, 1).contiguous())
pn_start = pn_next
pn_next = pn_next + self.scale_schedule[pidx+1]
f_hat.add_(ori_h_BTC[:, pn_start:])
return f_hat
# ===================== get_next_autoregressive_input: only used in VAR inference, for getting next step's input =====================
@torch.no_grad()
def vqidx_to_ar_vqfeat(self, this_pidx, bit_indices): # only used in VAR inference
B, T, C = bit_indices.shape[0], self.scale_schedule[-1], self.codebook_dim
f_hat = bit_indices.new_zeros(B, C, T, dtype=torch.float32)
ori_h_BTC = (bit_indices.float() * 2 - 1.0) / (self.codebook_dim ** 0.5)
ori_h_BCT = ori_h_BTC.permute(0, 2, 1).contiguous()
pn_start, pn_next = 0, self.scale_schedule[0]
next_scales = []
for pidx in range(this_pidx+1):
h_BCL = F.interpolate(ori_h_BCT[..., pn_start:pn_next], size=(T), mode='linear').contiguous()
f_hat.add_(h_BCL)
pn_start = pn_next
pn_next = pn_next + self.scale_schedule[pidx+1]
next_scales.append(F.interpolate(f_hat.clone(), size=(self.scale_schedule[pidx+1]), mode='area').contiguous())
return torch.cat(next_scales, dim=-1).permute(0, 2, 1).contiguous()
class BSQ(nn.Module):
def __init__(self, codebook_dim=32):
super().__init__()
self.inv_temperature = 100.0
self.commit_loss_weight = 0.2
self.entropy_loss_weight = 0.1
self.codebook_dim = codebook_dim
def forward(self, f_BTC):
f_BTC = F.normalize(f_BTC, dim=-1)
# use straight-through gradients (optionally with custom activation fn) if training
quantized = self.quantize(f_BTC) # B, T, C
# calculate loss
persample_entropy, cb_entropy = self.soft_entropy_loss(f_BTC)
entropy_penalty = (persample_entropy - cb_entropy) / self.inv_temperature
commit_loss = torch.mean(((quantized.detach() - f_BTC) ** 2).sum(dim=-1))
aux_loss = entropy_penalty * self.entropy_loss_weight + commit_loss * self.commit_loss_weight
# gather the indices
bit_indices = (quantized > 0).int() # B, T, C
return quantized, bit_indices, aux_loss
def quantize(self, z):
assert z.shape[-1] == self.codebook_dim, f"Expected {self.codebook_dim} dimensions, got {z.shape[-1]}"
q_scale = 1. / (self.codebook_dim ** 0.5)
zhat = torch.where(z > 0, torch.tensor(1).type_as(z), torch.tensor(-1).type_as(z))
zhat = q_scale * zhat # on unit sphere
return z + (zhat - z).detach()
def soft_entropy_loss(self, z):
def get_entropy(count, dim=-1):
H = -(count * torch.log(count + 1e-8)).sum(dim=dim)
return H
p = torch.sigmoid(-4 * z / (self.codebook_dim ** 0.5) * self.inv_temperature)
prob = torch.stack([p, 1-p], dim=-1) # (b, l, codebook_dim, 2)
per_sample_entropy = get_entropy(prob, dim=-1).sum(dim=-1).mean() # (b,l, codebook_dim)->(b,l)->scalar
# macro average of the probability of each subgroup
avg_prob = prob.mean(dim=[0, 1]) # (codebook_dim, 2)
codebook_entropy = get_entropy(avg_prob, dim=-1)
# the approximation of the entropy is the sum of the entropy of each subgroup
return per_sample_entropy, codebook_entropy.sum()