| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
|
|
| from collections.abc import Sequence |
| from dataclasses import dataclass |
| from enum import Enum |
|
|
| import torch |
| import torch.nn as nn |
| import torch.nn.functional as F |
| from einops import rearrange |
| from torch import Tensor |
| from diffusers.models import ModelMixin |
| from diffusers.configuration_utils import ConfigMixin, register_to_config |
|
|
| from ..vae.hunyuanvideo_15_vae import ( |
| CausalConv3d, |
| ResnetBlock, |
| RMS_norm, |
| forward_with_checkpointing, |
| swish, |
| ) |
|
|
|
|
| class UpsamplerType(Enum): |
| LEARNED = "learned" |
| FIXED = "fixed" |
| NONE = "none" |
| LEARNED_FIXED = "learned_fixed" |
|
|
|
|
| @dataclass |
| class UpsamplerConfig: |
| load_from: str |
| enable: bool = False |
| hidden_channels: int = 128 |
| num_blocks: int = 16 |
| model_type: UpsamplerType = UpsamplerType.NONE |
| version: str = "720p" |
|
|
|
|
| class SRResidualCausalBlock3D(nn.Module): |
| def __init__(self, channels: int): |
| super().__init__() |
| self.block = nn.Sequential( |
| CausalConv3d(channels, channels, kernel_size=3), |
| nn.SiLU(inplace=True), |
| CausalConv3d(channels, channels, kernel_size=3), |
| nn.SiLU(inplace=True), |
| CausalConv3d(channels, channels, kernel_size=3), |
| ) |
|
|
| def forward(self, x: torch.Tensor) -> torch.Tensor: |
| return x + self.block(x) |
|
|
|
|
| class SRTo720pUpsampler(ModelMixin, ConfigMixin): |
|
|
| @register_to_config |
| def __init__( |
| self, |
| in_channels: int, |
| out_channels: int, |
| hidden_channels: int | None = None, |
| num_blocks: int = 6, |
| global_residual: bool = False, |
| ): |
| super().__init__() |
| if hidden_channels is None: |
| hidden_channels = 64 |
| self.in_conv = CausalConv3d(in_channels, hidden_channels, kernel_size=3) |
| self.blocks = nn.ModuleList([SRResidualCausalBlock3D(hidden_channels) for _ in range(num_blocks)]) |
| self.out_conv = CausalConv3d(hidden_channels, out_channels, kernel_size=3) |
| self.global_residual = bool(global_residual) |
|
|
| def forward(self, x: torch.Tensor) -> torch.Tensor: |
| residual = x |
| y = self.in_conv(x) |
| for blk in self.blocks: |
| y = blk(y) |
| y = self.out_conv(y) |
| if self.global_residual and (y.shape == residual.shape): |
| y += residual |
| return y |
|
|
|
|
| class SRTo1080pUpsampler(ModelMixin, ConfigMixin): |
|
|
| @register_to_config |
| def __init__( |
| self, |
| z_channels: int, |
| out_channels: int, |
| block_out_channels: tuple[int, ...], |
| num_res_blocks: int = 2, |
| is_residual: bool = False, |
| ): |
| super().__init__() |
| self.num_res_blocks = num_res_blocks |
| self.block_out_channels = block_out_channels |
| self.z_channels = z_channels |
|
|
| block_in = block_out_channels[0] |
| self.conv_in = CausalConv3d(z_channels, block_in, kernel_size=3) |
|
|
| self.up = nn.ModuleList() |
| for i_level, ch in enumerate(block_out_channels): |
| block = nn.ModuleList() |
| block_out = ch |
| for _ in range(self.num_res_blocks + 1): |
| block.append(ResnetBlock(in_channels=block_in, out_channels=block_out)) |
| block_in = block_out |
| up = nn.Module() |
| up.block = block |
|
|
| self.up.append(up) |
|
|
| self.norm_out = RMS_norm(block_in, images=False) |
| self.conv_out = CausalConv3d(block_in, out_channels, kernel_size=3) |
|
|
| self.gradient_checkpointing = False |
| self.is_residual = is_residual |
|
|
| def forward(self, z: Tensor, target_shape: Sequence[int] = None) -> Tensor: |
| """ |
| Args: |
| z: (B, C, T, H, W) |
| target_shape: (H, W) |
| """ |
| use_checkpointing = bool(self.training and self.gradient_checkpointing) |
| if target_shape is not None and z.shape[-2:] != target_shape: |
| bsz = z.shape[0] |
| z = rearrange(z, "b c f h w -> (b f) c h w") |
| z = F.interpolate(z, size=target_shape, mode="bilinear", align_corners=False) |
| z = rearrange(z, "(b f) c h w -> b c f h w", b=bsz) |
|
|
| |
| repeats = self.block_out_channels[0] // (self.z_channels) |
| h = self.conv_in(z) + z.repeat_interleave(repeats=repeats, dim=1) |
|
|
| |
| for i_level in range(len(self.block_out_channels)): |
| for i_block in range(self.num_res_blocks + 1): |
| x_list= [h] |
| del h |
| h = self.up[i_level].block[i_block](x_list) |
| if hasattr(self.up[i_level], "upsample"): |
| x_list= [h] |
| del h |
| h = self.up[i_level].upsample(x_list) |
|
|
| |
| h = self.norm_out(h).to(z.dtype) |
| h = swish(h) |
| h = self.conv_out(h) |
| return h |
|
|