|
|
import math
|
|
|
import time
|
|
|
|
|
|
import torch
|
|
|
import torch.nn as nn
|
|
|
import torch.nn.functional as F
|
|
|
|
|
|
from compressai.entropy_models import EntropyBottleneck, GaussianConditional
|
|
|
from .RSTB import RSTB, CausalAttentionModule
|
|
|
from compressai.ans import BufferedRansEncoder, RansDecoder
|
|
|
from timm.models.layers import trunc_normal_
|
|
|
from compressai.models.utils import conv, deconv, update_registered_buffers
|
|
|
from compressai.layers import AttentionBlock
|
|
|
from PIL import Image
|
|
|
import numpy as np
|
|
|
import matplotlib
|
|
|
import matplotlib.pyplot as plt
|
|
|
from matplotlib.colors import ListedColormap
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
import itertools
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
SCALES_MIN = 0.11
|
|
|
SCALES_MAX = 256
|
|
|
SCALES_LEVELS = 64
|
|
|
|
|
|
device = "cuda"
|
|
|
|
|
|
def get_scale_table(min=SCALES_MIN, max=SCALES_MAX, levels=SCALES_LEVELS):
|
|
|
return torch.exp(torch.linspace(math.log(min), math.log(max), levels))
|
|
|
|
|
|
|
|
|
class Binarizer(torch.autograd.Function):
|
|
|
"""
|
|
|
An elementwise function that bins values
|
|
|
to 0 or 1 depending on a threshold of
|
|
|
0.5
|
|
|
|
|
|
Input: a tensor with values in range(0,1)
|
|
|
|
|
|
Returns: a tensor with binary values: 0 or 1
|
|
|
based on a threshold of 0.5
|
|
|
|
|
|
Equation(1) in paper
|
|
|
"""
|
|
|
@staticmethod
|
|
|
def forward(ctx, i):
|
|
|
result = torch.where(i > 0.9, torch.tensor(1.0), torch.tensor(0.2))
|
|
|
return result
|
|
|
|
|
|
@staticmethod
|
|
|
def backward(ctx, grad_output):
|
|
|
return grad_output
|
|
|
|
|
|
def bin_values(x):
|
|
|
return Binarizer.apply(x)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class TIC(nn.Module):
|
|
|
"""Neural image compression framework from
|
|
|
Lu Ming and Guo, Peiyao and Shi, Huiqing and Cao, Chuntong and Ma, Zhan:
|
|
|
`"Transformer-based Image Compression" <https://arxiv.org/abs/2111.06707>`, (DCC 2022).
|
|
|
|
|
|
Args:
|
|
|
N (int): Number of channels
|
|
|
M (int): Number of channels in the expansion layers (last layer of the
|
|
|
encoder and last layer of the hyperprior decoder)
|
|
|
input_resolution (int): Just used for window partition decision
|
|
|
"""
|
|
|
|
|
|
def __init__(self, N=192, M=192):
|
|
|
super().__init__()
|
|
|
|
|
|
depths = [1, 2, 3, 1, 1]
|
|
|
num_heads = [4, 8, 16, 16, 16]
|
|
|
window_size = 8
|
|
|
mlp_ratio = 4.
|
|
|
qkv_bias = True
|
|
|
qk_scale = None
|
|
|
drop_rate = 0.
|
|
|
attn_drop_rate = 0.
|
|
|
drop_path_rate = 0.2
|
|
|
norm_layer = nn.LayerNorm
|
|
|
use_checkpoint = False
|
|
|
|
|
|
|
|
|
dpr = [x.item() for x in torch.linspace(0, drop_path_rate, sum(depths))]
|
|
|
self.align_corners = True
|
|
|
|
|
|
|
|
|
self.g_a0 = conv(3, N, kernel_size=5, stride=2)
|
|
|
self.g_a1 = RSTB(dim=N,
|
|
|
input_resolution=(128, 128),
|
|
|
depth=depths[0],
|
|
|
num_heads=num_heads[0],
|
|
|
window_size=window_size,
|
|
|
mlp_ratio=mlp_ratio,
|
|
|
qkv_bias=qkv_bias, qk_scale=qk_scale,
|
|
|
drop=drop_rate, attn_drop=attn_drop_rate,
|
|
|
drop_path=dpr[sum(depths[:0]):sum(depths[:1])],
|
|
|
norm_layer=norm_layer,
|
|
|
use_checkpoint=use_checkpoint,
|
|
|
)
|
|
|
self.g_a2 = conv(N, N, kernel_size=3, stride=2)
|
|
|
self.g_a3 = RSTB(dim=N,
|
|
|
input_resolution=(64, 64),
|
|
|
depth=depths[1],
|
|
|
num_heads=num_heads[1],
|
|
|
window_size=window_size,
|
|
|
mlp_ratio=mlp_ratio,
|
|
|
qkv_bias=qkv_bias, qk_scale=qk_scale,
|
|
|
drop=drop_rate, attn_drop=attn_drop_rate,
|
|
|
drop_path=dpr[sum(depths[:1]):sum(depths[:2])],
|
|
|
norm_layer=norm_layer,
|
|
|
use_checkpoint=use_checkpoint,
|
|
|
)
|
|
|
self.g_a4 = conv(N, N, kernel_size=3, stride=2)
|
|
|
self.g_a5 = RSTB(dim=N,
|
|
|
input_resolution=(32, 32),
|
|
|
depth=depths[2],
|
|
|
num_heads=num_heads[2],
|
|
|
window_size=window_size,
|
|
|
mlp_ratio=mlp_ratio,
|
|
|
qkv_bias=qkv_bias, qk_scale=qk_scale,
|
|
|
drop=drop_rate, attn_drop=attn_drop_rate,
|
|
|
drop_path=dpr[sum(depths[:2]):sum(depths[:3])],
|
|
|
norm_layer=norm_layer,
|
|
|
use_checkpoint=use_checkpoint,
|
|
|
)
|
|
|
|
|
|
self.g_a6 = conv(N, M, kernel_size=3, stride=2)
|
|
|
|
|
|
self.h_a0 = conv(M, N, kernel_size=3, stride=1)
|
|
|
self.h_a1 = RSTB(dim=N,
|
|
|
input_resolution=(16, 16),
|
|
|
depth=depths[3],
|
|
|
num_heads=num_heads[3],
|
|
|
window_size=window_size // 2,
|
|
|
mlp_ratio=mlp_ratio,
|
|
|
qkv_bias=qkv_bias, qk_scale=qk_scale,
|
|
|
drop=drop_rate, attn_drop=attn_drop_rate,
|
|
|
drop_path=dpr[sum(depths[:3]):sum(depths[:4])],
|
|
|
norm_layer=norm_layer,
|
|
|
use_checkpoint=use_checkpoint,
|
|
|
)
|
|
|
self.h_a2 = conv(N, N, kernel_size=3, stride=2)
|
|
|
self.h_a3 = RSTB(dim=N,
|
|
|
input_resolution=(8, 8),
|
|
|
depth=depths[4],
|
|
|
num_heads=num_heads[4],
|
|
|
window_size=window_size // 2,
|
|
|
mlp_ratio=mlp_ratio,
|
|
|
qkv_bias=qkv_bias, qk_scale=qk_scale,
|
|
|
drop=drop_rate, attn_drop=attn_drop_rate,
|
|
|
drop_path=dpr[sum(depths[:4]):sum(depths[:5])],
|
|
|
norm_layer=norm_layer,
|
|
|
use_checkpoint=use_checkpoint,
|
|
|
)
|
|
|
self.h_a4 = conv(N, N, kernel_size=3, stride=2)
|
|
|
|
|
|
depths = depths[::-1]
|
|
|
num_heads = num_heads[::-1]
|
|
|
self.h_s0 = deconv(N, N, kernel_size=3, stride=2)
|
|
|
self.h_s1 = RSTB(dim=N,
|
|
|
input_resolution=(8, 8),
|
|
|
depth=depths[0],
|
|
|
num_heads=num_heads[0],
|
|
|
window_size=window_size // 2,
|
|
|
mlp_ratio=mlp_ratio,
|
|
|
qkv_bias=qkv_bias, qk_scale=qk_scale,
|
|
|
drop=drop_rate, attn_drop=attn_drop_rate,
|
|
|
drop_path=dpr[sum(depths[:0]):sum(depths[:1])],
|
|
|
norm_layer=norm_layer,
|
|
|
use_checkpoint=use_checkpoint,
|
|
|
)
|
|
|
self.h_s2 = deconv(N, N, kernel_size=3, stride=2)
|
|
|
self.h_s3 = RSTB(dim=N,
|
|
|
input_resolution=(16, 16),
|
|
|
depth=depths[1],
|
|
|
num_heads=num_heads[1],
|
|
|
window_size=window_size // 2,
|
|
|
mlp_ratio=mlp_ratio,
|
|
|
qkv_bias=qkv_bias, qk_scale=qk_scale,
|
|
|
drop=drop_rate, attn_drop=attn_drop_rate,
|
|
|
drop_path=dpr[sum(depths[:1]):sum(depths[:2])],
|
|
|
norm_layer=norm_layer,
|
|
|
use_checkpoint=use_checkpoint,
|
|
|
)
|
|
|
self.h_s4 = conv(N, M * 2, kernel_size=3, stride=1)
|
|
|
|
|
|
self.g_s0 = deconv(M, N, kernel_size=3, stride=2)
|
|
|
self.g_s1 = RSTB(dim=N,
|
|
|
input_resolution=(32, 32),
|
|
|
depth=depths[2],
|
|
|
num_heads=num_heads[2],
|
|
|
window_size=window_size,
|
|
|
mlp_ratio=mlp_ratio,
|
|
|
qkv_bias=qkv_bias, qk_scale=qk_scale,
|
|
|
drop=drop_rate, attn_drop=attn_drop_rate,
|
|
|
drop_path=dpr[sum(depths[:2]):sum(depths[:3])],
|
|
|
norm_layer=norm_layer,
|
|
|
use_checkpoint=use_checkpoint,
|
|
|
)
|
|
|
self.g_s2 = deconv(N, N, kernel_size=3, stride=2)
|
|
|
self.g_s3 = RSTB(dim=N,
|
|
|
input_resolution=(64, 64),
|
|
|
depth=depths[3],
|
|
|
num_heads=num_heads[3],
|
|
|
window_size=window_size,
|
|
|
mlp_ratio=mlp_ratio,
|
|
|
qkv_bias=qkv_bias, qk_scale=qk_scale,
|
|
|
drop=drop_rate, attn_drop=attn_drop_rate,
|
|
|
drop_path=dpr[sum(depths[:3]):sum(depths[:4])],
|
|
|
norm_layer=norm_layer,
|
|
|
use_checkpoint=use_checkpoint,
|
|
|
)
|
|
|
self.g_s4 = deconv(N, N, kernel_size=3, stride=2)
|
|
|
self.g_s5 = RSTB(dim=N,
|
|
|
input_resolution=(128, 128),
|
|
|
depth=depths[4],
|
|
|
num_heads=num_heads[4],
|
|
|
window_size=window_size,
|
|
|
mlp_ratio=mlp_ratio,
|
|
|
qkv_bias=qkv_bias, qk_scale=qk_scale,
|
|
|
drop=drop_rate, attn_drop=attn_drop_rate,
|
|
|
drop_path=dpr[sum(depths[:4]):sum(depths[:5])],
|
|
|
norm_layer=norm_layer,
|
|
|
use_checkpoint=use_checkpoint,
|
|
|
)
|
|
|
self.g_s6 = deconv(N, 3, kernel_size=5, stride=2)
|
|
|
|
|
|
self.entropy_bottleneck = EntropyBottleneck(N)
|
|
|
self.gaussian_conditional = GaussianConditional(None)
|
|
|
self.context_prediction = CausalAttentionModule(M, M * 2)
|
|
|
|
|
|
|
|
|
self.entropy_parameters = nn.Sequential(
|
|
|
nn.Conv2d(M * 12 // 3, M * 10 // 3, 1),
|
|
|
nn.GELU(),
|
|
|
nn.Conv2d(M * 10 // 3, M * 8 // 3, 1),
|
|
|
nn.GELU(),
|
|
|
nn.Conv2d(M * 8 // 3, M * 6 // 3, 1),
|
|
|
)
|
|
|
|
|
|
self.sub_net_leaky = nn.Sequential(
|
|
|
conv(N,N,kernel_size=3,stride=2),
|
|
|
nn.LeakyReLU()
|
|
|
)
|
|
|
|
|
|
self.sub_net0 = nn.Sequential(
|
|
|
conv(N,64,kernel_size=1,stride=1),
|
|
|
nn.ReLU()
|
|
|
)
|
|
|
self.sub_net1 = nn.Sequential(
|
|
|
conv(64,64,kernel_size=3,stride=1),
|
|
|
nn.ReLU()
|
|
|
)
|
|
|
self.sub_net2 = conv(64,N,kernel_size=1,stride=1)
|
|
|
|
|
|
self.sub_net_channel = conv(N,M,kernel_size=1,stride=1)
|
|
|
|
|
|
self.simi_net = nn.Sequential(
|
|
|
conv(1,64,kernel_size=3,stride=2),
|
|
|
nn.ReLU(),
|
|
|
conv(64,128,kernel_size=3,stride=2),
|
|
|
nn.ReLU(),
|
|
|
conv(128, M, kernel_size=3, stride=2),
|
|
|
)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
self.cosine_similarity = torch.nn.CosineSimilarity(dim=1)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
self.tanh = nn.Tanh()
|
|
|
self.softsign = nn.Softsign()
|
|
|
self.relu = nn.ReLU()
|
|
|
self.sigmoid = nn.Sigmoid()
|
|
|
|
|
|
|
|
|
self.apply(self._init_weights)
|
|
|
|
|
|
|
|
|
|
|
|
def g_a(self, x, x_size=None):
|
|
|
if x_size is None:
|
|
|
x_size = x.shape[2:4]
|
|
|
x = self.g_a0(x)
|
|
|
x = self.g_a1(x, (x_size[0] // 2, x_size[1] // 2))
|
|
|
x = self.g_a2(x)
|
|
|
x = self.g_a3(x, (x_size[0] // 4, x_size[1] // 4))
|
|
|
x = self.g_a4(x)
|
|
|
x = self.g_a5(x, (x_size[0] // 8, x_size[1] // 8))
|
|
|
|
|
|
return x
|
|
|
|
|
|
def g_s(self, x, x_size=None):
|
|
|
if x_size is None:
|
|
|
x_size = (x.shape[2] * 16, x.shape[3] * 16)
|
|
|
x = self.g_s0(x)
|
|
|
x = self.g_s1(x, (x_size[0] // 8, x_size[1] // 8))
|
|
|
x = self.g_s2(x)
|
|
|
x = self.g_s3(x, (x_size[0] // 4, x_size[1] // 4))
|
|
|
x = self.g_s4(x)
|
|
|
x = self.g_s5(x, (x_size[0] // 2, x_size[1] // 2))
|
|
|
x = self.g_s6(x)
|
|
|
return x
|
|
|
|
|
|
def h_a(self, x, x_size=None):
|
|
|
if x_size is None:
|
|
|
x_size = (x.shape[2] * 16, x.shape[3] * 16)
|
|
|
x = self.h_a0(x)
|
|
|
x = self.h_a1(x, (x_size[0] // 16, x_size[1] // 16))
|
|
|
x = self.h_a2(x)
|
|
|
x = self.h_a3(x, (x_size[0] // 32, x_size[1] // 32))
|
|
|
x = self.h_a4(x)
|
|
|
return x
|
|
|
|
|
|
def h_s(self, x, x_size=None):
|
|
|
if x_size is None:
|
|
|
x_size = (x.shape[2] * 64, x.shape[3] * 64)
|
|
|
x = self.h_s0(x)
|
|
|
x = self.h_s1(x, (x_size[0] // 32, x_size[1] // 32))
|
|
|
x = self.h_s2(x)
|
|
|
x = self.h_s3(x, (x_size[0] // 16, x_size[1] // 16))
|
|
|
x = self.h_s4(x)
|
|
|
return x
|
|
|
|
|
|
|
|
|
def sub_impor_net(self,x):
|
|
|
x1 = self.sub_net_leaky(x)
|
|
|
|
|
|
x2 = self.sub_net0(x1)
|
|
|
x2 = self.sub_net1(x2)
|
|
|
x2 = self.sub_net2(x2)
|
|
|
|
|
|
x2 = x1 + x2
|
|
|
x3 = self.sub_net0(x2)
|
|
|
x3 = self.sub_net1(x3)
|
|
|
x3 = self.sub_net2(x3)
|
|
|
|
|
|
x3 = x2 + x3
|
|
|
x4 = self.sub_net0(x3)
|
|
|
x4 = self.sub_net1(x4)
|
|
|
x4 = self.sub_net2(x4)
|
|
|
|
|
|
x_out = x4 + x3
|
|
|
x_out = self.sub_net_channel(x_out)
|
|
|
|
|
|
return x_out
|
|
|
|
|
|
|
|
|
|
|
|
def aux_loss(self):
|
|
|
"""Return the aggregated loss over the auxiliary entropy bottleneck
|
|
|
module(s).
|
|
|
"""
|
|
|
aux_loss = sum(
|
|
|
m.loss() for m in self.modules() if isinstance(m, EntropyBottleneck)
|
|
|
)
|
|
|
return aux_loss
|
|
|
|
|
|
def _init_weights(self, m):
|
|
|
if isinstance(m, nn.Linear):
|
|
|
trunc_normal_(m.weight, std=.02)
|
|
|
if isinstance(m, nn.Linear) and m.bias is not None:
|
|
|
nn.init.constant_(m.bias, 0)
|
|
|
elif isinstance(m, nn.LayerNorm):
|
|
|
nn.init.constant_(m.bias, 0)
|
|
|
nn.init.constant_(m.weight, 1.0)
|
|
|
|
|
|
|
|
|
@torch.jit.ignore
|
|
|
def no_weight_decay_keywords(self):
|
|
|
return {'relative_position_bias_table'}
|
|
|
|
|
|
def forward(self, x, similarity):
|
|
|
|
|
|
x_size = (x.shape[2], x.shape[3])
|
|
|
|
|
|
h, w = x.size(2), x.size(3)
|
|
|
|
|
|
similarity_loss = torch.where(similarity > 0.85, torch.tensor(1.0), torch.tensor(0.01))
|
|
|
similarity_imp = torch.where(similarity > 0.85, torch.tensor(1.0), torch.tensor(0.01))
|
|
|
|
|
|
similarity_up = F.interpolate(similarity_loss, scale_factor=2, mode='bilinear')
|
|
|
similarity_up_repeated = similarity_up.repeat(1, 3, 1, 1)
|
|
|
|
|
|
similarities_channel = self.simi_net(similarity_imp)
|
|
|
similarities_sigmoid = torch.sigmoid(similarities_channel)
|
|
|
|
|
|
|
|
|
y_codec = self.g_a(x, x_size)
|
|
|
y_codec_a6 = self.g_a6(y_codec)
|
|
|
|
|
|
|
|
|
y_import = self.sub_impor_net(y_codec)
|
|
|
y_tanh = self.tanh(y_import)
|
|
|
y_soft = self.softsign(y_tanh)
|
|
|
|
|
|
|
|
|
|
|
|
y_imp = y_soft + similarities_sigmoid
|
|
|
y = y_codec_a6 * y_imp
|
|
|
|
|
|
|
|
|
z = self.h_a(y, x_size)
|
|
|
z_hat, z_likelihoods = self.entropy_bottleneck(z)
|
|
|
params = self.h_s(z_hat, x_size)
|
|
|
|
|
|
y_hat = self.gaussian_conditional.quantize(
|
|
|
y, "noise" if self.training else "dequantize"
|
|
|
)
|
|
|
ctx_params = self.context_prediction(y_hat)
|
|
|
gaussian_params = self.entropy_parameters(
|
|
|
torch.cat((params, ctx_params), dim=1)
|
|
|
)
|
|
|
scales_hat, means_hat = gaussian_params.chunk(2, 1)
|
|
|
_, y_likelihoods = self.gaussian_conditional(y, scales_hat, means=means_hat)
|
|
|
x_hat = self.g_s(y_hat, x_size)
|
|
|
|
|
|
return {
|
|
|
"y_hat": y_hat,
|
|
|
"y": y,
|
|
|
"similarity":similarity_up_repeated,
|
|
|
"x_hat": x_hat,
|
|
|
"likelihoods": {"y": y_likelihoods, "z": z_likelihoods},
|
|
|
}
|
|
|
|
|
|
|
|
|
def update(self, scale_table=None, force=False):
|
|
|
"""Updates the entropy bottleneck(s) CDF values.
|
|
|
|
|
|
Needs to be called once after training to be able to later perform the
|
|
|
evaluation with an actual entropy coder.
|
|
|
|
|
|
Args:
|
|
|
scale_table (bool): (default: None)
|
|
|
force (bool): overwrite previous values (default: False)
|
|
|
|
|
|
Returns:
|
|
|
updated (bool): True if one of the EntropyBottlenecks was updated.
|
|
|
|
|
|
"""
|
|
|
if scale_table is None:
|
|
|
scale_table = get_scale_table()
|
|
|
self.gaussian_conditional.update_scale_table(scale_table, force=force)
|
|
|
|
|
|
updated = False
|
|
|
for m in self.children():
|
|
|
if not isinstance(m, EntropyBottleneck):
|
|
|
continue
|
|
|
rv = m.update(force=force)
|
|
|
updated |= rv
|
|
|
return updated
|
|
|
|
|
|
def load_state_dict(self, state_dict, strict=True):
|
|
|
|
|
|
update_registered_buffers(
|
|
|
self.entropy_bottleneck,
|
|
|
"entropy_bottleneck",
|
|
|
["_quantized_cdf", "_offset", "_cdf_length"],
|
|
|
state_dict,
|
|
|
)
|
|
|
update_registered_buffers(
|
|
|
self.gaussian_conditional,
|
|
|
"gaussian_conditional",
|
|
|
["_quantized_cdf", "_offset", "_cdf_length", "scale_table"],
|
|
|
state_dict,
|
|
|
)
|
|
|
super().load_state_dict(state_dict, strict=strict)
|
|
|
|
|
|
@classmethod
|
|
|
def from_state_dict(cls, state_dict):
|
|
|
"""Return a new model instance from `state_dict`."""
|
|
|
N = state_dict["g_a0.weight"].size(0)
|
|
|
M = state_dict["g_a6.weight"].size(0)
|
|
|
net = cls(N, M)
|
|
|
net.load_state_dict(state_dict)
|
|
|
return net
|
|
|
|
|
|
|
|
|
def compress(self, x):
|
|
|
x = x.cuda()
|
|
|
|
|
|
x_size = (x.shape[2], x.shape[3])
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
start = time.time()
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
y_codec = self.g_a(x, x_size)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
y_codec_a6 = self.g_a6(y_codec)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
y = y_codec_a6
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
z = self.h_a(y)
|
|
|
|
|
|
z_strings = self.entropy_bottleneck.compress(z)
|
|
|
z_hat = self.entropy_bottleneck.decompress(z_strings, z.size()[-2:])
|
|
|
|
|
|
params = self.h_s(z_hat)
|
|
|
|
|
|
s = 4
|
|
|
kernel_size = 5
|
|
|
padding = (kernel_size - 1) // 2
|
|
|
|
|
|
y_height = z_hat.size(2) * s
|
|
|
y_width = z_hat.size(3) * s
|
|
|
|
|
|
y_hat = F.pad(y, (padding, padding, padding, padding))
|
|
|
|
|
|
|
|
|
cdf = self.gaussian_conditional._quantized_cdf.tolist()
|
|
|
cdf_lengths = self.gaussian_conditional._cdf_length.reshape(-1).int().tolist()
|
|
|
offsets = self.gaussian_conditional._offset.reshape(-1).int().tolist()
|
|
|
|
|
|
|
|
|
y_strings = []
|
|
|
for i in range(y.size(0)):
|
|
|
encoder = BufferedRansEncoder()
|
|
|
|
|
|
|
|
|
symbols_list = []
|
|
|
indexes_list = []
|
|
|
y_q_ = torch.zeros_like(y)
|
|
|
indexes_ = torch.zeros_like(y)
|
|
|
for h in range(y_height):
|
|
|
for w in range(y_width):
|
|
|
y_crop = y_hat[
|
|
|
i: i + 1, :, h: h + kernel_size, w: w + kernel_size
|
|
|
]
|
|
|
ctx_p = self.context_prediction(y_crop)
|
|
|
|
|
|
|
|
|
p = params[i: i + 1, :, h: h + 1, w: w + 1]
|
|
|
gaussian_params = self.entropy_parameters(
|
|
|
torch.cat((p, ctx_p[i: i + 1, :, 2: 3, 2: 3]), dim=1)
|
|
|
)
|
|
|
scales_hat, means_hat = gaussian_params.chunk(2, 1)
|
|
|
|
|
|
indexes = self.gaussian_conditional.build_indexes(scales_hat)
|
|
|
y_q = torch.round(y_crop - means_hat)
|
|
|
y_hat[i, :, h + padding, w + padding] = (y_q + means_hat)[
|
|
|
i, :, padding, padding
|
|
|
]
|
|
|
y_q_[i,:, h, w] = y_q[i, :, padding, padding]
|
|
|
indexes_[i,:, h, w] = indexes[i, :,0,0]
|
|
|
|
|
|
flag = np.array(np.zeros(y_q_.shape[1]))
|
|
|
for idx in range(y_q_.shape[1]):
|
|
|
if torch.sum(torch.abs(y_q_[:, idx, :, :])) > 0:
|
|
|
flag[idx] = 1
|
|
|
y_q_ = y_q_[:,np.nonzero(flag),...].squeeze()
|
|
|
indexes_ = indexes_[:,np.nonzero(flag),...].squeeze()
|
|
|
for h in range(y_height):
|
|
|
for w in range(y_width):
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
symbols_list.extend(y_q_[:,h,w].int().tolist())
|
|
|
indexes_list.extend(indexes_[:,h,w].squeeze().int().tolist())
|
|
|
encoder.encode_with_indexes(
|
|
|
symbols_list, indexes_list, cdf, cdf_lengths, offsets
|
|
|
)
|
|
|
string = encoder.flush()
|
|
|
y_strings.append(string)
|
|
|
print(flag.sum())
|
|
|
|
|
|
torch.cuda.synchronize()
|
|
|
t2 = time.time() - start
|
|
|
|
|
|
|
|
|
return {"strings": [y_strings, z_strings], "shape": z.size()[-2:],"flag":flag}
|
|
|
|
|
|
|
|
|
|
|
|
def compress_1(self, x,similarity):
|
|
|
|
|
|
x = x.cuda()
|
|
|
x_size = (x.shape[2], x.shape[3])
|
|
|
|
|
|
similarity = similarity.cuda()
|
|
|
|
|
|
similarity_down_1 = torch.where(similarity == 0, torch.tensor(1e-4), torch.tensor(1.0))
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
similarity_down_1 = F.interpolate(similarity_down_1, scale_factor=0.5, mode='bilinear')
|
|
|
similarities_repeated = self.simi_net(similarity_down_1)
|
|
|
similarities_repeated = torch.sigmoid(similarities_repeated)
|
|
|
|
|
|
|
|
|
y_codec = self.g_a(x, x_size)
|
|
|
|
|
|
y_import = self.sub_impor_net(y_codec)
|
|
|
y_tanh = self.tanh(y_import)
|
|
|
y_soft = self.softsign(y_tanh)
|
|
|
|
|
|
|
|
|
y_codec_a6 = self.g_a6(y_codec)
|
|
|
|
|
|
|
|
|
|
|
|
y_imp = similarities_repeated + y_soft
|
|
|
|
|
|
|
|
|
y = y_codec_a6 * y_imp
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
z = self.h_a(y)
|
|
|
|
|
|
z_strings = self.entropy_bottleneck.compress(z)
|
|
|
z_hat = self.entropy_bottleneck.decompress(z_strings, z.size()[-2:])
|
|
|
|
|
|
params = self.h_s(z_hat)
|
|
|
|
|
|
s = 4
|
|
|
kernel_size = 5
|
|
|
padding = (kernel_size - 1) // 2
|
|
|
|
|
|
y_height = z_hat.size(2) * s
|
|
|
y_width = z_hat.size(3) * s
|
|
|
|
|
|
y_hat = F.pad(y, (padding, padding, padding, padding))
|
|
|
|
|
|
|
|
|
cdf = self.gaussian_conditional._quantized_cdf.tolist()
|
|
|
cdf_lengths = self.gaussian_conditional._cdf_length.reshape(-1).int().tolist()
|
|
|
offsets = self.gaussian_conditional._offset.reshape(-1).int().tolist()
|
|
|
|
|
|
|
|
|
|
|
|
y_strings = []
|
|
|
for i in range(y.size(0)):
|
|
|
encoder = BufferedRansEncoder()
|
|
|
|
|
|
|
|
|
symbols_list = []
|
|
|
indexes_list = []
|
|
|
for h in range(y_height):
|
|
|
for w in range(y_width):
|
|
|
y_crop = y_hat[
|
|
|
i: i + 1, :, h: h + kernel_size, w: w + kernel_size
|
|
|
]
|
|
|
ctx_p = self.context_prediction(y_crop)
|
|
|
|
|
|
|
|
|
p = params[i: i + 1, :, h: h + 1, w: w + 1]
|
|
|
gaussian_params = self.entropy_parameters(
|
|
|
torch.cat((p, ctx_p[i: i + 1, :, 2: 3, 2: 3]), dim=1)
|
|
|
)
|
|
|
scales_hat, means_hat = gaussian_params.chunk(2, 1)
|
|
|
|
|
|
indexes = self.gaussian_conditional.build_indexes(scales_hat)
|
|
|
y_q = torch.round(y_crop - means_hat)
|
|
|
y_hat[i, :, h + padding, w + padding] = (y_q + means_hat)[
|
|
|
i, :, padding, padding
|
|
|
]
|
|
|
|
|
|
symbols_list.extend(y_q[i, :, padding, padding].int().tolist())
|
|
|
indexes_list.extend(indexes[i, :].squeeze().int().tolist())
|
|
|
|
|
|
encoder.encode_with_indexes(
|
|
|
symbols_list, indexes_list, cdf, cdf_lengths, offsets
|
|
|
)
|
|
|
|
|
|
string = encoder.flush()
|
|
|
y_strings.append(string)
|
|
|
|
|
|
return {"strings": [y_strings, z_strings], "shape": z.size()[-2:]}
|
|
|
|
|
|
def compress_2(self, x,similarity):
|
|
|
|
|
|
x = x.cuda()
|
|
|
x_size = (x.shape[2], x.shape[3])
|
|
|
|
|
|
|
|
|
similarity_down_1 = torch.where(similarity > 0.9, torch.tensor(1.0), torch.tensor(0.1))
|
|
|
similarities_repeated = self.simi_net(similarity_down_1)
|
|
|
similarities_repeated = torch.sigmoid(similarities_repeated)
|
|
|
|
|
|
y_codec = self.g_a(x, x_size)
|
|
|
|
|
|
y_import = self.sub_impor_net(y_codec)
|
|
|
y_tanh = self.tanh(y_import)
|
|
|
|
|
|
y_codec_a6 = self.g_a6(y_codec)
|
|
|
|
|
|
|
|
|
|
|
|
y_imp = similarities_repeated + y_tanh
|
|
|
|
|
|
y = y_codec_a6 * y_imp
|
|
|
|
|
|
z = self.h_a(y)
|
|
|
|
|
|
z_strings = self.entropy_bottleneck.compress(z)
|
|
|
z_hat = self.entropy_bottleneck.decompress(z_strings, z.size()[-2:])
|
|
|
|
|
|
params = self.h_s(z_hat)
|
|
|
|
|
|
s = 4
|
|
|
kernel_size = 5
|
|
|
padding = (kernel_size - 1) // 2
|
|
|
|
|
|
y_height = z_hat.size(2) * s
|
|
|
y_width = z_hat.size(3) * s
|
|
|
|
|
|
y_hat = F.pad(y, (padding, padding, padding, padding))
|
|
|
|
|
|
|
|
|
cdf = self.gaussian_conditional._quantized_cdf.tolist()
|
|
|
cdf_lengths = self.gaussian_conditional._cdf_length.reshape(-1).int().tolist()
|
|
|
offsets = self.gaussian_conditional._offset.reshape(-1).int().tolist()
|
|
|
|
|
|
|
|
|
|
|
|
y_strings = []
|
|
|
for i in range(y.size(0)):
|
|
|
encoder = BufferedRansEncoder()
|
|
|
|
|
|
|
|
|
symbols_list = []
|
|
|
indexes_list = []
|
|
|
for h in range(y_height):
|
|
|
for w in range(y_width):
|
|
|
y_crop = y_hat[
|
|
|
i: i + 1, :, h: h + kernel_size, w: w + kernel_size
|
|
|
]
|
|
|
ctx_p = self.context_prediction(y_crop)
|
|
|
|
|
|
|
|
|
p = params[i: i + 1, :, h: h + 1, w: w + 1]
|
|
|
gaussian_params = self.entropy_parameters(
|
|
|
torch.cat((p, ctx_p[i: i + 1, :, 2: 3, 2: 3]), dim=1)
|
|
|
)
|
|
|
scales_hat, means_hat = gaussian_params.chunk(2, 1)
|
|
|
|
|
|
indexes = self.gaussian_conditional.build_indexes(scales_hat)
|
|
|
y_q = torch.round(y_crop - means_hat)
|
|
|
y_hat[i, :, h + padding, w + padding] = (y_q + means_hat)[
|
|
|
i, :, padding, padding
|
|
|
]
|
|
|
|
|
|
symbols_list.extend(y_q[i, :, padding, padding].int().tolist())
|
|
|
indexes_list.extend(indexes[i, :].squeeze().int().tolist())
|
|
|
|
|
|
encoder.encode_with_indexes(
|
|
|
symbols_list, indexes_list, cdf, cdf_lengths, offsets
|
|
|
)
|
|
|
|
|
|
string = encoder.flush()
|
|
|
y_strings.append(string)
|
|
|
|
|
|
return {"strings": [y_strings, z_strings], "shape": z.size()[-2:]}
|
|
|
|
|
|
def decompress(self, strings, shape, flag):
|
|
|
|
|
|
flag = np.nonzero(flag)
|
|
|
assert isinstance(strings, list) and len(strings) == 2
|
|
|
|
|
|
|
|
|
|
|
|
z_hat = self.entropy_bottleneck.decompress(strings[1], shape)
|
|
|
params = self.h_s(z_hat)
|
|
|
|
|
|
s = 4
|
|
|
kernel_size = 5
|
|
|
padding = (kernel_size - 1) // 2
|
|
|
|
|
|
y_height = z_hat.size(2) * s
|
|
|
y_width = z_hat.size(3) * s
|
|
|
|
|
|
|
|
|
|
|
|
y_hat = torch.zeros(
|
|
|
(z_hat.size(0), 192, y_height + 2 * padding, y_width + 2 * padding),
|
|
|
device=z_hat.device,
|
|
|
)
|
|
|
decoder = RansDecoder()
|
|
|
|
|
|
|
|
|
cdf = self.gaussian_conditional._quantized_cdf.tolist()
|
|
|
cdf_lengths = self.gaussian_conditional._cdf_length.reshape(-1).int().tolist()
|
|
|
offsets = self.gaussian_conditional._offset.reshape(-1).int().tolist()
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
for i, y_string in enumerate(strings[0]):
|
|
|
decoder.set_stream(y_string)
|
|
|
|
|
|
for h in range(y_height):
|
|
|
for w in range(y_width):
|
|
|
|
|
|
|
|
|
y_crop = y_hat[
|
|
|
i: i + 1, :, h: h + kernel_size, w: w + kernel_size
|
|
|
]
|
|
|
ctx_p = self.context_prediction(y_crop)
|
|
|
|
|
|
|
|
|
p = params[i: i + 1, :, h: h + 1, w: w + 1]
|
|
|
gaussian_params = self.entropy_parameters(
|
|
|
torch.cat((p, ctx_p[i: i + 1, :, 2: 3, 2: 3]), dim=1)
|
|
|
)
|
|
|
scales_hat, means_hat = gaussian_params.chunk(2, 1)
|
|
|
|
|
|
indexes = self.gaussian_conditional.build_indexes(scales_hat)
|
|
|
rv = decoder.decode_stream(
|
|
|
indexes[i, flag].squeeze().int().tolist(),
|
|
|
|
|
|
cdf,
|
|
|
cdf_lengths,
|
|
|
offsets,
|
|
|
)
|
|
|
|
|
|
rv = torch.Tensor(rv).reshape(1, -1, 1, 1)
|
|
|
tmp = torch.zeros((1, 192, 1, 1))
|
|
|
tmp[:, flag, ...] = rv
|
|
|
rv = self.gaussian_conditional._dequantize(tmp, means_hat)
|
|
|
|
|
|
|
|
|
y_hat[
|
|
|
i,
|
|
|
:,
|
|
|
h + padding: h + padding + 1,
|
|
|
w + padding: w + padding + 1,
|
|
|
] = rv
|
|
|
|
|
|
|
|
|
y_hat = y_hat[:, :, padding:-padding, padding:-padding]
|
|
|
|
|
|
|
|
|
x_hat = self.g_s(y_hat).clamp_(0, 1)
|
|
|
return {"x_hat": x_hat,}
|
|
|
|
|
|
def decompress_1(self, strings, shape):
|
|
|
assert isinstance(strings, list) and len(strings) == 2
|
|
|
|
|
|
|
|
|
|
|
|
z_hat = self.entropy_bottleneck.decompress(strings[1], shape)
|
|
|
params = self.h_s(z_hat)
|
|
|
|
|
|
s = 4
|
|
|
kernel_size = 5
|
|
|
padding = (kernel_size - 1) // 2
|
|
|
|
|
|
y_height = z_hat.size(2) * s
|
|
|
y_width = z_hat.size(3) * s
|
|
|
|
|
|
|
|
|
|
|
|
y_hat = torch.zeros(
|
|
|
(z_hat.size(0), 192, y_height + 2 * padding, y_width + 2 * padding),
|
|
|
device=z_hat.device,
|
|
|
)
|
|
|
decoder = RansDecoder()
|
|
|
|
|
|
|
|
|
cdf = self.gaussian_conditional._quantized_cdf.tolist()
|
|
|
cdf_lengths = self.gaussian_conditional._cdf_length.reshape(-1).int().tolist()
|
|
|
offsets = self.gaussian_conditional._offset.reshape(-1).int().tolist()
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
for i, y_string in enumerate(strings[0]):
|
|
|
decoder.set_stream(y_string)
|
|
|
|
|
|
for h in range(y_height):
|
|
|
for w in range(y_width):
|
|
|
|
|
|
|
|
|
y_crop = y_hat[
|
|
|
i: i + 1, :, h: h + kernel_size, w: w + kernel_size
|
|
|
]
|
|
|
ctx_p = self.context_prediction(y_crop)
|
|
|
|
|
|
|
|
|
p = params[i: i + 1, :, h: h + 1, w: w + 1]
|
|
|
gaussian_params = self.entropy_parameters(
|
|
|
torch.cat((p, ctx_p[i: i + 1, :, 2: 3, 2: 3]), dim=1)
|
|
|
)
|
|
|
scales_hat, means_hat = gaussian_params.chunk(2, 1)
|
|
|
|
|
|
indexes = self.gaussian_conditional.build_indexes(scales_hat)
|
|
|
|
|
|
rv = decoder.decode_stream(
|
|
|
indexes[i, :].squeeze().int().tolist(),
|
|
|
cdf,
|
|
|
cdf_lengths,
|
|
|
offsets,
|
|
|
)
|
|
|
rv = torch.Tensor(rv).reshape(1, -1, 1, 1)
|
|
|
|
|
|
rv = self.gaussian_conditional._dequantize(rv, means_hat)
|
|
|
|
|
|
y_hat[
|
|
|
i,
|
|
|
:,
|
|
|
h + padding: h + padding + 1,
|
|
|
w + padding: w + padding + 1,
|
|
|
] = rv
|
|
|
y_hat = y_hat[:, :, padding:-padding, padding:-padding]
|
|
|
|
|
|
|
|
|
x_hat = self.g_s(y_hat).clamp_(0, 1)
|
|
|
return {"x_hat": x_hat} |