NMR / src /models /tokenizers /quantizer /ms_ema_reset.py
Xxx999's picture
upload
45950ff
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
from einops import rearrange
from mmengine.registry import MODELS
# Ref: https://github.com/snap-research/SnapMoGen/blob/main/model/vq/quantizer.py#L330
@MODELS.register_module()
class MultiScaleQuantizeEMAReset(nn.Module):
def __init__(self, nb_code, code_dim, mu, scales, share_quant_resi=4, quant_resi=0.5,
temperature=0.5, start_drop=1, quantize_dropout_prob=0.):
super(MultiScaleQuantizeEMAReset, self).__init__()
self.nb_code = nb_code
self.code_dim = code_dim
self.mu = mu # TO_DO
self.scales = scales
self.reset_codebook()
self.temperature = temperature
self.start_drop = start_drop
self.quantize_dropout_prob = quantize_dropout_prob
self.quant_resi_ratio = quant_resi
if share_quant_resi == 0: # non-shared: \phi_{1 to K} for K scales
self.quant_resi = PhiNonShared(
[(Phi(code_dim, quant_resi) if abs(quant_resi) > 1e-6 else nn.Identity()) for _ in
range(len(self.scales))])
elif share_quant_resi == 1: # fully shared: only a single \phi for K scales
self.quant_resi = PhiShared(Phi(code_dim, quant_resi) if abs(quant_resi) > 1e-6 else nn.Identity()) # type: ignore
else: # partially shared: \phi_{1 to share_quant_resi} for K scales
self.quant_resi = PhiPartiallyShared(nn.ModuleList(
[(Phi(code_dim, quant_resi) if abs(quant_resi) > 1e-6 else nn.Identity())
for _ in range(share_quant_resi)]))
def reset_codebook(self):
self.init = False
self.code_sum = None
self.code_count = None
self.register_buffer('codebook', torch.zeros(self.nb_code, self.code_dim, requires_grad=False))
def _tile(self, x):
nb_code_x, code_dim = x.shape
if nb_code_x < self.nb_code:
n_repeats = (self.nb_code + nb_code_x - 1) // nb_code_x
std = 0.01 / np.sqrt(code_dim)
out = x.repeat(n_repeats, 1)
out = out + torch.randn_like(out) * std
else:
out = x
return out
def init_codebook(self, x):
out = self._tile(x)
self.codebook = out[:self.nb_code]
self.code_sum = self.codebook.clone()
self.code_count = torch.ones(self.nb_code, device=self.codebook.device)
self.init = True
def quantize(self, x, sample_codebook_temp=0.):
# N X C -> C X N
k_w = self.codebook.t()
# x: NT X C
# NT X N
distance = torch.sum(x ** 2, dim=-1, keepdim=True) - \
2 * torch.matmul(x, k_w) + \
torch.sum(k_w ** 2, dim=0, keepdim=True) # (N * L, b)
# code_idx = torch.argmin(distance, dim=-1)
code_idx = gumbel_sample(-distance, dim = -1, temperature = sample_codebook_temp, stochastic=True, training = self.training)
return code_idx
def dequantize(self, code_idx):
mask = code_idx == -1.
code_idx = code_idx.masked_fill(mask, 0.)
x = F.embedding(code_idx, self.codebook)
x[mask] = 0.
return x
def get_codebook_entry(self, indices):
return self.dequantize(indices).permute(0, 2, 1)
@torch.no_grad()
def compute_perplexity(self, code_idx):
# Calculate new centres
code_onehot = torch.zeros(self.nb_code, code_idx.shape[0], device=code_idx.device) # nb_code, N * L
code_onehot.scatter_(0, code_idx.view(1, code_idx.shape[0]), 1)
code_count = code_onehot.sum(dim=-1) # nb_code
prob = code_count / torch.sum(code_count)
perplexity = torch.exp(-torch.sum(prob * torch.log(prob + 1e-7)))
return perplexity
@torch.no_grad()
def update_codebook(self, x, code_idx):
# print(x.shape, self.codebook.shape)
code_onehot = torch.zeros(self.nb_code, x.shape[0], device=x.device) # nb_code, N * L
code_onehot.scatter_(0, code_idx.view(1, x.shape[0]), 1)
# print(code_onehot.shape, x.shape, self.codebook.shape)
code_sum = torch.matmul(code_onehot, x) # nb_code, c
code_count = code_onehot.sum(dim=-1) # nb_code
out = self._tile(x)
code_rand = out[:self.nb_code]
# Update centres
self.code_sum = self.mu * self.code_sum + (1. - self.mu) * code_sum
self.code_count = self.mu * self.code_count + (1. - self.mu) * code_count
usage = (self.code_count.view(self.nb_code, 1) >= 1.0).float()
code_update = self.code_sum.view(self.nb_code, self.code_dim) / self.code_count.view(self.nb_code, 1)
if len(code_idx) > self.nb_code * 5:
self.codebook = usage * code_update + (1-usage) * code_rand
else:
self.codebook = code_update
prob = code_count / torch.sum(code_count)
perplexity = torch.exp(-torch.sum(prob * torch.log(prob + 1e-7)))
return perplexity
def quantize_all(self, x, m_lens=None, return_latent=False):
N, width, T = x.shape
residual = x.clone()
f_hat = torch.zeros_like(x)
idx_list = []
if m_lens is not None:
full_scale_mask = length_to_mask(m_lens, x.shape[-1])
else:
full_scale_mask = torch.ones(x.shape[:-1], device=x.device).bool()
for i, scale in enumerate(self.scales):
residual = residual * full_scale_mask.unsqueeze(1)
# all_mask.append(rearrange(mask, 'n t -> (n t)'))
if scale != 1:
rest_down = F.interpolate(residual, size=int(T//scale), mode='area')
else:
rest_down = residual
if m_lens is not None:
mask = length_to_mask((m_lens//scale).long(), rest_down.shape[-1]) # (n t)
mask = rearrange(mask, 'n t -> (n t)')
rest_down = rearrange(rest_down, 'n c t -> (n t) c')
code_idx = self.quantize(rest_down)
x_d = self.dequantize(code_idx)
x_d[~mask] = 0
code_idx[~mask] = -1
idx_list.append(rearrange(code_idx, '(n t) -> n t', n=N))
x_d = rearrange(x_d, '(n t) c -> n c t', n=N)
up_x_d = F.interpolate(x_d, size=T, mode='linear')
if len(self.scales) > 1:
up_x_d = self.quant_resi[i / (len(self.scales) -1)](up_x_d)
# up_x_d = self.quant_resi[i / (len(self.scales) -1)](up_x_d)
residual -= up_x_d
f_hat += up_x_d
if return_latent:
return idx_list, f_hat
return idx_list
def get_codes_from_indices(self, indices_list):
assert len(indices_list) == len(self.scales)
T = indices_list[-1].shape[-1]
code = 0.0
for i, (indices, scale) in enumerate(zip(indices_list, self.scales)):
N, _ = indices.shape
indices = rearrange(indices, 'n t->(n t)')
x_d = self.dequantize(indices)
x_d = rearrange(x_d, '(n t) d -> n d t', n=N)
up_x_d = F.interpolate(x_d, size=T, mode='linear')
if len(self.scales) > 1:
up_x_d = self.quant_resi[i / (len(self.scales) -1)](up_x_d)
code += up_x_d
return code.permute(0, 2, 1)
def forward(self, x, motion_length):
N, width, T = x.shape
residual = x.clone()
f_hat = torch.zeros_like(x)
mean_vq_loss = 0.
full_scale_mask = length_to_mask(motion_length, x.shape[-1])
all_rest_down = []
all_code_indices = []
all_mask = []
if self.training and self.quantize_dropout_prob != 0:
n_quantizers = torch.randint(self.start_drop, len(self.scales) + 1, (N, ))
n_dropout = int(N * self.quantize_dropout_prob)
n_quantizers[n_dropout:] = len(self.scales) + 1
n_quantizers = n_quantizers.to(x.device)
else:
n_quantizers = torch.full((N,), len(self.scales)+1, device=x.device)
for i, scale in enumerate(self.scales):
# if should_quantize_dropout and
residual = residual * full_scale_mask.unsqueeze(1)
keep_mask = (torch.full((N,), fill_value=i, device=x.device) < n_quantizers) # 1:keep, 0:drop
if scale != 1:
rest_down = F.interpolate(residual, size=int(T//scale), mode='area')
else:
rest_down = residual
mask = length_to_mask((motion_length//scale).long(), rest_down.shape[-1]) # (n t)
mask = mask & keep_mask[:, None]
all_mask.append(rearrange(mask, 'n t -> (n t)'))
rest_down = rearrange(rest_down, 'n c t -> (n t) c')
if self.training and not self.init:
self.init_codebook(rest_down[all_mask[-1]])
code_idx = self.quantize(rest_down, self.temperature)
x_d = self.dequantize(code_idx)
x_d[~all_mask[-1]] = 0.
all_rest_down.append(rest_down)
all_code_indices.append(code_idx)
x_d = rearrange(x_d, '(n t) c -> n c t', n=N)
up_x_d = F.interpolate(x_d, size=T, mode='linear')
if len(self.scales) > 1:
up_x_d = self.quant_resi[i / (len(self.scales) -1)](up_x_d)
up_x_d[~keep_mask] = 0.
residual -= up_x_d
f_hat = f_hat + up_x_d
loss_mask = full_scale_mask & keep_mask[:, None]
mean_vq_loss += mean_flat((x-f_hat.detach()).pow(2), loss_mask.unsqueeze(1))
all_code_indices = torch.cat(all_code_indices, dim=0)
all_rest_down = torch.cat(all_rest_down, dim=0)
all_mask = torch.cat(all_mask, dim=0)
all_code_indices = all_code_indices[all_mask]
all_rest_down = all_rest_down[all_mask]
if self.training:
perplexity = self.update_codebook(all_rest_down, all_code_indices)
else:
perplexity = self.compute_perplexity(all_code_indices)
mean_vq_loss /= len(self.scales)
f_hat = x + (f_hat - x).detach()
return f_hat, mean_vq_loss, perplexity, 0, all_code_indices
# ===================== idxBl_to_var_input: only used in VAR training, for getting teacher-forcing input =====================
def idx_to_var_input(self, indices_list):
assert len(indices_list) == len(self.scales)
T = indices_list[-1].shape[-1]
code = 0.0
next_scale_input = []
for i in range(len(indices_list) - 1):
indices = indices_list[i]
N, _ = indices.shape
indices = rearrange(indices, 'n t->(n t)')
x_d = self.dequantize(indices)
x_d = rearrange(x_d, '(n t) d -> n d t', n=N)
up_x_d = F.interpolate(x_d, size=T, mode='linear')
up_x_d = self.quant_resi[i / (len(self.scales) -1)](up_x_d)
code += up_x_d
next_scale = F.interpolate(code, size=int(T//self.scales[i+1]), mode='linear')
next_scale_input.append(next_scale)
return torch.cat(next_scale_input, dim=-1).permute(0, 2, 1)
# ===================== get_next_var_input: only used in VAR inference, for getting next step's input =====================
def get_next_var_input(self, level, indices, code, T):
N, _ = indices.shape
indices = rearrange(indices, 'n t -> (n t)')
x_d = self.dequantize(indices)
x_d = rearrange(x_d, '(n t) d -> n d t', n=N)
if level != len(self.scales) - 1:
up_x_d = F.interpolate(x_d, size=T, mode='linear')
up_x_d = self.quant_resi[level / (len(self.scales) -1)](up_x_d)
code += up_x_d
next_scale = F.interpolate(code, size=int(T//self.scales[level+1]), mode='linear')
else:
code += x_d
next_scale = code
return code, next_scale
class Phi(nn.Conv1d):
def __init__(self, embed_dim, quant_resi):
ks = 3
super().__init__(in_channels=embed_dim, out_channels=embed_dim, kernel_size=ks, stride=1, padding=ks // 2)
self.resi_ratio = abs(quant_resi)
def forward(self, h_BChw): # type: ignore
return h_BChw.mul(1 - self.resi_ratio) + super().forward(h_BChw).mul_(self.resi_ratio)
class PhiShared(nn.Module):
def __init__(self, qresi: Phi):
super().__init__()
self.qresi: Phi = qresi
def __getitem__(self, _) -> Phi:
return self.qresi
class PhiPartiallyShared(nn.Module):
def __init__(self, qresi_ls: nn.ModuleList):
super().__init__()
self.qresi_ls = qresi_ls
K = len(qresi_ls)
self.ticks = np.linspace(1 / 3 / K, 1 - 1 / 3 / K, K) if K == 4 else np.linspace(1 / 2 / K, 1 - 1 / 2 / K, K)
def __getitem__(self, at_from_0_to_1: float) -> Phi:
return self.qresi_ls[np.argmin(np.abs(self.ticks - at_from_0_to_1)).item()] # type: ignore
def extra_repr(self) -> str:
return f'ticks={self.ticks}'
class PhiNonShared(nn.ModuleList):
def __init__(self, qresi):
super().__init__(qresi)
# self.qresi = qresi
K = len(qresi)
self.ticks = np.linspace(1 / 3 / K, 1 - 1 / 3 / K, K) if K == 4 else np.linspace(1 / 2 / K, 1 - 1 / 2 / K, K)
def __getitem__(self, at_from_0_to_1: float) -> Phi: # type: ignore
return super().__getitem__(np.argmin(np.abs(self.ticks - at_from_0_to_1)).item()) # type: ignore
def extra_repr(self) -> str:
return f'ticks={self.ticks}'
def gumbel_sample(
logits,
temperature = 1.,
stochastic = False,
dim = -1,
training = True
):
if training and stochastic and temperature > 0:
sampling_logits = (logits / temperature) + gumbel_noise(logits)
else:
sampling_logits = logits
ind = sampling_logits.argmax(dim = dim)
return ind
def gumbel_noise(t):
noise = torch.zeros_like(t).uniform_(0, 1)
return -log(-log(noise))
def log(t, eps = 1e-20):
return torch.log(t.clamp(min = eps))
def length_to_mask(length, max_len=None, device: torch.device = None) -> torch.Tensor:
if device is None:
device = length.device
if isinstance(length, list):
length = torch.tensor(length)
if max_len is None:
max_len = max(length)
length = length.to(device)
mask = torch.arange(max_len, device=device).expand(len(length), max_len).to( # type: ignore
device) < length.unsqueeze(1)
return mask
def mean_flat(tensor: torch.Tensor, mask=None):
"""
Take the mean over all non-batch dimensions.
"""
if mask is None:
return tensor.mean(dim=list(range(1, len(tensor.shape))))
else:
# mask = mask.unsqueeze(2) # [B, T] -> [T, B, 1]
assert tensor.dim() == 3
denom = mask.sum() * tensor.shape[-1]
loss = (tensor * mask).sum() / denom
return loss