| from typing import * |
| import torch |
| import torch.nn as nn |
| import torch.nn.functional as F |
| from ...modules import sparse as sp |
| from ..sparse_elastic_mixin import SparseTransformerElasticMixin |
| from .anigen_base import AniGenSparseTransformerBase, FreqPositionalEmbedder |
| from pytorch3d.ops import knn_points |
| from .skin_models import SkinEncoder |
| |
|
|
| def block_attn_config(self): |
| """ |
| Return the attention configuration of the model. |
| """ |
| for i in range(self.num_blocks): |
| if self.attn_mode == "shift_window": |
| yield "serialized", self.window_size, 0, (16 * (i % 2),) * 3, sp.SerializeMode.Z_ORDER |
| elif self.attn_mode == "shift_sequence": |
| yield "serialized", self.window_size, self.window_size // 2 * (i % 2), (0, 0, 0), sp.SerializeMode.Z_ORDER |
| elif self.attn_mode == "shift_order": |
| yield "serialized", self.window_size, 0, (0, 0, 0), sp.SerializeModes[i % 4] |
| elif self.attn_mode == "full": |
| yield "full", None, None, None, None |
| elif self.attn_mode == "swin": |
| yield "windowed", self.window_size, None, self.window_size // 2 * (i % 2), None |
|
|
|
|
| class FeedForwardNet(nn.Module): |
| def __init__(self, channels: int, channels_out: int=None, mlp_ratio: float = 4.0): |
| super().__init__() |
| channels_out = channels if channels_out is None else channels_out |
| self.mlp = nn.Sequential( |
| nn.Linear(channels, int(channels * mlp_ratio)), |
| nn.GELU(approximate="tanh"), |
| nn.Linear(int(channels * mlp_ratio), channels_out), |
| ) |
|
|
| def forward(self, x: torch.Tensor) -> torch.Tensor: |
| return self.mlp(x) |
|
|
|
|
| class AniGenSLatEncoder(AniGenSparseTransformerBase): |
| def __init__( |
| self, |
| resolution: int, |
| in_channels: int, |
| |
| model_channels: int, |
| model_channels_skl: int, |
| model_channels_skin: int, |
| |
| latent_channels: int, |
| latent_channels_skl: int, |
| latent_channels_vertskin: int, |
| |
| num_blocks: int, |
| num_heads: Optional[int] = None, |
| num_head_channels: Optional[int] = 64, |
| |
| num_heads_skl: int = 32, |
| num_heads_skin: int = 32, |
| |
| skl_pos_embed_freq: int = 10, |
| skin_encoder_config: Optional[Dict[str, Any]] = {}, |
| encode_upsampled_skin_feat: bool = True, |
| skin_ae_name: Optional[str] = "SkinAE", |
| |
| mlp_ratio: float = 4, |
| attn_mode: Literal["full", "shift_window", "shift_sequence", "shift_order", "swin"] = "swin", |
| attn_mode_cross: Literal["full", "serialized", "windowed"] = "full", |
| window_size: int = 8, |
| pe_mode: Literal["ape", "rope"] = "ape", |
| use_fp16: bool = False, |
| use_checkpoint: bool = False, |
| qk_rms_norm: bool = False, |
| |
| use_pretrain_branch: bool = True, |
| freeze_pretrain_branch: bool = True, |
| modules_to_freeze: Optional[List[str]] = ["input_layer", "blocks", "out_layer", "skin_encoder"], |
| |
| skin_cross_from_geo: bool = True, |
| skl_cross_from_geo: bool = True, |
| skin_skl_cross: bool = True, |
| |
| latent_denoising: bool = True, |
| normalize_z: bool = True, |
| normalize_scale: float = 1.0, |
| |
| jp_residual_fields: bool = False, |
| jp_hyper_continuous: bool = False, |
| ): |
| self.use_pretrain_branch = use_pretrain_branch |
| self.freeze_pretrain_branch = freeze_pretrain_branch |
| self.skl_pos_embed_freq = skl_pos_embed_freq |
| self.latent_denoising = latent_denoising |
| self.normalize_latent = normalize_z and latent_denoising |
| self.normalize_scale = normalize_scale |
| self.jp_residual_fields = jp_residual_fields |
| self.jp_hyper_continuous = jp_hyper_continuous |
| |
| super().__init__( |
| in_channels=in_channels, |
| in_channels_skl=model_channels_skl, |
| in_channels_skin=model_channels_skin, |
| model_channels=model_channels, |
| model_channels_skl=model_channels_skl, |
| model_channels_skin=model_channels_skin, |
| num_blocks=num_blocks, |
| num_heads=num_heads, |
| num_heads_skl=num_heads_skl, |
| num_heads_skin=num_heads_skin, |
| num_head_channels=num_head_channels, |
| mlp_ratio=mlp_ratio, |
| attn_mode=attn_mode, |
| attn_mode_cross=attn_mode_cross, |
| window_size=window_size, |
| pe_mode=pe_mode, |
| use_fp16=use_fp16, |
| use_checkpoint=use_checkpoint, |
| qk_rms_norm=qk_rms_norm, |
| skin_cross_from_geo=skin_cross_from_geo, |
| skl_cross_from_geo=skl_cross_from_geo, |
| skin_skl_cross=skin_skl_cross, |
| ) |
| self.pretrain_class_name = ["AniGenElasticSLatEncoder", skin_ae_name] |
| self.pretrain_ckpt_filter_prefix = {skin_ae_name: "skin_encoder"} |
| self.resolution = resolution |
|
|
| self.latent_channels = latent_channels |
| self.latent_channels_skl = latent_channels_skl |
| self.latent_channels_vertskin = latent_channels_vertskin |
|
|
| skin_encoder_config['use_fp16'] = use_fp16 |
| self.skin_encoder = SkinEncoder(**skin_encoder_config) |
| self.encode_upsampled_skin_feat = encode_upsampled_skin_feat |
| self.in_layer_skin = FeedForwardNet(channels=self.skin_encoder.skin_feat_channels * (8 if encode_upsampled_skin_feat else 1), channels_out=model_channels_skin) |
|
|
| self.pos_embedder_fourier = FreqPositionalEmbedder(in_dim=4 if self.jp_hyper_continuous else 3, max_freq_log2=self.skl_pos_embed_freq, num_freqs=self.skl_pos_embed_freq, include_input=True) |
| self.root_embedding = nn.Parameter(torch.zeros(1, self.pos_embedder_fourier.out_dim)) |
|
|
| |
| self.in_layer_jp_skl = FeedForwardNet(channels=2 * self.pos_embedder_fourier.out_dim, channels_out=model_channels_skl//4) |
| self.in_layer_skin_skl = FeedForwardNet(channels=self.skin_encoder.skin_feat_channels, channels_out=model_channels_skl-(model_channels_skl//4)) |
|
|
| self.out_layer = sp.SparseLinear(model_channels, 2 * latent_channels) |
| if self.latent_denoising: |
| self.out_layer_skl = sp.SparseLinear(model_channels_skl, latent_channels_skl) |
| self.out_layer_vertskin = sp.SparseLinear(model_channels_skin, latent_channels_vertskin) |
| else: |
| self.out_layer_skl = sp.SparseLinear(model_channels_skl, 2 * latent_channels_skl) |
| self.out_layer_vertskin = sp.SparseLinear(model_channels_skin, 2 * latent_channels_vertskin) |
|
|
| self.initialize_weights() |
| if use_fp16: |
| self.convert_to_fp16() |
| else: |
| self.convert_to_fp32() |
| |
| if 'all' in modules_to_freeze: |
| modules_to_freeze = list(set([k.split('.')[0] for k in self.state_dict().keys()])) |
| print(f"\033[93mFreezing all modules: {modules_to_freeze}\033[0m") |
| if self.use_pretrain_branch and self.freeze_pretrain_branch: |
| for module in modules_to_freeze: |
| if hasattr(self, module): |
| mod = getattr(self, module) |
| if isinstance(mod, nn.ModuleList): |
| for m in mod: |
| for name, param in m.named_parameters(): |
| if 'lora' not in name: |
| param.requires_grad = False |
| elif isinstance(mod, nn.Module): |
| for name, param in mod.named_parameters(): |
| if 'lora' not in name: |
| param.requires_grad = False |
| elif isinstance(mod, torch.Tensor): |
| if mod.requires_grad: |
| mod.requires_grad = False |
|
|
| def initialize_weights(self) -> None: |
| super().initialize_weights() |
| |
| nn.init.constant_(self.out_layer.weight, 0) |
| nn.init.constant_(self.out_layer.bias, 0) |
|
|
| def skeleton_embedding(self, x, x_skl, joints_list, parents_list, skin_list, gt_meshes, bvh_list=None): |
| res = self.resolution |
| feats_new = [] |
| feats_skl_new = [] |
| coords_new = [] |
| coords_skl_new = [] |
|
|
| joint_skin_embeds, vert_skin_embeds = self.skin_encoder(joints_list, parents_list, skin_list) |
| joints_pos_list = [] |
|
|
| for i in range(len(joints_list)): |
| parent_idx = parents_list[i].clone() |
| |
| coords_new.append(x[i].coords) |
| coords_skl_new.append(x_skl[i].coords) |
| coords_new[-1][:, 0] = i |
| coords_skl_new[-1][:, 0] = i |
|
|
| v_pos = (x[i].coords[:, 1:4] + 0.5) / res - 0.5 |
| v_pos_skl = (x_skl[i].coords[:, 1:4] + 0.5) / res - 0.5 |
| dist_nn_12, joints_nn_idx, _ = knn_points(v_pos_skl[None], joints_list[i][None], K=2, norm=2, return_nn=False) |
| joints_nn_idx = joints_nn_idx[0, :, 0] |
|
|
| |
| joints_pos = joints_list[i][joints_nn_idx] - (v_pos_skl if self.jp_residual_fields else 0) |
| parents_pos = joints_list[i][parent_idx[joints_nn_idx]] - (v_pos_skl if self.jp_residual_fields else 0) |
| if self.jp_hyper_continuous: |
| factor = (1 - (dist_nn_12[0, :, 0:1] / (dist_nn_12[0, :, 1:2] + 1e-8)).clamp(max=1.0)) |
| joints_pos = torch.cat([joints_pos, factor], dim=-1) |
| parents_pos = torch.cat([parents_pos, factor], dim=-1) |
| joints_pos_embed = self.pos_embedder_fourier(joints_pos) |
| parents_pos_embed = self.pos_embedder_fourier(parents_pos) |
| parents_pos_embed = torch.where(parent_idx[joints_nn_idx][:, None] == -1, self.root_embedding.expand_as(parents_pos_embed), parents_pos_embed) |
| jp_pos_embed_nn = torch.cat([joints_pos_embed, parents_pos_embed], dim=-1) |
| jp_pos_embed_nn = self.in_layer_jp_skl(jp_pos_embed_nn) |
|
|
| |
| j_skin_embed_nn = joint_skin_embeds[i][joints_nn_idx] |
| j_skin_embed_nn = self.in_layer_skin_skl(j_skin_embed_nn) |
| |
| |
| jp_skl_embed = torch.cat([jp_pos_embed_nn, j_skin_embed_nn], dim=-1) |
| feats_skl_new.append(jp_skl_embed) |
|
|
| if self.encode_upsampled_skin_feat: |
| |
| offsets = torch.tensor([ |
| [-1, -1, -1], [-1, -1, 1], [-1, 1, -1], [-1, 1, 1], |
| [1, -1, -1], [1, -1, 1], [1, 1, -1], [1, 1, 1] |
| ], device=v_pos.device, dtype=v_pos.dtype) * (0.25 / res) |
| query_pos = v_pos.unsqueeze(1) + offsets.unsqueeze(0) |
| query_pos_flat = query_pos.view(-1, 3) |
| else: |
| query_pos_flat = v_pos |
|
|
| if bvh_list is not None: |
| bvh = bvh_list[i].to(v_pos.device) |
| _, face_id, uvw = bvh.unsigned_distance(query_pos_flat, return_uvw=True) |
| uvw = uvw.clamp(min=0.0) |
| uvw_sum = uvw.sum(dim=-1, keepdim=True).clamp_min(1e-6) |
| uvw = uvw / uvw_sum |
| face_id = gt_meshes[i]['faces'][face_id] |
| voxel_skin_embeds = (vert_skin_embeds[i][face_id] * uvw[..., None]).sum(1) |
| else: |
| gt_mesh_verts = gt_meshes[i]['vertices'] |
| _, mesh_nn_idx, _ = knn_points(query_pos_flat[None], gt_mesh_verts[None], K=1, norm=2, return_nn=False) |
| mesh_nn_idx = mesh_nn_idx[0, :, 0] |
| voxel_skin_embeds = vert_skin_embeds[i][mesh_nn_idx] |
| |
| voxel_skin_embeds = voxel_skin_embeds.view(v_pos.shape[0], -1) |
| voxel_skin_embeds = self.in_layer_skin(voxel_skin_embeds) |
| feats_new.append(voxel_skin_embeds) |
| joints_pos_list.append(joints_pos) |
|
|
| x_new = sp.SparseTensor(coords=torch.cat(coords_new, dim=0), feats=torch.cat(feats_new, dim=0)) |
| x_skl_new = sp.SparseTensor(coords=torch.cat(coords_skl_new, dim=0), feats=torch.cat(feats_skl_new, dim=0)) |
|
|
| return x_new, x_skl_new, joint_skin_embeds, vert_skin_embeds, joints_pos_list |
| |
| def encode_sample(self, x: sp.SparseTensor, out_layer: sp.SparseLinear, sample_posterior: bool = True, latent_denoising: bool = False): |
| x = x.type(torch.float32) |
| x = x.replace(F.layer_norm(x.feats, x.feats.shape[-1:])) |
| x = out_layer(x) |
| if latent_denoising: |
| if self.normalize_latent: |
| x = x.replace(nn.functional.normalize(x.feats, dim=-1) * self.normalize_scale) |
| mean, logvar = x.feats, torch.zeros_like(x.feats) |
| else: |
| mean, logvar = x.feats.chunk(2, dim=-1) |
| if sample_posterior and not latent_denoising: |
| std = torch.exp(0.5 * logvar) |
| z = mean + std * torch.randn_like(std) |
| else: |
| z = mean |
| z = x.replace(z) |
| if latent_denoising: |
| mean = mean.detach() |
| return z, mean, logvar |
|
|
| def forward(self, x: sp.SparseTensor, x_skl: sp.SparseTensor, sample_posterior=True, return_raw=False, return_skin_encoded=False, **kwargs): |
| x_skin, x_skl, joint_skin_embeds, vert_skin_embeds, joints_pos = self.skeleton_embedding(x, x_skl, kwargs.get('gt_joints'), kwargs.get('gt_parents'), kwargs.get('gt_skin'), kwargs.get('gt_mesh'), kwargs.get('bvh_list', None)) |
| h, h_skl, h_skin = super().forward(x, x_skl, x_skin) |
|
|
| z, mean, logvar = self.encode_sample(h, self.out_layer, sample_posterior, latent_denoising=False) |
| z_skl, mean_skl, logvar_skl = self.encode_sample(h_skl, self.out_layer_skl, sample_posterior, latent_denoising=self.latent_denoising) |
| z_skin, mean_skin, logvar_skin = self.encode_sample(h_skin, self.out_layer_vertskin, sample_posterior, latent_denoising=self.latent_denoising) |
|
|
| z = z.replace(torch.cat([z.feats, z_skin.feats], dim=-1)) |
| mean, logvar = torch.cat([mean, mean_skin], dim=-1), torch.cat([logvar, logvar_skin], dim=-1) |
| |
| if not return_skin_encoded: |
| |
| if return_raw: |
| return z, mean, logvar, z_skl, mean_skl, logvar_skl, joint_skin_embeds, vert_skin_embeds, joints_pos |
| else: |
| return z, z_skl, joint_skin_embeds, vert_skin_embeds, joints_pos |
| else: |
| |
| if return_raw: |
| return z, mean, logvar, z_skl, mean_skl, logvar_skl, joint_skin_embeds, vert_skin_embeds, joints_pos, x_skin, x_skl |
| else: |
| return z, z_skl, joint_skin_embeds, vert_skin_embeds, joints_pos, x_skin, x_skl |
|
|
| def encode_skin(self, joints_list: List[torch.Tensor], parents_list: List[torch.Tensor], skin_list: List[torch.Tensor]=None): |
| joint_skin_embeds, vert_skin_embeds = self.skin_encoder(joints_list, parents_list, skin_list) |
| return joint_skin_embeds, vert_skin_embeds |
|
|
|
|
| class AniGenElasticSLatEncoder(SparseTransformerElasticMixin, AniGenSLatEncoder): |
| """ |
| SLat VAE encoder with elastic memory management. |
| Used for training with low VRAM. |
| """ |
|
|