| import torch |
| import torch.nn as nn |
| import torch.nn.functional as F |
|
|
|
|
| def make_group_norm( |
| channels: int, max_groups: int = 32, eps: float = 1e-6 |
| ) -> nn.GroupNorm: |
| groups = min(max_groups, channels) |
| while channels % groups != 0 and groups > 1: |
| groups -= 1 |
| return nn.GroupNorm(groups, channels, eps=eps) |
|
|
|
|
| class SinusoidalTimeEmbedding(nn.Module): |
| def __init__(self, dim: int = 128, max_period: int = 10000): |
| super().__init__() |
| self.dim = dim |
| self.max_period = max_period |
|
|
| def forward(self, timesteps: torch.Tensor) -> torch.Tensor: |
| half = self.dim // 2 |
|
|
| freqs = torch.exp( |
| -torch.log(torch.tensor(float(self.max_period), device=timesteps.device)) |
| * torch.arange(half, device=timesteps.device, dtype=timesteps.dtype) |
| / half |
| ) |
| args = timesteps[:, None] * freqs[None] |
| emb = torch.cat([torch.sin(args), torch.cos(args)], dim=-1) |
|
|
| if self.dim % 2 == 1: |
| emb = F.pad(emb, (0, 1)) |
|
|
| return emb |
|
|
|
|
| class ConditioningEncoder(nn.Module): |
| def __init__(self, time_dim: int = 128, cond_dim: int = 256): |
| super().__init__() |
| self.time_embed = SinusoidalTimeEmbedding(time_dim) |
|
|
| self.time_proj = nn.Sequential( |
| nn.Linear(time_dim, cond_dim), |
| nn.SiLU(), |
| nn.Linear(cond_dim, cond_dim), |
| ) |
|
|
| def forward(self, timestep: torch.Tensor) -> torch.Tensor: |
| time_vec = self.time_proj(self.time_embed(timestep)) |
| return time_vec |
|
|
|
|
| class ConditionedResidualBlock(nn.Module): |
| """ |
| SDXL-style residual block: |
| GN -> SiLU -> Conv |
| + condition (scale/shift) |
| GN -> SiLU -> Dropout -> Conv |
| + skip connection |
| """ |
|
|
| def __init__( |
| self, |
| input_channels: int, |
| output_channels: int, |
| cond_dim: int = 256, |
| dropout: float = 0.0, |
| ): |
| super().__init__() |
|
|
| self.norm1 = make_group_norm(input_channels) |
| self.conv1 = nn.Conv2d( |
| input_channels, output_channels, kernel_size=3, padding=1 |
| ) |
|
|
| self.cond_proj = nn.Sequential( |
| nn.SiLU(), |
| nn.Linear(cond_dim, 2 * output_channels), |
| ) |
|
|
| self.norm2 = make_group_norm(output_channels) |
| self.dropout = nn.Dropout(dropout) |
| self.conv2 = nn.Conv2d( |
| output_channels, output_channels, kernel_size=3, padding=1 |
| ) |
|
|
| if input_channels != output_channels: |
| self.skip = nn.Conv2d( |
| input_channels, output_channels, kernel_size=1, bias=False |
| ) |
| else: |
| self.skip = nn.Identity() |
|
|
| def forward(self, x: torch.Tensor, cond: torch.Tensor) -> torch.Tensor: |
| residual = self.skip(x) |
|
|
| h = self.norm1(x) |
| h = F.silu(h) |
| h = self.conv1(h) |
|
|
| scale_shift = self.cond_proj(cond) |
| scale, shift = scale_shift.chunk(2, dim=1) |
|
|
| h = self.norm2(h) |
| h = h * (1 + scale[:, :, None, None]) + shift[:, :, None, None] |
| h = F.silu(h) |
| h = self.dropout(h) |
| h = self.conv2(h) |
|
|
| return h + residual |
|
|
|
|
| class DownStage(nn.Module): |
| def __init__( |
| self, |
| input_channels: int, |
| output_channels: int, |
| cond_dim: int = 256, |
| dropout: float = 0.0, |
| num_blocks: int = 1, |
| downsample_first: bool = False, |
| ): |
| super().__init__() |
| self.downsample_first = downsample_first |
|
|
| self.blocks = nn.ModuleList() |
| for i in range(num_blocks): |
| in_ch = input_channels if i == 0 else output_channels |
| self.blocks.append( |
| ConditionedResidualBlock( |
| input_channels=in_ch, |
| output_channels=output_channels, |
| cond_dim=cond_dim, |
| dropout=dropout, |
| ) |
| ) |
|
|
| self.downsample = nn.Conv2d( |
| output_channels, output_channels, kernel_size=3, stride=2, padding=1 |
| ) |
|
|
| def forward(self, x: torch.Tensor, cond: torch.Tensor): |
|
|
| if self.downsample_first: |
| x = self.downsample(x) |
|
|
| for block in self.blocks: |
| x = block(x, cond) |
| skip = x |
|
|
| if not self.downsample_first: |
| x = self.downsample(x) |
|
|
| return x, skip |
|
|
|
|
| class UpStage(nn.Module): |
| def __init__( |
| self, |
| input_channels: int, |
| skip_channels: int, |
| output_channels: int, |
| cond_dim: int = 256, |
| dropout: float = 0.0, |
| num_blocks: int = 1, |
| ): |
| super().__init__() |
|
|
| self.upsample = nn.Upsample( |
| scale_factor=2, mode="bilinear", align_corners=False |
| ) |
|
|
| self.blocks = nn.ModuleList() |
| for i in range(num_blocks): |
| in_ch = (input_channels + skip_channels) if i == 0 else output_channels |
| self.blocks.append( |
| ConditionedResidualBlock( |
| input_channels=in_ch, |
| output_channels=output_channels, |
| cond_dim=cond_dim, |
| dropout=dropout, |
| ) |
| ) |
|
|
| def forward( |
| self, x: torch.Tensor, skip: torch.Tensor, cond: torch.Tensor |
| ) -> torch.Tensor: |
| x = self.upsample(x) |
|
|
| if x.shape[-2:] != skip.shape[-2:]: |
| x = F.interpolate( |
| x, size=skip.shape[-2:], mode="bilinear", align_corners=False |
| ) |
|
|
| x = torch.cat([x, skip], dim=1) |
|
|
| for block in self.blocks: |
| x = block(x, cond) |
|
|
| return x |
|
|
|
|
| class LowResEncoder(nn.Module): |
| def __init__( |
| self, |
| sample_channels: int = 32, |
| base_channels: int = 128, |
| cond_dim: int = 1024, |
| dropout: float = 0.0, |
| ): |
| super().__init__() |
|
|
| self.in_conv = nn.Conv2d( |
| sample_channels, base_channels, kernel_size=1, padding=0 |
| ) |
|
|
| self.block_1 = ConditionedResidualBlock( |
| input_channels=base_channels, |
| output_channels=base_channels, |
| cond_dim=cond_dim, |
| dropout=dropout, |
| ) |
|
|
| self.block_2 = DownStage( |
| input_channels=base_channels, |
| output_channels=base_channels, |
| cond_dim=cond_dim, |
| dropout=dropout, |
| num_blocks=1, |
| downsample_first=True, |
| ) |
|
|
| self.block_3 = DownStage( |
| input_channels=base_channels, |
| output_channels=base_channels, |
| cond_dim=cond_dim, |
| dropout=dropout, |
| num_blocks=1, |
| downsample_first=True, |
| ) |
|
|
| def forward(self, latents_small, cond): |
| x = self.in_conv(latents_small) |
| block_1_out = self.block_1(x, cond) |
| block_2_out, _ = self.block_2(block_1_out, cond) |
| block_3_out, _ = self.block_3(block_2_out, cond) |
|
|
| return block_1_out, block_2_out, block_3_out |
|
|
|
|
| class FilmCond2D(nn.Module): |
| def __init__(self, base_channels: int = 256, cond_channels: int = 256): |
| super().__init__() |
|
|
| self.cond_proj = nn.Sequential( |
| nn.SiLU(), |
| nn.Conv2d(cond_channels, base_channels * 2, kernel_size=1), |
| ) |
|
|
| def forward(self, x, cond): |
| scale_shift = self.cond_proj(cond) |
| scale, shift = scale_shift.chunk(2, dim=1) |
|
|
| x = x * (1 + scale) + shift |
|
|
| return x |
|
|
|
|
| class UpscalerUNet(nn.Module): |
| def __init__( |
| self, |
| sample_channels: int = 32, |
| base_channels: int = 384, |
| time_dim: int = 512, |
| cond_dim: int = 1024, |
| dropout: float = 0.01, |
| ): |
| super().__init__() |
|
|
| self.conditioning = ConditioningEncoder( |
| time_dim=time_dim, |
| cond_dim=cond_dim, |
| ) |
|
|
| self.in_conv = nn.Conv2d( |
| sample_channels, base_channels, kernel_size=1, padding=0 |
| ) |
|
|
| self.low_res_encoder = LowResEncoder(base_channels=base_channels) |
|
|
| self.film_cond_1 = FilmCond2D( |
| base_channels=base_channels, cond_channels=base_channels |
| ) |
| self.film_cond_2 = FilmCond2D( |
| base_channels=base_channels, cond_channels=base_channels |
| ) |
| self.film_cond_3 = FilmCond2D( |
| base_channels=base_channels, cond_channels=base_channels |
| ) |
|
|
| self.down_stages = nn.ModuleList( |
| [ |
| DownStage( |
| input_channels=base_channels, |
| output_channels=base_channels, |
| cond_dim=cond_dim, |
| dropout=dropout, |
| num_blocks=3, |
| ), |
| DownStage( |
| input_channels=base_channels, |
| output_channels=base_channels, |
| cond_dim=cond_dim, |
| dropout=dropout, |
| num_blocks=2, |
| ), |
| ] |
| ) |
|
|
| self.mid_stages = nn.ModuleList( |
| [ |
| ConditionedResidualBlock( |
| input_channels=base_channels, |
| output_channels=base_channels, |
| cond_dim=cond_dim, |
| dropout=dropout, |
| ) |
| for i in range(1) |
| ] |
| ) |
|
|
| self.up_stages = nn.ModuleList( |
| [ |
| UpStage( |
| input_channels=base_channels, |
| skip_channels=base_channels, |
| output_channels=base_channels, |
| cond_dim=cond_dim, |
| dropout=dropout, |
| num_blocks=2, |
| ), |
| UpStage( |
| input_channels=base_channels, |
| skip_channels=base_channels, |
| output_channels=base_channels, |
| cond_dim=cond_dim, |
| dropout=dropout, |
| num_blocks=3, |
| ), |
| ] |
| ) |
|
|
| self.out_conv = nn.Conv2d( |
| base_channels, sample_channels, kernel_size=1, padding=0 |
| ) |
|
|
| def forward( |
| self, sample: torch.Tensor, timestep: torch.Tensor, latents_small: torch.Tensor |
| ) -> torch.Tensor: |
| cond = self.conditioning(timestep) |
|
|
| B, C, H, W = sample.shape |
|
|
| lr_cond_1, lr_cond_2, lr_cond_3 = self.low_res_encoder(latents_small, cond) |
|
|
| lr_cond_1 = torch.nn.functional.interpolate(lr_cond_1, (H, W), mode="bilinear") |
| lr_cond_2 = torch.nn.functional.interpolate( |
| lr_cond_2, (H // 2, W // 2), mode="bilinear" |
| ) |
| lr_cond_3 = torch.nn.functional.interpolate( |
| lr_cond_3, (H // 4, W // 4), mode="bilinear" |
| ) |
|
|
| x = self.in_conv(sample) |
| x = self.film_cond_1(x, lr_cond_1) |
|
|
| skips = [] |
|
|
| x, skip = self.down_stages[0](x, cond) |
| skips.append(skip) |
|
|
| x = self.film_cond_2(x, lr_cond_2) |
|
|
| x, skip = self.down_stages[1](x, cond) |
| skips.append(skip) |
|
|
| x = self.film_cond_3(x, lr_cond_3) |
|
|
| for mid in self.mid_stages: |
| x = mid(x, cond) |
|
|
| for up in self.up_stages: |
| x = up(x, skips.pop(), cond) |
|
|
| x = self.out_conv(x) |
| return x |
|
|