FlowUpscaler / upscaler_unet.py
TensorForger's picture
uploaded weights
bfc01ab
Raw
History Blame Contribute Delete
11.2 kB
import torch
import torch.nn as nn
import torch.nn.functional as F
def make_group_norm(
channels: int, max_groups: int = 32, eps: float = 1e-6
) -> nn.GroupNorm:
groups = min(max_groups, channels)
while channels % groups != 0 and groups > 1:
groups -= 1
return nn.GroupNorm(groups, channels, eps=eps)
class SinusoidalTimeEmbedding(nn.Module):
def __init__(self, dim: int = 128, max_period: int = 10000):
super().__init__()
self.dim = dim
self.max_period = max_period
def forward(self, timesteps: torch.Tensor) -> torch.Tensor:
half = self.dim // 2
freqs = torch.exp(
-torch.log(torch.tensor(float(self.max_period), device=timesteps.device))
* torch.arange(half, device=timesteps.device, dtype=timesteps.dtype)
/ half
)
args = timesteps[:, None] * freqs[None]
emb = torch.cat([torch.sin(args), torch.cos(args)], dim=-1)
if self.dim % 2 == 1:
emb = F.pad(emb, (0, 1))
return emb
class ConditioningEncoder(nn.Module):
def __init__(self, time_dim: int = 128, cond_dim: int = 256):
super().__init__()
self.time_embed = SinusoidalTimeEmbedding(time_dim)
self.time_proj = nn.Sequential(
nn.Linear(time_dim, cond_dim),
nn.SiLU(),
nn.Linear(cond_dim, cond_dim),
)
def forward(self, timestep: torch.Tensor) -> torch.Tensor:
time_vec = self.time_proj(self.time_embed(timestep))
return time_vec
class ConditionedResidualBlock(nn.Module):
"""
SDXL-style residual block:
GN -> SiLU -> Conv
+ condition (scale/shift)
GN -> SiLU -> Dropout -> Conv
+ skip connection
"""
def __init__(
self,
input_channels: int,
output_channels: int,
cond_dim: int = 256,
dropout: float = 0.0,
):
super().__init__()
self.norm1 = make_group_norm(input_channels)
self.conv1 = nn.Conv2d(
input_channels, output_channels, kernel_size=3, padding=1
)
self.cond_proj = nn.Sequential(
nn.SiLU(),
nn.Linear(cond_dim, 2 * output_channels),
)
self.norm2 = make_group_norm(output_channels)
self.dropout = nn.Dropout(dropout)
self.conv2 = nn.Conv2d(
output_channels, output_channels, kernel_size=3, padding=1
)
if input_channels != output_channels:
self.skip = nn.Conv2d(
input_channels, output_channels, kernel_size=1, bias=False
)
else:
self.skip = nn.Identity()
def forward(self, x: torch.Tensor, cond: torch.Tensor) -> torch.Tensor:
residual = self.skip(x)
h = self.norm1(x)
h = F.silu(h)
h = self.conv1(h)
scale_shift = self.cond_proj(cond)
scale, shift = scale_shift.chunk(2, dim=1)
h = self.norm2(h)
h = h * (1 + scale[:, :, None, None]) + shift[:, :, None, None]
h = F.silu(h)
h = self.dropout(h)
h = self.conv2(h)
return h + residual
class DownStage(nn.Module):
def __init__(
self,
input_channels: int,
output_channels: int,
cond_dim: int = 256,
dropout: float = 0.0,
num_blocks: int = 1,
downsample_first: bool = False,
):
super().__init__()
self.downsample_first = downsample_first
self.blocks = nn.ModuleList()
for i in range(num_blocks):
in_ch = input_channels if i == 0 else output_channels
self.blocks.append(
ConditionedResidualBlock(
input_channels=in_ch,
output_channels=output_channels,
cond_dim=cond_dim,
dropout=dropout,
)
)
self.downsample = nn.Conv2d(
output_channels, output_channels, kernel_size=3, stride=2, padding=1
)
def forward(self, x: torch.Tensor, cond: torch.Tensor):
if self.downsample_first:
x = self.downsample(x)
for block in self.blocks:
x = block(x, cond)
skip = x
if not self.downsample_first:
x = self.downsample(x)
return x, skip
class UpStage(nn.Module):
def __init__(
self,
input_channels: int,
skip_channels: int,
output_channels: int,
cond_dim: int = 256,
dropout: float = 0.0,
num_blocks: int = 1,
):
super().__init__()
self.upsample = nn.Upsample(
scale_factor=2, mode="bilinear", align_corners=False
)
self.blocks = nn.ModuleList()
for i in range(num_blocks):
in_ch = (input_channels + skip_channels) if i == 0 else output_channels
self.blocks.append(
ConditionedResidualBlock(
input_channels=in_ch,
output_channels=output_channels,
cond_dim=cond_dim,
dropout=dropout,
)
)
def forward(
self, x: torch.Tensor, skip: torch.Tensor, cond: torch.Tensor
) -> torch.Tensor:
x = self.upsample(x)
if x.shape[-2:] != skip.shape[-2:]:
x = F.interpolate(
x, size=skip.shape[-2:], mode="bilinear", align_corners=False
)
x = torch.cat([x, skip], dim=1)
for block in self.blocks:
x = block(x, cond)
return x
class LowResEncoder(nn.Module):
def __init__(
self,
sample_channels: int = 32,
base_channels: int = 128,
cond_dim: int = 1024,
dropout: float = 0.0,
):
super().__init__()
self.in_conv = nn.Conv2d(
sample_channels, base_channels, kernel_size=1, padding=0
)
self.block_1 = ConditionedResidualBlock(
input_channels=base_channels,
output_channels=base_channels,
cond_dim=cond_dim,
dropout=dropout,
)
self.block_2 = DownStage(
input_channels=base_channels,
output_channels=base_channels,
cond_dim=cond_dim,
dropout=dropout,
num_blocks=1,
downsample_first=True,
)
self.block_3 = DownStage(
input_channels=base_channels,
output_channels=base_channels,
cond_dim=cond_dim,
dropout=dropout,
num_blocks=1,
downsample_first=True,
)
def forward(self, latents_small, cond):
x = self.in_conv(latents_small)
block_1_out = self.block_1(x, cond)
block_2_out, _ = self.block_2(block_1_out, cond)
block_3_out, _ = self.block_3(block_2_out, cond)
return block_1_out, block_2_out, block_3_out
class FilmCond2D(nn.Module):
def __init__(self, base_channels: int = 256, cond_channels: int = 256):
super().__init__()
self.cond_proj = nn.Sequential(
nn.SiLU(),
nn.Conv2d(cond_channels, base_channels * 2, kernel_size=1),
)
def forward(self, x, cond):
scale_shift = self.cond_proj(cond)
scale, shift = scale_shift.chunk(2, dim=1)
x = x * (1 + scale) + shift
return x
class UpscalerUNet(nn.Module):
def __init__(
self,
sample_channels: int = 32,
base_channels: int = 384,
time_dim: int = 512,
cond_dim: int = 1024,
dropout: float = 0.01,
):
super().__init__()
self.conditioning = ConditioningEncoder(
time_dim=time_dim,
cond_dim=cond_dim,
)
self.in_conv = nn.Conv2d(
sample_channels, base_channels, kernel_size=1, padding=0
)
self.low_res_encoder = LowResEncoder(base_channels=base_channels)
self.film_cond_1 = FilmCond2D(
base_channels=base_channels, cond_channels=base_channels
)
self.film_cond_2 = FilmCond2D(
base_channels=base_channels, cond_channels=base_channels
)
self.film_cond_3 = FilmCond2D(
base_channels=base_channels, cond_channels=base_channels
)
self.down_stages = nn.ModuleList(
[
DownStage(
input_channels=base_channels,
output_channels=base_channels,
cond_dim=cond_dim,
dropout=dropout,
num_blocks=3,
),
DownStage(
input_channels=base_channels,
output_channels=base_channels,
cond_dim=cond_dim,
dropout=dropout,
num_blocks=2,
),
]
)
self.mid_stages = nn.ModuleList(
[
ConditionedResidualBlock(
input_channels=base_channels,
output_channels=base_channels,
cond_dim=cond_dim,
dropout=dropout,
)
for i in range(1)
]
)
self.up_stages = nn.ModuleList(
[
UpStage(
input_channels=base_channels,
skip_channels=base_channels,
output_channels=base_channels,
cond_dim=cond_dim,
dropout=dropout,
num_blocks=2,
),
UpStage(
input_channels=base_channels,
skip_channels=base_channels,
output_channels=base_channels,
cond_dim=cond_dim,
dropout=dropout,
num_blocks=3,
),
]
)
self.out_conv = nn.Conv2d(
base_channels, sample_channels, kernel_size=1, padding=0
)
def forward(
self, sample: torch.Tensor, timestep: torch.Tensor, latents_small: torch.Tensor
) -> torch.Tensor:
cond = self.conditioning(timestep)
B, C, H, W = sample.shape
lr_cond_1, lr_cond_2, lr_cond_3 = self.low_res_encoder(latents_small, cond)
lr_cond_1 = torch.nn.functional.interpolate(lr_cond_1, (H, W), mode="bilinear")
lr_cond_2 = torch.nn.functional.interpolate(
lr_cond_2, (H // 2, W // 2), mode="bilinear"
)
lr_cond_3 = torch.nn.functional.interpolate(
lr_cond_3, (H // 4, W // 4), mode="bilinear"
)
x = self.in_conv(sample)
x = self.film_cond_1(x, lr_cond_1)
skips = []
x, skip = self.down_stages[0](x, cond)
skips.append(skip)
x = self.film_cond_2(x, lr_cond_2)
x, skip = self.down_stages[1](x, cond)
skips.append(skip)
x = self.film_cond_3(x, lr_cond_3)
for mid in self.mid_stages:
x = mid(x, cond)
for up in self.up_stages:
x = up(x, skips.pop(), cond)
x = self.out_conv(x)
return x