| |
|
| | from transformers import PreTrainedModel |
| | import torch |
| | import torch.nn as nn |
| | import torch.nn.functional as F |
| | import math |
| |
|
| | from .wavelet import WaveletTransform |
| | from .pfsq import PFSQ |
| | from .config import PLPQConfig |
| |
|
| |
|
| | class PLPQ(PreTrainedModel): |
| | """Pyramidal Local Patch Quantizer""" |
| | config_class = PLPQConfig |
| |
|
| | def __init__(self, config): |
| | super().__init__(config) |
| | self.config = config |
| |
|
| | if config.__dict__.get('use_wavelets', False): |
| | wavelets = WaveletTransform(patch_size=config.patch_size) |
| | wavelet_channels = wavelets.num_transformed_channels(config.num_in_channels) |
| | in_proj = nn.Sequential( |
| | wavelets, |
| | nn.Conv2d( |
| | wavelet_channels, config.encoder_blocks[0][1], |
| | kernel_size=1, stride=1 |
| | ) |
| | ) |
| | out_proj = nn.Sequential( |
| | nn.Conv2d( |
| | config.decoder_blocks[-1][2], wavelet_channels, |
| | kernel_size=3, stride=1, padding=1 |
| | ), |
| | WaveletTransform(patch_size=config.patch_size, inverse=True) |
| | ) |
| | else: |
| | in_proj = nn.Conv2d( |
| | config.num_in_channels, config.encoder_blocks[0][1], |
| | kernel_size=config.patch_size, stride=config.patch_size |
| | ) |
| | out_proj = nn.Conv2d( |
| | config.decoder_blocks[-1][2], config.num_out_channels, |
| | kernel_size=3, stride=1, padding=1 |
| | ) |
| |
|
| | self.encoder = nn.Sequential( |
| | in_proj, |
| | nn.SiLU(), |
| | *[ |
| | PatchResidualConvBlock(*block_params[1:]) if block_params[0] == "ResBlock" else Downsample(*block_params[1:]) |
| | for block_params in config.encoder_blocks |
| | ] |
| | ) |
| |
|
| | |
| | self.quantizer = PFSQ( |
| | levels = config.levels, |
| | num_codebooks = config.num_quantizers, |
| | dim = config.encoder_blocks[-1][2], |
| | ) |
| |
|
| | |
| | self.coarse_decoder = nn.Conv2d(len(config.levels), config.num_out_channels, kernel_size=1, stride=1) |
| |
|
| | self.decoder = nn.Sequential( |
| | *[ |
| | PatchResidualConvBlock(*block_params[1:]) if block_params[0] == "ResBlock" else Upsample(*block_params[1:]) |
| | for block_params in config.decoder_blocks |
| | ], |
| | out_proj |
| | ) |
| |
|
| |
|
| | def get_num_params(self) -> int: |
| | """Return the number of parameters in the model.""" |
| | return sum(p.numel() for p in self.parameters()) |
| |
|
| |
|
| | @torch.no_grad() |
| | def quantize(self, x: torch.Tensor) -> torch.Tensor: |
| | """ |
| | Quantize the input tensor |
| | Parameters: |
| | x (torch.Tensor): The input tensor of shape (b, c, h, w) |
| | Returns: |
| | torch.Tensor: The indices tensor of shape (b, t, n_quantizers) |
| | """ |
| | z = self.encoder(x).permute(0, 2, 3, 1).contiguous() |
| | b, h, w, c = z.shape |
| | z = z.view(b, h * w, -1) |
| | quantized, coarse_quantized, all_codes = self.quantizer(z) |
| | return all_codes |
| |
|
| |
|
| | @torch.no_grad() |
| | def decode(self, indices: torch.Tensor) -> torch.Tensor: |
| | """ |
| | Decode a tensor, inverse of self.quantize |
| | Parameters: |
| | indices (torch.Tensor): The input codes of shape (b, t, n_quantizers) |
| | Returns: |
| | torch.Tensor: The decoded tensor of shape (b, c, h, w) |
| | """ |
| |
|
| | ncodes = indices.shape[-1] |
| | emb = self.quantizer.indices_to_codes(indices).squeeze(-1) |
| | |
| | b, h, w = emb.size(0), int(math.sqrt(emb.size(1))), int(math.sqrt(emb.size(1))) |
| | emb = emb.permute(0, 2, 1).view(b, -1, h, w).contiguous() |
| |
|
| | if ncodes == 1: |
| | return self.coarse_decoder(emb) |
| |
|
| | |
| | return self.decoder(emb) |
| |
|
| |
|
| |
|
| | class LayerNorm(nn.Module): |
| | """LayerNorm but with an optional bias. PyTorch doesn't support simply bias=False""" |
| | |
| | def __init__(self, ndim, bias): |
| | super().__init__() |
| | self.weight = nn.Parameter(torch.ones(ndim)) |
| | self.bias = nn.Parameter(torch.zeros(ndim)) if bias else None |
| |
|
| | def forward(self, input): |
| | return F.layer_norm(input, self.weight.shape, self.weight, self.bias, 1e-5) |
| |
|
| |
|
| |
|
| | class PatchResidualConvBlock(nn.Module): |
| |
|
| | def __init__(self, in_dim, out_dim, hidden_dim, kernel_size, stride, padding, dorpout=0.1) -> None: |
| | super().__init__() |
| | self.nonlinearity = nn.SiLU() |
| | self.ln1 = LayerNorm(in_dim, bias=True) |
| | self.dropout = nn.Dropout(dorpout) |
| | self.conv1 = nn.Conv2d(in_dim, hidden_dim, kernel_size=kernel_size, stride=stride, padding=padding) |
| | self.conv2 = nn.Conv2d(hidden_dim, out_dim, kernel_size=kernel_size, stride=stride, padding=padding) |
| | |
| | def forward(self, x): |
| | b, c, h, w = x.shape |
| | z = self.ln1(x.permute(0, 2, 3, 1).reshape(b * h * w, c)).reshape(b, h, w, c).permute(0, 3, 1, 2).contiguous() |
| | z = self.nonlinearity(self.conv1(z)) |
| | z = self.dropout(z) |
| | z = self.nonlinearity(self.conv2(z)) |
| | return z + x |
| |
|
| |
|
| |
|
| | class Upsample(nn.Module): |
| | def __init__(self, in_channels, out_channels): |
| | super().__init__() |
| | self.conv = torch.nn.Conv2d(in_channels, |
| | out_channels, |
| | kernel_size=3, |
| | stride=1, |
| | padding=1) |
| |
|
| | def forward(self, x): |
| | x = torch.nn.functional.interpolate(x, scale_factor=2.0, mode="nearest") |
| | x = self.conv(x) |
| | return x |
| |
|
| |
|
| |
|
| | class Downsample(nn.Module): |
| | def __init__(self, in_channels, out_channels): |
| | super().__init__() |
| | |
| | self.conv = torch.nn.Conv2d(in_channels, |
| | out_channels, |
| | kernel_size=3, |
| | stride=2, |
| | padding=0) |
| |
|
| | def forward(self, x): |
| | pad = (0,1,0,1) |
| | x = torch.nn.functional.pad(x, pad, mode="constant", value=0) |
| | x = self.conv(x) |
| | return x |
| |
|