import math import time import torch import torch.nn as nn import torch.nn.functional as F from compressai.entropy_models import EntropyBottleneck, GaussianConditional from .RSTB import RSTB, CausalAttentionModule from compressai.ans import BufferedRansEncoder, RansDecoder from timm.models.layers import trunc_normal_ from compressai.models.utils import conv, deconv, update_registered_buffers from compressai.layers import AttentionBlock from PIL import Image import numpy as np import matplotlib import matplotlib.pyplot as plt from matplotlib.colors import ListedColormap # from lseg.lseg_net import LSegNet # import cv2 # import random import itertools # device = "cuda" if torch.cuda.is_available() else "cpu" # From Balle's tensorflow compression examples SCALES_MIN = 0.11 SCALES_MAX = 256 SCALES_LEVELS = 64 device = "cuda" def get_scale_table(min=SCALES_MIN, max=SCALES_MAX, levels=SCALES_LEVELS): return torch.exp(torch.linspace(math.log(min), math.log(max), levels)) class Binarizer(torch.autograd.Function): """ An elementwise function that bins values to 0 or 1 depending on a threshold of 0.5 Input: a tensor with values in range(0,1) Returns: a tensor with binary values: 0 or 1 based on a threshold of 0.5 Equation(1) in paper """ @staticmethod def forward(ctx, i): result = torch.where(i > 0.9, torch.tensor(1.0), torch.tensor(0.2)) return result @staticmethod def backward(ctx, grad_output): return grad_output def bin_values(x): return Binarizer.apply(x) class TIC(nn.Module): """Neural image compression framework from Lu Ming and Guo, Peiyao and Shi, Huiqing and Cao, Chuntong and Ma, Zhan: `"Transformer-based Image Compression" `, (DCC 2022). Args: N (int): Number of channels M (int): Number of channels in the expansion layers (last layer of the encoder and last layer of the hyperprior decoder) input_resolution (int): Just used for window partition decision """ def __init__(self, N=192, M=192): super().__init__() depths = [1, 2, 3, 1, 1] num_heads = [4, 8, 16, 16, 16] window_size = 8 mlp_ratio = 4. qkv_bias = True qk_scale = None drop_rate = 0. attn_drop_rate = 0. drop_path_rate = 0.2 norm_layer = nn.LayerNorm use_checkpoint = False # stochastic depth dpr = [x.item() for x in torch.linspace(0, drop_path_rate, sum(depths))] self.align_corners = True self.g_a0 = conv(3, N, kernel_size=5, stride=2) self.g_a1 = RSTB(dim=N, input_resolution=(128, 128), depth=depths[0], num_heads=num_heads[0], window_size=window_size, mlp_ratio=mlp_ratio, qkv_bias=qkv_bias, qk_scale=qk_scale, drop=drop_rate, attn_drop=attn_drop_rate, drop_path=dpr[sum(depths[:0]):sum(depths[:1])], norm_layer=norm_layer, use_checkpoint=use_checkpoint, ) self.g_a2 = conv(N, N, kernel_size=3, stride=2) self.g_a3 = RSTB(dim=N, input_resolution=(64, 64), depth=depths[1], num_heads=num_heads[1], window_size=window_size, mlp_ratio=mlp_ratio, qkv_bias=qkv_bias, qk_scale=qk_scale, drop=drop_rate, attn_drop=attn_drop_rate, drop_path=dpr[sum(depths[:1]):sum(depths[:2])], norm_layer=norm_layer, use_checkpoint=use_checkpoint, ) self.g_a4 = conv(N, N, kernel_size=3, stride=2) self.g_a5 = RSTB(dim=N, input_resolution=(32, 32), depth=depths[2], num_heads=num_heads[2], window_size=window_size, mlp_ratio=mlp_ratio, qkv_bias=qkv_bias, qk_scale=qk_scale, drop=drop_rate, attn_drop=attn_drop_rate, drop_path=dpr[sum(depths[:2]):sum(depths[:3])], norm_layer=norm_layer, use_checkpoint=use_checkpoint, ) self.g_a6 = conv(N, M, kernel_size=3, stride=2) self.h_a0 = conv(M, N, kernel_size=3, stride=1) self.h_a1 = RSTB(dim=N, input_resolution=(16, 16), depth=depths[3], num_heads=num_heads[3], window_size=window_size // 2, mlp_ratio=mlp_ratio, qkv_bias=qkv_bias, qk_scale=qk_scale, drop=drop_rate, attn_drop=attn_drop_rate, drop_path=dpr[sum(depths[:3]):sum(depths[:4])], norm_layer=norm_layer, use_checkpoint=use_checkpoint, ) self.h_a2 = conv(N, N, kernel_size=3, stride=2) self.h_a3 = RSTB(dim=N, input_resolution=(8, 8), depth=depths[4], num_heads=num_heads[4], window_size=window_size // 2, mlp_ratio=mlp_ratio, qkv_bias=qkv_bias, qk_scale=qk_scale, drop=drop_rate, attn_drop=attn_drop_rate, drop_path=dpr[sum(depths[:4]):sum(depths[:5])], norm_layer=norm_layer, use_checkpoint=use_checkpoint, ) self.h_a4 = conv(N, N, kernel_size=3, stride=2) depths = depths[::-1] num_heads = num_heads[::-1] self.h_s0 = deconv(N, N, kernel_size=3, stride=2) self.h_s1 = RSTB(dim=N, input_resolution=(8, 8), depth=depths[0], num_heads=num_heads[0], window_size=window_size // 2, mlp_ratio=mlp_ratio, qkv_bias=qkv_bias, qk_scale=qk_scale, drop=drop_rate, attn_drop=attn_drop_rate, drop_path=dpr[sum(depths[:0]):sum(depths[:1])], norm_layer=norm_layer, use_checkpoint=use_checkpoint, ) self.h_s2 = deconv(N, N, kernel_size=3, stride=2) self.h_s3 = RSTB(dim=N, input_resolution=(16, 16), depth=depths[1], num_heads=num_heads[1], window_size=window_size // 2, mlp_ratio=mlp_ratio, qkv_bias=qkv_bias, qk_scale=qk_scale, drop=drop_rate, attn_drop=attn_drop_rate, drop_path=dpr[sum(depths[:1]):sum(depths[:2])], norm_layer=norm_layer, use_checkpoint=use_checkpoint, ) self.h_s4 = conv(N, M * 2, kernel_size=3, stride=1) self.g_s0 = deconv(M, N, kernel_size=3, stride=2) self.g_s1 = RSTB(dim=N, input_resolution=(32, 32), depth=depths[2], num_heads=num_heads[2], window_size=window_size, mlp_ratio=mlp_ratio, qkv_bias=qkv_bias, qk_scale=qk_scale, drop=drop_rate, attn_drop=attn_drop_rate, drop_path=dpr[sum(depths[:2]):sum(depths[:3])], norm_layer=norm_layer, use_checkpoint=use_checkpoint, ) self.g_s2 = deconv(N, N, kernel_size=3, stride=2) self.g_s3 = RSTB(dim=N, input_resolution=(64, 64), depth=depths[3], num_heads=num_heads[3], window_size=window_size, mlp_ratio=mlp_ratio, qkv_bias=qkv_bias, qk_scale=qk_scale, drop=drop_rate, attn_drop=attn_drop_rate, drop_path=dpr[sum(depths[:3]):sum(depths[:4])], norm_layer=norm_layer, use_checkpoint=use_checkpoint, ) self.g_s4 = deconv(N, N, kernel_size=3, stride=2) self.g_s5 = RSTB(dim=N, input_resolution=(128, 128), depth=depths[4], num_heads=num_heads[4], window_size=window_size, mlp_ratio=mlp_ratio, qkv_bias=qkv_bias, qk_scale=qk_scale, drop=drop_rate, attn_drop=attn_drop_rate, drop_path=dpr[sum(depths[:4]):sum(depths[:5])], norm_layer=norm_layer, use_checkpoint=use_checkpoint, ) self.g_s6 = deconv(N, 3, kernel_size=5, stride=2) self.entropy_bottleneck = EntropyBottleneck(N) self.gaussian_conditional = GaussianConditional(None) self.context_prediction = CausalAttentionModule(M, M * 2) # self.attetionmap = AttentionBlock(M) self.entropy_parameters = nn.Sequential( nn.Conv2d(M * 12 // 3, M * 10 // 3, 1), nn.GELU(), nn.Conv2d(M * 10 // 3, M * 8 // 3, 1), nn.GELU(), nn.Conv2d(M * 8 // 3, M * 6 // 3, 1), ) self.sub_net_leaky = nn.Sequential( conv(N,N,kernel_size=3,stride=2), nn.LeakyReLU() ) self.sub_net0 = nn.Sequential( conv(N,64,kernel_size=1,stride=1), nn.ReLU() ) self.sub_net1 = nn.Sequential( conv(64,64,kernel_size=3,stride=1), nn.ReLU() ) self.sub_net2 = conv(64,N,kernel_size=1,stride=1) self.sub_net_channel = conv(N,M,kernel_size=1,stride=1) self.simi_net = nn.Sequential( conv(1,64,kernel_size=3,stride=2), nn.ReLU(), conv(64,128,kernel_size=3,stride=2), nn.ReLU(), conv(128, M, kernel_size=3, stride=2), ) # self.net_lseg = LSegNet( # backbone="clip_vitl16_384", # features=256, # crop_size=256, # arch_option=0, # block_depth=0, # activation="lrelu", # ) self.cosine_similarity = torch.nn.CosineSimilarity(dim=1) # self.con3_3 = conv(192,192,kernel_size=3,stride=1) self.tanh = nn.Tanh() self.softsign = nn.Softsign() self.relu = nn.ReLU() self.sigmoid = nn.Sigmoid() self.apply(self._init_weights) def g_a(self, x, x_size=None): if x_size is None: x_size = x.shape[2:4] x = self.g_a0(x) x = self.g_a1(x, (x_size[0] // 2, x_size[1] // 2)) x = self.g_a2(x) x = self.g_a3(x, (x_size[0] // 4, x_size[1] // 4)) x = self.g_a4(x) x = self.g_a5(x, (x_size[0] // 8, x_size[1] // 8)) # x = self.g_a6(x) return x def g_s(self, x, x_size=None): if x_size is None: x_size = (x.shape[2] * 16, x.shape[3] * 16) x = self.g_s0(x) x = self.g_s1(x, (x_size[0] // 8, x_size[1] // 8)) x = self.g_s2(x) x = self.g_s3(x, (x_size[0] // 4, x_size[1] // 4)) x = self.g_s4(x) x = self.g_s5(x, (x_size[0] // 2, x_size[1] // 2)) x = self.g_s6(x) return x def h_a(self, x, x_size=None): if x_size is None: x_size = (x.shape[2] * 16, x.shape[3] * 16) x = self.h_a0(x) x = self.h_a1(x, (x_size[0] // 16, x_size[1] // 16)) x = self.h_a2(x) x = self.h_a3(x, (x_size[0] // 32, x_size[1] // 32)) x = self.h_a4(x) return x def h_s(self, x, x_size=None): if x_size is None: x_size = (x.shape[2] * 64, x.shape[3] * 64) x = self.h_s0(x) x = self.h_s1(x, (x_size[0] // 32, x_size[1] // 32)) x = self.h_s2(x) x = self.h_s3(x, (x_size[0] // 16, x_size[1] // 16)) x = self.h_s4(x) return x def sub_impor_net(self,x): # important map x1 = self.sub_net_leaky(x) x2 = self.sub_net0(x1) x2 = self.sub_net1(x2) x2 = self.sub_net2(x2) x2 = x1 + x2 x3 = self.sub_net0(x2) x3 = self.sub_net1(x3) x3 = self.sub_net2(x3) x3 = x2 + x3 x4 = self.sub_net0(x3) x4 = self.sub_net1(x4) x4 = self.sub_net2(x4) x_out = x4 + x3 x_out = self.sub_net_channel(x_out) return x_out def aux_loss(self): """Return the aggregated loss over the auxiliary entropy bottleneck module(s). """ aux_loss = sum( m.loss() for m in self.modules() if isinstance(m, EntropyBottleneck) ) return aux_loss def _init_weights(self, m): if isinstance(m, nn.Linear): trunc_normal_(m.weight, std=.02) if isinstance(m, nn.Linear) and m.bias is not None: nn.init.constant_(m.bias, 0) elif isinstance(m, nn.LayerNorm): nn.init.constant_(m.bias, 0) nn.init.constant_(m.weight, 1.0) @torch.jit.ignore def no_weight_decay_keywords(self): return {'relative_position_bias_table'} def forward(self, x, similarity): x_size = (x.shape[2], x.shape[3]) h, w = x.size(2), x.size(3) similarity_loss = torch.where(similarity > 0.85, torch.tensor(1.0), torch.tensor(0.01)) similarity_imp = torch.where(similarity > 0.85, torch.tensor(1.0), torch.tensor(0.01)) similarity_up = F.interpolate(similarity_loss, scale_factor=2, mode='bilinear') similarity_up_repeated = similarity_up.repeat(1, 3, 1, 1) similarities_channel = self.simi_net(similarity_imp) similarities_sigmoid = torch.sigmoid(similarities_channel) y_codec = self.g_a(x, x_size) # y y_codec_a6 = self.g_a6(y_codec) y_import = self.sub_impor_net(y_codec) y_tanh = self.tanh(y_import) y_soft = self.softsign(y_tanh) y_imp = y_soft + similarities_sigmoid y = y_codec_a6 * y_imp z = self.h_a(y, x_size) z_hat, z_likelihoods = self.entropy_bottleneck(z) params = self.h_s(z_hat, x_size) y_hat = self.gaussian_conditional.quantize( y, "noise" if self.training else "dequantize" ) ctx_params = self.context_prediction(y_hat) gaussian_params = self.entropy_parameters( torch.cat((params, ctx_params), dim=1) ) scales_hat, means_hat = gaussian_params.chunk(2, 1) _, y_likelihoods = self.gaussian_conditional(y, scales_hat, means=means_hat) x_hat = self.g_s(y_hat, x_size) return { "y_hat": y_hat, "y": y, "similarity":similarity_up_repeated, "x_hat": x_hat, "likelihoods": {"y": y_likelihoods, "z": z_likelihoods}, } def update(self, scale_table=None, force=False): """Updates the entropy bottleneck(s) CDF values. Needs to be called once after training to be able to later perform the evaluation with an actual entropy coder. Args: scale_table (bool): (default: None) force (bool): overwrite previous values (default: False) Returns: updated (bool): True if one of the EntropyBottlenecks was updated. """ if scale_table is None: scale_table = get_scale_table() self.gaussian_conditional.update_scale_table(scale_table, force=force) updated = False for m in self.children(): if not isinstance(m, EntropyBottleneck): continue rv = m.update(force=force) updated |= rv return updated def load_state_dict(self, state_dict, strict=True): # Dynamically update the entropy bottleneck buffers related to the CDFs update_registered_buffers( self.entropy_bottleneck, "entropy_bottleneck", ["_quantized_cdf", "_offset", "_cdf_length"], state_dict, ) update_registered_buffers( self.gaussian_conditional, "gaussian_conditional", ["_quantized_cdf", "_offset", "_cdf_length", "scale_table"], state_dict, ) super().load_state_dict(state_dict, strict=strict) @classmethod def from_state_dict(cls, state_dict): """Return a new model instance from `state_dict`.""" N = state_dict["g_a0.weight"].size(0) M = state_dict["g_a6.weight"].size(0) net = cls(N, M) net.load_state_dict(state_dict) return net # def compress(self, x,similarity): def compress(self, x): x = x.cuda() # similarity = similarity.to(device) x_size = (x.shape[2], x.shape[3]) # start_1 = time.time() # # img_feat = self.net_lseg.forward(x) # img_feat_norm = torch.nn.functional.normalize(img_feat, dim=1) # # # prompt = clip.tokenize(similarity).cuda() # text_feat = self.net_lseg.clip_pretrained.encode_text(prompt) # 1, 512 # text_feat_norm = torch.nn.functional.normalize(text_feat, dim=1) # # # similarity = self.cosine_similarity( # img_feat_norm, text_feat_norm.unsqueeze(-1).unsqueeze(-1) # ) # similarity = similarity.unsqueeze(0) # # torch.cuda.synchronize() # # inf_time = time.time() - start_1 # # print(inf_time) # #####在这里 start = time.time() # similarity_down_1 = torch.where(similarity > 0.9, torch.tensor(1.0), torch.tensor(1.0)) # similarities_repeated = self.simi_net(similarity_down_1) # similarities_repeated = torch.sigmoid(similarities_repeated) y_codec = self.g_a(x, x_size) # y # y_import = self.sub_impor_net(y_codec) # y_tanh = self.tanh(y_import) # # y_soft = self.softsign(y_tanh) y_codec_a6 = self.g_a6(y_codec) # y_imp = y_soft + similarities_repeated # 相似度* important map # y = y_codec_a6 * y_imp y = y_codec_a6 # y = y_imp * y_codec_a6 # y = self.sub_net_channel(y) # y = y_codec_a6 * similarities_repeated z = self.h_a(y) z_strings = self.entropy_bottleneck.compress(z) z_hat = self.entropy_bottleneck.decompress(z_strings, z.size()[-2:]) params = self.h_s(z_hat) s = 4 # scaling factor between z and y kernel_size = 5 # context prediction kernel size padding = (kernel_size - 1) // 2 y_height = z_hat.size(2) * s y_width = z_hat.size(3) * s y_hat = F.pad(y, (padding, padding, padding, padding)) # pylint: disable=protected-access cdf = self.gaussian_conditional._quantized_cdf.tolist() cdf_lengths = self.gaussian_conditional._cdf_length.reshape(-1).int().tolist() offsets = self.gaussian_conditional._offset.reshape(-1).int().tolist() # pylint: enable=protected-access # print(cdf, cdf_lengths, offsets) y_strings = [] for i in range(y.size(0)): encoder = BufferedRansEncoder() # Warning, this is slow... # TODO: profile the calls to the bindings... symbols_list = [] indexes_list = [] y_q_ = torch.zeros_like(y) indexes_ = torch.zeros_like(y) for h in range(y_height): for w in range(y_width): y_crop = y_hat[ i: i + 1, :, h: h + kernel_size, w: w + kernel_size ] ctx_p = self.context_prediction(y_crop) # 1x1 conv for the entropy parameters prediction network, so # we only keep the elements in the "center" p = params[i: i + 1, :, h: h + 1, w: w + 1] gaussian_params = self.entropy_parameters( torch.cat((p, ctx_p[i: i + 1, :, 2: 3, 2: 3]), dim=1) ) scales_hat, means_hat = gaussian_params.chunk(2, 1) indexes = self.gaussian_conditional.build_indexes(scales_hat) y_q = torch.round(y_crop - means_hat) y_hat[i, :, h + padding, w + padding] = (y_q + means_hat)[ i, :, padding, padding ] y_q_[i,:, h, w] = y_q[i, :, padding, padding] indexes_[i,:, h, w] = indexes[i, :,0,0] flag = np.array(np.zeros(y_q_.shape[1])) for idx in range(y_q_.shape[1]): if torch.sum(torch.abs(y_q_[:, idx, :, :])) > 0: # 全部大于0就设置标志位是1 flag[idx] = 1 y_q_ = y_q_[:,np.nonzero(flag),...].squeeze() indexes_ = indexes_[:,np.nonzero(flag),...].squeeze() for h in range(y_height): for w in range(y_width): # encoder.encode_with_indexes( # y_q_[:,np.nonzero(flag),h,w].squeeze().int().tolist(), # indexes_[:,np.nonzero(flag),h,w].squeeze().int().tolist(), cdf, cdf_lengths, offsets # ) symbols_list.extend(y_q_[:,h,w].int().tolist()) indexes_list.extend(indexes_[:,h,w].squeeze().int().tolist()) encoder.encode_with_indexes( symbols_list, indexes_list, cdf, cdf_lengths, offsets ) string = encoder.flush() y_strings.append(string) print(flag.sum()) torch.cuda.synchronize() # 确保 model2 真正跑完 t2 = time.time() - start # print(t2) return {"strings": [y_strings, z_strings], "shape": z.size()[-2:],"flag":flag} # return {"test":similarity} def compress_1(self, x,similarity): # def compress_1(self, x): x = x.cuda() x_size = (x.shape[2], x.shape[3]) similarity = similarity.cuda() # # # similarity_down_1 = torch.where(similarity == 0, torch.tensor(1e-4), torch.tensor(1.0)) # # # similarity_down_1 = torch.where(similarity > 0.9, torch.tensor(1.0), torch.tensor(1e-4)) # similarity_down_1 = F.interpolate(similarity_down_1, scale_factor=0.5, mode='bilinear') similarities_repeated = self.simi_net(similarity_down_1) similarities_repeated = torch.sigmoid(similarities_repeated) y_codec = self.g_a(x, x_size) # y y_import = self.sub_impor_net(y_codec) y_tanh = self.tanh(y_import) y_soft = self.softsign(y_tanh) # important2 # y_soft = self.sigmoid(y_soft) y_codec_a6 = self.g_a6(y_codec) # y_codec_a6 = self.attetionmap(y_codec_a6) y_imp = similarities_repeated + y_soft # 相似度* important map y = y_codec_a6 * y_imp # # y= y_codec_a6 * y_tanh # cmap = ListedColormap(['yellow']) # # similarity_image = torch.where(similarity > 0.9, torch.tensor(1.0), torch.tensor(0.1)) # similarity_image = F.interpolate(similarity_image, scale_factor=2, mode='bilinear') # abs = torch.abs(similarity_image) # mean = torch.mean(abs, axis=1, keepdims=True) # viz = mean.detach().cpu().numpy() # viz = viz[0] # viz = viz.squeeze() # plt.imshow(viz) # # # 保存图像 # plt.imsave('/mnt/disk10T/xfx/CLIP/bird.png', viz) z = self.h_a(y) z_strings = self.entropy_bottleneck.compress(z) z_hat = self.entropy_bottleneck.decompress(z_strings, z.size()[-2:]) params = self.h_s(z_hat) s = 4 # scaling factor between z and y kernel_size = 5 # context prediction kernel size padding = (kernel_size - 1) // 2 y_height = z_hat.size(2) * s y_width = z_hat.size(3) * s y_hat = F.pad(y, (padding, padding, padding, padding)) # pylint: disable=protected-access cdf = self.gaussian_conditional._quantized_cdf.tolist() cdf_lengths = self.gaussian_conditional._cdf_length.reshape(-1).int().tolist() offsets = self.gaussian_conditional._offset.reshape(-1).int().tolist() # pylint: enable=protected-access # print(cdf, cdf_lengths, offsets) y_strings = [] for i in range(y.size(0)): encoder = BufferedRansEncoder() # Warning, this is slow... # TODO: profile the calls to the bindings... symbols_list = [] indexes_list = [] for h in range(y_height): for w in range(y_width): y_crop = y_hat[ i: i + 1, :, h: h + kernel_size, w: w + kernel_size ] ctx_p = self.context_prediction(y_crop) # 1x1 conv for the entropy parameters prediction network, so # we only keep the elements in the "center" p = params[i: i + 1, :, h: h + 1, w: w + 1] gaussian_params = self.entropy_parameters( torch.cat((p, ctx_p[i: i + 1, :, 2: 3, 2: 3]), dim=1) ) scales_hat, means_hat = gaussian_params.chunk(2, 1) indexes = self.gaussian_conditional.build_indexes(scales_hat) y_q = torch.round(y_crop - means_hat) y_hat[i, :, h + padding, w + padding] = (y_q + means_hat)[ i, :, padding, padding ] symbols_list.extend(y_q[i, :, padding, padding].int().tolist()) indexes_list.extend(indexes[i, :].squeeze().int().tolist()) encoder.encode_with_indexes( symbols_list, indexes_list, cdf, cdf_lengths, offsets ) string = encoder.flush() y_strings.append(string) return {"strings": [y_strings, z_strings], "shape": z.size()[-2:]} def compress_2(self, x,similarity): # def compress_1(self, x): x = x.cuda() x_size = (x.shape[2], x.shape[3]) #####在这里 similarity_down_1 = torch.where(similarity > 0.9, torch.tensor(1.0), torch.tensor(0.1)) similarities_repeated = self.simi_net(similarity_down_1) similarities_repeated = torch.sigmoid(similarities_repeated) y_codec = self.g_a(x, x_size) # y y_import = self.sub_impor_net(y_codec) y_tanh = self.tanh(y_import) y_codec_a6 = self.g_a6(y_codec) y_imp = similarities_repeated + y_tanh # 相似度* important map y = y_codec_a6 * y_imp z = self.h_a(y) z_strings = self.entropy_bottleneck.compress(z) z_hat = self.entropy_bottleneck.decompress(z_strings, z.size()[-2:]) params = self.h_s(z_hat) s = 4 # scaling factor between z and y kernel_size = 5 # context prediction kernel size padding = (kernel_size - 1) // 2 y_height = z_hat.size(2) * s y_width = z_hat.size(3) * s y_hat = F.pad(y, (padding, padding, padding, padding)) # pylint: disable=protected-access cdf = self.gaussian_conditional._quantized_cdf.tolist() cdf_lengths = self.gaussian_conditional._cdf_length.reshape(-1).int().tolist() offsets = self.gaussian_conditional._offset.reshape(-1).int().tolist() # pylint: enable=protected-access # print(cdf, cdf_lengths, offsets) y_strings = [] for i in range(y.size(0)): encoder = BufferedRansEncoder() # Warning, this is slow... # TODO: profile the calls to the bindings... symbols_list = [] indexes_list = [] for h in range(y_height): for w in range(y_width): y_crop = y_hat[ i: i + 1, :, h: h + kernel_size, w: w + kernel_size ] ctx_p = self.context_prediction(y_crop) # 1x1 conv for the entropy parameters prediction network, so # we only keep the elements in the "center" p = params[i: i + 1, :, h: h + 1, w: w + 1] gaussian_params = self.entropy_parameters( torch.cat((p, ctx_p[i: i + 1, :, 2: 3, 2: 3]), dim=1) ) scales_hat, means_hat = gaussian_params.chunk(2, 1) indexes = self.gaussian_conditional.build_indexes(scales_hat) y_q = torch.round(y_crop - means_hat) y_hat[i, :, h + padding, w + padding] = (y_q + means_hat)[ i, :, padding, padding ] symbols_list.extend(y_q[i, :, padding, padding].int().tolist()) indexes_list.extend(indexes[i, :].squeeze().int().tolist()) encoder.encode_with_indexes( symbols_list, indexes_list, cdf, cdf_lengths, offsets ) string = encoder.flush() y_strings.append(string) return {"strings": [y_strings, z_strings], "shape": z.size()[-2:]} def decompress(self, strings, shape, flag): # def decompress(self, strings, shape): flag = np.nonzero(flag) assert isinstance(strings, list) and len(strings) == 2 # FIXME: we don't respect the default entropy coder and directly call the # range ANS decoder z_hat = self.entropy_bottleneck.decompress(strings[1], shape) params = self.h_s(z_hat) s = 4 # scaling factor between z and y kernel_size = 5 # context prediction kernel size padding = (kernel_size - 1) // 2 y_height = z_hat.size(2) * s y_width = z_hat.size(3) * s # initialize y_hat to zeros, and pad it so we can directly work with # sub-tensors of size (N, C, kernel size, kernel_size) y_hat = torch.zeros( (z_hat.size(0), 192, y_height + 2 * padding, y_width + 2 * padding), device=z_hat.device, ) decoder = RansDecoder() # pylint: disable=protected-access cdf = self.gaussian_conditional._quantized_cdf.tolist() cdf_lengths = self.gaussian_conditional._cdf_length.reshape(-1).int().tolist() offsets = self.gaussian_conditional._offset.reshape(-1).int().tolist() # Warning: this is slow due to the auto-regressive nature of the # decoding... See more recent publication where they use an # auto-regressive module on chunks of channels for faster decoding... for i, y_string in enumerate(strings[0]): decoder.set_stream(y_string) for h in range(y_height): for w in range(y_width): # only perform the 5x5 convolution on a cropped tensor # centered in (h, w) y_crop = y_hat[ i: i + 1, :, h: h + kernel_size, w: w + kernel_size ] ctx_p = self.context_prediction(y_crop) # 1x1 conv for the entropy parameters prediction network, so # we only keep the elements in the "center" p = params[i: i + 1, :, h: h + 1, w: w + 1] gaussian_params = self.entropy_parameters( torch.cat((p, ctx_p[i: i + 1, :, 2: 3, 2: 3]), dim=1) ) scales_hat, means_hat = gaussian_params.chunk(2, 1) indexes = self.gaussian_conditional.build_indexes(scales_hat) rv = decoder.decode_stream( indexes[i, flag].squeeze().int().tolist(), # indexes[i, :].squeeze().int().tolist(), cdf, cdf_lengths, offsets, ) # rv = torch.Tensor(rv).reshape(1, -1, 1, 1) rv = torch.Tensor(rv).reshape(1, -1, 1, 1) tmp = torch.zeros((1, 192, 1, 1)) tmp[:, flag, ...] = rv rv = self.gaussian_conditional._dequantize(tmp, means_hat) # rv = self.gaussian_conditional._dequantize(rv, means_hat) y_hat[ i, :, h + padding: h + padding + 1, w + padding: w + padding + 1, ] = rv y_hat = y_hat[:, :, padding:-padding, padding:-padding] # pylint: enable=protected-access x_hat = self.g_s(y_hat).clamp_(0, 1) return {"x_hat": x_hat,} def decompress_1(self, strings, shape): assert isinstance(strings, list) and len(strings) == 2 # FIXME: we don't respect the default entropy coder and directly call the # range ANS decoder z_hat = self.entropy_bottleneck.decompress(strings[1], shape) params = self.h_s(z_hat) s = 4 # scaling factor between z and y kernel_size = 5 # context prediction kernel size padding = (kernel_size - 1) // 2 y_height = z_hat.size(2) * s y_width = z_hat.size(3) * s # initialize y_hat to zeros, and pad it so we can directly work with # sub-tensors of size (N, C, kernel size, kernel_size) y_hat = torch.zeros( (z_hat.size(0), 192, y_height + 2 * padding, y_width + 2 * padding), device=z_hat.device, ) decoder = RansDecoder() # pylint: disable=protected-access cdf = self.gaussian_conditional._quantized_cdf.tolist() cdf_lengths = self.gaussian_conditional._cdf_length.reshape(-1).int().tolist() offsets = self.gaussian_conditional._offset.reshape(-1).int().tolist() # Warning: this is slow due to the auto-regressive nature of the # decoding... See more recent publication where they use an # auto-regressive module on chunks of channels for faster decoding... for i, y_string in enumerate(strings[0]): decoder.set_stream(y_string) for h in range(y_height): for w in range(y_width): # only perform the 5x5 convolution on a cropped tensor # centered in (h, w) y_crop = y_hat[ i: i + 1, :, h: h + kernel_size, w: w + kernel_size ] ctx_p = self.context_prediction(y_crop) # 1x1 conv for the entropy parameters prediction network, so # we only keep the elements in the "center" p = params[i: i + 1, :, h: h + 1, w: w + 1] gaussian_params = self.entropy_parameters( torch.cat((p, ctx_p[i: i + 1, :, 2: 3, 2: 3]), dim=1) ) scales_hat, means_hat = gaussian_params.chunk(2, 1) indexes = self.gaussian_conditional.build_indexes(scales_hat) rv = decoder.decode_stream( indexes[i, :].squeeze().int().tolist(), cdf, cdf_lengths, offsets, ) rv = torch.Tensor(rv).reshape(1, -1, 1, 1) rv = self.gaussian_conditional._dequantize(rv, means_hat) y_hat[ i, :, h + padding: h + padding + 1, w + padding: w + padding + 1, ] = rv y_hat = y_hat[:, :, padding:-padding, padding:-padding] # pylint: enable=protected-access x_hat = self.g_s(y_hat).clamp_(0, 1) return {"x_hat": x_hat}