import torch import torch.nn as nn import torch.nn.functional as F import numpy as np import math class WaveletTransform(nn.Module): def __init__(self, patch_size: int, inverse: bool = False): ''' `patchwise` in forward/invert makes *no difference*; the result is numerically identical either way. It's still enabled by default in case we pass in a non-square image, which may not be equivalent. `reshape` is pretty much useless. TODO: Clean up these options. ''' super().__init__() self.patch_size = patch_size self.inverse = inverse # From https://github.com/NVIDIA/Cosmos-Tokenizer/blob/3584ae752ce8ebdbe06a420bf60d7513c0e878cc/cosmos_tokenizer/modules/patching.py#L33 self.haar = torch.tensor([0.7071067811865476, 0.7071067811865476]) self.arange = torch.arange(len(self.haar)) self.steps = int(math.log2(self.patch_size)) def num_transformed_channels(self, in_channels: int = 3) -> int: ''' Returns the number of channels to expect in the transformed image given the channels in the input image. ''' return in_channels * (4 ** self.steps) def forward(self, x: torch.Tensor, patchwise: bool = True, reshape: bool = False) -> torch.Tensor: if self.inverse: return self.invert(x, patchwise=patchwise, from_reshaped=reshape) else: return self.transform(x, patchwise=patchwise, reshape=reshape) def transform(self, x: torch.Tensor, patchwise: bool = True, reshape: bool = False) -> torch.Tensor: ''' ### Parameters: `x`: ImageNet-normalized images with shape (B C H W) `patchwise`: Whether to compute independently on patches `reshape`: Reshape the results to match the input HxW ### Returns: If `reshape`, returns (B C H W) otherwise, returns (B C*patch_size**2 H/patch_size W/patch_size) ''' p = self.patch_size if patchwise: # Place patches into batch dimension # (B C H W) -> (B*L C H/root(L), W/root(L)) b, c, h, w = x.shape init_b = b # (B C H W) -> (B C LH LW P P) x = x.reshape(b, c, h//p, p, w//p, p).moveaxis(4,3) # (B C LH LW P P) -> (B' C P P) x = x.moveaxis(1,3).reshape(-1, c, p, p) for _ in range(self.steps): x = self.dwt(x) if patchwise: # Extract patches from batch dimension # (B' C' 1 1) -> (B LH LW C') -> (B C' LH LW) x = x.reshape(init_b, h//p, w//p, -1).moveaxis(3,1) if reshape: # (B C*patch_size**2 H/patch_size W/patch_size) -> (B C H W) b, cp2, hdp, wdp = x.shape c, h, w = cp2//(p**2), hdp*p, wdp*p x = x.reshape(b, p, p, c, hdp, wdp) x = x.moveaxis(3,1).moveaxis(3,4).reshape(b, c, h, w).contiguous() return x def invert(self, x: torch.Tensor, patchwise: bool = True, from_reshaped: bool = False) -> torch.Tensor: ''' ### Parameters: `x`: Wavelet-space input of either (B C H W) (when `from_reshaped=True`) or (B C*patch_size**2 H/patch_size W/patch_size) `patchwise`: Whether to compute independently on patches `from_reshaped`: Determines the shape of `x`; should match the value of `reshape` used when calling `forward` ''' p = self.patch_size if from_reshaped: # (B C H W) -> (B C*patch_size**2 H/patch_size W/patch_size) b, c, h, w = x.shape cp2, hdp, wdp = c*self.patch_size**2, h//self.patch_size, w//self.patch_size x = x.reshape(b, c, self.patch_size, hdp, self.patch_size, wdp) x = x.moveaxis(4,3).moveaxis(1,3).reshape(b, cp2, hdp, wdp) if patchwise: # Put patches into batch dimension # (B C' LH LW) -> (B LH LW C') -> (B' C' 1 1) init_b, lh, lw = x.shape[0], x.shape[2], x.shape[3] x = x.moveaxis(1,3).reshape(-1, x.shape[1], 1, 1) for _ in range(self.steps): x = self.idwt(x) if patchwise: # Extract patches from batch dimension and expand # (B' C P P) -> (B C LH LW P P) x = x.reshape(init_b, lh, lw, *x.shape[1:]).moveaxis(3,1) # (B C LH LW P P) -> (B C H W) x = x.moveaxis(3,4).reshape(*x.shape[:2], lh*p, lw*p) return x def dwt(self, x: torch.Tensor): dtype = x.dtype h = self.haar n = h.shape[0] g = x.shape[1] hl = h.flip(0).reshape(1, 1, -1).repeat(g, 1, 1) hh = (h * ((-1) ** self.arange)).reshape(1, 1, -1).repeat(g, 1, 1) hh = hh.to(device=x.device, dtype=dtype) hl = hl.to(device=x.device, dtype=dtype) x = F.pad(x, pad=(n - 2, n - 1, n - 2, n - 1), mode='reflect').to(dtype) xl = F.conv2d(x, hl.unsqueeze(2), groups=g, stride=(1, 2)) xh = F.conv2d(x, hh.unsqueeze(2), groups=g, stride=(1, 2)) xll = F.conv2d(xl, hl.unsqueeze(3), groups=g, stride=(2, 1)) xlh = F.conv2d(xl, hh.unsqueeze(3), groups=g, stride=(2, 1)) xhl = F.conv2d(xh, hl.unsqueeze(3), groups=g, stride=(2, 1)) xhh = F.conv2d(xh, hh.unsqueeze(3), groups=g, stride=(2, 1)) return 0.5 * torch.cat([xll, xlh, xhl, xhh], dim=1) def idwt(self, x: torch.Tensor): dtype = x.dtype h = self.haar n = h.shape[0] g = x.shape[1] // 4 hl = h.flip([0]).reshape(1, 1, -1).repeat([g, 1, 1]) hh = (h * ((-1) ** self.arange)).reshape(1, 1, -1).repeat(g, 1, 1) hh = hh.to(device=x.device, dtype=dtype) hl = hl.to(device=x.device, dtype=dtype) xll, xlh, xhl, xhh = torch.chunk(x.to(dtype), 4, dim=1) # Inverse transform. yl = torch.nn.functional.conv_transpose2d( xll, hl.unsqueeze(3), groups=g, stride=(2, 1), padding=(n - 2, 0) ) yl += torch.nn.functional.conv_transpose2d( xlh, hh.unsqueeze(3), groups=g, stride=(2, 1), padding=(n - 2, 0) ) yh = torch.nn.functional.conv_transpose2d( xhl, hl.unsqueeze(3), groups=g, stride=(2, 1), padding=(n - 2, 0) ) yh += torch.nn.functional.conv_transpose2d( xhh, hh.unsqueeze(3), groups=g, stride=(2, 1), padding=(n - 2, 0) ) y = torch.nn.functional.conv_transpose2d( yl, hl.unsqueeze(2), groups=g, stride=(1, 2), padding=(0, n - 2) ) y += torch.nn.functional.conv_transpose2d( yh, hh.unsqueeze(2), groups=g, stride=(1, 2), padding=(0, n - 2) ) return 2.0 * y