|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
import torch |
|
|
from torch import nn |
|
|
import torch.nn.functional as F |
|
|
|
|
|
from torch.nn.utils.parametrizations import weight_norm |
|
|
from torch.nn.utils.parametrize import remove_parametrizations |
|
|
|
|
|
|
|
|
def bake_weight_norm(model: nn.Module) -> nn.Module: |
|
|
"""Remove weight_norm parametrizations, baking normalized weights into regular tensors. |
|
|
|
|
|
This is required for torch.compile/CUDA graph compatibility since weight_norm |
|
|
performs in-place updates during forward passes. |
|
|
""" |
|
|
for module in model.modules(): |
|
|
if hasattr(module, "parametrizations") and "weight" in getattr(module, "parametrizations", {}): |
|
|
remove_parametrizations(module, "weight", leave_parametrized=True) |
|
|
return model |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def WeightNormConv2d(*args, **kwargs): |
|
|
return weight_norm(nn.Conv2d(*args, **kwargs)) |
|
|
|
|
|
class ResBlock(nn.Module): |
|
|
def __init__(self, ch): |
|
|
super().__init__() |
|
|
|
|
|
hidden = 2 * ch |
|
|
|
|
|
n_grps = max(1, hidden // 16) |
|
|
|
|
|
self.conv1 = WeightNormConv2d(ch, hidden, 1, 1, 0) |
|
|
self.conv2 = WeightNormConv2d(hidden, hidden, 3, 1, 1, groups=n_grps) |
|
|
self.conv3 = WeightNormConv2d(hidden, ch, 1, 1, 0, bias=False) |
|
|
|
|
|
self.act1 = nn.LeakyReLU(inplace=False) |
|
|
self.act2 = nn.LeakyReLU(inplace=False) |
|
|
|
|
|
def forward(self, x): |
|
|
h = self.conv1(x) |
|
|
h = self.act1(h) |
|
|
h = self.conv2(h) |
|
|
h = self.act2(h) |
|
|
h = self.conv3(h) |
|
|
return x + h |
|
|
|
|
|
|
|
|
|
|
|
class LandscapeToSquare(nn.Module): |
|
|
|
|
|
def __init__(self, ch_in, ch_out): |
|
|
super().__init__() |
|
|
|
|
|
self.proj = WeightNormConv2d(ch_in, ch_out, 3, 1, 1) |
|
|
|
|
|
def forward(self, x): |
|
|
x = F.interpolate(x, (512, 512), mode='bicubic') |
|
|
x = self.proj(x) |
|
|
return x |
|
|
|
|
|
class Downsample(nn.Module): |
|
|
def __init__(self, ch_in, ch_out): |
|
|
super().__init__() |
|
|
|
|
|
self.proj = WeightNormConv2d(ch_in, ch_out, 1, 1, 0, bias=False) |
|
|
|
|
|
def forward(self, x): |
|
|
x = F.interpolate(x, scale_factor=0.5, mode='bicubic') |
|
|
x = self.proj(x) |
|
|
return x |
|
|
|
|
|
class DownBlock(nn.Module): |
|
|
def __init__(self, ch_in, ch_out, num_res=1): |
|
|
super().__init__() |
|
|
|
|
|
self.down = Downsample(ch_in, ch_out) |
|
|
blocks = [] |
|
|
for _ in range(num_res): |
|
|
blocks.append(ResBlock(ch_in)) |
|
|
self.blocks = nn.ModuleList(blocks) |
|
|
|
|
|
def forward(self, x): |
|
|
for block in self.blocks: |
|
|
x = block(x) |
|
|
x = self.down(x) |
|
|
return x |
|
|
|
|
|
class SpaceToChannel(nn.Module): |
|
|
def __init__(self, ch_in, ch_out): |
|
|
super().__init__() |
|
|
|
|
|
self.proj = WeightNormConv2d(ch_in, ch_out // 4, 3, 1, 1) |
|
|
|
|
|
def forward(self, x): |
|
|
x = self.proj(x) |
|
|
x = F.pixel_unshuffle(x, 2).contiguous() |
|
|
return x |
|
|
|
|
|
class ChannelAverage(nn.Module): |
|
|
def __init__(self, ch_in, ch_out): |
|
|
super().__init__() |
|
|
|
|
|
self.proj = WeightNormConv2d(ch_in, ch_out, 3, 1, 1) |
|
|
self.grps = ch_in // ch_out |
|
|
self.scale = (self.grps) ** 0.5 |
|
|
|
|
|
def forward(self, x): |
|
|
res = x |
|
|
x = self.proj(x.contiguous()) |
|
|
|
|
|
|
|
|
res = res.view(res.shape[0], self.grps, res.shape[1] // self.grps, res.shape[2], res.shape[3]).contiguous() |
|
|
res = res.mean(dim=1) * self.scale |
|
|
|
|
|
return res + x |
|
|
|
|
|
|
|
|
|
|
|
class SquareToLandscape(nn.Module): |
|
|
def __init__(self, ch_in, ch_out): |
|
|
super().__init__() |
|
|
|
|
|
self.proj = WeightNormConv2d(ch_in, ch_out, 3, 1, 1) |
|
|
|
|
|
def forward(self, x): |
|
|
x = self.proj(x) |
|
|
x = F.interpolate(x, (360, 640), mode='bicubic') |
|
|
return x |
|
|
|
|
|
class Upsample(nn.Module): |
|
|
def __init__(self, ch_in, ch_out): |
|
|
super().__init__() |
|
|
|
|
|
self.proj = nn.Identity() if ch_in == ch_out else WeightNormConv2d( |
|
|
ch_in, ch_out, 1, 1, 0, bias=False |
|
|
) |
|
|
|
|
|
def forward(self, x): |
|
|
x = self.proj(x) |
|
|
x = F.interpolate(x, scale_factor=2.0, mode='bicubic') |
|
|
return x |
|
|
|
|
|
class UpBlock(nn.Module): |
|
|
def __init__(self, ch_in, ch_out, num_res=1): |
|
|
super().__init__() |
|
|
|
|
|
self.up = Upsample(ch_in, ch_out) |
|
|
blocks = [] |
|
|
for _ in range(num_res): |
|
|
blocks.append(ResBlock(ch_out)) |
|
|
self.blocks = nn.ModuleList(blocks) |
|
|
|
|
|
def forward(self, x): |
|
|
x = self.up(x) |
|
|
for block in self.blocks: |
|
|
x = block(x) |
|
|
return x |
|
|
|
|
|
class ChannelToSpace(nn.Module): |
|
|
def __init__(self, ch_in, ch_out): |
|
|
super().__init__() |
|
|
|
|
|
self.proj = WeightNormConv2d(ch_in, ch_out * 4, 3, 1, 1) |
|
|
|
|
|
def forward(self, x): |
|
|
x = self.proj(x) |
|
|
x = F.pixel_shuffle(x, 2).contiguous() |
|
|
return x |
|
|
|
|
|
class ChannelDuplication(nn.Module): |
|
|
def __init__(self, ch_in, ch_out): |
|
|
super().__init__() |
|
|
|
|
|
self.proj = WeightNormConv2d(ch_in, ch_out, 3, 1, 1) |
|
|
self.reps = ch_out // ch_in |
|
|
self.scale = (self.reps) ** -0.5 |
|
|
|
|
|
def forward(self, x): |
|
|
res = x |
|
|
x = self.proj(x.contiguous()) |
|
|
|
|
|
b, c, h, w = res.shape |
|
|
res = res.unsqueeze(2) |
|
|
res = res.expand(b, c, self.reps, h, w) |
|
|
res = res.reshape(b, c * self.reps, h, w).contiguous() |
|
|
res = res * self.scale |
|
|
|
|
|
return res + x |
|
|
|
|
|
|
|
|
|
|
|
class Encoder(nn.Module): |
|
|
def __init__(self, config): |
|
|
super().__init__() |
|
|
|
|
|
self.conv_in = LandscapeToSquare(config.channels, config.ch_0) |
|
|
|
|
|
blocks = [] |
|
|
residuals = [] |
|
|
|
|
|
ch = config.ch_0 |
|
|
for block_count in config.encoder_blocks_per_stage: |
|
|
next_ch = min(ch*2, config.ch_max) |
|
|
|
|
|
blocks.append(DownBlock(ch, next_ch, block_count)) |
|
|
residuals.append(SpaceToChannel(ch, next_ch)) |
|
|
|
|
|
ch = next_ch |
|
|
|
|
|
self.blocks = nn.ModuleList(blocks) |
|
|
self.residuals = nn.ModuleList(residuals) |
|
|
self.conv_out = ChannelAverage(ch, config.latent_channels) |
|
|
|
|
|
self.skip_logvar = bool(getattr(config, "skip_logvar", False)) |
|
|
if not self.skip_logvar: |
|
|
|
|
|
self.conv_out_logvar = WeightNormConv2d(ch, 1, 3, 1, 1) |
|
|
|
|
|
def forward(self, x): |
|
|
x = self.conv_in(x) |
|
|
for block, residual in zip(self.blocks, self.residuals): |
|
|
x = block(x) + residual(x) |
|
|
return self.conv_out(x) |
|
|
|
|
|
class Decoder(nn.Module): |
|
|
def __init__(self, config): |
|
|
super().__init__() |
|
|
|
|
|
self.conv_in = ChannelDuplication(config.latent_channels, config.ch_max) |
|
|
|
|
|
blocks = [] |
|
|
residuals = [] |
|
|
|
|
|
ch = config.ch_0 |
|
|
for block_count in reversed(config.decoder_blocks_per_stage): |
|
|
next_ch = min(ch*2, config.ch_max) |
|
|
|
|
|
blocks.append(UpBlock(next_ch, ch, block_count)) |
|
|
residuals.append(ChannelToSpace(next_ch, ch)) |
|
|
|
|
|
ch = next_ch |
|
|
|
|
|
self.blocks = nn.ModuleList(reversed(blocks)) |
|
|
self.residuals = nn.ModuleList(reversed(residuals)) |
|
|
|
|
|
self.act_out = nn.SiLU() |
|
|
self.conv_out = SquareToLandscape(config.ch_0, config.channels) |
|
|
|
|
|
def forward(self, x): |
|
|
x = self.conv_in(x) |
|
|
for block, residual in zip(self.blocks, self.residuals): |
|
|
x = block(x) + residual(x) |
|
|
x = self.act_out(x) |
|
|
return self.conv_out(x) |
|
|
|