| from typing import * |
| import torch |
| import torch.nn as nn |
| import torch.nn.functional as F |
| from ..modules.norm import GroupNorm32, ChannelLayerNorm32 |
| from ..modules.spatial import pixel_shuffle_3d |
| from ..modules.utils import zero_module, convert_module_to_f16, convert_module_to_f32 |
| from ..modules.transformer import FeedForwardNet, TransformerBlock, TransformerCrossBlock, AbsolutePositionEmbedder |
|
|
|
|
| def norm_layer(norm_type: str, *args, **kwargs) -> nn.Module: |
| """ |
| Return a normalization layer. |
| """ |
| if norm_type == "group": |
| return GroupNorm32(32, *args, **kwargs) |
| elif norm_type == "layer": |
| return ChannelLayerNorm32(*args, **kwargs) |
| else: |
| raise ValueError(f"Invalid norm type {norm_type}") |
|
|
|
|
| class ResBlock3d(nn.Module): |
| def __init__( |
| self, |
| channels: int, |
| out_channels: Optional[int] = None, |
| norm_type: Literal["group", "layer"] = "layer", |
| ): |
| super().__init__() |
| self.channels = channels |
| self.out_channels = out_channels or channels |
|
|
| self.norm1 = norm_layer(norm_type, channels) |
| self.norm2 = norm_layer(norm_type, self.out_channels) |
| self.conv1 = nn.Conv3d(channels, self.out_channels, 3, padding=1) |
| self.conv2 = zero_module(nn.Conv3d(self.out_channels, self.out_channels, 3, padding=1)) |
| self.skip_connection = nn.Conv3d(channels, self.out_channels, 1) if channels != self.out_channels else nn.Identity() |
| |
| def forward(self, x: torch.Tensor) -> torch.Tensor: |
| h = self.norm1(x) |
| h = F.silu(h) |
| h = self.conv1(h) |
| h = self.norm2(h) |
| h = F.silu(h) |
| h = self.conv2(h) |
| h = h + self.skip_connection(x) |
| return h |
|
|
|
|
| class DownsampleBlock3d(nn.Module): |
| def __init__( |
| self, |
| in_channels: int, |
| out_channels: int, |
| mode: Literal["conv", "avgpool"] = "conv", |
| ): |
| assert mode in ["conv", "avgpool"], f"Invalid mode {mode}" |
|
|
| super().__init__() |
| self.in_channels = in_channels |
| self.out_channels = out_channels |
|
|
| if mode == "conv": |
| self.conv = nn.Conv3d(in_channels, out_channels, 2, stride=2) |
| elif mode == "avgpool": |
| assert in_channels == out_channels, "Pooling mode requires in_channels to be equal to out_channels" |
|
|
| def forward(self, x: torch.Tensor) -> torch.Tensor: |
| if hasattr(self, "conv"): |
| return self.conv(x) |
| else: |
| return F.avg_pool3d(x, 2) |
|
|
|
|
| class UpsampleBlock3d(nn.Module): |
| def __init__( |
| self, |
| in_channels: int, |
| out_channels: int, |
| mode: Literal["conv", "nearest"] = "conv", |
| ): |
| assert mode in ["conv", "nearest"], f"Invalid mode {mode}" |
|
|
| super().__init__() |
| self.in_channels = in_channels |
| self.out_channels = out_channels |
|
|
| if mode == "conv": |
| self.conv = nn.Conv3d(in_channels, out_channels*8, 3, padding=1) |
| elif mode == "nearest": |
| assert in_channels == out_channels, "Nearest mode requires in_channels to be equal to out_channels" |
|
|
| def forward(self, x: torch.Tensor) -> torch.Tensor: |
| if hasattr(self, "conv"): |
| x = self.conv(x) |
| return pixel_shuffle_3d(x, 2) |
| else: |
| return F.interpolate(x, scale_factor=2, mode="nearest") |
| |
|
|
| class AniGenSparseStructureEncoder(nn.Module): |
| """ |
| Encoder for Sparse Structure (\mathcal{E}_S in the paper Sec. 3.3). |
| |
| Args: |
| in_channels (int): Channels of the input. |
| latent_channels (int): Channels of the latent representation. |
| num_res_blocks (int): Number of residual blocks at each resolution. |
| channels (List[int]): Channels of the encoder blocks. |
| num_res_blocks_middle (int): Number of residual blocks in the middle. |
| norm_type (Literal["group", "layer"]): Type of normalization layer. |
| use_fp16 (bool): Whether to use FP16. |
| """ |
| def __init__( |
| self, |
| in_channels: int, |
| in_channels_skl: int, |
| latent_channels: int, |
| latent_channels_skl: int, |
| num_res_blocks: int, |
| channels: List[int], |
| num_res_blocks_middle: int = 2, |
| norm_type: Literal["group", "layer"] = "layer", |
| use_fp16: bool = False, |
| encode_global: bool = False, |
| global_token_num: int = 1024, |
| encode_global_skl: bool = True, |
| global_token_num_skl: int = 1024, |
| use_pretrain_branch: bool = True, |
| freeze_pretrain_branch: bool = True, |
| modules_to_freeze: Optional[List[str]] = ["input_layer", "blocks", "middle_block", "out_layer"], |
| latent_denoising: bool = False, |
| latent_denoising_skl: bool = True, |
| normalize_z: bool = False, |
| normalize_z_skl: bool = True, |
| normalize_scale: float = 1.0 |
| ): |
| super().__init__() |
| self.in_channels = in_channels |
| self.in_channels_skl = in_channels_skl |
| self.latent_channels = latent_channels |
| self.latent_channels_skl = latent_channels_skl |
| self.num_res_blocks = num_res_blocks |
| self.channels = channels |
| self.num_res_blocks_middle = num_res_blocks_middle |
| self.norm_type = norm_type |
| self.use_fp16 = use_fp16 |
| self.dtype = torch.float16 if use_fp16 else torch.float32 |
| self.encode_global = encode_global |
| self.global_token_num = global_token_num |
| self.encode_global_skl = encode_global_skl |
| self.global_token_num_skl = global_token_num_skl |
| self.use_pretrain_branch = use_pretrain_branch |
| self.freeze_pretrain_branch = freeze_pretrain_branch |
| self.latent_denoising = latent_denoising |
| self.latent_denoising_skl = latent_denoising_skl |
| self.normalize_latent = normalize_z and latent_denoising |
| self.normalize_latent_skl = normalize_z_skl and latent_denoising_skl |
| self.normalize_scale = normalize_scale |
|
|
| self.input_layer = nn.Conv3d(self.in_channels, channels[0], 3, padding=1) |
| self.input_layer_skl = nn.Conv3d(self.in_channels_skl, channels[0], 3, padding=1) |
|
|
| self.blocks = nn.ModuleList([]) |
| self.blocks_skl = nn.ModuleList([]) |
| for i, ch in enumerate(channels): |
| self.blocks.extend([ |
| ResBlock3d(ch, ch) |
| for _ in range(num_res_blocks) |
| ]) |
| self.blocks_skl.extend([ |
| ResBlock3d(ch, ch) |
| for _ in range(num_res_blocks) |
| ]) |
| if i < len(channels) - 1: |
| self.blocks.append( |
| DownsampleBlock3d(ch, channels[i+1]) |
| ) |
| self.blocks_skl.append( |
| DownsampleBlock3d(ch, channels[i+1]) |
| ) |
| |
| self.middle_block = nn.Sequential(*[ |
| ResBlock3d(channels[-1], channels[-1]) |
| for _ in range(num_res_blocks_middle) |
| ]) |
| self.middle_block_skl = nn.Sequential(*[ |
| ResBlock3d(channels[-1] if _ == 0 else channels[-1], channels[-1]) |
| for _ in range(num_res_blocks_middle) |
| ]) |
|
|
| if self.encode_global: |
| |
| self.init_tokens_ss = nn.Parameter(torch.zeros(1, global_token_num, channels[-1])) |
| pos_embedder = AbsolutePositionEmbedder(channels[-1], 1) |
| coords = torch.arange(global_token_num, device=self.device).reshape(-1, 1) |
| tokens_pos_emb = pos_embedder(coords) |
| self.register_buffer('tokens_pos_emb_ss', tokens_pos_emb) |
| |
| upsample_factor = 2 ** (len(channels) - 1) |
| self.base_size_ss = 64 // upsample_factor |
| pos_embedder = AbsolutePositionEmbedder(channels[-1], 3) |
| coords = torch.meshgrid(*[torch.arange(res, device=self.device) for res in [self.base_size_ss] * 3], indexing='ij') |
| coords = torch.stack(coords, dim=-1).reshape(-1, 3) |
| grid_pos_emb = pos_embedder(coords) |
| self.register_buffer("grid_pos_emb_ss", grid_pos_emb) |
| |
| self.token_proj_ss = nn.Linear(channels[-1]*2, channels[-1]) |
|
|
| |
| self.out_layer = nn.ModuleList( |
| [TransformerCrossBlock( |
| channels=channels[-1], |
| ctx_channels=channels[-1]*2, |
| out_channels=channels[-1], |
| num_heads=16, |
| attn_mode="full", |
| qkv_bias=False, |
| x_is_query=False)] + |
| [TransformerBlock( |
| channels=channels[-1], |
| out_channels=channels[-1], |
| num_heads=16, |
| attn_mode="full", |
| qkv_bias=False, |
| ) for _ in range(4)] + |
| [FeedForwardNet( |
| channels=channels[-1], |
| out_channels=latent_channels*2 if not self.latent_denoising else latent_channels)] |
| ) |
| else: |
| self.out_layer = nn.Sequential( |
| norm_layer(norm_type, channels[-1]), |
| nn.SiLU(), |
| nn.Conv3d(channels[-1], latent_channels*2 if not self.latent_denoising else latent_channels, 3, padding=1) |
| ) |
| |
| if self.encode_global_skl: |
| |
| self.init_tokens = nn.Parameter(torch.zeros(1, global_token_num_skl, channels[-1])) |
| pos_embedder = AbsolutePositionEmbedder(channels[-1], 1) |
| coords = torch.arange(global_token_num_skl, device=self.device).reshape(-1, 1) |
| tokens_pos_emb = pos_embedder(coords) |
| self.register_buffer('tokens_pos_emb', tokens_pos_emb) |
| |
| upsample_factor = 2 ** (len(channels) - 1) |
| self.base_size = 64 // upsample_factor |
| pos_embedder = AbsolutePositionEmbedder(channels[-1], 3) |
| coords = torch.meshgrid(*[torch.arange(res, device=self.device) for res in [self.base_size] * 3], indexing='ij') |
| coords = torch.stack(coords, dim=-1).reshape(-1, 3) |
| grid_pos_emb = pos_embedder(coords) |
| self.register_buffer("grid_pos_emb", grid_pos_emb) |
| |
| self.token_proj = nn.Linear(channels[-1]*2, channels[-1]) |
|
|
| |
| self.out_layer_skl = nn.ModuleList( |
| [TransformerCrossBlock( |
| channels=channels[-1], |
| ctx_channels=channels[-1]*2, |
| out_channels=channels[-1], |
| num_heads=16, |
| attn_mode="full", |
| qkv_bias=False, |
| x_is_query=False)] + |
| [TransformerBlock( |
| channels=channels[-1], |
| out_channels=channels[-1], |
| num_heads=16, |
| attn_mode="full", |
| qkv_bias=False, |
| ) for _ in range(4)] + |
| [FeedForwardNet( |
| channels=channels[-1], |
| out_channels=latent_channels_skl*2 if not self.latent_denoising_skl else latent_channels_skl)] |
| ) |
| else: |
| self.out_layer_skl = nn.Sequential( |
| norm_layer(norm_type, channels[-1]), |
| nn.SiLU(), |
| nn.Conv3d(channels[-1], latent_channels_skl*2 if not self.latent_denoising_skl else latent_channels_skl, 3, padding=1) |
| ) |
|
|
| self.initialize_weights() |
| if use_fp16: |
| self.convert_to_fp16() |
| |
| if self.use_pretrain_branch and self.freeze_pretrain_branch: |
| |
| for module in modules_to_freeze: |
| if hasattr(self, module): |
| mod = getattr(self, module) |
| if isinstance(mod, nn.ModuleList): |
| for m in mod: |
| for param in m.parameters(): |
| param.requires_grad = False |
| else: |
| for param in mod.parameters(): |
| param.requires_grad = False |
|
|
| @property |
| def device(self) -> torch.device: |
| """ |
| Return the device of the model. |
| """ |
| return next(self.parameters()).device |
|
|
| def convert_to_fp16(self) -> None: |
| """ |
| Convert the torso of the model to float16. |
| """ |
| self.use_fp16 = True |
| self.dtype = torch.float16 |
| self.blocks.apply(convert_module_to_f16) |
| self.middle_block.apply(convert_module_to_f16) |
| self.blocks_skl.apply(convert_module_to_f16) |
| self.middle_block_skl.apply(convert_module_to_f16) |
| if self.encode_global_skl: |
| self.token_proj.apply(convert_module_to_f16) |
| self.out_layer_skl.apply(convert_module_to_f16) |
| if self.encode_global: |
| self.token_proj_ss.apply(convert_module_to_f16) |
| self.out_layer.apply(convert_module_to_f16) |
|
|
| def convert_to_fp32(self) -> None: |
| """ |
| Convert the torso of the model to float32. |
| """ |
| self.use_fp16 = False |
| self.dtype = torch.float32 |
| self.blocks.apply(convert_module_to_f32) |
| self.middle_block.apply(convert_module_to_f32) |
| self.blocks_skl.apply(convert_module_to_f32) |
| self.middle_block_skl.apply(convert_module_to_f32) |
| if self.encode_global_skl: |
| self.token_proj.apply(convert_module_to_f32) |
| self.out_layer_skl.apply(convert_module_to_f32) |
| if self.encode_global: |
| self.token_proj_ss.apply(convert_module_to_f32) |
| self.out_layer.apply(convert_module_to_f32) |
|
|
| def initialize_weights(self) -> None: |
| |
| def _basic_init(module): |
| if isinstance(module, nn.Linear): |
| torch.nn.init.kaiming_uniform_(module.weight, nonlinearity='linear') |
| if module.bias is not None: |
| nn.init.constant_(module.bias, 0) |
| self.apply(_basic_init) |
|
|
| def forward(self, x: torch.Tensor, x_skl: torch.Tensor = None, sample_posterior: bool = False, return_raw: bool = False) -> torch.Tensor: |
| h = self.input_layer(x) |
| h = h.type(self.dtype) |
| h_skl = self.input_layer_skl(x_skl) |
| h_skl = h_skl.type(self.dtype) |
|
|
| for block, block_skl in zip(self.blocks, self.blocks_skl): |
| h_skl = block_skl(h_skl) |
| h = block(h) |
| h_skl = self.middle_block_skl(h_skl) |
| h = self.middle_block(h) |
|
|
| if self.encode_global: |
| B, C, D, H, W = h.shape |
| h = h.view(B, C, D*H*W).permute(0, 2, 1) |
| h = torch.cat([h, self.grid_pos_emb_ss[None].expand(B, -1, -1)], dim=-1).type(h.dtype) |
| init_tokens = torch.cat([self.init_tokens_ss, self.tokens_pos_emb_ss[None].expand_as(self.init_tokens_ss)], dim=-1).type(h.dtype) |
| tokens = self.token_proj_ss(init_tokens.expand(B, -1, -1)) |
| h = self.out_layer[0](tokens, h) |
| for layer in self.out_layer[1:]: |
| h = layer(h) |
| h = h.type(x.dtype) |
| if self.latent_denoising: |
| if self.normalize_latent: |
| h = nn.functional.normalize(h, dim=-1) * self.normalize_scale |
| mean = h |
| logvar = torch.zeros_like(h) |
| else: |
| mean, logvar = h.chunk(2, dim=2) |
| if sample_posterior and not self.latent_denoising: |
| std = torch.exp(0.5 * logvar) |
| z = mean + std * torch.randn_like(std) |
| else: |
| z = mean |
| else: |
| h = h.type(x.dtype) |
| h = self.out_layer(h) |
| if self.latent_denoising: |
| if self.normalize_latent: |
| h = nn.functional.normalize(h, dim=1) * self.normalize_scale |
| mean = h |
| logvar = torch.zeros_like(h) |
| else: |
| mean, logvar = h.chunk(2, dim=1) |
| if sample_posterior and not self.latent_denoising: |
| std = torch.exp(0.5 * logvar) |
| z = mean + std * torch.randn_like(std) |
| else: |
| z = mean |
| |
| if self.encode_global_skl: |
| B, C, D, H, W = h_skl.shape |
| h_skl = h_skl.view(B, C, D*H*W).permute(0, 2, 1) |
| h_skl = torch.cat([h_skl, self.grid_pos_emb[None].expand(B, -1, -1)], dim=-1).type(h_skl.dtype) |
| init_tokens = torch.cat([self.init_tokens, self.tokens_pos_emb[None].expand_as(self.init_tokens)], dim=-1).type(h_skl.dtype) |
| tokens = self.token_proj(init_tokens.expand(B, -1, -1)) |
| h_skl = self.out_layer_skl[0](tokens, h_skl) |
| for layer in self.out_layer_skl[1:]: |
| h_skl = layer(h_skl) |
| h_skl = h_skl.type(x_skl.dtype) |
| if self.latent_denoising_skl: |
| if self.normalize_latent_skl: |
| h_skl = nn.functional.normalize(h_skl, dim=-1) * self.normalize_scale |
| mean_skl = h_skl |
| logvar_skl = torch.zeros_like(h_skl) |
| else: |
| mean_skl, logvar_skl = h_skl.chunk(2, dim=2) |
| if sample_posterior and not self.latent_denoising_skl: |
| std_skl = torch.exp(0.5 * logvar_skl) |
| z_skl = mean_skl + std_skl * torch.randn_like(std_skl) |
| else: |
| z_skl = mean_skl |
| else: |
| h_skl = h_skl.type(x_skl.dtype) |
| h_skl = self.out_layer_skl(h_skl) |
| if self.latent_denoising_skl: |
| if self.normalize_latent_skl: |
| h_skl = nn.functional.normalize(h_skl, dim=1) * self.normalize_scale |
| mean_skl = h_skl |
| logvar_skl = torch.zeros_like(h_skl) |
| else: |
| mean_skl, logvar_skl = h_skl.chunk(2, dim=1) |
| if sample_posterior and not self.latent_denoising_skl: |
| std_skl = torch.exp(0.5 * logvar_skl) |
| z_skl = mean_skl + std_skl * torch.randn_like(std_skl) |
| else: |
| z_skl = mean_skl |
| |
| if self.latent_denoising: |
| mean = mean.detach() |
| if self.latent_denoising_skl: |
| mean_skl = mean_skl.detach() |
| |
| if return_raw: |
| return z, mean, logvar, z_skl, mean_skl, logvar_skl |
| return z, z_skl |
| |
|
|
| class AniGenSparseStructureDecoder(nn.Module): |
| """ |
| Decoder for Sparse Structure (\mathcal{D}_S in the paper Sec. 3.3). |
| |
| Args: |
| out_channels (int): Channels of the output. |
| latent_channels (int): Channels of the latent representation. |
| num_res_blocks (int): Number of residual blocks at each resolution. |
| channels (List[int]): Channels of the decoder blocks. |
| num_res_blocks_middle (int): Number of residual blocks in the middle. |
| norm_type (Literal["group", "layer"]): Type of normalization layer. |
| use_fp16 (bool): Whether to use FP16. |
| """ |
| def __init__( |
| self, |
| out_channels: int, |
| out_channels_skl: int, |
| latent_channels: int, |
| latent_channels_skl: int, |
| num_res_blocks: int, |
| channels: List[int], |
| num_res_blocks_middle: int = 2, |
| norm_type: Literal["group", "layer"] = "layer", |
| use_fp16: bool = False, |
| encode_global: bool = False, |
| global_token_num: int = 1024, |
| encode_global_skl: bool = True, |
| global_token_num_skl: int = 1024, |
| use_pretrain_branch: bool = True, |
| freeze_pretrain_branch: bool = True, |
| modules_to_freeze: Optional[List[str]] = ["input_layer", "blocks", "middle_block", "out_layer"], |
| normalize_z: bool = False, |
| normalize_z_skl: bool = True, |
| normalize_scale: float = 1.0, |
| ): |
| super().__init__() |
| self.out_channels = out_channels |
| self.out_channels_skl = out_channels_skl |
| self.latent_channels = latent_channels |
| self.latent_channels_skl = latent_channels_skl |
| self.num_res_blocks = num_res_blocks |
| self.channels = channels |
| self.num_res_blocks_middle = num_res_blocks_middle |
| self.norm_type = norm_type |
| self.use_fp16 = use_fp16 |
| self.dtype = torch.float16 if use_fp16 else torch.float32 |
| self.encode_global = encode_global |
| self.global_token_num = global_token_num |
| self.encode_global_skl = encode_global_skl |
| self.global_token_num_skl = global_token_num_skl |
| self.use_pretrain_branch = use_pretrain_branch |
| self.freeze_pretrain_branch = freeze_pretrain_branch |
| self.normalize_z = normalize_z |
| self.normalize_z_skl = normalize_z_skl |
| self.normalize_scale = normalize_scale |
|
|
| if self.encode_global: |
| |
| upsample_factor = 2 ** (len(channels) - 1) |
| self.base_size_ss = 64 // upsample_factor |
| self.init_grids_ss = nn.Parameter(torch.zeros(1, channels[0], self.base_size_ss**3).permute(0, 2, 1).contiguous().clone()) |
| pos_embedder = AbsolutePositionEmbedder(channels[0], 3) |
| coords = torch.meshgrid(*[torch.arange(res, device=self.device) for res in [self.base_size_ss] * 3], indexing='ij') |
| coords = torch.stack(coords, dim=-1).reshape(-1, 3) |
| grid_pos_emb = pos_embedder(coords) |
| self.register_buffer("grid_pos_emb_ss", grid_pos_emb) |
| |
| pos_embedder = AbsolutePositionEmbedder(channels[0], 1) |
| coords = torch.arange(global_token_num, device=self.device).reshape(-1, 1) |
| tokens_pos_emb = pos_embedder(coords) |
| self.register_buffer('tokens_pos_emb_ss', tokens_pos_emb) |
| |
| self.token_proj_ss = nn.Linear(channels[0]*2, channels[0]) |
|
|
| |
| self.input_layer = nn.ModuleList( |
| [TransformerBlock( |
| channels=channels[0] if _ != 0 else latent_channels + channels[0], |
| out_channels=channels[0], |
| num_heads=4 if _ == 0 else 16, |
| attn_mode="full", |
| qkv_bias=False, |
| ) for _ in range(4)] + |
| [TransformerCrossBlock( |
| channels=channels[0], |
| ctx_channels=channels[0], |
| out_channels=channels[0], |
| num_heads=16, |
| attn_mode="full", |
| qkv_bias=False, |
| x_is_query=False)] |
| ) |
| else: |
| self.input_layer = nn.Conv3d(latent_channels, channels[0], 3, padding=1) |
|
|
| if self.encode_global_skl: |
| |
| upsample_factor = 2 ** (len(channels) - 1) |
| self.base_size = 64 // upsample_factor |
| self.init_grids = nn.Parameter(torch.zeros(1, channels[0], self.base_size**3).permute(0, 2, 1).contiguous().clone()) |
| pos_embedder = AbsolutePositionEmbedder(channels[0], 3) |
| coords = torch.meshgrid(*[torch.arange(res, device=self.device) for res in [self.base_size] * 3], indexing='ij') |
| coords = torch.stack(coords, dim=-1).reshape(-1, 3) |
| grid_pos_emb = pos_embedder(coords) |
| self.register_buffer("grid_pos_emb", grid_pos_emb) |
| |
| pos_embedder = AbsolutePositionEmbedder(channels[0], 1) |
| coords = torch.arange(global_token_num_skl, device=self.device).reshape(-1, 1) |
| tokens_pos_emb = pos_embedder(coords) |
| self.register_buffer('tokens_pos_emb', tokens_pos_emb) |
| |
| self.token_proj = nn.Linear(channels[0]*2, channels[0]) |
|
|
| |
| self.input_layer_skl = nn.ModuleList( |
| [TransformerBlock( |
| channels=channels[0] if _ != 0 else latent_channels_skl + channels[0], |
| out_channels=channels[0], |
| num_heads=4 if _ == 0 else 16, |
| attn_mode="full", |
| qkv_bias=False, |
| ) for _ in range(4)] + |
| [TransformerCrossBlock( |
| channels=channels[0], |
| ctx_channels=channels[0], |
| out_channels=channels[0], |
| num_heads=16, |
| attn_mode="full", |
| qkv_bias=False, |
| x_is_query=False)] |
| ) |
| else: |
| self.input_layer_skl = nn.Conv3d(latent_channels_skl, channels[0], 3, padding=1) |
|
|
| self.middle_block = nn.Sequential(*[ |
| ResBlock3d(channels[0], channels[0]) |
| for _ in range(num_res_blocks_middle) |
| ]) |
| self.middle_block_skl = nn.Sequential(*[ |
| ResBlock3d(channels[0] if _ == 0 else channels[0], channels[0]) |
| for _ in range(num_res_blocks_middle) |
| ]) |
|
|
| self.blocks = nn.ModuleList([]) |
| self.blocks_skl = nn.ModuleList([]) |
| for i, ch in enumerate(channels): |
| self.blocks.extend([ |
| ResBlock3d(ch, ch) |
| for _ in range(num_res_blocks) |
| ]) |
| if i < len(channels) - 1: |
| self.blocks.append( |
| UpsampleBlock3d(ch, channels[i+1]) |
| ) |
| self.blocks_skl.extend([ |
| ResBlock3d(ch, ch) |
| for _ in range(num_res_blocks) |
| ]) |
| if i < len(channels) - 1: |
| self.blocks_skl.append( |
| UpsampleBlock3d(ch, channels[i+1]) |
| ) |
|
|
| self.out_layer = nn.Sequential( |
| norm_layer(norm_type, channels[-1]), |
| nn.SiLU(), |
| nn.Conv3d(channels[-1], self.out_channels, 3, padding=1) |
| ) |
| self.out_layer_skl = nn.Sequential( |
| norm_layer(norm_type, channels[-1]), |
| nn.SiLU(), |
| nn.Conv3d(channels[-1], self.out_channels_skl, 3, padding=1) |
| ) |
|
|
| self.initialize_weights() |
| if use_fp16: |
| self.convert_to_fp16() |
| |
| if self.use_pretrain_branch and self.freeze_pretrain_branch: |
| |
| for module in modules_to_freeze: |
| if hasattr(self, module): |
| mod = getattr(self, module) |
| if isinstance(mod, nn.ModuleList): |
| for m in mod: |
| for param in m.parameters(): |
| param.requires_grad = False |
| else: |
| for param in mod.parameters(): |
| param.requires_grad = False |
|
|
| @property |
| def device(self) -> torch.device: |
| """ |
| Return the device of the model. |
| """ |
| return next(self.parameters()).device |
| |
| def convert_to_fp16(self) -> None: |
| """ |
| Convert the torso of the model to float16. |
| """ |
| self.use_fp16 = True |
| self.dtype = torch.float16 |
| self.blocks.apply(convert_module_to_f16) |
| self.middle_block.apply(convert_module_to_f16) |
| self.blocks_skl.apply(convert_module_to_f16) |
| self.middle_block_skl.apply(convert_module_to_f16) |
| if self.encode_global_skl: |
| self.token_proj.apply(convert_module_to_f16) |
| self.input_layer_skl.apply(convert_module_to_f16) |
| if self.encode_global: |
| self.token_proj_ss.apply(convert_module_to_f16) |
| self.input_layer.apply(convert_module_to_f16) |
|
|
| def convert_to_fp32(self) -> None: |
| """ |
| Convert the torso of the model to float32. |
| """ |
| self.use_fp16 = False |
| self.dtype = torch.float32 |
| self.blocks.apply(convert_module_to_f32) |
| self.middle_block.apply(convert_module_to_f32) |
| self.blocks_skl.apply(convert_module_to_f32) |
| self.middle_block_skl.apply(convert_module_to_f32) |
| if self.encode_global_skl: |
| self.token_proj.apply(convert_module_to_f32) |
| self.input_layer_skl.apply(convert_module_to_f32) |
| if self.encode_global: |
| self.token_proj_ss.apply(convert_module_to_f32) |
| self.input_layer.apply(convert_module_to_f32) |
|
|
| def initialize_weights(self) -> None: |
| |
| def _basic_init(module): |
| if isinstance(module, nn.Linear): |
| torch.nn.init.kaiming_uniform_(module.weight, nonlinearity='linear') |
| if module.bias is not None: |
| nn.init.constant_(module.bias, 0) |
| self.apply(_basic_init) |
| |
| def forward(self, x: torch.Tensor, x_skl: torch.Tensor) -> torch.Tensor: |
| h = F.normalize(x, dim=1) * self.normalize_scale if self.normalize_z else x |
| h_skl = F.normalize(x_skl, dim=1) * self.normalize_scale if self.normalize_z_skl else x_skl |
| if self.encode_global: |
| B, _, _ = h.shape |
| h = torch.cat([h, self.tokens_pos_emb_ss[None].expand(B, -1, -1)], dim=-1).type(self.dtype) |
| h = h.type(self.dtype) |
| for layer in self.input_layer[:-1]: |
| h = layer(h) |
| init_grids = torch.cat([self.init_grids_ss, self.grid_pos_emb_ss[None].expand_as(self.init_grids_ss)], dim=-1).type(self.dtype) |
| grids = self.token_proj_ss(init_grids.expand(B, -1, -1)) |
| h = self.input_layer[-1](grids, h) |
| h = h.permute(0, 2, 1).view(B, -1, self.base_size, self.base_size, self.base_size) |
| else: |
| h = self.input_layer(h) |
| h = h.type(self.dtype) |
| if self.encode_global_skl: |
| B, _, _ = h_skl.shape |
| h_skl = torch.cat([h_skl, self.tokens_pos_emb[None].expand(B, -1, -1)], dim=-1).type(self.dtype) |
| h_skl = h_skl.type(self.dtype) |
| for layer in self.input_layer_skl[:-1]: |
| h_skl = layer(h_skl) |
| init_grids = torch.cat([self.init_grids, self.grid_pos_emb[None].expand_as(self.init_grids)], dim=-1).type(self.dtype) |
| grids = self.token_proj(init_grids.expand(B, -1, -1)) |
| h_skl = self.input_layer_skl[-1](grids, h_skl) |
| h_skl = h_skl.permute(0, 2, 1).view(B, -1, self.base_size, self.base_size, self.base_size) |
| else: |
| h_skl = self.input_layer_skl(h_skl) |
| h_skl = h_skl.type(self.dtype) |
| h_skl = self.middle_block_skl(h_skl) |
| h = self.middle_block(h) |
| for block, block_skl in zip(self.blocks, self.blocks_skl): |
| h_skl = block_skl(h_skl) |
| h = block(h) |
| h = h.type(x.dtype) |
| h = self.out_layer(h) |
| h_skl = h_skl.type(x.dtype) |
| h_skl = self.out_layer_skl(h_skl) |
| return h, h_skl |
|
|