|
|
|
|
|
|
|
|
|
|
|
from enum import unique |
|
|
import numpy as np |
|
|
|
|
|
import torch |
|
|
import torch.nn as nn |
|
|
import torch.nn.functional as F |
|
|
import torch.distributed as dist |
|
|
|
|
|
from infinity.models.videovae.utils.misc import shift_dim |
|
|
|
|
|
class Codebook(nn.Module): |
|
|
def __init__(self, n_codes, embedding_dim, no_random_restart=False, restart_thres=1.0, usage_sigma=0.99, fp32_quant=False): |
|
|
super().__init__() |
|
|
self.register_buffer('embeddings', torch.randn(n_codes, embedding_dim)) |
|
|
self.register_buffer('N', torch.zeros(n_codes)) |
|
|
self.register_buffer('z_avg', self.embeddings.data.clone()) |
|
|
self.register_buffer('codebook_usage', torch.zeros(n_codes)) |
|
|
|
|
|
self.call_cnt = 0 |
|
|
self.usage_sigma = usage_sigma |
|
|
|
|
|
self.n_codes = n_codes |
|
|
self.embedding_dim = embedding_dim |
|
|
self._need_init = True |
|
|
self.no_random_restart = no_random_restart |
|
|
self.restart_thres = restart_thres |
|
|
|
|
|
self.fp32_quant = fp32_quant |
|
|
|
|
|
def _tile(self, x): |
|
|
d, ew = x.shape |
|
|
if d < self.n_codes: |
|
|
n_repeats = (self.n_codes + d - 1) // d |
|
|
std = 0.01 / np.sqrt(ew) |
|
|
x = x.repeat(n_repeats, 1) |
|
|
x = x + torch.randn_like(x) * std |
|
|
return x |
|
|
|
|
|
def _init_embeddings(self, z): |
|
|
|
|
|
self._need_init = False |
|
|
flat_inputs = shift_dim(z, 1, -1).flatten(end_dim=-2) |
|
|
y = self._tile(flat_inputs) |
|
|
|
|
|
d = y.shape[0] |
|
|
_k_rand = y[torch.randperm(y.shape[0])][:self.n_codes] |
|
|
if dist.is_initialized(): |
|
|
dist.broadcast(_k_rand, 0) |
|
|
self.embeddings.data.copy_(_k_rand) |
|
|
self.z_avg.data.copy_(_k_rand) |
|
|
self.N.data.copy_(torch.ones(self.n_codes)) |
|
|
|
|
|
|
|
|
def calculate_batch_codebook_usage_percentage(self, batch_encoding_indices): |
|
|
|
|
|
all_indices = batch_encoding_indices.flatten() |
|
|
|
|
|
|
|
|
total_indices = all_indices.numel() |
|
|
|
|
|
|
|
|
codebook_usage_percentage = torch.zeros(self.n_codes, device=all_indices.device) |
|
|
|
|
|
|
|
|
unique_indices, counts = torch.unique(all_indices, return_counts=True) |
|
|
|
|
|
percentages = (counts.float() / total_indices) |
|
|
|
|
|
|
|
|
codebook_usage_percentage[unique_indices.long()] = percentages |
|
|
|
|
|
return codebook_usage_percentage |
|
|
|
|
|
|
|
|
|
|
|
def forward(self, z): |
|
|
|
|
|
if self._need_init and self.training: |
|
|
self._init_embeddings(z) |
|
|
flat_inputs = shift_dim(z, 1, -1).flatten(end_dim=-2) |
|
|
|
|
|
distances = (flat_inputs ** 2).sum(dim=1, keepdim=True) \ |
|
|
- 2 * flat_inputs @ self.embeddings.t() \ |
|
|
+ (self.embeddings.t() ** 2).sum(dim=0, keepdim=True) |
|
|
|
|
|
encoding_indices = torch.argmin(distances, dim=1) |
|
|
encode_onehot = F.one_hot(encoding_indices, self.n_codes).type_as(flat_inputs) |
|
|
encoding_indices = encoding_indices.view(z.shape[0], *z.shape[2:]) |
|
|
|
|
|
embeddings = F.embedding(encoding_indices, self.embeddings) |
|
|
embeddings = shift_dim(embeddings, -1, 1) |
|
|
|
|
|
commitment_loss = 0.25 * F.mse_loss(z, embeddings.detach()) |
|
|
|
|
|
|
|
|
if self.training: |
|
|
n_total = encode_onehot.sum(dim=0) |
|
|
encode_sum = flat_inputs.t() @ encode_onehot |
|
|
if dist.is_initialized(): |
|
|
dist.all_reduce(n_total) |
|
|
dist.all_reduce(encode_sum) |
|
|
|
|
|
self.N.data.mul_(0.99).add_(n_total, alpha=0.01) |
|
|
self.z_avg.data.mul_(0.99).add_(encode_sum.t(), alpha=0.01) |
|
|
|
|
|
n = self.N.sum() |
|
|
weights = (self.N + 1e-7) / (n + self.n_codes * 1e-7) * n |
|
|
encode_normalized = self.z_avg / weights.unsqueeze(1) |
|
|
self.embeddings.data.copy_(encode_normalized) |
|
|
|
|
|
y = self._tile(flat_inputs) |
|
|
_k_rand = y[torch.randperm(y.shape[0])][:self.n_codes] |
|
|
if dist.is_initialized(): |
|
|
dist.broadcast(_k_rand, 0) |
|
|
|
|
|
if not self.no_random_restart: |
|
|
usage = (self.N.view(self.n_codes, 1) >= self.restart_thres).float() |
|
|
self.embeddings.data.mul_(usage).add_(_k_rand * (1 - usage)) |
|
|
|
|
|
embeddings_st = (embeddings - z).detach() + z |
|
|
|
|
|
avg_probs = torch.mean(encode_onehot, dim=0) |
|
|
perplexity = torch.exp(-torch.sum(avg_probs * torch.log(avg_probs + 1e-10))) |
|
|
|
|
|
try: |
|
|
usage = self.calculate_batch_codebook_usage_percentage(encoding_indices) |
|
|
except: |
|
|
usage = torch.zeros(self.n_codes, device=encoding_indices.device) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
if self.call_cnt == 0: |
|
|
self.codebook_usage.data = usage |
|
|
else: |
|
|
self.codebook_usage.data = self.usage_sigma * self.codebook_usage.data + (1 - self.usage_sigma) * usage |
|
|
|
|
|
self.call_cnt += 1 |
|
|
|
|
|
avg_usage = (self.codebook_usage.data > (1/self.n_codes)).sum() / self.n_codes |
|
|
|
|
|
return dict(embeddings=embeddings_st, encodings=encoding_indices, |
|
|
commitment_loss=commitment_loss, perplexity=perplexity, avg_usage=avg_usage, batch_usage=usage) |
|
|
|
|
|
def dictionary_lookup(self, encodings): |
|
|
embeddings = F.embedding(encodings, self.embeddings) |
|
|
return embeddings |
|
|
|
|
|
|
|
|
|
|
|
from typing import List, Optional, Tuple, Sequence, Union |
|
|
|
|
|
|
|
|
class ResConvAfterUpsample(nn.Conv3d): |
|
|
def __init__(self, embed_dim, quant_resi): |
|
|
ks = 3 if quant_resi < 0 else 1 |
|
|
super().__init__(in_channels=embed_dim, out_channels=embed_dim, kernel_size=ks, stride=1, padding=ks//2) |
|
|
self.resi_ratio = abs(quant_resi) |
|
|
|
|
|
def forward(self, h_BCthw): |
|
|
return h_BCthw.mul(1-self.resi_ratio) + super().forward(h_BCthw).mul_(self.resi_ratio) |
|
|
|
|
|
|
|
|
class SharedResConvAfterUpsample(nn.Module): |
|
|
def __init__(self, qresi: ResConvAfterUpsample): |
|
|
super().__init__() |
|
|
self.qresi: ResConvAfterUpsample = qresi |
|
|
|
|
|
def __getitem__(self, _) -> ResConvAfterUpsample: |
|
|
return self.qresi |
|
|
|
|
|
|
|
|
class ResConvAfterUpsampleList(nn.Module): |
|
|
def __init__(self, qresi_ls: nn.ModuleList): |
|
|
super().__init__() |
|
|
self.qresi_ls = qresi_ls |
|
|
K = len(qresi_ls) |
|
|
self.ticks = np.linspace(1/3/K, 1-1/3/K, K) if K == 4 else np.linspace(1/2/K, 1-1/2/K, K) |
|
|
|
|
|
def __getitem__(self, at_from_0_to_1: float) -> ResConvAfterUpsample: |
|
|
return self.qresi_ls[np.argmin(np.abs(self.ticks - at_from_0_to_1)).item()] |
|
|
|
|
|
def extra_repr(self) -> str: |
|
|
return f'ticks={self.ticks}' |
|
|
|
|
|
|
|
|
class ResConvAfterUpsampleModuleList(nn.ModuleList): |
|
|
def __init__(self, qresi: List): |
|
|
super().__init__(qresi) |
|
|
|
|
|
K = len(qresi) |
|
|
self.ticks = np.linspace(1/3/K, 1-1/3/K, K) if K == 4 else np.linspace(1/2/K, 1-1/2/K, K) |
|
|
|
|
|
def __getitem__(self, at_from_0_to_1: float) -> ResConvAfterUpsample: |
|
|
return super().__getitem__(np.argmin(np.abs(self.ticks - at_from_0_to_1)).item()) |
|
|
|
|
|
def extra_repr(self) -> str: |
|
|
return f'ticks={self.ticks}' |
|
|
|
|
|
class MultiScaleCodebook(nn.Module): |
|
|
def __init__(self, n_codes, |
|
|
embedding_dim, no_random_restart=False, |
|
|
restart_thres=1.0, usage_sigma=0.99, fp32_quant=False, |
|
|
quant_resi = -0.5, share_quant_resi = 4, default_qresi_counts = 10, |
|
|
t_patch_nums = (1, 1, 2, 2, 2, 4, 4, 4, 4, 4), |
|
|
v_patch_nums = (1, 2, 3, 4, 5, 6, 8, 10, 13, 16), |
|
|
): |
|
|
super().__init__() |
|
|
self.register_buffer('embeddings', torch.randn(n_codes, embedding_dim)) |
|
|
self.register_buffer('N', torch.zeros(n_codes)) |
|
|
self.register_buffer('z_avg', self.embeddings.data.clone()) |
|
|
self.register_buffer('codebook_usage', torch.zeros(n_codes)) |
|
|
|
|
|
self.call_cnt = 0 |
|
|
self.usage_sigma = usage_sigma |
|
|
|
|
|
self.n_codes = n_codes |
|
|
self.embedding_dim = embedding_dim |
|
|
self._need_init = True |
|
|
self.no_random_restart = no_random_restart |
|
|
self.restart_thres = restart_thres |
|
|
|
|
|
self.fp32_quant = fp32_quant |
|
|
|
|
|
|
|
|
|
|
|
self.t_patch_nums = t_patch_nums |
|
|
self.v_patch_nums = v_patch_nums |
|
|
self.quant_resi_ratio = quant_resi |
|
|
|
|
|
if share_quant_resi == 1: |
|
|
self.quant_resi = SharedResConvAfterUpsample(ResConvAfterUpsample(embedding_dim, quant_resi) if abs(quant_resi) > 1e-6 else nn.Identity()) |
|
|
elif share_quant_resi == 0: |
|
|
self.quant_resi = ResConvAfterUpsampleModuleList([(ResConvAfterUpsample(embedding_dim, quant_resi) if abs(quant_resi) > 1e-6 else nn.Identity()) for _ in range(default_qresi_counts or len(self.v_patch_nums))]) |
|
|
else: |
|
|
self.quant_resi = ResConvAfterUpsampleList(nn.ModuleList([(ResConvAfterUpsample(embedding_dim, quant_resi) if abs(quant_resi) > 1e-6 else nn.Identity()) for _ in range(share_quant_resi)])) |
|
|
|
|
|
self.z_interplote_down = 'area' |
|
|
self.z_interplote_up = 'trilinear' |
|
|
|
|
|
|
|
|
|
|
|
def _tile(self, x): |
|
|
d, ew = x.shape |
|
|
if d < self.n_codes: |
|
|
n_repeats = (self.n_codes + d - 1) // d |
|
|
std = 0.01 / np.sqrt(ew) |
|
|
x = x.repeat(n_repeats, 1) |
|
|
x = x + torch.randn_like(x) * std |
|
|
return x |
|
|
|
|
|
def _init_embeddings(self, z): |
|
|
|
|
|
self._need_init = False |
|
|
flat_inputs = shift_dim(z, 1, -1).flatten(end_dim=-2) |
|
|
y = self._tile(flat_inputs) |
|
|
|
|
|
d = y.shape[0] |
|
|
_k_rand = y[torch.randperm(y.shape[0])][:self.n_codes] |
|
|
if dist.is_initialized(): |
|
|
dist.broadcast(_k_rand, 0) |
|
|
self.embeddings.data.copy_(_k_rand) |
|
|
self.z_avg.data.copy_(_k_rand) |
|
|
self.N.data.copy_(torch.ones(self.n_codes)) |
|
|
|
|
|
|
|
|
def calculate_batch_codebook_usage_percentage(self, batch_encoding_indices): |
|
|
|
|
|
all_indices = batch_encoding_indices.flatten() |
|
|
|
|
|
|
|
|
total_indices = all_indices.numel() |
|
|
|
|
|
|
|
|
codebook_usage_percentage = torch.zeros(self.n_codes, device=all_indices.device) |
|
|
|
|
|
|
|
|
unique_indices, counts = torch.unique(all_indices, return_counts=True) |
|
|
|
|
|
percentages = (counts.float() / total_indices) |
|
|
|
|
|
|
|
|
codebook_usage_percentage[unique_indices.long()] = percentages |
|
|
|
|
|
return codebook_usage_percentage |
|
|
|
|
|
|
|
|
|
|
|
def forward(self, z): |
|
|
|
|
|
if self._need_init and self.training: |
|
|
self._init_embeddings(z) |
|
|
|
|
|
|
|
|
B, C, T, H, W = z.shape |
|
|
|
|
|
z_no_grad = z.detach() |
|
|
accu_h = torch.zeros_like(z_no_grad) |
|
|
|
|
|
|
|
|
if self.training: |
|
|
all_flat_inputs, all_encode_onehot = [], [] |
|
|
|
|
|
commitment_loss = 0.0 |
|
|
scale_num = len(self.v_patch_nums) |
|
|
ms_encoding_indices = [] |
|
|
|
|
|
|
|
|
with torch.cuda.amp.autocast(enabled=False): |
|
|
|
|
|
for si, (tpn, pn) in enumerate(zip(self.t_patch_nums, self.v_patch_nums)): |
|
|
tpn = min(tpn, T) |
|
|
|
|
|
|
|
|
rest_z = z_no_grad - accu_h.data |
|
|
|
|
|
if si != scale_num - 1: |
|
|
rest_z = F.interpolate(rest_z, size=(tpn, pn, pn), mode=self.z_interplote_down) |
|
|
|
|
|
z_NC = rest_z.permute(0, 2, 3, 4, 1).reshape(-1, C) |
|
|
|
|
|
|
|
|
d_no_grad = torch.sum(z_NC.square(), dim=1, keepdim=True) + torch.sum(self.embeddings.square(), dim=1, keepdim=False) |
|
|
d_no_grad.addmm_(z_NC, self.embeddings.t(), alpha=-2, beta=1) |
|
|
|
|
|
|
|
|
encoding_indices = torch.argmin(d_no_grad, dim=1) |
|
|
encode_onehot = F.one_hot(encoding_indices, self.n_codes).type_as(z_NC) |
|
|
encoding_indices = encoding_indices.view(rest_z.shape[0], *rest_z.shape[2:]) |
|
|
|
|
|
ms_encoding_indices.append(encoding_indices) |
|
|
|
|
|
|
|
|
h_BTHWC = F.embedding(encoding_indices, self.embeddings) |
|
|
h_BCTHW = h_BTHWC.permute(0, 4, 1, 2, 3).contiguous() |
|
|
|
|
|
|
|
|
|
|
|
h_BCTHW = F.interpolate(h_BCTHW, size=(T, H, W), mode=self.z_interplote_up).contiguous() |
|
|
|
|
|
|
|
|
quant_head = si / max(1, (scale_num - 1)) |
|
|
h_BCTHW = self.quant_resi[quant_head](h_BCTHW) |
|
|
|
|
|
|
|
|
accu_h = accu_h + h_BCTHW |
|
|
|
|
|
commitment_loss += 0.25 * F.mse_loss(accu_h, z.detach()) |
|
|
|
|
|
if self.training: |
|
|
all_flat_inputs.append(z_NC) |
|
|
all_encode_onehot.append(encode_onehot) |
|
|
|
|
|
if self.training: |
|
|
|
|
|
encode_onehot = torch.cat(all_encode_onehot, dim=0) |
|
|
flat_inputs = torch.cat(all_flat_inputs, dim=0) |
|
|
|
|
|
n_total = encode_onehot.sum(dim=0) |
|
|
encode_sum = flat_inputs.t() @ encode_onehot |
|
|
if dist.is_initialized(): |
|
|
dist.all_reduce(n_total) |
|
|
dist.all_reduce(encode_sum) |
|
|
|
|
|
self.N.data.mul_(0.99).add_(n_total, alpha=0.01) |
|
|
self.z_avg.data.mul_(0.99).add_(encode_sum.t(), alpha=0.01) |
|
|
|
|
|
n = self.N.sum() |
|
|
weights = (self.N + 1e-7) / (n + self.n_codes * 1e-7) * n |
|
|
encode_normalized = self.z_avg / weights.unsqueeze(1) |
|
|
self.embeddings.data.copy_(encode_normalized) |
|
|
|
|
|
y = self._tile(flat_inputs) |
|
|
_k_rand = y[torch.randperm(y.shape[0])][:self.n_codes] |
|
|
if dist.is_initialized(): |
|
|
dist.broadcast(_k_rand, 0) |
|
|
|
|
|
if not self.no_random_restart: |
|
|
usage = (self.N.view(self.n_codes, 1) >= self.restart_thres).float() |
|
|
self.embeddings.data.mul_(usage).add_(_k_rand * (1 - usage)) |
|
|
|
|
|
commitment_loss *= 1.0 / scale_num |
|
|
embeddings_st = (accu_h - z_no_grad).detach() + z |
|
|
|
|
|
avg_probs = torch.mean(encode_onehot, dim=0) |
|
|
perplexity = torch.exp(-torch.sum(avg_probs * torch.log(avg_probs + 1e-10))) |
|
|
|
|
|
try: |
|
|
usage = self.calculate_batch_codebook_usage_percentage(encoding_indices) |
|
|
except: |
|
|
usage = torch.zeros(self.n_codes, device=encoding_indices.device) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
if self.call_cnt == 0: |
|
|
self.codebook_usage.data = usage |
|
|
else: |
|
|
self.codebook_usage.data = self.usage_sigma * self.codebook_usage.data + (1 - self.usage_sigma) * usage |
|
|
|
|
|
self.call_cnt += 1 |
|
|
|
|
|
avg_usage = (self.codebook_usage.data > (1/self.n_codes)).sum() / self.n_codes |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
return dict(embeddings=embeddings_st, encodings=ms_encoding_indices, |
|
|
commitment_loss=commitment_loss, perplexity=perplexity, avg_usage=avg_usage, batch_usage=usage) |
|
|
|
|
|
def dictionary_lookup(self, encodings): |
|
|
embeddings = F.embedding(encodings, self.embeddings) |
|
|
return embeddings |
|
|
|
|
|
|