dn6's picture
dn6 HF Staff
Add diffusers support
57eef5f verified
raw
history blame
8.22 kB
# Copyright (C) 2025 Hugging Face Team and Overworld
#
# This program is free software: you can redistribute it and/or modify
# it under the terms of the GNU General Public License as published by
# the Free Software Foundation, either version 3 of the License, or
# (at your option) any later version.
#
# This program is distributed in the hope that it will be useful,
# but WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
# GNU General Public License for more details.
#
# You should have received a copy of the GNU General Public License
# along with this program. If not, see <https://www.gnu.org/licenses/>.
import torch
from torch import nn
import torch.nn.functional as F
from torch.nn.utils.parametrizations import weight_norm
from torch.nn.utils.parametrize import remove_parametrizations
def bake_weight_norm(model: nn.Module) -> nn.Module:
"""Remove weight_norm parametrizations, baking normalized weights into regular tensors.
This is required for torch.compile/CUDA graph compatibility since weight_norm
performs in-place updates during forward passes.
"""
for module in model.modules():
if hasattr(module, "parametrizations") and "weight" in getattr(module, "parametrizations", {}):
remove_parametrizations(module, "weight", leave_parametrized=True)
return model
# === General Blocks ===
def WeightNormConv2d(*args, **kwargs):
return weight_norm(nn.Conv2d(*args, **kwargs))
class ResBlock(nn.Module):
def __init__(self, ch):
super().__init__()
hidden = 2 * ch
# 16 channels per group (matches checkpoint shapes like [128,16,3,3] when ch=64)
n_grps = max(1, hidden // 16)
self.conv1 = WeightNormConv2d(ch, hidden, 1, 1, 0)
self.conv2 = WeightNormConv2d(hidden, hidden, 3, 1, 1, groups=n_grps)
self.conv3 = WeightNormConv2d(hidden, ch, 1, 1, 0, bias=False)
self.act1 = nn.LeakyReLU(inplace=False)
self.act2 = nn.LeakyReLU(inplace=False)
def forward(self, x):
h = self.conv1(x)
h = self.act1(h)
h = self.conv2(h)
h = self.act2(h)
h = self.conv3(h)
return x + h
# === Encoder ===
class LandscapeToSquare(nn.Module):
# Strict assumption of 360p
def __init__(self, ch_in, ch_out):
super().__init__()
self.proj = WeightNormConv2d(ch_in, ch_out, 3, 1, 1)
def forward(self, x):
x = F.interpolate(x, (512, 512), mode='bicubic')
x = self.proj(x)
return x
class Downsample(nn.Module):
def __init__(self, ch_in, ch_out):
super().__init__()
self.proj = WeightNormConv2d(ch_in, ch_out, 1, 1, 0, bias=False)
def forward(self, x):
x = F.interpolate(x, scale_factor=0.5, mode='bicubic')
x = self.proj(x)
return x
class DownBlock(nn.Module):
def __init__(self, ch_in, ch_out, num_res=1):
super().__init__()
self.down = Downsample(ch_in, ch_out)
blocks = []
for _ in range(num_res):
blocks.append(ResBlock(ch_in))
self.blocks = nn.ModuleList(blocks)
def forward(self, x):
for block in self.blocks:
x = block(x)
x = self.down(x)
return x
class SpaceToChannel(nn.Module):
def __init__(self, ch_in, ch_out):
super().__init__()
self.proj = WeightNormConv2d(ch_in, ch_out // 4, 3, 1, 1)
def forward(self, x):
x = self.proj(x)
x = F.pixel_unshuffle(x, 2).contiguous()
return x
class ChannelAverage(nn.Module):
def __init__(self, ch_in, ch_out):
super().__init__()
self.proj = WeightNormConv2d(ch_in, ch_out, 3, 1, 1)
self.grps = ch_in // ch_out
self.scale = (self.grps) ** 0.5
def forward(self, x):
res = x
x = self.proj(x.contiguous()) # [b, ch_out, h, w]
# Residual goes through channel avg
res = res.view(res.shape[0], self.grps, res.shape[1] // self.grps, res.shape[2], res.shape[3]).contiguous()
res = res.mean(dim=1) * self.scale # [b, ch_out, h, w]
return res + x
# === Decoder ===
class SquareToLandscape(nn.Module):
def __init__(self, ch_in, ch_out):
super().__init__()
self.proj = WeightNormConv2d(ch_in, ch_out, 3, 1, 1)
def forward(self, x):
x = self.proj(x) # TODO This ordering is wrong for both
x = F.interpolate(x, (360, 640), mode='bicubic')
return x
class Upsample(nn.Module):
def __init__(self, ch_in, ch_out):
super().__init__()
self.proj = nn.Identity() if ch_in == ch_out else WeightNormConv2d(
ch_in, ch_out, 1, 1, 0, bias=False
)
def forward(self, x):
x = self.proj(x)
x = F.interpolate(x, scale_factor=2.0, mode='bicubic')
return x
class UpBlock(nn.Module):
def __init__(self, ch_in, ch_out, num_res=1):
super().__init__()
self.up = Upsample(ch_in, ch_out)
blocks = []
for _ in range(num_res):
blocks.append(ResBlock(ch_out))
self.blocks = nn.ModuleList(blocks)
def forward(self, x):
x = self.up(x)
for block in self.blocks:
x = block(x)
return x
class ChannelToSpace(nn.Module):
def __init__(self, ch_in, ch_out):
super().__init__()
self.proj = WeightNormConv2d(ch_in, ch_out * 4, 3, 1, 1)
def forward(self, x):
x = self.proj(x)
x = F.pixel_shuffle(x, 2).contiguous()
return x
class ChannelDuplication(nn.Module):
def __init__(self, ch_in, ch_out):
super().__init__()
self.proj = WeightNormConv2d(ch_in, ch_out, 3, 1, 1)
self.reps = ch_out // ch_in
self.scale = (self.reps) ** -0.5
def forward(self, x):
res = x
x = self.proj(x.contiguous())
b, c, h, w = res.shape
res = res.unsqueeze(2) # [b, c, 1, h, w]
res = res.expand(b, c, self.reps, h, w) # [b, c, reps, h, w]
res = res.reshape(b, c * self.reps, h, w).contiguous()
res = res * self.scale
return res + x
# === Main AE ===
class Encoder(nn.Module):
def __init__(self, config):
super().__init__()
self.conv_in = LandscapeToSquare(config.channels, config.ch_0)
blocks = []
residuals = []
ch = config.ch_0
for block_count in config.encoder_blocks_per_stage:
next_ch = min(ch*2, config.ch_max)
blocks.append(DownBlock(ch, next_ch, block_count))
residuals.append(SpaceToChannel(ch, next_ch))
ch = next_ch
self.blocks = nn.ModuleList(blocks)
self.residuals = nn.ModuleList(residuals)
self.conv_out = ChannelAverage(ch, config.latent_channels)
self.skip_logvar = bool(getattr(config, "skip_logvar", False))
if not self.skip_logvar:
# Checkpoint expects a 1-channel logvar head: [1, ch, 3, 3]
self.conv_out_logvar = WeightNormConv2d(ch, 1, 3, 1, 1)
def forward(self, x):
x = self.conv_in(x)
for block, residual in zip(self.blocks, self.residuals):
x = block(x) + residual(x)
return self.conv_out(x)
class Decoder(nn.Module):
def __init__(self, config):
super().__init__()
self.conv_in = ChannelDuplication(config.latent_channels, config.ch_max)
blocks = []
residuals = []
ch = config.ch_0
for block_count in reversed(config.decoder_blocks_per_stage):
next_ch = min(ch*2, config.ch_max)
blocks.append(UpBlock(next_ch, ch, block_count))
residuals.append(ChannelToSpace(next_ch, ch))
ch = next_ch
self.blocks = nn.ModuleList(reversed(blocks))
self.residuals = nn.ModuleList(reversed(residuals))
self.act_out = nn.SiLU()
self.conv_out = SquareToLandscape(config.ch_0, config.channels)
def forward(self, x):
x = self.conv_in(x)
for block, residual in zip(self.blocks, self.residuals):
x = block(x) + residual(x)
x = self.act_out(x)
return self.conv_out(x)