AniGen / anigen /models /structured_latent_vae /anigen_decoder.py
Yihua7's picture
Initial commit: AniGen - Animatable 3D Generation
6b92ff7
from typing import *
import contextlib
import torch
import torch.nn as nn
from ...modules.sparse.transformer import SparseTransformerMultiContextCrossBlock, SparseTransformerBlock
from ...modules.utils import zero_module, convert_module_to_f16, convert_module_to_f32
from ...modules import sparse as sp
from ...representations import MeshExtractResult
from ...representations.mesh import AniGenSparseFeatures2Mesh, AniGenSklFeatures2Skeleton
from ..sparse_elastic_mixin import SparseTransformerElasticMixin
from pytorch3d.ops import knn_points
from .anigen_base import AniGenSparseTransformerBase, FreqPositionalEmbedder
from .skin_models import SKIN_MODEL_DICT
import torch.nn.functional as F
from ...representations.skeleton.grouping import GROUPING_STRATEGIES
class SparseSubdivideBlock3d(nn.Module):
"""
A 3D subdivide block that can subdivide the sparse tensor.
Args:
channels: channels in the inputs and outputs.
out_channels: if specified, the number of output channels.
num_groups: the number of groups for the group norm.
"""
def __init__(
self,
channels: int,
resolution: int,
out_channels: Optional[int] = None,
num_groups: int = 32,
sub_divide: bool = True,
conv_as_residual: bool = False,
):
super().__init__()
self.channels = channels
self.resolution = resolution
self.out_resolution = resolution * 2 if sub_divide else resolution
self.out_channels = out_channels or channels
self.sub_divide = sub_divide
self.conv_as_residual = conv_as_residual
self.act_layers = nn.Sequential(
sp.SparseGroupNorm32(num_groups, channels),
sp.SparseSiLU()
)
self.sub = sp.SparseSubdivide() if sub_divide else nn.Identity()
self.out_layers = nn.Sequential(
sp.SparseConv3d(channels, self.out_channels, 3, indice_key=f"res_{self.out_resolution}"),
sp.SparseGroupNorm32(num_groups, self.out_channels),
sp.SparseSiLU(),
zero_module(sp.SparseConv3d(self.out_channels, self.out_channels, 3, indice_key=f"res_{self.out_resolution}")),
)
if self.out_channels == channels and not self.conv_as_residual:
self.skip_connection = nn.Identity()
else:
self.skip_connection = sp.SparseConv3d(channels, self.out_channels, 1, indice_key=f"res_{self.out_resolution}")
def forward(self, x: sp.SparseTensor) -> sp.SparseTensor:
"""
Apply the block to a Tensor, conditioned on a timestep embedding.
SparseConv3d
Args:
x: an [N x C x ...] Tensor of features.
Returns:
an [N x C x ...] Tensor of outputs.
"""
h = self.act_layers(x)
h = self.sub(h)
x = self.sub(x)
h = self.out_layers(h)
h = h + self.skip_connection(x)
return h
class SparseDownsampleWithCache(nn.Module):
"""SparseDownsample that stores upsample caches under a unique suffix.
This avoids cache-key collisions when stacking multiple down/up stages.
"""
def __init__(self, factor: Union[int, Tuple[int, ...], List[int]], cache_suffix: str):
super().__init__()
self.factor = tuple(factor) if isinstance(factor, (list, tuple)) else factor
self.cache_suffix = cache_suffix
self._down = sp.SparseDownsample(self.factor)
def forward(self, x: sp.SparseTensor) -> sp.SparseTensor:
out = self._down(x)
dim = out.coords.shape[-1] - 1
factor = self.factor if isinstance(self.factor, tuple) else (self.factor,) * dim
k_coords = f'upsample_{factor}_coords'
k_layout = f'upsample_{factor}_layout'
k_idx = f'upsample_{factor}_idx'
coords = out.get_spatial_cache(k_coords)
layout = out.get_spatial_cache(k_layout)
idx = out.get_spatial_cache(k_idx)
if any(v is None for v in [coords, layout, idx]):
raise ValueError('Downsample cache not found after SparseDownsample.')
# spconv expects int32 indices; SparseDownsample produces int64 coords.
if out.coords.dtype != torch.int32:
out = sp.SparseTensor(
out.feats,
out.coords.to(torch.int32),
out.shape,
out.layout,
scale=out._scale,
spatial_cache=out._spatial_cache,
)
out.register_spatial_cache(f'upsample_{factor}_{self.cache_suffix}_coords', coords)
out.register_spatial_cache(f'upsample_{factor}_{self.cache_suffix}_layout', layout)
out.register_spatial_cache(f'upsample_{factor}_{self.cache_suffix}_idx', idx)
# Remove unsuffixed keys to prevent later stages overwriting them.
try:
del out._spatial_cache[k_coords]
del out._spatial_cache[k_layout]
del out._spatial_cache[k_idx]
except Exception:
pass
return out
class SparseUpsampleWithCache(nn.Module):
"""SparseUpsample that reads upsample caches under a unique suffix."""
def __init__(self, factor: Union[int, Tuple[int, ...], List[int]], cache_suffix: str):
super().__init__()
self.factor = tuple(factor) if isinstance(factor, (list, tuple)) else factor
self.cache_suffix = cache_suffix
def forward(self, x: sp.SparseTensor) -> sp.SparseTensor:
dim = x.coords.shape[-1] - 1
factor = self.factor if isinstance(self.factor, tuple) else (self.factor,) * dim
new_coords = x.get_spatial_cache(f'upsample_{factor}_{self.cache_suffix}_coords')
new_layout = x.get_spatial_cache(f'upsample_{factor}_{self.cache_suffix}_layout')
idx = x.get_spatial_cache(f'upsample_{factor}_{self.cache_suffix}_idx')
if any(v is None for v in [new_coords, new_layout, idx]):
raise ValueError('Upsample cache not found. Must be paired with SparseDownsampleWithCache.')
if new_coords.dtype != torch.int32:
new_coords = new_coords.to(torch.int32)
new_feats = x.feats[idx]
out = sp.SparseTensor(new_feats, new_coords, x.shape, new_layout)
out._scale = tuple([s * f for s, f in zip(x._scale, factor)])
out._spatial_cache = x._spatial_cache
return out
class SparseSkinUNetNLevel(nn.Module):
"""A simple N-down/N-up sparse UNet for local smoothing.
Note: `SparseSubdivideBlock3d` uses `resolution` only to name spconv `indice_key`s.
We must provide distinct (and stage-appropriate) values per hierarchy to avoid
rulebook collisions across different coordinate sets.
"""
def __init__(self, channels: int, base_resolution: int, num_groups: int = 32, num_levels: int = 3):
super().__init__()
if num_levels < 1:
raise ValueError(f"num_levels must be >= 1, got {num_levels}")
self.channels = channels
self.base_resolution = int(base_resolution)
self.num_groups = num_groups
self.num_levels = int(num_levels)
def res_block(resolution: int):
return SparseSubdivideBlock3d(
channels=channels,
resolution=resolution,
out_channels=channels,
sub_divide=False,
conv_as_residual=True,
num_groups=num_groups,
)
# resolutions[i] corresponds to the i-th encoder stage (before downsample)
resolutions: List[int] = [max(1, self.base_resolution // (2 ** i)) for i in range(self.num_levels)]
bottom_resolution = max(1, self.base_resolution // (2 ** self.num_levels))
self.enc = nn.ModuleList([res_block(r) for r in resolutions])
self.down = nn.ModuleList([SparseDownsampleWithCache(2, f'unet{i}') for i in range(self.num_levels)])
self.mid = res_block(bottom_resolution)
# Decoder blocks operate at the same resolutions as encoder blocks.
self.up = nn.ModuleList([SparseUpsampleWithCache(2, f'unet{i}') for i in range(self.num_levels)])
self.fuse = nn.ModuleList([sp.SparseLinear(channels * 2, channels) for _ in range(self.num_levels)])
self.dec = nn.ModuleList([res_block(r) for r in resolutions])
def forward(self, x: sp.SparseTensor) -> sp.SparseTensor:
in_dtype = x.feats.dtype
if x.coords.dtype != torch.int32:
x = sp.SparseTensor(
x.feats,
x.coords.to(torch.int32),
x.shape,
x.layout,
scale=x._scale,
spatial_cache=x._spatial_cache,
)
# spconv implicit_gemm has a runtime tuner that can fail for some sparse
# rulebooks under AMP + fp16/bf16. Running UNet convs in fp32 avoids that.
if hasattr(torch, 'autocast'):
autocast_ctx = torch.autocast(device_type=x.device.type, enabled=False)
else:
# Older torch fallback
autocast_ctx = torch.cuda.amp.autocast(enabled=False) if x.device.type == 'cuda' else contextlib.nullcontext()
with autocast_ctx:
x_fp32 = x if x.feats.dtype == torch.float32 else x.replace(x.feats.float())
skips: List[sp.SparseTensor] = []
h = x_fp32
for i in range(self.num_levels):
s = self.enc[i](h)
skips.append(s)
h = self.down[i](s)
h = self.mid(h)
for i in reversed(range(self.num_levels)):
h_up = self.up[i](h)
s = skips[i]
h = self.fuse[i](h_up.replace(torch.cat([h_up.feats, s.feats], dim=-1)))
h = self.dec[i](h)
u0 = h
if in_dtype != u0.feats.dtype:
u0 = u0.replace(u0.feats.to(dtype=in_dtype))
return u0
class AniGenSLatMeshDecoder(AniGenSparseTransformerBase):
def __init__(
self,
resolution: int,
model_channels: int,
model_channels_skl: int,
model_channels_skin: int,
latent_channels: int,
latent_channels_skl: int,
latent_channels_vertskin: int,
num_blocks: int,
num_heads: Optional[int] = None,
num_head_channels: Optional[int] = 64,
num_heads_skl: int = 32,
num_heads_skin: int = 32,
skin_cross_from_groupped: bool = False,
h_skin_unet_num_levels: int = 4,
skin_decoder_config: Optional[Dict[str, Any]] = {},
upsample_skl: bool = False,
skl_defined_on_center: bool = True,
mlp_ratio: float = 4,
attn_mode: Literal["full", "shift_window", "shift_sequence", "shift_order", "swin"] = "swin",
attn_mode_cross: Literal["full", "serialized", "windowed"] = "full",
window_size: int = 8,
pe_mode: Literal["ape", "rope"] = "ape",
use_fp16: bool = False,
use_checkpoint: bool = False,
qk_rms_norm: bool = False,
representation_config: dict = None,
use_pretrain_branch: bool = True,
freeze_pretrain_branch: bool = True,
modules_to_freeze: Optional[List[str]] = ["blocks", "upsample", "out_layer", "skin_decoder"],
skin_cross_from_geo: bool = False,
skl_cross_from_geo: bool = False,
skin_skl_cross: bool = False,
skin_ae_name: str = "SkinAE",
normalize_z: bool = False,
normalize_scale: float = 1.0,
jp_residual_fields: bool = True,
jp_hyper_continuous: bool = True,
grouping_strategy: Literal["mean_shift", "threshold"] = "mean_shift",
vertex_skin_feat_interp_sparse: bool = False,
vertex_skin_feat_interp_nearest: bool = False,
vertex_skin_feat_interp_use_deformed_grid: bool = False,
vertex_skin_feat_interp_trilinear: bool = False,
flexicube_disable_deform: bool = False,
vertex_skin_feat_nodeform_trilinear: bool = False,
):
super().__init__(
in_channels=latent_channels,
in_channels_skl=latent_channels_skl,
in_channels_skin=latent_channels_vertskin,
model_channels=model_channels,
model_channels_skl=model_channels_skl,
model_channels_skin=model_channels_skin,
num_blocks=num_blocks,
num_heads=num_heads,
num_heads_skl=num_heads_skl,
num_heads_skin=num_heads_skin,
num_head_channels=num_head_channels,
mlp_ratio=mlp_ratio,
attn_mode=attn_mode,
attn_mode_cross=attn_mode_cross,
window_size=window_size,
pe_mode=pe_mode,
use_fp16=use_fp16,
use_checkpoint=use_checkpoint,
qk_rms_norm=qk_rms_norm,
skin_cross_from_geo=skin_cross_from_geo,
skl_cross_from_geo=skl_cross_from_geo,
skin_skl_cross=skin_skl_cross,
)
self.pretrain_class_name = ["AniGenElasticSLatMeshDecoder", skin_ae_name]
self.pretrain_ckpt_filter_prefix = {skin_ae_name: "skin_decoder"}
self.latent_channels = latent_channels
self.latent_channels_skl = latent_channels_skl
self.latent_channels_vertskin = latent_channels_vertskin
self.jp_residual_fields = jp_residual_fields
self.jp_hyper_continuous = jp_hyper_continuous
self.grouping_func = GROUPING_STRATEGIES[grouping_strategy]
self.skin_cross_from_groupped = skin_cross_from_groupped
self.normalize_z = normalize_z
self.normalize_scale = normalize_scale
skin_decoder_config['use_fp16'] = use_fp16
self.skin_decoder = SKIN_MODEL_DICT[skin_decoder_config.pop('model_type')](**skin_decoder_config)
self.skin_feat_channels = self.skin_decoder.skin_feat_channels
# Optional local smoothing UNet on h_skin (independent of grouped cross-attn).
# If `h_skin_unet_num_levels < 0`, UNet is disabled.
self.h_skin_unet_num_levels = int(h_skin_unet_num_levels)
if self.h_skin_unet_num_levels >= 1:
self.h_skin_unet = SparseSkinUNetNLevel(
model_channels_skin,
base_resolution=resolution,
num_levels=self.h_skin_unet_num_levels,
)
else:
self.h_skin_unet = None
if self.skin_cross_from_groupped:
# Trainable parent feature for root joints (where parent_idx < 0).
self.root_parent_feat = nn.Parameter(torch.zeros(self.skin_feat_channels))
# Joint feature preprocessing: [joint_skin, fourier(joint_xyz), parent_skin] -> proj -> self-attn
self.joints_pos_embedder = FreqPositionalEmbedder(
in_dim=3,
include_input=True,
max_freq_log2=6,
num_freqs=6,
log_sampling=True,
)
joints_pe_dim = self.joints_pos_embedder.out_dim
joints_in_dim = self.skin_feat_channels + joints_pe_dim + self.skin_feat_channels
self.joints_ctx_channels = model_channels_skin
self.joints_in_proj = nn.Sequential(
nn.Linear(joints_in_dim, self.joints_ctx_channels, bias=True),
nn.SiLU(),
nn.LayerNorm(self.joints_ctx_channels, elementwise_affine=True),
)
self.joints_self_attn = nn.ModuleList([
SparseTransformerBlock(
self.joints_ctx_channels,
num_heads=num_heads_skin,
mlp_ratio=self.mlp_ratio,
attn_mode="full",
window_size=None,
use_checkpoint=self.use_checkpoint,
use_rope=False,
qk_rms_norm=self.qk_rms_norm,
ln_affine=True,
) for _ in range(4)
])
# Coordinate PE for h_skin before cross-attn: coords in [-1, 1] -> Fourier PE -> proj(C), concat, fuse back to C.
self.h_skin_coord_embedder = FreqPositionalEmbedder(
in_dim=3,
include_input=True,
max_freq_log2=6,
num_freqs=6,
log_sampling=True,
)
h_skin_pe_dim = self.h_skin_coord_embedder.out_dim
self.h_skin_coord_proj = nn.Linear(h_skin_pe_dim, model_channels_skin, bias=True)
self.h_skin_coord_fuse = sp.SparseLinear(model_channels_skin * 2, model_channels_skin)
self.skin_cross_groupped_net = SparseTransformerMultiContextCrossBlock(
model_channels_skin,
# Context includes processed joint tokens + raw joint skin feats (skip connection).
ctx_channels=[self.joints_ctx_channels + self.skin_feat_channels],
num_heads=num_heads_skin,
mlp_ratio=self.mlp_ratio,
attn_mode="full",
attn_mode_cross="full",
cross_attn_cache_suffix='_skin_cross_from_groupped',
)
self.resolution = resolution
self.use_pretrain_branch = use_pretrain_branch
self.freeze_pretrain_branch = freeze_pretrain_branch
self.upsample_skl = upsample_skl
self.rep_config = representation_config
self.mesh_extractor = AniGenSparseFeatures2Mesh(
res=self.resolution*4,
use_color=self.rep_config.get('use_color', False),
skin_feat_channels=self.skin_feat_channels,
predict_skin=True,
vertex_skin_feat_interp_sparse=vertex_skin_feat_interp_sparse,
vertex_skin_feat_interp_nearest=vertex_skin_feat_interp_nearest,
vertex_skin_feat_interp_use_deformed_grid=vertex_skin_feat_interp_use_deformed_grid,
vertex_skin_feat_interp_trilinear=vertex_skin_feat_interp_trilinear,
flexicube_disable_deform=flexicube_disable_deform,
vertex_skin_feat_nodeform_trilinear=vertex_skin_feat_nodeform_trilinear,
)
self.out_channels = self.mesh_extractor.feats_channels
self.upsample = nn.ModuleList([
SparseSubdivideBlock3d(
channels=model_channels,
resolution=resolution,
out_channels=model_channels // 4
),
SparseSubdivideBlock3d(
channels=model_channels // 4,
resolution=resolution * 2,
out_channels=model_channels // 8
)
])
upsample_skin_blocks = []
upsample_skin_blocks.extend([
SparseSubdivideBlock3d(
channels=model_channels_skin,
resolution=resolution,
out_channels=model_channels // 4
),
SparseSubdivideBlock3d(
channels=model_channels // 4,
resolution=resolution * 2,
out_channels=model_channels // 8
)
])
self.upsample_skin_net = nn.ModuleList(upsample_skin_blocks)
self.out_layer = sp.SparseLinear(model_channels // 8, self.out_channels)
self.out_layer_skin = sp.SparseLinear(model_channels // 8, self.skin_feat_channels*8)
self.out_layer_skl_skin = sp.SparseLinear(model_channels // 8 if upsample_skl else model_channels_skl, self.skin_feat_channels if skl_defined_on_center else self.skin_feat_channels * 8)
self.use_conf_jp = self.rep_config.get('use_conf_jp', False) or self.jp_hyper_continuous
self.use_conf_skin = self.rep_config.get('use_conf_skin', False)
res_skl = self.resolution * 4 if self.upsample_skl else self.resolution
self.skeleton_extractor = AniGenSklFeatures2Skeleton(skin_feat_channels=self.skin_feat_channels, device=self.device, res=res_skl, use_conf_jp=self.use_conf_jp, use_conf_skin=self.use_conf_skin, predict_skin=True, defined_on_center=skl_defined_on_center, jp_hyper_continuous=self.jp_hyper_continuous, jp_residual_fields=self.jp_residual_fields)
self.out_channels_skl = self.skeleton_extractor.feats_channels
if self.upsample_skl:
self.upsample_skl_net = nn.ModuleList([
SparseSubdivideBlock3d(
channels=model_channels_skl,
resolution=resolution,
out_channels=model_channels // 4
),
SparseSubdivideBlock3d(
channels=model_channels // 4,
resolution=resolution * 2,
out_channels=model_channels // 8
)
])
self.out_layer_skl = sp.SparseLinear(model_channels // 8, self.out_channels_skl)
else:
self.out_layer_skl = sp.SparseLinear(model_channels_skl, self.out_channels_skl)
self.initialize_weights()
if use_fp16:
self.convert_to_fp16()
else:
self.convert_to_fp32()
if self.use_pretrain_branch and self.freeze_pretrain_branch:
for module in modules_to_freeze:
if hasattr(self, module):
mod = getattr(self, module)
if isinstance(mod, nn.ModuleList):
for m in mod:
for name, param in m.named_parameters():
if 'lora' not in name:
param.requires_grad = False
elif isinstance(mod, nn.Module):
for name, param in mod.named_parameters():
if 'lora' not in name:
param.requires_grad = False
elif isinstance(mod, torch.Tensor):
if mod.requires_grad:
mod.requires_grad = False
def initialize_weights(self) -> None:
super().initialize_weights()
scale = 1e-4
# Kaiming initialization for output layers (better for ReLU/SiLU-like activations)
nn.init.kaiming_normal_(self.out_layer.weight, mode='fan_in', nonlinearity='relu')
self.out_layer.weight.data.mul_(scale)
nn.init.constant_(self.out_layer.bias, 0)
nn.init.kaiming_normal_(self.out_layer_skl.weight, mode='fan_in', nonlinearity='relu')
self.out_layer_skl.weight.data.mul_(scale)
nn.init.constant_(self.out_layer_skl.bias, 0)
# Initialize skin layer:
self.skin_decoder.initialize_weights()
nn.init.kaiming_normal_(self.out_layer_skin.weight, mode='fan_in', nonlinearity='relu')
self.out_layer_skin.weight.data.mul_(scale)
nn.init.constant_(self.out_layer_skin.bias, 0)
nn.init.kaiming_normal_(self.out_layer_skl_skin.weight, mode='fan_in', nonlinearity='relu')
self.out_layer_skl_skin.weight.data.mul_(scale)
nn.init.constant_(self.out_layer_skl_skin.bias, 0)
def convert_to_fp16(self) -> None:
"""
Convert the torso of the model to float16.
"""
super().convert_to_fp16()
self.upsample.apply(convert_module_to_f16)
self.upsample_skin_net.apply(convert_module_to_f16)
if self.upsample_skl:
self.upsample_skl_net.apply(convert_module_to_f16)
if self.skin_cross_from_groupped:
# Joint preprocessing and cross-attn should match model dtype.
self.root_parent_feat.data = self.root_parent_feat.data.half()
self.joints_in_proj.apply(convert_module_to_f16)
self.joints_self_attn.apply(convert_module_to_f16)
# `convert_module_to_f16` doesn't include `nn.LayerNorm`, so cast LN params explicitly.
for _m in self.joints_in_proj.modules():
if isinstance(_m, nn.LayerNorm):
if _m.weight is not None:
_m.weight.data = _m.weight.data.half()
if _m.bias is not None:
_m.bias.data = _m.bias.data.half()
# IMPORTANT: `SparseTransformerBlock` uses `LayerNorm32` which internally
# normalizes in fp32 (`x.float()`), so its parameters must stay fp32.
for _m in self.joints_self_attn.modules():
if isinstance(_m, nn.LayerNorm):
if _m.weight is not None:
_m.weight.data = _m.weight.data.float()
if _m.bias is not None:
_m.bias.data = _m.bias.data.float()
self.skin_cross_groupped_net.apply(convert_module_to_f16)
self.h_skin_coord_proj.apply(convert_module_to_f16)
self.h_skin_coord_fuse.apply(convert_module_to_f16)
# UNet is executed in fp32 (see `SparseSkinUNetNLevel.forward`), so keep its
# weights in fp32 to avoid dtype mismatches inside spconv.
if self.h_skin_unet is not None:
self.h_skin_unet.apply(convert_module_to_f32)
self.skin_decoder.convert_to_fp16()
def convert_to_fp32(self) -> None:
"""
Convert the torso of the model to float32.
"""
super().convert_to_fp32()
self.upsample.apply(convert_module_to_f32)
self.upsample_skin_net.apply(convert_module_to_f32)
if self.upsample_skl:
self.upsample_skl_net.apply(convert_module_to_f32)
if self.skin_cross_from_groupped:
self.root_parent_feat.data = self.root_parent_feat.data.float()
self.joints_in_proj.apply(convert_module_to_f32)
self.joints_self_attn.apply(convert_module_to_f32)
for _m in self.joints_in_proj.modules():
if isinstance(_m, nn.LayerNorm):
if _m.weight is not None:
_m.weight.data = _m.weight.data.float()
if _m.bias is not None:
_m.bias.data = _m.bias.data.float()
for _m in self.joints_self_attn.modules():
if isinstance(_m, nn.LayerNorm):
if _m.weight is not None:
_m.weight.data = _m.weight.data.float()
if _m.bias is not None:
_m.bias.data = _m.bias.data.float()
self.skin_cross_groupped_net.apply(convert_module_to_f32)
self.h_skin_coord_proj.apply(convert_module_to_f32)
self.h_skin_coord_fuse.apply(convert_module_to_f32)
if self.h_skin_unet is not None:
self.h_skin_unet.apply(convert_module_to_f32)
self.skin_decoder.convert_to_fp32()
def to_representation(self, x: sp.SparseTensor) -> List[MeshExtractResult]:
"""
Convert a batch of network outputs to 3D representations.
Args:
x: The [N x * x C] sparse tensor output by the network.
Returns:
list of representations
"""
ret = []
for i in range(x.shape[0]):
mesh = self.mesh_extractor(x[i], training=self.training)
ret.append(mesh)
return ret
def to_representation_skl(self, x: sp.SparseTensor) -> List[MeshExtractResult]:
"""
Convert a batch of network outputs to skeleton representations.
Args:
x: The [N x * x C] sparse tensor output by the network.
Returns:
list of skeleton representations
"""
ret = []
for i in range(x.shape[0]):
skl = self.skeleton_extractor(x[i], training=self.training)
ret.append(skl)
return ret
def forward(self, x: sp.SparseTensor, x_skl: sp.SparseTensor, gt_joints=None, gt_parents=None) -> List[MeshExtractResult]:
x0 = x
x_skin = sp.SparseTensor(feats=x0.feats[:, self.latent_channels:], coords=x0.coords.clone())
x = x0.replace(x0.feats[:, :self.latent_channels])
if self.normalize_z:
x_skin = x_skin.replace(F.normalize(x_skin.feats, dim=-1))
x_skl = x_skl.replace(F.normalize(x_skl.feats, dim=-1))
# Backbone forward
h, h_skl, h_skin = super().forward(x, x_skl, x_skin)
# Optional smoothing on h_skin.
if self.h_skin_unet is not None:
h_skin = self.h_skin_unet(h_skin)
# Skeleton prediction
if self.upsample_skl:
for block_skl in self.upsample_skl_net:
h_skl = block_skl(h_skl)
h_skl_middle = h_skl.type(x_skl.dtype)
h_skl = self.out_layer_skl(h_skl_middle)
h_skl_skin = self.out_layer_skl_skin(h_skl_middle)
h_skl = h_skl.replace(torch.cat([h_skl.feats, h_skl_skin.feats], dim=-1))
skeletons = self.to_representation_skl(h_skl)
skin_feats_joints_list = self.skeleton_grouping(skeletons, gt_joints=gt_joints, gt_parents=gt_parents)
# Skin cross with grouped joint features
if self.skin_cross_from_groupped:
coords_xyz = h_skin.coords[:, 1:].to(device=h_skin.device, dtype=torch.float32)
coords_norm = (coords_xyz + 0.5) / self.resolution * 2.0 - 1.0
coords_pe = self.h_skin_coord_embedder(coords_norm)
coords_pe = coords_pe.to(device=h_skin.device, dtype=h_skin.feats.dtype)
coords_pe = self.h_skin_coord_proj(coords_pe)
h_skin = h_skin.replace(torch.cat([h_skin.feats, coords_pe], dim=-1))
h_skin = self.h_skin_coord_fuse(h_skin)
joints_ctx = self._build_processed_joints_context(
skeletons,
skin_feats_joints_list,
device=h_skin.device,
dtype=h_skin.feats.dtype,
)
h_skin = self.skin_cross_groupped_net(h_skin, [joints_ctx])
for block in self.upsample_skin_net:
h_skin = block(h_skin)
h_skin = h_skin.type(x.dtype)
h_skin = self.out_layer_skin(h_skin)
# Mesh prediction
for block in self.upsample:
h = block(h)
h_middle = h.type(x.dtype)
h = self.out_layer(h_middle)
h = h.replace(torch.cat([h.feats, h_skin.feats], dim=-1))
meshes = self.to_representation(h)
self.skinweight_forward(meshes, skeletons, gt_joints=gt_joints, gt_parents=gt_parents)
return meshes, skeletons
def _joints_feats_list_to_sparse(
self,
joints_feats_list: List[torch.Tensor],
device: Optional[torch.device] = None,
dtype: Optional[torch.dtype] = None,
) -> sp.SparseTensor:
if device is None:
device = self.device
if dtype is None:
dtype = self.dtype
feats_per_batch: List[torch.Tensor] = []
for joints_feats in joints_feats_list:
joints_feats = joints_feats.to(device=device, dtype=dtype)
feats_per_batch.append(joints_feats)
feats = torch.cat(feats_per_batch, dim=0)
# Coords are [batch, x, y, z]. We encode token index into x and keep y/z = 0.
batch_indices: List[torch.Tensor] = []
x_indices: List[torch.Tensor] = []
for bi, joints_feats in enumerate(feats_per_batch):
ji = int(joints_feats.shape[0])
batch_indices.append(torch.full((ji,), bi, device=device, dtype=torch.int32))
x_indices.append(torch.arange(ji, device=device, dtype=torch.int32))
b = torch.cat(batch_indices, dim=0)
x = torch.cat(x_indices, dim=0)
yz = torch.zeros((x.shape[0], 2), device=device, dtype=torch.int32)
coords = torch.cat([b[:, None], x[:, None], yz], dim=1)
return sp.SparseTensor(feats=feats, coords=coords)
def _build_processed_joints_context(
self,
skeletons: List[Any],
skin_feats_joints_list: List[torch.Tensor],
device: torch.device,
dtype: torch.dtype,
) -> sp.SparseTensor:
processed: List[torch.Tensor] = []
raw_skin: List[torch.Tensor] = []
for rep_skl, skin_feats_joints in zip(skeletons, skin_feats_joints_list):
joints = rep_skl.joints_grouped
parents = rep_skl.parents_grouped
if joints is None or parents is None:
raise ValueError('Expected grouped joints/parents for skin_cross_from_groupped.')
joints = joints.to(device=device, dtype=dtype)
parents = parents.to(device=device)
skin_feats_joints = skin_feats_joints.to(device=device, dtype=dtype)
raw_skin.append(skin_feats_joints)
pe = self.joints_pos_embedder(joints).to(device=device, dtype=dtype)
# Parent skin features (root uses trainable parameter)
parent_idx = parents.to(torch.long)
valid = parent_idx >= 0
root_feat = self.root_parent_feat.to(device=device, dtype=dtype)
parent_feat_root = root_feat.unsqueeze(0).expand(skin_feats_joints.shape[0], -1)
parent_feat_gather = skin_feats_joints[parent_idx.clamp(min=0)]
parent_feat = torch.where(valid.unsqueeze(1), parent_feat_gather, parent_feat_root)
joint_in = torch.cat([skin_feats_joints, pe, parent_feat], dim=-1)
joint_h = self.joints_in_proj(joint_in)
processed.append(joint_h)
joints_ctx = self._joints_feats_list_to_sparse(processed, device=device, dtype=dtype)
for blk in self.joints_self_attn:
joints_ctx = blk(joints_ctx)
# Skip connection: concatenate original joint skin feats after self-attn.
joints_skip = self._joints_feats_list_to_sparse(raw_skin, device=device, dtype=dtype)
joints_ctx = joints_ctx.replace(torch.cat([joints_ctx.feats, joints_skip.feats], dim=-1))
return joints_ctx
def skeleton_grouping(self, reps_skl, gt_joints=None, gt_parents=None, skin_feats_skl_list=None, return_skin_pred_only=False):
skin_feats_joints_list = []
for i, rep_skl in zip(range(len(reps_skl)), reps_skl):
if gt_joints is not None:
joints_grouped = gt_joints[i]
parents_grouped = gt_parents[i]
elif rep_skl.joints_grouped is None:
with torch.no_grad():
joints_grouped, parents_grouped = self.grouping_func(joints=rep_skl.joints, parents=rep_skl.parents, joints_conf=rep_skl.conf_j, parents_conf=rep_skl.conf_p)
else:
joints_grouped = rep_skl.joints_grouped
parents_grouped = rep_skl.parents_grouped
if not return_skin_pred_only:
rep_skl.joints_grouped = joints_grouped
rep_skl.parents_grouped = parents_grouped
# Calculate NN indices for joints
positions_skl = rep_skl.positions
_, joints_nn_idx, _ = knn_points(positions_skl[None], joints_grouped[None].detach(), K=1, norm=2, return_nn=False)
joints_nn_idx = joints_nn_idx[0, :, 0]
skin_feats_skl = rep_skl.skin_feats if skin_feats_skl_list is None else skin_feats_skl_list[i]
# Average the predicted joint features
conf_skin = torch.sigmoid(rep_skl.conf_skin) if rep_skl.conf_skin is not None else torch.ones_like(skin_feats_skl[:, :1])
skin_feats_joints = torch.zeros([joints_grouped.shape[0], skin_feats_skl.shape[-1]], device=self.device, dtype=skin_feats_skl.dtype)
skin_feats_square_joints = skin_feats_joints.clone()
skin_conf_joints = torch.zeros([joints_grouped.shape[0], 1], device=self.device, dtype=skin_feats_skl.dtype)
skin_feats_joints.scatter_add_(0, joints_nn_idx[:, None].expand(-1, skin_feats_skl.shape[-1]), skin_feats_skl * conf_skin)
skin_feats_square_joints.scatter_add_(0, joints_nn_idx[:, None].expand(-1, skin_feats_skl.shape[-1]), skin_feats_skl.square() * conf_skin)
skin_conf_joints.scatter_add_(0, joints_nn_idx[:, None], conf_skin)
skin_feats_joints = skin_feats_joints / skin_conf_joints.clamp(min=1e-6)
skin_feats_square_joints = skin_feats_square_joints / skin_conf_joints.clamp(min=1e-6)
skin_feats_joints_var = skin_feats_square_joints - skin_feats_joints.square()
skin_feats_joints_var_loss = skin_feats_joints_var.mean()
if not return_skin_pred_only:
rep_skl.skin_feats_joints_var_loss = skin_feats_joints_var_loss
rep_skl.skin_feats_joints = skin_feats_joints
skin_feats_joints_list.append(skin_feats_joints)
return skin_feats_joints_list
def skinweight_forward(self, reps, reps_skl, gt_joints=None, gt_parents=None, return_skin_pred_only=False, skin_feats_verts_list=None, skin_feats_skl_list=None, *args, **kwargs):
if return_skin_pred_only:
skin_preds = []
if reps_skl[0].parents_grouped is None or return_skin_pred_only:
skin_feats_joints_list = self.skeleton_grouping(reps_skl, gt_joints=gt_joints, gt_parents=gt_parents, skin_feats_skl_list=skin_feats_skl_list, return_skin_pred_only=return_skin_pred_only)
else:
skin_feats_joints_list = [rep_skl.skin_feats_joints for rep_skl in reps_skl]
for i, rep, rep_skl in zip(range(len(reps)), reps, reps_skl):
# Joint skinning features
skin_feats_joints = skin_feats_joints_list[i]
# Vertex skinning features
skin_feats_verts = rep.vertex_skin_feats if skin_feats_verts_list is None else skin_feats_verts_list[i]
# Predict skin weights
parents_grouped = rep_skl.parents_grouped
skin_pred = self.skin_decoder(skin_feats_verts[None], skin_feats_joints[None], parents_grouped[None])
skin_pred = skin_pred[0]
if return_skin_pred_only:
skin_preds.append(skin_pred)
else:
reps_skl[i].skin_pred = skin_pred
if return_skin_pred_only:
return skin_preds
class AniGenElasticSLatMeshDecoder(SparseTransformerElasticMixin, AniGenSLatMeshDecoder):
"""
Slat VAE Mesh decoder with elastic memory management.
Used for training with low VRAM.
"""
pass