yuangai's picture
init space
849926f
import torch
import torch.nn as nn
import math
from einops import rearrange
import torch.nn.functional as F
from collections import defaultdict
def swish(x):
return x*torch.sigmoid(x)
class ResBlock(nn.Module):
def __init__(self,
in_filters,
out_filters,
use_conv_shortcut = False,
use_agn = False,
) -> None:
super().__init__()
self.in_filters = in_filters
self.out_filters = out_filters
self.use_conv_shortcut = use_conv_shortcut
self.use_agn = use_agn
if not use_agn: ## agn is GroupNorm likewise skip it if has agn before
self.norm1 = nn.GroupNorm(32, in_filters, eps=1e-6)
self.norm2 = nn.GroupNorm(32, out_filters, eps=1e-6)
self.conv1 = nn.Conv2d(in_filters, out_filters, kernel_size=(3, 3), padding=1, bias=False)
self.conv2 = nn.Conv2d(out_filters, out_filters, kernel_size=(3, 3), padding=1, bias=False)
if in_filters != out_filters:
if self.use_conv_shortcut:
self.conv_shortcut = nn.Conv2d(in_filters, out_filters, kernel_size=(3, 3), padding=1, bias=False)
else:
self.nin_shortcut = nn.Conv2d(in_filters, out_filters, kernel_size=(1, 1), padding=0, bias=False)
def forward(self, x, **kwargs):
residual = x
if not self.use_agn:
x = self.norm1(x)
x = swish(x)
x = self.conv1(x)
x = self.norm2(x)
x = swish(x)
x = self.conv2(x)
if self.in_filters != self.out_filters:
if self.use_conv_shortcut:
residual = self.conv_shortcut(residual)
else:
residual = self.nin_shortcut(residual)
return x + residual
class Encoder(nn.Module):
def __init__(self, *, ch, out_ch, in_channels, num_res_blocks, z_channels, ch_mult=(1, 2, 2, 4),
resolution=None, double_z=False,
):
super().__init__()
self.in_channels = in_channels
self.z_channels = z_channels
self.resolution = resolution
self.num_res_blocks = num_res_blocks
self.num_blocks = len(ch_mult)
self.conv_in = nn.Conv2d(in_channels,
ch,
kernel_size=(3, 3),
padding=1,
bias=False
)
## construct the model
self.down = nn.ModuleList()
in_ch_mult = (1,)+tuple(ch_mult)
for i_level in range(self.num_blocks):
block = nn.ModuleList()
block_in = ch*in_ch_mult[i_level] #[1, 1, 2, 2, 4]
block_out = ch*ch_mult[i_level] #[1, 2, 2, 4]
for _ in range(self.num_res_blocks):
block.append(ResBlock(block_in, block_out))
block_in = block_out
down = nn.Module()
down.block = block
if i_level < self.num_blocks - 1:
down.downsample = nn.Conv2d(block_out, block_out, kernel_size=(3, 3), stride=(2, 2), padding=1)
self.down.append(down)
### mid
self.mid_block = nn.ModuleList()
for res_idx in range(self.num_res_blocks):
self.mid_block.append(ResBlock(block_in, block_in))
### end
self.norm_out = nn.GroupNorm(32, block_out, eps=1e-6)
self.conv_out = nn.Conv2d(block_out, z_channels, kernel_size=(1, 1))
def forward(self, x):
## down
x = self.conv_in(x)
for i_level in range(self.num_blocks):
for i_block in range(self.num_res_blocks):
x = self.down[i_level].block[i_block](x)
if i_level < self.num_blocks - 1:
x = self.down[i_level].downsample(x)
## mid
for res in range(self.num_res_blocks):
x = self.mid_block[res](x)
x = self.norm_out(x)
x = swish(x)
x = self.conv_out(x)
return x
class Decoder(nn.Module):
def __init__(self, *, ch, out_ch, in_channels, num_res_blocks, z_channels, ch_mult=(1, 2, 2, 4),
resolution=None, double_z=False,) -> None:
super().__init__()
self.ch = ch
self.num_blocks = len(ch_mult)
self.num_res_blocks = num_res_blocks
self.resolution = resolution
self.in_channels = in_channels
block_in = ch*ch_mult[self.num_blocks-1]
self.conv_in = nn.Conv2d(
z_channels, block_in, kernel_size=(3, 3), padding=1, bias=True
)
self.mid_block = nn.ModuleList()
for res_idx in range(self.num_res_blocks):
self.mid_block.append(ResBlock(block_in, block_in))
self.up = nn.ModuleList()
self.adaptive = nn.ModuleList()
for i_level in reversed(range(self.num_blocks)):
block = nn.ModuleList()
block_out = ch*ch_mult[i_level]
self.adaptive.insert(0, AdaptiveGroupNorm(z_channels, block_in))
for i_block in range(self.num_res_blocks):
block.append(ResBlock(block_in, block_out))
block_in = block_out
up = nn.Module()
up.block = block
if i_level > 0:
up.upsample = Upsampler(block_in)
self.up.insert(0, up)
self.norm_out = nn.GroupNorm(32, block_in, eps=1e-6)
self.conv_out = nn.Conv2d(block_in, out_ch, kernel_size=(3, 3), padding=1)
def forward(self, z):
style = z.clone() #for adaptive groupnorm
z = self.conv_in(z)
## mid
for res in range(self.num_res_blocks):
z = self.mid_block[res](z)
## upsample
for i_level in reversed(range(self.num_blocks)):
### pass in each resblock first adaGN
z = self.adaptive[i_level](z, style)
for i_block in range(self.num_res_blocks):
z = self.up[i_level].block[i_block](z)
if i_level > 0:
z = self.up[i_level].upsample(z)
z = self.norm_out(z)
z = swish(z)
z = self.conv_out(z)
return z
def depth_to_space(x: torch.Tensor, block_size: int) -> torch.Tensor:
""" Depth-to-Space DCR mode (depth-column-row) core implementation.
Args:
x (torch.Tensor): input tensor. The channels-first (*CHW) layout is supported.
block_size (int): block side size
"""
# check inputs
if x.dim() < 3:
raise ValueError(
f"Expecting a channels-first (*CHW) tensor of at least 3 dimensions"
)
c, h, w = x.shape[-3:]
s = block_size**2
if c % s != 0:
raise ValueError(
f"Expecting a channels-first (*CHW) tensor with C divisible by {s}, but got C={c} channels"
)
outer_dims = x.shape[:-3]
# splitting two additional dimensions from the channel dimension
x = x.view(-1, block_size, block_size, c // s, h, w)
# putting the two new dimensions along H and W
x = x.permute(0, 3, 4, 1, 5, 2)
# merging the two new dimensions with H and W
x = x.contiguous().view(*outer_dims, c // s, h * block_size,
w * block_size)
return x
class Upsampler(nn.Module):
def __init__(
self,
dim,
dim_out = None
):
super().__init__()
dim_out = dim * 4
self.conv1 = nn.Conv2d(dim, dim_out, (3, 3), padding=1)
self.depth2space = depth_to_space
def forward(self, x):
"""
input_image: [B C H W]
"""
out = self.conv1(x)
out = self.depth2space(out, block_size=2)
return out
class AdaptiveGroupNorm(nn.Module):
def __init__(self, z_channel, in_filters, num_groups=32, eps=1e-6):
super().__init__()
self.gn = nn.GroupNorm(num_groups=32, num_channels=in_filters, eps=eps, affine=False)
# self.lin = nn.Linear(z_channels, in_filters * 2)
self.gamma = nn.Linear(z_channel, in_filters)
self.beta = nn.Linear(z_channel, in_filters)
self.eps = eps
def forward(self, x, quantizer):
B, C, _, _ = x.shape
# quantizer = F.adaptive_avg_pool2d(quantizer, (1, 1))
### calcuate var for scale
scale = rearrange(quantizer, "b c h w -> b c (h w)")
scale = scale.var(dim=-1) + self.eps #not unbias
scale = scale.sqrt()
scale = self.gamma(scale).view(B, C, 1, 1)
### calculate mean for bias
bias = rearrange(quantizer, "b c h w -> b c (h w)")
bias = bias.mean(dim=-1)
bias = self.beta(bias).view(B, C, 1, 1)
x = self.gn(x)
x = scale * x + bias
return x
class GANDecoder(nn.Module):
def __init__(self, *, ch, out_ch, in_channels, num_res_blocks, z_channels, ch_mult=(1, 2, 2, 4),
resolution=None, double_z=False,) -> None:
super().__init__()
self.ch = ch
self.num_blocks = len(ch_mult)
self.num_res_blocks = num_res_blocks
self.resolution = resolution
self.in_channels = in_channels
block_in = ch*ch_mult[self.num_blocks-1]
self.conv_in = nn.Conv2d(
z_channels * 2, block_in, kernel_size=(3, 3), padding=1, bias=True
)
self.mid_block = nn.ModuleList()
for res_idx in range(self.num_res_blocks):
self.mid_block.append(ResBlock(block_in, block_in))
self.up = nn.ModuleList()
self.adaptive = nn.ModuleList()
for i_level in reversed(range(self.num_blocks)):
block = nn.ModuleList()
block_out = ch*ch_mult[i_level]
self.adaptive.insert(0, AdaptiveGroupNorm(z_channels, block_in))
for i_block in range(self.num_res_blocks):
# if i_block == 0:
# block.append(ResBlock(block_in, block_out, use_agn=True))
# else:
block.append(ResBlock(block_in, block_out))
block_in = block_out
up = nn.Module()
up.block = block
if i_level > 0:
up.upsample = Upsampler(block_in)
self.up.insert(0, up)
self.norm_out = nn.GroupNorm(32, block_in, eps=1e-6)
self.conv_out = nn.Conv2d(block_in, out_ch, kernel_size=(3, 3), padding=1)
def forward(self, z):
style = z.clone() #for adaptive groupnorm
noise = torch.randn_like(z).to(z.device) #generate noise
z = torch.cat([z, noise], dim=1) #concat noise to the style vector
z = self.conv_in(z)
## mid
for res in range(self.num_res_blocks):
z = self.mid_block[res](z)
## upsample
for i_level in reversed(range(self.num_blocks)):
### pass in each resblock first adaGN
z = self.adaptive[i_level](z, style)
for i_block in range(self.num_res_blocks):
z = self.up[i_level].block[i_block](z)
if i_level > 0:
z = self.up[i_level].upsample(z)
z = self.norm_out(z)
z = swish(z)
z = self.conv_out(z)
return z
class VQModel(nn.Module):
def __init__(self,
ddconfig,
checkpoint=None,
gan_decoder = False,
):
super().__init__()
self.encoder = Encoder(**ddconfig)
self.decoder = GANDecoder(**ddconfig) if gan_decoder else Decoder(**ddconfig)
# Load weights from the checkpoint
if checkpoint is not None:
self.load_from_ckpt(checkpoint)
def load_from_ckpt(self, checkpoint):
state = torch.load(checkpoint, mmap=True, map_location="cpu")
log_info = self.load_state_dict(state["state_dict"], strict=False)
has_missing_keys = bool(log_info.missing_keys)
has_unexpected_keys = bool(log_info.unexpected_keys)
if not has_missing_keys:
print(f"Successfully loaded all weights from checkpoint: {checkpoint}")
else:
if has_missing_keys:
print("Missing keys (model layers not in checkpoint):")
for key in log_info.missing_keys:
print(f" - {key}")
if False and has_unexpected_keys:
print("\nUnexpected keys (checkpoint layers not in model):")
for key in log_info.unexpected_keys:
print(f" - {key}")
def encode(self, x):
h = self.encoder(x)
codebook_value = torch.Tensor([1.0]).to(h)
quant_h = torch.where(h > 0, codebook_value, -codebook_value) # higher than 0 filled
return quant_h
# def vt_forward(self, image_list):
# q_list = []
# for x in image_list:
# quant = self.encode(x)
# quant = rearrange(quant.squeeze(0), "c h w -> (h w) c")
# q_list.append(quant)
# return torch.cat(q_list, dim=0)
def vt_forward(self, image_list, max_bs=32, ps=1):
groups = defaultdict(list) # {(H, W): [(idx, image_tensor), ...]}
for i, img in enumerate(image_list):
_, _, H, W = img.shape
groups[(H, W)].append((i, img))
output = [None] * len(image_list)
for (H, W), items in groups.items():
for start in range(0, len(items), max_bs):
chunk = items[start:start + max_bs]
idxs = [x[0] for x in chunk]
imgs = [x[1] for x in chunk]
batch = torch.cat(imgs, dim=0) # [B, 3, H, W]
quant = self.encode(batch) # [B, C, h, w]
for b in range(quant.size(0)):
q = rearrange(quant[b], "c (h p1) (w p2) -> (h w p1 p2) c", p1=ps, p2=ps)
output[idxs[b]] = q
return torch.cat(output, dim=0)
def vt_forward_maxpad(
self,
image_list,
max_bs=32,
stride=32,
min_size=256,
max_size=2048,
max_pixels=1024 * 1024,
normal_buckets=(384, 512, 768, 1024),
):
"""
image_list: list of [1, 3, H, W]
return: Tensor [(sum_i Hi*Wi/stride^2), C]
"""
def is_long_image(H, W):
major = max(H, W)
minor = min(H, W)
return (
major >= 1024 and
minor <= 768 and
major / minor >= 1.5
)
groups = defaultdict(list)
sizes = {}
for idx, img in enumerate(image_list):
_, _, H, W = img.shape
# assert H >= min_size and W >= min_size
# assert H <= max_size and W <= max_size
# assert H * W <= max_pixels, f"image is too large: {H}x{W}"
if is_long_image(H, W):
bucket = "long"
else:
major = max(H, W)
for b in normal_buckets:
if major <= b:
bucket = b
break
else:
bucket = "long"
groups[bucket].append(idx)
sizes[idx] = (H, W)
output = [None] * len(image_list)
for bucket, idxs in groups.items():
imgs = [image_list[i] for i in idxs]
for start in range(0, len(imgs), max_bs):
batch_imgs = imgs[start:start + max_bs]
batch_idxs = idxs[start:start + max_bs]
H_max = max(img.shape[-2] for img in batch_imgs)
W_max = max(img.shape[-1] for img in batch_imgs)
H_pad = math.ceil(H_max / stride) * stride
W_pad = math.ceil(W_max / stride) * stride
padded = []
for img in batch_imgs:
_, _, H, W = img.shape
pad_h = H_pad - H
pad_w = W_pad - W
padded.append(F.pad(img, (0, pad_w, 0, pad_h)))
batch = torch.cat(padded, dim=0) # [B, 3, H_pad, W_pad]
quant = self.encode(batch) # [B, C, h', w']
for i, q in enumerate(quant):
H, W = sizes[batch_idxs[i]]
h_lat = math.ceil(H / stride)
w_lat = math.ceil(W / stride)
q = q[:, :h_lat, :w_lat]
q = rearrange(q, "c h w -> (h w) c")
output[batch_idxs[i]] = q
return torch.cat(output, dim=0)
def decode(self, quant):
dec = self.decoder(quant)
return dec
def forward(self, input):
quant = self.encode(input)
dec = self.decode(quant)
return dec, quant