| | """ |
| | Binary Spherical Quantization |
| | Proposed in https://arxiv.org/abs/2406.07548 |
| | |
| | In the simplest setup, each dimension is quantized into {-1, 1}. |
| | An entropy penalty is used to encourage utilization. |
| | """ |
| | import json |
| | import random |
| | import copy |
| | from math import log2, ceil |
| | from functools import partial, cache |
| | from collections import namedtuple |
| | from contextlib import nullcontext |
| |
|
| | import torch.distributed as dist |
| | from torch.distributed import nn as dist_nn |
| |
|
| | import torch |
| | from torch import nn, einsum |
| | import torch.nn.functional as F |
| | from torch.nn import Module |
| | from torch.amp import autocast |
| | import numpy as np |
| |
|
| | from einops import rearrange, reduce, pack, unpack |
| |
|
| | |
| |
|
| | from infinity.models.videovae.utils.dynamic_resolution import predefined_HW_Scales_dynamic |
| | from infinity.models.videovae.utils.dynamic_resolution_two_pyramid import dynamic_resolution_thw, total_pixels2scales |
| | from infinity.models.videovae.modules.quantizer.finite_scalar_quantization import FSQ |
| |
|
| | |
| |
|
| | |
| |
|
| | Return = namedtuple('Return', ['quantized', 'indices', 'entropy_aux_loss']) |
| |
|
| | LossBreakdown = namedtuple('LossBreakdown', ['per_sample_entropy', 'batch_entropy', 'commitment']) |
| |
|
| | |
| |
|
| | @cache |
| | def is_distributed(): |
| | return dist.is_initialized() and dist.get_world_size() > 1 |
| |
|
| | def maybe_distributed_mean(t): |
| | if not is_distributed(): |
| | return t |
| |
|
| | dist_nn.all_reduce(t) |
| | t = t / dist.get_world_size() |
| | return t |
| |
|
| | |
| |
|
| | def exists(v): |
| | return v is not None |
| |
|
| | def identity(t): |
| | return t |
| |
|
| | def default(*args): |
| | for arg in args: |
| | if exists(arg): |
| | return arg() if callable(arg) else arg |
| | return None |
| |
|
| | def round_up_multiple(num, mult): |
| | return ceil(num / mult) * mult |
| |
|
| | def pack_one(t, pattern): |
| | return pack([t], pattern) |
| |
|
| | def unpack_one(t, ps, pattern): |
| | return unpack(t, ps, pattern)[0] |
| |
|
| | def l2norm(t): |
| | return F.normalize(t, dim = -1) |
| |
|
| | |
| |
|
| | def log(t, eps = 1e-5): |
| | return t.clamp(min = eps).log() |
| |
|
| | def entropy(prob): |
| | return (-prob * log(prob)).sum(dim=-1) |
| |
|
| | |
| |
|
| | class CosineSimLinear(Module): |
| | def __init__( |
| | self, |
| | dim_in, |
| | dim_out, |
| | scale = 1. |
| | ): |
| | super().__init__() |
| | self.scale = scale |
| | self.weight = nn.Parameter(torch.randn(dim_in, dim_out)) |
| |
|
| | def forward(self, x): |
| | x = F.normalize(x, dim = -1) |
| | w = F.normalize(self.weight, dim = 0) |
| | return (x @ w) * self.scale |
| |
|
| | def repeat_schedule(scale_schedule, repeat_scales_num, times): |
| | new_scale_schedule = [] |
| | for i in range(repeat_scales_num): |
| | new_scale_schedule.extend([scale_schedule[i] for _ in range(times)]) |
| | new_scale_schedule.extend(scale_schedule[repeat_scales_num:]) |
| | return new_scale_schedule |
| |
|
| | def get_latent2scale_schedule(T: int, H: int, W: int, mode="original", last_scale_repeat_n=0, args=None): |
| | predefined_HW_Scales = {} |
| | if mode.startswith("infinity_video_two_pyramid"): |
| | if 'elegant' in mode: |
| | base_scale_schedule = copy.deepcopy(dynamic_resolution_thw[(H, W)]['scales']) |
| | image_scale_repetition = json.loads(args.image_scale_repetition) |
| | video_scale_repetition = json.loads(args.video_scale_repetition) |
| | |
| | base_scale_schedule = copy.deepcopy(dynamic_resolution_thw[(H, W)]['scales']) |
| | def repeat_scales(base_scale_schedule, scale_repetition): |
| | scale_schedule = [] |
| | for i in range(len(base_scale_schedule)): |
| | scale_schedule.extend([base_scale_schedule[i] for _ in range(scale_repetition[i])]) |
| | return scale_schedule |
| | image_scale_schedule = repeat_scales(base_scale_schedule, image_scale_repetition) |
| | spatial_time_schedule = [] |
| | spatial_time_schedule.extend(image_scale_schedule) |
| | firstframe_scalecnt = len(image_scale_schedule) |
| | if T > 1: |
| | scale_schedule = repeat_scales(base_scale_schedule, video_scale_repetition) |
| | spatial_time_schedule.extend([(T-1, h, w) for i, (_, h, w) in enumerate(scale_schedule)]) |
| | |
| | tower_split_index = firstframe_scalecnt |
| | |
| | return spatial_time_schedule, tower_split_index |
| | if "motion_boost_v2" in mode: |
| | times = 6 |
| | base_scale_schedule = copy.deepcopy(dynamic_resolution_thw[(H, W)]['scales']) |
| | image_scale_schedule = repeat_schedule(base_scale_schedule, 3, times) |
| | spatial_time_schedule = [] |
| | spatial_time_schedule.extend(image_scale_schedule) |
| | firstframe_scalecnt = len(image_scale_schedule) |
| | if T > 1: |
| | scale_schedule = repeat_schedule(base_scale_schedule, 7, times) |
| | predefined_t = [T - 1 for _ in range(len(scale_schedule))] |
| | spatial_time_schedule.extend([(min(int(np.round(predefined_t[i])), T - 1), h, w) for i, (_, h, w) in enumerate(scale_schedule)]) |
| | |
| | spatial_time_schedule_double = [(t, 2*h, 2*w) for (t, h, w) in spatial_time_schedule] |
| | tower_split_index = firstframe_scalecnt |
| | return spatial_time_schedule_double, tower_split_index |
| | spatial_time_schedule = copy.deepcopy(dynamic_resolution_thw[(H, W)]['scales']) |
| | spatial_time_schedule.extend(spatial_time_schedule[-1:] * last_scale_repeat_n) |
| | tower_split_index = dynamic_resolution_thw[(H, W)]['tower_split_index'] + last_scale_repeat_n |
| | if T > 1: |
| | |
| | if mode == "infinity_video_two_pyramid_full_time": |
| | spatial_time_schedule.extend([(T - 1, h, w) for i, (_, h, w) in enumerate(spatial_time_schedule)]) |
| | else: |
| | predefined_t = np.linspace(1, T - 1, total_pixels2scales['0.06M']-3).tolist() + [T - 1] * (len(spatial_time_schedule)-total_pixels2scales['0.06M']+3) |
| | spatial_time_schedule.extend([(min(int(np.round(predefined_t[i])), T - 1), h, w) for i, (_, h, w) in enumerate(spatial_time_schedule)]) |
| | spatial_time_schedule.extend(spatial_time_schedule[-1:] * last_scale_repeat_n) |
| | |
| | spatial_time_schedule_double = [(t, 2*h, 2*w) for (t, h, w) in spatial_time_schedule] |
| | return spatial_time_schedule_double, tower_split_index |
| | if mode == "original": |
| | predefined_HW_Scales = { |
| | |
| | (16, 16): [(1, 1), (2, 2), (3, 3), (4, 4), (5, 5), (6, 6), (8, 8), (10, 10), (13, 13), (16, 16)], |
| | (36, 64): [(1, 1), (2, 2), (3, 3), (4, 4), (6, 6), (9, 12), (13, 16), (18, 24), (24, 32), (32, 48), (36, 64)], |
| | (18, 32): [(1, 1), (2, 2), (3, 3), (4, 4), (6, 8), (8, 10), (10, 14), (12, 18), (14, 22), (16, 26), (18, 32)], |
| | (30, 53): [(1, 1), (2, 2), (3, 3), (4, 7), (6, 11), (8, 14), (12, 21), (16, 28), (20, 35), (22, 39), (24, 42), (26, 46), (28, 50), (30, 53)] |
| | } |
| | predefined_HW_Scales[(32, 32)] = predefined_HW_Scales[(16, 16)] + [(20, 20), (24, 24), (32, 32)] |
| | predefined_HW_Scales[(64, 64)] = predefined_HW_Scales[(32, 32)] + [(40, 40), (48, 48), (64, 64)] |
| | elif mode == "dynamic": |
| | predefined_HW_Scales.update(predefined_HW_Scales_dynamic) |
| | elif mode == "dense": |
| | predefined_HW_Scales[(16, 16)] = [(x, x) for x in range(1, 16+1)] |
| | predefined_HW_Scales[(32, 32)] = predefined_HW_Scales[(16, 16)] + [(20, 20), (24, 24), (28, 28), (32, 32)] |
| | predefined_HW_Scales[(64, 64)] = predefined_HW_Scales[(32, 32)] + [(40, 40), (48, 48), (56, 56), (64, 64)] |
| | elif mode == "dense_f8": |
| | |
| | predefined_HW_Scales[(32, 32)] = [(x, x) for x in range(1, 16+1)] + [(20, 20), (24, 24), (28, 28), (32, 32)] |
| | predefined_HW_Scales[(64, 64)] = predefined_HW_Scales[(32, 32)] + [(40, 40), (48, 48), (56, 56), (64, 64)] |
| | predefined_HW_Scales[(128, 128)] = predefined_HW_Scales[(64, 64)] + [(80, 80), (96, 96), (112, 112), (128, 128)] |
| | elif mode == "dense_f8_double": |
| | |
| | predefined_HW_Scales[(32, 32)] = [(x, x) for x in range(1, 16+1)] |
| | predefined_HW_Scales[(64, 64)] = predefined_HW_Scales[(32, 32)] + [(20, 20), (24, 24), (28, 28), (32, 32)] |
| | predefined_HW_Scales[(96, 96)] = predefined_HW_Scales[(64, 64)] + [(40, 40), (48, 48)] |
| | predefined_HW_Scales[(128, 128)] = predefined_HW_Scales[(64, 64)] + [(40, 40), (48, 48), (56, 56), (64, 64)] |
| |
|
| | predefined_HW_Scales[(24, 42)] = [(1, 1), (2, 2), (3, 3), (3, 4), (3, 5), (4, 6), (4, 7), (5, 8), (6, 9), (6, 10), (6, 11), (7, 12), (7, 13), (8, 14), (9, 15), (9, 16), (12, 21)] |
| | predefined_HW_Scales[(36, 64)] = predefined_HW_Scales[(24, 42)] + [(14, 26), (18, 32)] |
| | predefined_HW_Scales[(60, 108)] = predefined_HW_Scales[(36, 64)] + [(24, 42), (30, 54)] |
| | predefined_HW_Scales[(90, 160)] = predefined_HW_Scales[(60, 108)] + [(38, 66),(45, 80)] |
| |
|
| | for k, v in predefined_HW_Scales.items(): |
| | predefined_HW_Scales[k] = [(2*x, 2*y) for (x, y) in v] |
| | elif mode.startswith("same"): |
| | num_quant = int(mode[len("same"):]) |
| | predefined_HW_Scales[(16, 16)] = [(16, 16) for _ in range(num_quant)] |
| | predefined_HW_Scales[(32, 32)] = [(32, 32) for _ in range(num_quant)] |
| | predefined_HW_Scales[(64, 64)] = [(64, 64) for _ in range(num_quant)] |
| | elif mode == "half": |
| | predefined_HW_Scales[(32, 32)] = [(1, 1), (2, 2), (3, 3), (4, 4), (5, 5), (6, 6), (8, 8), (10, 10), (13, 13), (16, 16)] |
| | predefined_HW_Scales[(64, 64)] = [(1,1),(2,2),(4,4),(6,6),(8,8),(12,12),(16,16)] |
| | else: |
| | raise NotImplementedError |
| |
|
| | |
| | |
| | predefined_T_Scales = [1, 2, 3, 5, 7, 9, 11, 13, 15, 17, 19, 21, 23, 25, 27, 29] |
| | |
| | patch_THW_shape_per_scale = predefined_HW_Scales[(H, W)] |
| | if len(predefined_T_Scales) < len(patch_THW_shape_per_scale): |
| | |
| | predefined_T_Scales += [predefined_T_Scales[-1]] * (len(patch_THW_shape_per_scale) - len(predefined_T_Scales)) |
| | patch_THW_shape_per_scale = [(min(T, t), h, w ) for (h, w), t in zip(patch_THW_shape_per_scale, predefined_T_Scales[:len(patch_THW_shape_per_scale)])] |
| | return patch_THW_shape_per_scale |
| |
|
| | def interpolate(tensor, size, mode): |
| | """ |
| | arguments: |
| | tensor: (B,C,T,H,W) |
| | size: (C1,T,H1,W1) |
| | mode: str |
| | return: |
| | tensor: (B,*size) |
| | """ |
| | C1, T, H1, W1 = size |
| | tensor = tensor.permute(0,2,1,3,4) |
| | tensor = F.interpolate(tensor, size=(C1, H1, W1), mode=mode) |
| | tensor = tensor.permute(0,2,1,3,4) |
| | return tensor |
| |
|
| | |
| | class MultiScaleBSQTP(Module): |
| | """ Follows Algorithm 1. in https://arxiv.org/pdf/2107.03312.pdf """ |
| |
|
| | def __init__( |
| | self, |
| | *, |
| | dim, |
| | soft_clamp_input_value = None, |
| | aux_loss = False, |
| | use_stochastic_depth=False, |
| | drop_rate=0., |
| | schedule_mode="original", |
| | keep_first_quant=False, |
| | keep_last_quant=False, |
| | remove_residual_detach=False, |
| | random_flip = False, |
| | flip_prob = 0.5, |
| | flip_mode = "stochastic", |
| | max_flip_lvl = 1, |
| | random_flip_1lvl = False, |
| | flip_lvl_idx = None, |
| | drop_when_test=False, |
| | drop_lvl_idx=None, |
| | drop_lvl_num=0, |
| | random_short_schedule = False, |
| | short_schedule_prob = 0.5, |
| | disable_flip_prob = 0.0, |
| | casual_multi_scale = False, |
| | temporal_slicing = False, |
| | last_scale_repeat_n = 0, |
| | num_lvl_fsq = None, |
| | other_args = None, |
| | **kwargs |
| | ): |
| | super().__init__() |
| | codebook_dim = dim |
| | self.use_stochastic_depth = use_stochastic_depth |
| | self.drop_rate = drop_rate |
| | self.remove_residual_detach = remove_residual_detach |
| | self.random_flip = random_flip |
| | self.flip_prob = flip_prob |
| | self.flip_mode = flip_mode |
| | self.max_flip_lvl = max_flip_lvl |
| | self.random_flip_1lvl = random_flip_1lvl |
| | self.flip_lvl_idx = flip_lvl_idx |
| | assert (random_flip and random_flip_1lvl) == False |
| | self.disable_flip_prob = disable_flip_prob |
| | self.casual_multi_scale = casual_multi_scale |
| | self.temporal_slicing = temporal_slicing |
| | self.last_scale_repeat_n = last_scale_repeat_n |
| | |
| |
|
| | self.drop_when_test = drop_when_test |
| | self.drop_lvl_idx = drop_lvl_idx |
| | self.drop_lvl_num = drop_lvl_num |
| | if self.drop_when_test: |
| | assert drop_lvl_idx is not None |
| | assert drop_lvl_num > 0 |
| | self.random_short_schedule = random_short_schedule |
| | self.short_schedule_prob = short_schedule_prob |
| | self.z_interplote_up = 'trilinear' |
| | self.z_interplote_down = 'area' |
| | |
| | self.schedule_mode = schedule_mode |
| | self.keep_first_quant = keep_first_quant |
| | self.keep_last_quant = keep_last_quant |
| | if self.use_stochastic_depth and self.drop_rate > 0: |
| | assert self.keep_first_quant or self.keep_last_quant |
| |
|
| | self.full2short = {7:7, 10:7, 13:7, 16:16, 20:16, 24:16} |
| | if self.schedule_mode == 'dense_f8': |
| | self.full2short_f8 = {20:20, 24:24, 28:24} |
| | elif self.schedule_mode == 'dense_f8_double': |
| | self.full2short_f8 = {16: 14, 17: 14, 19: 14, 20:14, 21:14, 22:14, 24:14} |
| | elif self.schedule_mode.startswith("infinity_video_two_pyramid"): |
| | self.full2short_f8 = {11: 11, 13: 11, 14: 11, 16: 11, 29: 26, 28: 26, 26: 26} |
| |
|
| | self.other_args = other_args |
| | print(f'{self.other_args=}') |
| | self.origin_C = self.other_args.detail_scale_dim |
| | self.detail_scale_dim, self.semantic_scale_dim = self.other_args.detail_scale_dim, self.other_args.semantic_scale_dim |
| | self.semantic_scales = other_args.semantic_scales |
| |
|
| | if self.other_args.semantic_num_lvl == 2: |
| | self.lfq_semantic = BSQ( |
| | dim = self.semantic_scale_dim, |
| | codebook_scale = 1, |
| | soft_clamp_input_value = soft_clamp_input_value, |
| | **kwargs, |
| | ) |
| | else: |
| | assert self.other_args.semantic_num_lvl >= 2, f'{self.other_args.semantic_num_lvl=} is not supported' |
| | self.lfq_semantic = FSQ( |
| | num_lvl = self.other_args.semantic_num_lvl, |
| | dim = self.semantic_scale_dim, |
| | ) |
| | if self.other_args.detail_num_lvl == 2: |
| | self.lfq_detail = BSQ( |
| | dim = self.detail_scale_dim, |
| | codebook_scale = 1, |
| | soft_clamp_input_value = soft_clamp_input_value, |
| | **kwargs, |
| | ) |
| | else: |
| | assert self.other_args.detail_num_lvl >= 2, f'{self.other_args.semantic_num_lvl=} is not supported' |
| | self.lfq_detail = FSQ( |
| | num_lvl = self.other_args.detail_num_lvl, |
| | dim = self.detail_scale_dim, |
| | ) |
| |
|
| | @property |
| | def codebooks(self): |
| | return self.lfq_detail.codebook |
| |
|
| | def get_codes_from_indices(self, indices_list): |
| | all_codes = [] |
| | for indices in indices_list: |
| | |
| | if indices.shape[-1] == self.origin_C: |
| | codes = self.lfq.indices_to_codes(indices) |
| | elif indices.shape[-1] == self.semantic_scale_dim: |
| | codes = self.lfq_semantic.indices_to_codes(indices) |
| | else: |
| | raise NotImplementedError(f'indices shape {indices.shape} not supported') |
| | all_codes.append(codes) |
| | _, _, T, H, W = all_codes[-1].size() |
| | summed_codes = 0 |
| | for code in all_codes: |
| | summed_codes += F.interpolate(code, size=(T, H, W), mode=self.z_interplote_up) |
| | return summed_codes |
| |
|
| | def get_output_from_indices(self, indices): |
| | codes = self.get_codes_from_indices(indices) |
| | codes_summed = reduce(codes, 'q ... -> ...', 'sum') |
| | return codes_summed |
| |
|
| | def flip_quant(self, x): |
| | |
| | if self.flip_mode == 'stochastic': |
| | flip_mask = torch.rand_like(x) < self.flip_prob |
| | elif self.flip_mode == 'stochastic_dynamic': |
| | flip_prob = random.uniform(0, self.flip_prob) |
| | flip_mask = torch.rand_like(x) < flip_prob |
| | else: |
| | raise NotImplementedError |
| | x = x.clone() |
| | x[flip_mask] = -x[flip_mask] |
| | return x |
| |
|
| | def forward( |
| | self, |
| | x_list, |
| | mask = None, |
| | return_all_codes = False, |
| | ): |
| | assert len(x_list) <= 2 |
| | multi_scale = len(x_list) == 2 |
| | for i in range(len(x_list)): |
| | if x_list[i].ndim == 4: |
| | x_list[i] = x_list[i].unsqueeze(2) |
| | B, C, T, H, W = x_list[-1].size() |
| |
|
| | if self.schedule_mode.startswith("same"): |
| | scale_num = int(self.schedule_mode[len("same"):]) |
| | assert T == 1 |
| | scale_schedule = [(1, H, W)] * scale_num |
| | elif self.schedule_mode.startswith("infinity_video_two_pyramid"): |
| | scale_schedule, tower_split_index = get_latent2scale_schedule(T, H, W, mode=self.schedule_mode, last_scale_repeat_n=self.last_scale_repeat_n, args=self.other_args) |
| | scale_num = len(scale_schedule) |
| | else: |
| | scale_schedule = get_latent2scale_schedule(T, H, W, mode=self.schedule_mode, args=self.other_args) |
| | scale_num = len(scale_schedule) |
| | |
| | quantized_out = torch.zeros((B, C, 1, 1, 1), device=x_list[-1].device, dtype=x_list[-1].dtype) |
| | quantized_out_firstframe = None |
| |
|
| | all_losses = [] |
| | all_indices = [] |
| | |
| | |
| | unique_scale_schedule = [scale_schedule[0]] |
| | scale_in_one_clip = 1 |
| | for si in range(1, len(scale_schedule)): |
| | if np.array(scale_schedule[si]).prod() < np.array(scale_schedule[si-1]).prod(): |
| | break |
| | if scale_schedule[si] != scale_schedule[si-1]: |
| | unique_scale_schedule.append(scale_schedule[si]) |
| | scale_in_one_clip += 1 |
| | |
| | current_scale_in_one_clip = 0 |
| | must_preserve_scales = [] |
| | if self.other_args.quant_not_rely_256: |
| | must_preserve_scales = [11] |
| | with autocast('cuda', enabled = False): |
| | for si, (pt, ph, pw) in enumerate(scale_schedule): |
| | if si > 0 and scale_schedule[si] != scale_schedule[si-1]: |
| | current_scale_in_one_clip += 1 |
| | current_scale_in_one_clip = current_scale_in_one_clip % scale_in_one_clip |
| | |
| | last_step_in_one_scale = False |
| | if si < len(scale_schedule)-1 and (scale_schedule[si] != scale_schedule[si+1]): |
| | last_step_in_one_scale = True |
| | if si == len(scale_schedule)-1: |
| | last_step_in_one_scale = True |
| | |
| | if si < tower_split_index: |
| | ss, ee = 0, 1 |
| | else: |
| | ss, ee = 1, T |
| | |
| | if multi_scale and current_scale_in_one_clip < self.other_args.scales_256: |
| | target = x_list[0][:,:,ss:ee] |
| | else: |
| | target = x_list[-1][:,:,ss:ee] |
| | tgt_shape = target.shape[-4:] |
| |
|
| | skip_this_scale = False |
| | if current_scale_in_one_clip < self.semantic_scales: |
| | C1 = self.semantic_scale_dim |
| | lfq = self.lfq_semantic |
| | else: |
| | C1 = self.detail_scale_dim |
| | lfq = self.lfq_detail |
| | if current_scale_in_one_clip not in must_preserve_scales: |
| | skip_this_scale = random.random() < self.other_args.skip_detail_scales_prob |
| |
|
| | if not skip_this_scale: |
| | quantized_out = interpolate(quantized_out, size=tgt_shape, mode=self.z_interplote_up) |
| | interpolate_residual = interpolate(target-quantized_out, size=(C1, pt, ph, pw), mode=self.z_interplote_down) |
| | quantized, indices, loss = lfq(interpolate_residual) |
| | quantized = interpolate(quantized, size=tgt_shape, mode=self.z_interplote_up) |
| | all_indices.append(indices) |
| | all_losses.append(loss) |
| | quantized_out = quantized_out + quantized |
| | |
| | if si == tower_split_index - 1: |
| | quantized_out_firstframe = quantized_out.clone() |
| | quantized_out = quantized_out * 0. |
| |
|
| | if multi_scale and si < tower_split_index and last_step_in_one_scale and current_scale_in_one_clip == self.other_args.scales_256-1: |
| | quantized_out_firstframe_256 = quantized_out.clone() |
| | if self.other_args.quant_not_rely_256: |
| | quantized_out = quantized_out * 0. |
| | if multi_scale and si >= tower_split_index and last_step_in_one_scale and current_scale_in_one_clip == self.other_args.scales_256-1: |
| | quantized_out_256 = quantized_out.clone() |
| | if self.other_args.quant_not_rely_256: |
| | quantized_out = quantized_out * 0. |
| | |
| | quantized_out_list = [] |
| | if T == 1: |
| | if multi_scale: |
| | quantized_out_list.append(quantized_out_firstframe_256) |
| | quantized_out_list.append(quantized_out_firstframe) |
| | else: |
| | quantized_out_list.append(quantized_out_firstframe) |
| | else: |
| | if multi_scale: |
| | quantized_out_256 = torch.cat([quantized_out_firstframe_256, quantized_out_256], dim=2) |
| | quantized_out_list.append(quantized_out_256) |
| | quantized_out = torch.cat([quantized_out_firstframe, quantized_out], dim=2) |
| | quantized_out_list.append(quantized_out) |
| | else: |
| | quantized_out = torch.cat([quantized_out_firstframe, quantized_out], dim=2) |
| | quantized_out_list.append(quantized_out) |
| |
|
| | all_losses = torch.stack(all_losses, dim = -1) |
| |
|
| | ret = (quantized_out_list, all_indices, all_losses) |
| |
|
| | if not return_all_codes: |
| | return ret |
| |
|
| | |
| | all_codes = self.get_codes_from_indices(all_indices) |
| |
|
| | |
| |
|
| | return (*ret, all_codes) |
| |
|
| |
|
| | class BSQ(Module): |
| | def __init__( |
| | self, |
| | *, |
| | dim = None, |
| | entropy_loss_weight = 0.1, |
| | commitment_loss_weight = 0.25, |
| | num_codebooks = 1, |
| | keep_num_codebooks_dim = None, |
| | codebook_scale = 1., |
| | frac_per_sample_entropy = 1., |
| | soft_clamp_input_value = None, |
| | channel_first = None, |
| | experimental_softplus_entropy_loss = False, |
| | entropy_loss_offset = 5., |
| | spherical = True, |
| | force_quantization_f32 = True, |
| | inv_temperature = 100.0, |
| | gamma0=1.0, gamma=1.0, zeta=1.0, |
| | use_out_phi = False, |
| | use_out_phi_res = False, |
| | use_bernoulli = False, |
| | use_rot_trick = False, |
| | ): |
| | super().__init__() |
| |
|
| | |
| | assert exists(dim) , 'dim must be specified for BSQ' |
| |
|
| | codebook_dim = dim |
| | codebook_dims = codebook_dim * num_codebooks |
| | dim = default(dim, codebook_dims) |
| | self.codebook_dims = codebook_dims |
| |
|
| | self.out_phi = nn.Linear(codebook_dims, codebook_dims) if use_out_phi else nn.Identity() |
| | self.use_out_phi_res = use_out_phi_res |
| | if self.use_out_phi_res: |
| | self.out_phi_scale = nn.Parameter(torch.zeros(codebook_dims), requires_grad=True) |
| |
|
| | self.dim = dim |
| | self.codebook_dim = codebook_dim |
| | self.num_codebooks = num_codebooks |
| |
|
| | keep_num_codebooks_dim = default(keep_num_codebooks_dim, num_codebooks > 1) |
| | assert not (num_codebooks > 1 and not keep_num_codebooks_dim) |
| | self.keep_num_codebooks_dim = keep_num_codebooks_dim |
| |
|
| | |
| | self.channel_first = channel_first |
| |
|
| | |
| | if not spherical: |
| | raise ValueError("For BSQ, spherical must be True.") |
| | self.persample_entropy_compute = 'analytical' |
| | self.inv_temperature = inv_temperature |
| | self.gamma0 = gamma0 |
| | self.gamma = gamma |
| | self.zeta = zeta |
| | self.use_bernoulli = use_bernoulli |
| | self.use_rot_trick = use_rot_trick |
| |
|
| | |
| |
|
| | assert 0 < frac_per_sample_entropy <= 1. |
| | self.frac_per_sample_entropy = frac_per_sample_entropy |
| |
|
| | self.entropy_loss_weight = entropy_loss_weight |
| |
|
| | |
| |
|
| | self.codebook_scale = codebook_scale |
| |
|
| | |
| |
|
| | self.commitment_loss_weight = commitment_loss_weight |
| |
|
| | |
| |
|
| | self.soft_clamp_input_value = soft_clamp_input_value |
| | assert not exists(soft_clamp_input_value) or soft_clamp_input_value >= codebook_scale |
| |
|
| | |
| |
|
| | self.entropy_loss_offset = entropy_loss_offset |
| | self.experimental_softplus_entropy_loss = experimental_softplus_entropy_loss |
| |
|
| | |
| |
|
| | self.register_buffer('mask', 2 ** torch.arange(codebook_dim - 1, -1, -1)) |
| | self.register_buffer('zero', torch.tensor(0.), persistent = False) |
| |
|
| | |
| |
|
| | self.force_quantization_f32 = force_quantization_f32 |
| |
|
| | def bits_to_codes(self, bits): |
| | return bits * self.codebook_scale * 2 - self.codebook_scale |
| |
|
| | |
| | |
| | |
| |
|
| | def indices_to_codes( |
| | self, |
| | indices, |
| | project_out = True |
| | ): |
| | is_img_or_video = indices.ndim >= (3 + int(self.keep_num_codebooks_dim)) |
| | should_transpose = default(self.channel_first, is_img_or_video) |
| |
|
| | if not self.keep_num_codebooks_dim: |
| | indices = rearrange(indices, '... -> ... 1') |
| |
|
| | |
| |
|
| | bits = ((indices[..., None].int() & self.mask) != 0).to(self.dtype) |
| |
|
| | codes = self.bits_to_codes(bits) |
| |
|
| | codes = l2norm(codes) |
| |
|
| | codes = rearrange(codes, '... c d -> ... (c d)') |
| |
|
| | |
| | |
| |
|
| | |
| |
|
| | if should_transpose: |
| | codes = rearrange(codes, 'b ... d -> b d ...') |
| |
|
| | return codes |
| |
|
| | def quantize(self, z): |
| | assert z.shape[-1] == self.codebook_dims, f"Expected {self.codebook_dims} dimensions, got {z.shape[-1]}" |
| |
|
| | zhat = torch.where(z > 0, |
| | torch.tensor(1, dtype=z.dtype, device=z.device), |
| | torch.tensor(-1, dtype=z.dtype, device=z.device)) |
| |
|
| | q_scale = 1. / (self.codebook_dims ** 0.5) |
| | zhat = q_scale * zhat |
| |
|
| | return z + (zhat - z).detach() |
| |
|
| | def quantize_new_bernoulli(self, z, prob_z): |
| | assert z.shape[-1] == self.codebook_dims, f"Expected {self.codebook_dims} dimensions, got {z.shape[-1]}" |
| |
|
| | zhat = (torch.bernoulli(prob_z) - 0.5) * 2.0 |
| |
|
| | q_scale = 1. / (self.codebook_dims ** 0.5) |
| | zhat = q_scale * zhat |
| |
|
| | return z + (zhat - z).detach() |
| |
|
| | def rot_quantize(self, z, inference=False): |
| | assert z.shape[-1] == self.codebook_dims, f"Expected {self.codebook_dims} dimensions, got {z.shape[-1]}" |
| | q_scale = 1. / (self.codebook_dims ** 0.5) |
| | zhat = torch.where(z > 0, |
| | torch.tensor(1, dtype=z.dtype, device=z.device), |
| | torch.tensor(-1, dtype=z.dtype, device=z.device)) * q_scale |
| | if inference: |
| | return zhat |
| |
|
| | w = ((z + zhat) / torch.norm(z + zhat, dim=-1, keepdim=True)).detach() |
| | z = z.unsqueeze(1) - 2*torch.bmm(torch.bmm(z.unsqueeze(1), w.unsqueeze(-1)), w.unsqueeze(1)) + 2 * torch.bmm( |
| | torch.bmm(z.unsqueeze(1), z.unsqueeze(-1).detach()), zhat.unsqueeze(1).detach()) |
| | return z.squeeze() |
| |
|
| | def soft_entropy_loss(self, z): |
| | if self.persample_entropy_compute == 'analytical': |
| | |
| | p = torch.sigmoid(-4 * z / (self.codebook_dims ** 0.5) * self.inv_temperature) |
| | |
| | |
| | prob = torch.stack([p, 1-p], dim=-1) |
| | per_sample_entropy = self.get_entropy(prob, dim=-1, normalize=False).sum(dim=-1).mean() |
| | else: |
| | per_sample_entropy = self.get_entropy(prob, dim=-1, normalize=False).sum(dim=-1).mean() |
| |
|
| | |
| | avg_prob = reduce(prob, '... g d ->g d', 'mean') |
| | codebook_entropy = self.get_entropy(avg_prob, dim=-1, normalize=False) |
| |
|
| | |
| | return per_sample_entropy, codebook_entropy.sum(), avg_prob |
| |
|
| | def get_entropy(self, count, dim=-1, eps=1e-4, normalize=True): |
| | if normalize: |
| | probs = (count + eps) / (count + eps).sum(dim=dim, keepdim =True) |
| | else: |
| | probs = count |
| | H = -(probs * torch.log(probs + 1e-8)).sum(dim=dim) |
| | return H |
| |
|
| | def forward( |
| | self, |
| | x, |
| | return_loss_breakdown = False, |
| | mask = None, |
| | entropy_weight=0.1 |
| | ): |
| | """ |
| | einstein notation |
| | b - batch |
| | n - sequence (or flattened spatial dimensions) |
| | d - feature dimension, which is also log2(codebook size) |
| | c - number of codebook dim |
| | """ |
| |
|
| | is_img_or_video = x.ndim >= 4 |
| | should_transpose = default(self.channel_first, is_img_or_video) |
| |
|
| | |
| |
|
| | if should_transpose: |
| | x = rearrange(x, 'b d ... -> b ... d') |
| | x, ps = pack_one(x, 'b * d') |
| |
|
| | assert x.shape[-1] == self.dim, f'expected dimension of {self.dim} but received {x.shape[-1]}' |
| |
|
| | |
| |
|
| | x = rearrange(x, 'b n (c d) -> b n c d', c = self.num_codebooks) |
| |
|
| | if self.use_bernoulli: |
| | prob_x = torch.sigmoid(x) |
| | |
| | x = l2norm(x) |
| |
|
| | |
| |
|
| | force_f32 = self.force_quantization_f32 |
| |
|
| | quantization_context = partial(autocast, 'cuda', enabled = False) if force_f32 else nullcontext |
| |
|
| | with quantization_context(): |
| |
|
| | if force_f32: |
| | orig_dtype = x.dtype |
| | x = x.float() |
| | |
| | |
| | if self.use_rot_trick: |
| | x_f = x.flatten(end_dim=-2) |
| | q_f = self.rot_quantize(x_f, inference= not self.training) |
| | quantized = q_f.reshape(x.shape) |
| | elif self.use_bernoulli: |
| | quantized = self.quantize_new_bernoulli(x, prob_x) |
| | else: |
| | quantized = self.quantize(x) |
| |
|
| | |
| | indices = reduce((quantized > 0).int() * self.mask.int(), 'b n c d -> b n c', 'sum') |
| |
|
| | |
| | if self.training: |
| | persample_entropy, cb_entropy, avg_prob = self.soft_entropy_loss(x) |
| | entropy_penalty = self.gamma0 * persample_entropy - self.gamma * cb_entropy |
| | else: |
| | |
| | entropy_penalty = persample_entropy = cb_entropy = self.zero |
| |
|
| | |
| |
|
| | if self.training and self.commitment_loss_weight > 0.: |
| |
|
| | commit_loss = F.mse_loss(x, quantized.detach(), reduction = 'none') |
| |
|
| | if exists(mask): |
| | commit_loss = commit_loss[mask] |
| |
|
| | commit_loss = commit_loss.mean() |
| | else: |
| | commit_loss = self.zero |
| |
|
| | |
| |
|
| | if force_f32: |
| | x = x.type(orig_dtype) |
| |
|
| | |
| | x = quantized |
| | |
| | if self.use_out_phi_res: |
| | x = x + self.out_phi_scale * self.out_phi(x) |
| | else: |
| | x = self.out_phi(x) |
| | |
| | x = rearrange(x, 'b n c d -> b n (c d)') |
| |
|
| | |
| |
|
| | if should_transpose: |
| | x = unpack_one(x, ps, 'b * d') |
| | x = rearrange(x, 'b ... d -> b d ...') |
| |
|
| | indices = unpack_one(indices, ps, 'b * c') |
| |
|
| | |
| |
|
| | if not self.keep_num_codebooks_dim: |
| | indices = rearrange(indices, '... 1 -> ...') |
| |
|
| | |
| |
|
| | aux_loss = commit_loss * self.commitment_loss_weight + (self.zeta * entropy_penalty / self.inv_temperature)*entropy_weight |
| | |
| |
|
| | ret = Return(x, indices, aux_loss) |
| |
|
| | if not return_loss_breakdown: |
| | return ret |
| |
|
| | return ret, LossBreakdown(persample_entropy, cb_entropy, commit_loss) |
| |
|
| |
|
| | |
| | |
| |
|
| |
|