| from typing import * |
| import contextlib |
| import torch |
| import torch.nn as nn |
| from ...modules.sparse.transformer import SparseTransformerMultiContextCrossBlock, SparseTransformerBlock |
| from ...modules.utils import zero_module, convert_module_to_f16, convert_module_to_f32 |
| from ...modules import sparse as sp |
| from ...representations import MeshExtractResult |
| from ...representations.mesh import AniGenSparseFeatures2Mesh, AniGenSklFeatures2Skeleton |
| from ..sparse_elastic_mixin import SparseTransformerElasticMixin |
| from pytorch3d.ops import knn_points |
| from .anigen_base import AniGenSparseTransformerBase, FreqPositionalEmbedder |
| from .skin_models import SKIN_MODEL_DICT |
| import torch.nn.functional as F |
| from ...representations.skeleton.grouping import GROUPING_STRATEGIES |
|
|
|
|
| class SparseSubdivideBlock3d(nn.Module): |
| """ |
| A 3D subdivide block that can subdivide the sparse tensor. |
| |
| Args: |
| channels: channels in the inputs and outputs. |
| out_channels: if specified, the number of output channels. |
| num_groups: the number of groups for the group norm. |
| """ |
| def __init__( |
| self, |
| channels: int, |
| resolution: int, |
| out_channels: Optional[int] = None, |
| num_groups: int = 32, |
| sub_divide: bool = True, |
| conv_as_residual: bool = False, |
| ): |
| super().__init__() |
| self.channels = channels |
| self.resolution = resolution |
| self.out_resolution = resolution * 2 if sub_divide else resolution |
| self.out_channels = out_channels or channels |
| self.sub_divide = sub_divide |
| self.conv_as_residual = conv_as_residual |
|
|
| self.act_layers = nn.Sequential( |
| sp.SparseGroupNorm32(num_groups, channels), |
| sp.SparseSiLU() |
| ) |
| |
| self.sub = sp.SparseSubdivide() if sub_divide else nn.Identity() |
| |
| self.out_layers = nn.Sequential( |
| sp.SparseConv3d(channels, self.out_channels, 3, indice_key=f"res_{self.out_resolution}"), |
| sp.SparseGroupNorm32(num_groups, self.out_channels), |
| sp.SparseSiLU(), |
| zero_module(sp.SparseConv3d(self.out_channels, self.out_channels, 3, indice_key=f"res_{self.out_resolution}")), |
| ) |
| |
| if self.out_channels == channels and not self.conv_as_residual: |
| self.skip_connection = nn.Identity() |
| else: |
| self.skip_connection = sp.SparseConv3d(channels, self.out_channels, 1, indice_key=f"res_{self.out_resolution}") |
| |
| def forward(self, x: sp.SparseTensor) -> sp.SparseTensor: |
| """ |
| Apply the block to a Tensor, conditioned on a timestep embedding. |
| SparseConv3d |
| Args: |
| x: an [N x C x ...] Tensor of features. |
| Returns: |
| an [N x C x ...] Tensor of outputs. |
| """ |
| h = self.act_layers(x) |
| h = self.sub(h) |
| x = self.sub(x) |
| h = self.out_layers(h) |
| h = h + self.skip_connection(x) |
| return h |
|
|
|
|
| class SparseDownsampleWithCache(nn.Module): |
| """SparseDownsample that stores upsample caches under a unique suffix. |
| |
| This avoids cache-key collisions when stacking multiple down/up stages. |
| """ |
| def __init__(self, factor: Union[int, Tuple[int, ...], List[int]], cache_suffix: str): |
| super().__init__() |
| self.factor = tuple(factor) if isinstance(factor, (list, tuple)) else factor |
| self.cache_suffix = cache_suffix |
| self._down = sp.SparseDownsample(self.factor) |
|
|
| def forward(self, x: sp.SparseTensor) -> sp.SparseTensor: |
| out = self._down(x) |
|
|
| dim = out.coords.shape[-1] - 1 |
| factor = self.factor if isinstance(self.factor, tuple) else (self.factor,) * dim |
| k_coords = f'upsample_{factor}_coords' |
| k_layout = f'upsample_{factor}_layout' |
| k_idx = f'upsample_{factor}_idx' |
| coords = out.get_spatial_cache(k_coords) |
| layout = out.get_spatial_cache(k_layout) |
| idx = out.get_spatial_cache(k_idx) |
| if any(v is None for v in [coords, layout, idx]): |
| raise ValueError('Downsample cache not found after SparseDownsample.') |
|
|
| |
| if out.coords.dtype != torch.int32: |
| out = sp.SparseTensor( |
| out.feats, |
| out.coords.to(torch.int32), |
| out.shape, |
| out.layout, |
| scale=out._scale, |
| spatial_cache=out._spatial_cache, |
| ) |
|
|
| out.register_spatial_cache(f'upsample_{factor}_{self.cache_suffix}_coords', coords) |
| out.register_spatial_cache(f'upsample_{factor}_{self.cache_suffix}_layout', layout) |
| out.register_spatial_cache(f'upsample_{factor}_{self.cache_suffix}_idx', idx) |
| |
| try: |
| del out._spatial_cache[k_coords] |
| del out._spatial_cache[k_layout] |
| del out._spatial_cache[k_idx] |
| except Exception: |
| pass |
| return out |
|
|
|
|
| class SparseUpsampleWithCache(nn.Module): |
| """SparseUpsample that reads upsample caches under a unique suffix.""" |
| def __init__(self, factor: Union[int, Tuple[int, ...], List[int]], cache_suffix: str): |
| super().__init__() |
| self.factor = tuple(factor) if isinstance(factor, (list, tuple)) else factor |
| self.cache_suffix = cache_suffix |
|
|
| def forward(self, x: sp.SparseTensor) -> sp.SparseTensor: |
| dim = x.coords.shape[-1] - 1 |
| factor = self.factor if isinstance(self.factor, tuple) else (self.factor,) * dim |
| new_coords = x.get_spatial_cache(f'upsample_{factor}_{self.cache_suffix}_coords') |
| new_layout = x.get_spatial_cache(f'upsample_{factor}_{self.cache_suffix}_layout') |
| idx = x.get_spatial_cache(f'upsample_{factor}_{self.cache_suffix}_idx') |
| if any(v is None for v in [new_coords, new_layout, idx]): |
| raise ValueError('Upsample cache not found. Must be paired with SparseDownsampleWithCache.') |
| if new_coords.dtype != torch.int32: |
| new_coords = new_coords.to(torch.int32) |
| new_feats = x.feats[idx] |
| out = sp.SparseTensor(new_feats, new_coords, x.shape, new_layout) |
| out._scale = tuple([s * f for s, f in zip(x._scale, factor)]) |
| out._spatial_cache = x._spatial_cache |
| return out |
|
|
|
|
| class SparseSkinUNetNLevel(nn.Module): |
| """A simple N-down/N-up sparse UNet for local smoothing. |
| |
| Note: `SparseSubdivideBlock3d` uses `resolution` only to name spconv `indice_key`s. |
| We must provide distinct (and stage-appropriate) values per hierarchy to avoid |
| rulebook collisions across different coordinate sets. |
| """ |
| def __init__(self, channels: int, base_resolution: int, num_groups: int = 32, num_levels: int = 3): |
| super().__init__() |
|
|
| if num_levels < 1: |
| raise ValueError(f"num_levels must be >= 1, got {num_levels}") |
| self.channels = channels |
| self.base_resolution = int(base_resolution) |
| self.num_groups = num_groups |
| self.num_levels = int(num_levels) |
|
|
| def res_block(resolution: int): |
| return SparseSubdivideBlock3d( |
| channels=channels, |
| resolution=resolution, |
| out_channels=channels, |
| sub_divide=False, |
| conv_as_residual=True, |
| num_groups=num_groups, |
| ) |
|
|
| |
| resolutions: List[int] = [max(1, self.base_resolution // (2 ** i)) for i in range(self.num_levels)] |
| bottom_resolution = max(1, self.base_resolution // (2 ** self.num_levels)) |
|
|
| self.enc = nn.ModuleList([res_block(r) for r in resolutions]) |
| self.down = nn.ModuleList([SparseDownsampleWithCache(2, f'unet{i}') for i in range(self.num_levels)]) |
| self.mid = res_block(bottom_resolution) |
|
|
| |
| self.up = nn.ModuleList([SparseUpsampleWithCache(2, f'unet{i}') for i in range(self.num_levels)]) |
| self.fuse = nn.ModuleList([sp.SparseLinear(channels * 2, channels) for _ in range(self.num_levels)]) |
| self.dec = nn.ModuleList([res_block(r) for r in resolutions]) |
|
|
| def forward(self, x: sp.SparseTensor) -> sp.SparseTensor: |
| in_dtype = x.feats.dtype |
| if x.coords.dtype != torch.int32: |
| x = sp.SparseTensor( |
| x.feats, |
| x.coords.to(torch.int32), |
| x.shape, |
| x.layout, |
| scale=x._scale, |
| spatial_cache=x._spatial_cache, |
| ) |
|
|
| |
| |
| if hasattr(torch, 'autocast'): |
| autocast_ctx = torch.autocast(device_type=x.device.type, enabled=False) |
| else: |
| |
| autocast_ctx = torch.cuda.amp.autocast(enabled=False) if x.device.type == 'cuda' else contextlib.nullcontext() |
|
|
| with autocast_ctx: |
| x_fp32 = x if x.feats.dtype == torch.float32 else x.replace(x.feats.float()) |
|
|
| skips: List[sp.SparseTensor] = [] |
| h = x_fp32 |
| for i in range(self.num_levels): |
| s = self.enc[i](h) |
| skips.append(s) |
| h = self.down[i](s) |
|
|
| h = self.mid(h) |
|
|
| for i in reversed(range(self.num_levels)): |
| h_up = self.up[i](h) |
| s = skips[i] |
| h = self.fuse[i](h_up.replace(torch.cat([h_up.feats, s.feats], dim=-1))) |
| h = self.dec[i](h) |
|
|
| u0 = h |
|
|
| if in_dtype != u0.feats.dtype: |
| u0 = u0.replace(u0.feats.to(dtype=in_dtype)) |
| return u0 |
|
|
|
|
| class AniGenSLatMeshDecoder(AniGenSparseTransformerBase): |
| def __init__( |
| self, |
| resolution: int, |
| model_channels: int, |
| model_channels_skl: int, |
| model_channels_skin: int, |
| |
| latent_channels: int, |
| latent_channels_skl: int, |
| latent_channels_vertskin: int, |
| |
| num_blocks: int, |
| num_heads: Optional[int] = None, |
| num_head_channels: Optional[int] = 64, |
| |
| num_heads_skl: int = 32, |
| num_heads_skin: int = 32, |
| |
| skin_cross_from_groupped: bool = False, |
| h_skin_unet_num_levels: int = 4, |
| |
| skin_decoder_config: Optional[Dict[str, Any]] = {}, |
| |
| upsample_skl: bool = False, |
| skl_defined_on_center: bool = True, |
| mlp_ratio: float = 4, |
| attn_mode: Literal["full", "shift_window", "shift_sequence", "shift_order", "swin"] = "swin", |
| attn_mode_cross: Literal["full", "serialized", "windowed"] = "full", |
| window_size: int = 8, |
| pe_mode: Literal["ape", "rope"] = "ape", |
| use_fp16: bool = False, |
| use_checkpoint: bool = False, |
| qk_rms_norm: bool = False, |
| representation_config: dict = None, |
| |
| use_pretrain_branch: bool = True, |
| freeze_pretrain_branch: bool = True, |
| modules_to_freeze: Optional[List[str]] = ["blocks", "upsample", "out_layer", "skin_decoder"], |
| |
| skin_cross_from_geo: bool = False, |
| skl_cross_from_geo: bool = False, |
| skin_skl_cross: bool = False, |
| skin_ae_name: str = "SkinAE", |
| |
| normalize_z: bool = False, |
| normalize_scale: float = 1.0, |
| |
| jp_residual_fields: bool = True, |
| jp_hyper_continuous: bool = True, |
| |
| grouping_strategy: Literal["mean_shift", "threshold"] = "mean_shift", |
| |
| vertex_skin_feat_interp_sparse: bool = False, |
| vertex_skin_feat_interp_nearest: bool = False, |
| vertex_skin_feat_interp_use_deformed_grid: bool = False, |
| vertex_skin_feat_interp_trilinear: bool = False, |
| flexicube_disable_deform: bool = False, |
| vertex_skin_feat_nodeform_trilinear: bool = False, |
| ): |
| super().__init__( |
| in_channels=latent_channels, |
| in_channels_skl=latent_channels_skl, |
| in_channels_skin=latent_channels_vertskin, |
| model_channels=model_channels, |
| model_channels_skl=model_channels_skl, |
| model_channels_skin=model_channels_skin, |
| num_blocks=num_blocks, |
| num_heads=num_heads, |
| num_heads_skl=num_heads_skl, |
| num_heads_skin=num_heads_skin, |
| num_head_channels=num_head_channels, |
| mlp_ratio=mlp_ratio, |
| attn_mode=attn_mode, |
| attn_mode_cross=attn_mode_cross, |
| window_size=window_size, |
| pe_mode=pe_mode, |
| use_fp16=use_fp16, |
| use_checkpoint=use_checkpoint, |
| qk_rms_norm=qk_rms_norm, |
| skin_cross_from_geo=skin_cross_from_geo, |
| skl_cross_from_geo=skl_cross_from_geo, |
| skin_skl_cross=skin_skl_cross, |
| ) |
| self.pretrain_class_name = ["AniGenElasticSLatMeshDecoder", skin_ae_name] |
| self.pretrain_ckpt_filter_prefix = {skin_ae_name: "skin_decoder"} |
| self.latent_channels = latent_channels |
| self.latent_channels_skl = latent_channels_skl |
| self.latent_channels_vertskin = latent_channels_vertskin |
| self.jp_residual_fields = jp_residual_fields |
| self.jp_hyper_continuous = jp_hyper_continuous |
| self.grouping_func = GROUPING_STRATEGIES[grouping_strategy] |
| self.skin_cross_from_groupped = skin_cross_from_groupped |
|
|
| self.normalize_z = normalize_z |
| self.normalize_scale = normalize_scale |
|
|
| skin_decoder_config['use_fp16'] = use_fp16 |
| self.skin_decoder = SKIN_MODEL_DICT[skin_decoder_config.pop('model_type')](**skin_decoder_config) |
| self.skin_feat_channels = self.skin_decoder.skin_feat_channels |
|
|
| |
| |
| self.h_skin_unet_num_levels = int(h_skin_unet_num_levels) |
| if self.h_skin_unet_num_levels >= 1: |
| self.h_skin_unet = SparseSkinUNetNLevel( |
| model_channels_skin, |
| base_resolution=resolution, |
| num_levels=self.h_skin_unet_num_levels, |
| ) |
| else: |
| self.h_skin_unet = None |
|
|
| if self.skin_cross_from_groupped: |
| |
| self.root_parent_feat = nn.Parameter(torch.zeros(self.skin_feat_channels)) |
|
|
| |
| self.joints_pos_embedder = FreqPositionalEmbedder( |
| in_dim=3, |
| include_input=True, |
| max_freq_log2=6, |
| num_freqs=6, |
| log_sampling=True, |
| ) |
| joints_pe_dim = self.joints_pos_embedder.out_dim |
| joints_in_dim = self.skin_feat_channels + joints_pe_dim + self.skin_feat_channels |
| self.joints_ctx_channels = model_channels_skin |
| self.joints_in_proj = nn.Sequential( |
| nn.Linear(joints_in_dim, self.joints_ctx_channels, bias=True), |
| nn.SiLU(), |
| nn.LayerNorm(self.joints_ctx_channels, elementwise_affine=True), |
| ) |
| self.joints_self_attn = nn.ModuleList([ |
| SparseTransformerBlock( |
| self.joints_ctx_channels, |
| num_heads=num_heads_skin, |
| mlp_ratio=self.mlp_ratio, |
| attn_mode="full", |
| window_size=None, |
| use_checkpoint=self.use_checkpoint, |
| use_rope=False, |
| qk_rms_norm=self.qk_rms_norm, |
| ln_affine=True, |
| ) for _ in range(4) |
| ]) |
|
|
| |
| self.h_skin_coord_embedder = FreqPositionalEmbedder( |
| in_dim=3, |
| include_input=True, |
| max_freq_log2=6, |
| num_freqs=6, |
| log_sampling=True, |
| ) |
| h_skin_pe_dim = self.h_skin_coord_embedder.out_dim |
| self.h_skin_coord_proj = nn.Linear(h_skin_pe_dim, model_channels_skin, bias=True) |
| self.h_skin_coord_fuse = sp.SparseLinear(model_channels_skin * 2, model_channels_skin) |
|
|
| self.skin_cross_groupped_net = SparseTransformerMultiContextCrossBlock( |
| model_channels_skin, |
| |
| ctx_channels=[self.joints_ctx_channels + self.skin_feat_channels], |
| num_heads=num_heads_skin, |
| mlp_ratio=self.mlp_ratio, |
| attn_mode="full", |
| attn_mode_cross="full", |
| cross_attn_cache_suffix='_skin_cross_from_groupped', |
| ) |
|
|
| self.resolution = resolution |
| self.use_pretrain_branch = use_pretrain_branch |
| self.freeze_pretrain_branch = freeze_pretrain_branch |
| self.upsample_skl = upsample_skl |
| self.rep_config = representation_config |
| self.mesh_extractor = AniGenSparseFeatures2Mesh( |
| res=self.resolution*4, |
| use_color=self.rep_config.get('use_color', False), |
| skin_feat_channels=self.skin_feat_channels, |
| predict_skin=True, |
| vertex_skin_feat_interp_sparse=vertex_skin_feat_interp_sparse, |
| vertex_skin_feat_interp_nearest=vertex_skin_feat_interp_nearest, |
| vertex_skin_feat_interp_use_deformed_grid=vertex_skin_feat_interp_use_deformed_grid, |
| vertex_skin_feat_interp_trilinear=vertex_skin_feat_interp_trilinear, |
| flexicube_disable_deform=flexicube_disable_deform, |
| vertex_skin_feat_nodeform_trilinear=vertex_skin_feat_nodeform_trilinear, |
| ) |
| self.out_channels = self.mesh_extractor.feats_channels |
| self.upsample = nn.ModuleList([ |
| SparseSubdivideBlock3d( |
| channels=model_channels, |
| resolution=resolution, |
| out_channels=model_channels // 4 |
| ), |
| SparseSubdivideBlock3d( |
| channels=model_channels // 4, |
| resolution=resolution * 2, |
| out_channels=model_channels // 8 |
| ) |
| ]) |
| upsample_skin_blocks = [] |
| upsample_skin_blocks.extend([ |
| SparseSubdivideBlock3d( |
| channels=model_channels_skin, |
| resolution=resolution, |
| out_channels=model_channels // 4 |
| ), |
| SparseSubdivideBlock3d( |
| channels=model_channels // 4, |
| resolution=resolution * 2, |
| out_channels=model_channels // 8 |
| ) |
| ]) |
|
|
| self.upsample_skin_net = nn.ModuleList(upsample_skin_blocks) |
| self.out_layer = sp.SparseLinear(model_channels // 8, self.out_channels) |
| self.out_layer_skin = sp.SparseLinear(model_channels // 8, self.skin_feat_channels*8) |
| self.out_layer_skl_skin = sp.SparseLinear(model_channels // 8 if upsample_skl else model_channels_skl, self.skin_feat_channels if skl_defined_on_center else self.skin_feat_channels * 8) |
| self.use_conf_jp = self.rep_config.get('use_conf_jp', False) or self.jp_hyper_continuous |
| self.use_conf_skin = self.rep_config.get('use_conf_skin', False) |
|
|
| res_skl = self.resolution * 4 if self.upsample_skl else self.resolution |
| self.skeleton_extractor = AniGenSklFeatures2Skeleton(skin_feat_channels=self.skin_feat_channels, device=self.device, res=res_skl, use_conf_jp=self.use_conf_jp, use_conf_skin=self.use_conf_skin, predict_skin=True, defined_on_center=skl_defined_on_center, jp_hyper_continuous=self.jp_hyper_continuous, jp_residual_fields=self.jp_residual_fields) |
| |
| self.out_channels_skl = self.skeleton_extractor.feats_channels |
| if self.upsample_skl: |
| self.upsample_skl_net = nn.ModuleList([ |
| SparseSubdivideBlock3d( |
| channels=model_channels_skl, |
| resolution=resolution, |
| out_channels=model_channels // 4 |
| ), |
| SparseSubdivideBlock3d( |
| channels=model_channels // 4, |
| resolution=resolution * 2, |
| out_channels=model_channels // 8 |
| ) |
| ]) |
| self.out_layer_skl = sp.SparseLinear(model_channels // 8, self.out_channels_skl) |
| else: |
| self.out_layer_skl = sp.SparseLinear(model_channels_skl, self.out_channels_skl) |
|
|
| self.initialize_weights() |
| if use_fp16: |
| self.convert_to_fp16() |
| else: |
| self.convert_to_fp32() |
| |
| if self.use_pretrain_branch and self.freeze_pretrain_branch: |
| for module in modules_to_freeze: |
| if hasattr(self, module): |
| mod = getattr(self, module) |
| if isinstance(mod, nn.ModuleList): |
| for m in mod: |
| for name, param in m.named_parameters(): |
| if 'lora' not in name: |
| param.requires_grad = False |
| elif isinstance(mod, nn.Module): |
| for name, param in mod.named_parameters(): |
| if 'lora' not in name: |
| param.requires_grad = False |
| elif isinstance(mod, torch.Tensor): |
| if mod.requires_grad: |
| mod.requires_grad = False |
|
|
| def initialize_weights(self) -> None: |
| super().initialize_weights() |
| scale = 1e-4 |
| |
| nn.init.kaiming_normal_(self.out_layer.weight, mode='fan_in', nonlinearity='relu') |
| self.out_layer.weight.data.mul_(scale) |
| nn.init.constant_(self.out_layer.bias, 0) |
|
|
| nn.init.kaiming_normal_(self.out_layer_skl.weight, mode='fan_in', nonlinearity='relu') |
| self.out_layer_skl.weight.data.mul_(scale) |
| nn.init.constant_(self.out_layer_skl.bias, 0) |
| |
| |
| self.skin_decoder.initialize_weights() |
| nn.init.kaiming_normal_(self.out_layer_skin.weight, mode='fan_in', nonlinearity='relu') |
| self.out_layer_skin.weight.data.mul_(scale) |
| nn.init.constant_(self.out_layer_skin.bias, 0) |
| |
| nn.init.kaiming_normal_(self.out_layer_skl_skin.weight, mode='fan_in', nonlinearity='relu') |
| self.out_layer_skl_skin.weight.data.mul_(scale) |
| nn.init.constant_(self.out_layer_skl_skin.bias, 0) |
|
|
| def convert_to_fp16(self) -> None: |
| """ |
| Convert the torso of the model to float16. |
| """ |
| super().convert_to_fp16() |
| self.upsample.apply(convert_module_to_f16) |
| self.upsample_skin_net.apply(convert_module_to_f16) |
| if self.upsample_skl: |
| self.upsample_skl_net.apply(convert_module_to_f16) |
| if self.skin_cross_from_groupped: |
| |
| self.root_parent_feat.data = self.root_parent_feat.data.half() |
| self.joints_in_proj.apply(convert_module_to_f16) |
| self.joints_self_attn.apply(convert_module_to_f16) |
|
|
| |
| for _m in self.joints_in_proj.modules(): |
| if isinstance(_m, nn.LayerNorm): |
| if _m.weight is not None: |
| _m.weight.data = _m.weight.data.half() |
| if _m.bias is not None: |
| _m.bias.data = _m.bias.data.half() |
|
|
| |
| |
| for _m in self.joints_self_attn.modules(): |
| if isinstance(_m, nn.LayerNorm): |
| if _m.weight is not None: |
| _m.weight.data = _m.weight.data.float() |
| if _m.bias is not None: |
| _m.bias.data = _m.bias.data.float() |
|
|
| self.skin_cross_groupped_net.apply(convert_module_to_f16) |
| self.h_skin_coord_proj.apply(convert_module_to_f16) |
| self.h_skin_coord_fuse.apply(convert_module_to_f16) |
|
|
| |
| |
| if self.h_skin_unet is not None: |
| self.h_skin_unet.apply(convert_module_to_f32) |
| self.skin_decoder.convert_to_fp16() |
|
|
| def convert_to_fp32(self) -> None: |
| """ |
| Convert the torso of the model to float32. |
| """ |
| super().convert_to_fp32() |
| self.upsample.apply(convert_module_to_f32) |
| self.upsample_skin_net.apply(convert_module_to_f32) |
| if self.upsample_skl: |
| self.upsample_skl_net.apply(convert_module_to_f32) |
| if self.skin_cross_from_groupped: |
| self.root_parent_feat.data = self.root_parent_feat.data.float() |
| self.joints_in_proj.apply(convert_module_to_f32) |
| self.joints_self_attn.apply(convert_module_to_f32) |
|
|
| for _m in self.joints_in_proj.modules(): |
| if isinstance(_m, nn.LayerNorm): |
| if _m.weight is not None: |
| _m.weight.data = _m.weight.data.float() |
| if _m.bias is not None: |
| _m.bias.data = _m.bias.data.float() |
| for _m in self.joints_self_attn.modules(): |
| if isinstance(_m, nn.LayerNorm): |
| if _m.weight is not None: |
| _m.weight.data = _m.weight.data.float() |
| if _m.bias is not None: |
| _m.bias.data = _m.bias.data.float() |
|
|
| self.skin_cross_groupped_net.apply(convert_module_to_f32) |
| self.h_skin_coord_proj.apply(convert_module_to_f32) |
| self.h_skin_coord_fuse.apply(convert_module_to_f32) |
| if self.h_skin_unet is not None: |
| self.h_skin_unet.apply(convert_module_to_f32) |
| self.skin_decoder.convert_to_fp32() |
| |
| def to_representation(self, x: sp.SparseTensor) -> List[MeshExtractResult]: |
| """ |
| Convert a batch of network outputs to 3D representations. |
| |
| Args: |
| x: The [N x * x C] sparse tensor output by the network. |
| |
| Returns: |
| list of representations |
| """ |
| ret = [] |
| for i in range(x.shape[0]): |
| mesh = self.mesh_extractor(x[i], training=self.training) |
| ret.append(mesh) |
| return ret |
| |
| def to_representation_skl(self, x: sp.SparseTensor) -> List[MeshExtractResult]: |
| """ |
| Convert a batch of network outputs to skeleton representations. |
| |
| Args: |
| x: The [N x * x C] sparse tensor output by the network. |
| |
| Returns: |
| list of skeleton representations |
| """ |
| ret = [] |
| for i in range(x.shape[0]): |
| skl = self.skeleton_extractor(x[i], training=self.training) |
| ret.append(skl) |
| return ret |
|
|
| def forward(self, x: sp.SparseTensor, x_skl: sp.SparseTensor, gt_joints=None, gt_parents=None) -> List[MeshExtractResult]: |
| x0 = x |
| x_skin = sp.SparseTensor(feats=x0.feats[:, self.latent_channels:], coords=x0.coords.clone()) |
| x = x0.replace(x0.feats[:, :self.latent_channels]) |
| if self.normalize_z: |
| x_skin = x_skin.replace(F.normalize(x_skin.feats, dim=-1)) |
| x_skl = x_skl.replace(F.normalize(x_skl.feats, dim=-1)) |
| |
| |
| h, h_skl, h_skin = super().forward(x, x_skl, x_skin) |
|
|
| |
| if self.h_skin_unet is not None: |
| h_skin = self.h_skin_unet(h_skin) |
| |
| |
| if self.upsample_skl: |
| for block_skl in self.upsample_skl_net: |
| h_skl = block_skl(h_skl) |
| h_skl_middle = h_skl.type(x_skl.dtype) |
| h_skl = self.out_layer_skl(h_skl_middle) |
| h_skl_skin = self.out_layer_skl_skin(h_skl_middle) |
| h_skl = h_skl.replace(torch.cat([h_skl.feats, h_skl_skin.feats], dim=-1)) |
| skeletons = self.to_representation_skl(h_skl) |
| skin_feats_joints_list = self.skeleton_grouping(skeletons, gt_joints=gt_joints, gt_parents=gt_parents) |
|
|
| |
| if self.skin_cross_from_groupped: |
| coords_xyz = h_skin.coords[:, 1:].to(device=h_skin.device, dtype=torch.float32) |
| coords_norm = (coords_xyz + 0.5) / self.resolution * 2.0 - 1.0 |
| coords_pe = self.h_skin_coord_embedder(coords_norm) |
| coords_pe = coords_pe.to(device=h_skin.device, dtype=h_skin.feats.dtype) |
| coords_pe = self.h_skin_coord_proj(coords_pe) |
| h_skin = h_skin.replace(torch.cat([h_skin.feats, coords_pe], dim=-1)) |
| h_skin = self.h_skin_coord_fuse(h_skin) |
| joints_ctx = self._build_processed_joints_context( |
| skeletons, |
| skin_feats_joints_list, |
| device=h_skin.device, |
| dtype=h_skin.feats.dtype, |
| ) |
| h_skin = self.skin_cross_groupped_net(h_skin, [joints_ctx]) |
| for block in self.upsample_skin_net: |
| h_skin = block(h_skin) |
| h_skin = h_skin.type(x.dtype) |
| h_skin = self.out_layer_skin(h_skin) |
|
|
| |
| for block in self.upsample: |
| h = block(h) |
| h_middle = h.type(x.dtype) |
| h = self.out_layer(h_middle) |
| h = h.replace(torch.cat([h.feats, h_skin.feats], dim=-1)) |
| meshes = self.to_representation(h) |
|
|
| self.skinweight_forward(meshes, skeletons, gt_joints=gt_joints, gt_parents=gt_parents) |
| return meshes, skeletons |
|
|
| def _joints_feats_list_to_sparse( |
| self, |
| joints_feats_list: List[torch.Tensor], |
| device: Optional[torch.device] = None, |
| dtype: Optional[torch.dtype] = None, |
| ) -> sp.SparseTensor: |
| if device is None: |
| device = self.device |
| if dtype is None: |
| dtype = self.dtype |
| feats_per_batch: List[torch.Tensor] = [] |
| for joints_feats in joints_feats_list: |
| joints_feats = joints_feats.to(device=device, dtype=dtype) |
| feats_per_batch.append(joints_feats) |
| feats = torch.cat(feats_per_batch, dim=0) |
| |
| batch_indices: List[torch.Tensor] = [] |
| x_indices: List[torch.Tensor] = [] |
| for bi, joints_feats in enumerate(feats_per_batch): |
| ji = int(joints_feats.shape[0]) |
| batch_indices.append(torch.full((ji,), bi, device=device, dtype=torch.int32)) |
| x_indices.append(torch.arange(ji, device=device, dtype=torch.int32)) |
| b = torch.cat(batch_indices, dim=0) |
| x = torch.cat(x_indices, dim=0) |
| yz = torch.zeros((x.shape[0], 2), device=device, dtype=torch.int32) |
| coords = torch.cat([b[:, None], x[:, None], yz], dim=1) |
| return sp.SparseTensor(feats=feats, coords=coords) |
|
|
| def _build_processed_joints_context( |
| self, |
| skeletons: List[Any], |
| skin_feats_joints_list: List[torch.Tensor], |
| device: torch.device, |
| dtype: torch.dtype, |
| ) -> sp.SparseTensor: |
| processed: List[torch.Tensor] = [] |
| raw_skin: List[torch.Tensor] = [] |
| for rep_skl, skin_feats_joints in zip(skeletons, skin_feats_joints_list): |
| joints = rep_skl.joints_grouped |
| parents = rep_skl.parents_grouped |
| if joints is None or parents is None: |
| raise ValueError('Expected grouped joints/parents for skin_cross_from_groupped.') |
| joints = joints.to(device=device, dtype=dtype) |
| parents = parents.to(device=device) |
| skin_feats_joints = skin_feats_joints.to(device=device, dtype=dtype) |
| raw_skin.append(skin_feats_joints) |
|
|
| pe = self.joints_pos_embedder(joints).to(device=device, dtype=dtype) |
|
|
| |
| parent_idx = parents.to(torch.long) |
| valid = parent_idx >= 0 |
| root_feat = self.root_parent_feat.to(device=device, dtype=dtype) |
| parent_feat_root = root_feat.unsqueeze(0).expand(skin_feats_joints.shape[0], -1) |
| parent_feat_gather = skin_feats_joints[parent_idx.clamp(min=0)] |
| parent_feat = torch.where(valid.unsqueeze(1), parent_feat_gather, parent_feat_root) |
|
|
| joint_in = torch.cat([skin_feats_joints, pe, parent_feat], dim=-1) |
| joint_h = self.joints_in_proj(joint_in) |
| processed.append(joint_h) |
|
|
| joints_ctx = self._joints_feats_list_to_sparse(processed, device=device, dtype=dtype) |
| for blk in self.joints_self_attn: |
| joints_ctx = blk(joints_ctx) |
| |
| joints_skip = self._joints_feats_list_to_sparse(raw_skin, device=device, dtype=dtype) |
| joints_ctx = joints_ctx.replace(torch.cat([joints_ctx.feats, joints_skip.feats], dim=-1)) |
| return joints_ctx |
| |
| def skeleton_grouping(self, reps_skl, gt_joints=None, gt_parents=None, skin_feats_skl_list=None, return_skin_pred_only=False): |
| skin_feats_joints_list = [] |
| for i, rep_skl in zip(range(len(reps_skl)), reps_skl): |
| if gt_joints is not None: |
| joints_grouped = gt_joints[i] |
| parents_grouped = gt_parents[i] |
| elif rep_skl.joints_grouped is None: |
| with torch.no_grad(): |
| joints_grouped, parents_grouped = self.grouping_func(joints=rep_skl.joints, parents=rep_skl.parents, joints_conf=rep_skl.conf_j, parents_conf=rep_skl.conf_p) |
| else: |
| joints_grouped = rep_skl.joints_grouped |
| parents_grouped = rep_skl.parents_grouped |
| |
| if not return_skin_pred_only: |
| rep_skl.joints_grouped = joints_grouped |
| rep_skl.parents_grouped = parents_grouped |
|
|
| |
| positions_skl = rep_skl.positions |
| _, joints_nn_idx, _ = knn_points(positions_skl[None], joints_grouped[None].detach(), K=1, norm=2, return_nn=False) |
| joints_nn_idx = joints_nn_idx[0, :, 0] |
| skin_feats_skl = rep_skl.skin_feats if skin_feats_skl_list is None else skin_feats_skl_list[i] |
| |
| |
| conf_skin = torch.sigmoid(rep_skl.conf_skin) if rep_skl.conf_skin is not None else torch.ones_like(skin_feats_skl[:, :1]) |
|
|
| skin_feats_joints = torch.zeros([joints_grouped.shape[0], skin_feats_skl.shape[-1]], device=self.device, dtype=skin_feats_skl.dtype) |
| skin_feats_square_joints = skin_feats_joints.clone() |
| skin_conf_joints = torch.zeros([joints_grouped.shape[0], 1], device=self.device, dtype=skin_feats_skl.dtype) |
|
|
| skin_feats_joints.scatter_add_(0, joints_nn_idx[:, None].expand(-1, skin_feats_skl.shape[-1]), skin_feats_skl * conf_skin) |
| skin_feats_square_joints.scatter_add_(0, joints_nn_idx[:, None].expand(-1, skin_feats_skl.shape[-1]), skin_feats_skl.square() * conf_skin) |
| skin_conf_joints.scatter_add_(0, joints_nn_idx[:, None], conf_skin) |
|
|
| skin_feats_joints = skin_feats_joints / skin_conf_joints.clamp(min=1e-6) |
| skin_feats_square_joints = skin_feats_square_joints / skin_conf_joints.clamp(min=1e-6) |
| skin_feats_joints_var = skin_feats_square_joints - skin_feats_joints.square() |
| skin_feats_joints_var_loss = skin_feats_joints_var.mean() |
| |
| if not return_skin_pred_only: |
| rep_skl.skin_feats_joints_var_loss = skin_feats_joints_var_loss |
| rep_skl.skin_feats_joints = skin_feats_joints |
| skin_feats_joints_list.append(skin_feats_joints) |
| return skin_feats_joints_list |
| |
| def skinweight_forward(self, reps, reps_skl, gt_joints=None, gt_parents=None, return_skin_pred_only=False, skin_feats_verts_list=None, skin_feats_skl_list=None, *args, **kwargs): |
| if return_skin_pred_only: |
| skin_preds = [] |
| if reps_skl[0].parents_grouped is None or return_skin_pred_only: |
| skin_feats_joints_list = self.skeleton_grouping(reps_skl, gt_joints=gt_joints, gt_parents=gt_parents, skin_feats_skl_list=skin_feats_skl_list, return_skin_pred_only=return_skin_pred_only) |
| else: |
| skin_feats_joints_list = [rep_skl.skin_feats_joints for rep_skl in reps_skl] |
| for i, rep, rep_skl in zip(range(len(reps)), reps, reps_skl): |
| |
| skin_feats_joints = skin_feats_joints_list[i] |
| |
| skin_feats_verts = rep.vertex_skin_feats if skin_feats_verts_list is None else skin_feats_verts_list[i] |
| |
| parents_grouped = rep_skl.parents_grouped |
| skin_pred = self.skin_decoder(skin_feats_verts[None], skin_feats_joints[None], parents_grouped[None]) |
| skin_pred = skin_pred[0] |
| if return_skin_pred_only: |
| skin_preds.append(skin_pred) |
| else: |
| reps_skl[i].skin_pred = skin_pred |
| if return_skin_pred_only: |
| return skin_preds |
| |
|
|
| class AniGenElasticSLatMeshDecoder(SparseTransformerElasticMixin, AniGenSLatMeshDecoder): |
| """ |
| Slat VAE Mesh decoder with elastic memory management. |
| Used for training with low VRAM. |
| """ |
| pass |
|
|