AniGen / anigen /models /structured_latent_vae /anigen_encoder.py
Yihua7's picture
Initial commit: AniGen - Animatable 3D Generation
6b92ff7
from typing import *
import torch
import torch.nn as nn
import torch.nn.functional as F
from ...modules import sparse as sp
from ..sparse_elastic_mixin import SparseTransformerElasticMixin
from .anigen_base import AniGenSparseTransformerBase, FreqPositionalEmbedder
from pytorch3d.ops import knn_points
from .skin_models import SkinEncoder
def block_attn_config(self):
"""
Return the attention configuration of the model.
"""
for i in range(self.num_blocks):
if self.attn_mode == "shift_window":
yield "serialized", self.window_size, 0, (16 * (i % 2),) * 3, sp.SerializeMode.Z_ORDER
elif self.attn_mode == "shift_sequence":
yield "serialized", self.window_size, self.window_size // 2 * (i % 2), (0, 0, 0), sp.SerializeMode.Z_ORDER
elif self.attn_mode == "shift_order":
yield "serialized", self.window_size, 0, (0, 0, 0), sp.SerializeModes[i % 4]
elif self.attn_mode == "full":
yield "full", None, None, None, None
elif self.attn_mode == "swin":
yield "windowed", self.window_size, None, self.window_size // 2 * (i % 2), None
class FeedForwardNet(nn.Module):
def __init__(self, channels: int, channels_out: int=None, mlp_ratio: float = 4.0):
super().__init__()
channels_out = channels if channels_out is None else channels_out
self.mlp = nn.Sequential(
nn.Linear(channels, int(channels * mlp_ratio)),
nn.GELU(approximate="tanh"),
nn.Linear(int(channels * mlp_ratio), channels_out),
)
def forward(self, x: torch.Tensor) -> torch.Tensor:
return self.mlp(x)
class AniGenSLatEncoder(AniGenSparseTransformerBase):
def __init__(
self,
resolution: int,
in_channels: int,
model_channels: int,
model_channels_skl: int,
model_channels_skin: int,
latent_channels: int,
latent_channels_skl: int,
latent_channels_vertskin: int,
num_blocks: int,
num_heads: Optional[int] = None,
num_head_channels: Optional[int] = 64,
num_heads_skl: int = 32,
num_heads_skin: int = 32,
skl_pos_embed_freq: int = 10,
skin_encoder_config: Optional[Dict[str, Any]] = {},
encode_upsampled_skin_feat: bool = True,
skin_ae_name: Optional[str] = "SkinAE",
mlp_ratio: float = 4,
attn_mode: Literal["full", "shift_window", "shift_sequence", "shift_order", "swin"] = "swin",
attn_mode_cross: Literal["full", "serialized", "windowed"] = "full",
window_size: int = 8,
pe_mode: Literal["ape", "rope"] = "ape",
use_fp16: bool = False,
use_checkpoint: bool = False,
qk_rms_norm: bool = False,
use_pretrain_branch: bool = True,
freeze_pretrain_branch: bool = True,
modules_to_freeze: Optional[List[str]] = ["input_layer", "blocks", "out_layer", "skin_encoder"],
skin_cross_from_geo: bool = True,
skl_cross_from_geo: bool = True,
skin_skl_cross: bool = True,
latent_denoising: bool = True,
normalize_z: bool = True,
normalize_scale: float = 1.0,
jp_residual_fields: bool = False,
jp_hyper_continuous: bool = False,
):
self.use_pretrain_branch = use_pretrain_branch
self.freeze_pretrain_branch = freeze_pretrain_branch
self.skl_pos_embed_freq = skl_pos_embed_freq
self.latent_denoising = latent_denoising
self.normalize_latent = normalize_z and latent_denoising
self.normalize_scale = normalize_scale
self.jp_residual_fields = jp_residual_fields
self.jp_hyper_continuous = jp_hyper_continuous
super().__init__(
in_channels=in_channels,
in_channels_skl=model_channels_skl,
in_channels_skin=model_channels_skin,
model_channels=model_channels,
model_channels_skl=model_channels_skl,
model_channels_skin=model_channels_skin,
num_blocks=num_blocks,
num_heads=num_heads,
num_heads_skl=num_heads_skl,
num_heads_skin=num_heads_skin,
num_head_channels=num_head_channels,
mlp_ratio=mlp_ratio,
attn_mode=attn_mode,
attn_mode_cross=attn_mode_cross,
window_size=window_size,
pe_mode=pe_mode,
use_fp16=use_fp16,
use_checkpoint=use_checkpoint,
qk_rms_norm=qk_rms_norm,
skin_cross_from_geo=skin_cross_from_geo,
skl_cross_from_geo=skl_cross_from_geo,
skin_skl_cross=skin_skl_cross,
)
self.pretrain_class_name = ["AniGenElasticSLatEncoder", skin_ae_name]
self.pretrain_ckpt_filter_prefix = {skin_ae_name: "skin_encoder"}
self.resolution = resolution
self.latent_channels = latent_channels
self.latent_channels_skl = latent_channels_skl
self.latent_channels_vertskin = latent_channels_vertskin
skin_encoder_config['use_fp16'] = use_fp16
self.skin_encoder = SkinEncoder(**skin_encoder_config)
self.encode_upsampled_skin_feat = encode_upsampled_skin_feat
self.in_layer_skin = FeedForwardNet(channels=self.skin_encoder.skin_feat_channels * (8 if encode_upsampled_skin_feat else 1), channels_out=model_channels_skin)
self.pos_embedder_fourier = FreqPositionalEmbedder(in_dim=4 if self.jp_hyper_continuous else 3, max_freq_log2=self.skl_pos_embed_freq, num_freqs=self.skl_pos_embed_freq, include_input=True)
self.root_embedding = nn.Parameter(torch.zeros(1, self.pos_embedder_fourier.out_dim))
# Channel Balance
self.in_layer_jp_skl = FeedForwardNet(channels=2 * self.pos_embedder_fourier.out_dim, channels_out=model_channels_skl//4)
self.in_layer_skin_skl = FeedForwardNet(channels=self.skin_encoder.skin_feat_channels, channels_out=model_channels_skl-(model_channels_skl//4))
self.out_layer = sp.SparseLinear(model_channels, 2 * latent_channels)
if self.latent_denoising:
self.out_layer_skl = sp.SparseLinear(model_channels_skl, latent_channels_skl)
self.out_layer_vertskin = sp.SparseLinear(model_channels_skin, latent_channels_vertskin)
else:
self.out_layer_skl = sp.SparseLinear(model_channels_skl, 2 * latent_channels_skl)
self.out_layer_vertskin = sp.SparseLinear(model_channels_skin, 2 * latent_channels_vertskin)
self.initialize_weights()
if use_fp16:
self.convert_to_fp16()
else:
self.convert_to_fp32()
if 'all' in modules_to_freeze:
modules_to_freeze = list(set([k.split('.')[0] for k in self.state_dict().keys()]))
print(f"\033[93mFreezing all modules: {modules_to_freeze}\033[0m")
if self.use_pretrain_branch and self.freeze_pretrain_branch:
for module in modules_to_freeze:
if hasattr(self, module):
mod = getattr(self, module)
if isinstance(mod, nn.ModuleList):
for m in mod:
for name, param in m.named_parameters():
if 'lora' not in name:
param.requires_grad = False
elif isinstance(mod, nn.Module):
for name, param in mod.named_parameters():
if 'lora' not in name:
param.requires_grad = False
elif isinstance(mod, torch.Tensor):
if mod.requires_grad:
mod.requires_grad = False
def initialize_weights(self) -> None:
super().initialize_weights()
# Zero-out output layers:
nn.init.constant_(self.out_layer.weight, 0)
nn.init.constant_(self.out_layer.bias, 0)
def skeleton_embedding(self, x, x_skl, joints_list, parents_list, skin_list, gt_meshes, bvh_list=None):
res = self.resolution
feats_new = []
feats_skl_new = []
coords_new = []
coords_skl_new = []
joint_skin_embeds, vert_skin_embeds = self.skin_encoder(joints_list, parents_list, skin_list)
joints_pos_list = []
for i in range(len(joints_list)):
parent_idx = parents_list[i].clone()
coords_new.append(x[i].coords)
coords_skl_new.append(x_skl[i].coords)
coords_new[-1][:, 0] = i
coords_skl_new[-1][:, 0] = i
v_pos = (x[i].coords[:, 1:4] + 0.5) / res - 0.5
v_pos_skl = (x_skl[i].coords[:, 1:4] + 0.5) / res - 0.5
dist_nn_12, joints_nn_idx, _ = knn_points(v_pos_skl[None], joints_list[i][None], K=2, norm=2, return_nn=False)
joints_nn_idx = joints_nn_idx[0, :, 0]
# Skeleton positional embedding
joints_pos = joints_list[i][joints_nn_idx] - (v_pos_skl if self.jp_residual_fields else 0)
parents_pos = joints_list[i][parent_idx[joints_nn_idx]] - (v_pos_skl if self.jp_residual_fields else 0)
if self.jp_hyper_continuous:
factor = (1 - (dist_nn_12[0, :, 0:1] / (dist_nn_12[0, :, 1:2] + 1e-8)).clamp(max=1.0))
joints_pos = torch.cat([joints_pos, factor], dim=-1)
parents_pos = torch.cat([parents_pos, factor], dim=-1)
joints_pos_embed = self.pos_embedder_fourier(joints_pos)
parents_pos_embed = self.pos_embedder_fourier(parents_pos)
parents_pos_embed = torch.where(parent_idx[joints_nn_idx][:, None] == -1, self.root_embedding.expand_as(parents_pos_embed), parents_pos_embed)
jp_pos_embed_nn = torch.cat([joints_pos_embed, parents_pos_embed], dim=-1)
jp_pos_embed_nn = self.in_layer_jp_skl(jp_pos_embed_nn)
# Skeleton skin embedding
j_skin_embed_nn = joint_skin_embeds[i][joints_nn_idx]
j_skin_embed_nn = self.in_layer_skin_skl(j_skin_embed_nn)
# Concatenate
jp_skl_embed = torch.cat([jp_pos_embed_nn, j_skin_embed_nn], dim=-1)
feats_skl_new.append(jp_skl_embed)
if self.encode_upsampled_skin_feat:
# Create 8 sub-voxel points
offsets = torch.tensor([
[-1, -1, -1], [-1, -1, 1], [-1, 1, -1], [-1, 1, 1],
[1, -1, -1], [1, -1, 1], [1, 1, -1], [1, 1, 1]
], device=v_pos.device, dtype=v_pos.dtype) * (0.25 / res)
query_pos = v_pos.unsqueeze(1) + offsets.unsqueeze(0) # (N, 8, 3)
query_pos_flat = query_pos.view(-1, 3)
else:
query_pos_flat = v_pos
if bvh_list is not None:
bvh = bvh_list[i].to(v_pos.device)
_, face_id, uvw = bvh.unsigned_distance(query_pos_flat, return_uvw=True)
uvw = uvw.clamp(min=0.0)
uvw_sum = uvw.sum(dim=-1, keepdim=True).clamp_min(1e-6)
uvw = uvw / uvw_sum
face_id = gt_meshes[i]['faces'][face_id]
voxel_skin_embeds = (vert_skin_embeds[i][face_id] * uvw[..., None]).sum(1)
else:
gt_mesh_verts = gt_meshes[i]['vertices']
_, mesh_nn_idx, _ = knn_points(query_pos_flat[None], gt_mesh_verts[None], K=1, norm=2, return_nn=False)
mesh_nn_idx = mesh_nn_idx[0, :, 0]
voxel_skin_embeds = vert_skin_embeds[i][mesh_nn_idx]
voxel_skin_embeds = voxel_skin_embeds.view(v_pos.shape[0], -1)
voxel_skin_embeds = self.in_layer_skin(voxel_skin_embeds)
feats_new.append(voxel_skin_embeds)
joints_pos_list.append(joints_pos)
x_new = sp.SparseTensor(coords=torch.cat(coords_new, dim=0), feats=torch.cat(feats_new, dim=0))
x_skl_new = sp.SparseTensor(coords=torch.cat(coords_skl_new, dim=0), feats=torch.cat(feats_skl_new, dim=0))
return x_new, x_skl_new, joint_skin_embeds, vert_skin_embeds, joints_pos_list
def encode_sample(self, x: sp.SparseTensor, out_layer: sp.SparseLinear, sample_posterior: bool = True, latent_denoising: bool = False):
x = x.type(torch.float32)
x = x.replace(F.layer_norm(x.feats, x.feats.shape[-1:]))
x = out_layer(x)
if latent_denoising:
if self.normalize_latent:
x = x.replace(nn.functional.normalize(x.feats, dim=-1) * self.normalize_scale)
mean, logvar = x.feats, torch.zeros_like(x.feats)
else:
mean, logvar = x.feats.chunk(2, dim=-1)
if sample_posterior and not latent_denoising:
std = torch.exp(0.5 * logvar)
z = mean + std * torch.randn_like(std)
else:
z = mean
z = x.replace(z)
if latent_denoising:
mean = mean.detach()
return z, mean, logvar
def forward(self, x: sp.SparseTensor, x_skl: sp.SparseTensor, sample_posterior=True, return_raw=False, return_skin_encoded=False, **kwargs):
x_skin, x_skl, joint_skin_embeds, vert_skin_embeds, joints_pos = self.skeleton_embedding(x, x_skl, kwargs.get('gt_joints'), kwargs.get('gt_parents'), kwargs.get('gt_skin'), kwargs.get('gt_mesh'), kwargs.get('bvh_list', None))
h, h_skl, h_skin = super().forward(x, x_skl, x_skin)
z, mean, logvar = self.encode_sample(h, self.out_layer, sample_posterior, latent_denoising=False)
z_skl, mean_skl, logvar_skl = self.encode_sample(h_skl, self.out_layer_skl, sample_posterior, latent_denoising=self.latent_denoising)
z_skin, mean_skin, logvar_skin = self.encode_sample(h_skin, self.out_layer_vertskin, sample_posterior, latent_denoising=self.latent_denoising)
z = z.replace(torch.cat([z.feats, z_skin.feats], dim=-1))
mean, logvar = torch.cat([mean, mean_skin], dim=-1), torch.cat([logvar, logvar_skin], dim=-1)
if not return_skin_encoded:
# Ordinary return without skin encoded features
if return_raw:
return z, mean, logvar, z_skl, mean_skl, logvar_skl, joint_skin_embeds, vert_skin_embeds, joints_pos
else:
return z, z_skl, joint_skin_embeds, vert_skin_embeds, joints_pos
else:
# Return skin encoded features as well for checking
if return_raw:
return z, mean, logvar, z_skl, mean_skl, logvar_skl, joint_skin_embeds, vert_skin_embeds, joints_pos, x_skin, x_skl
else:
return z, z_skl, joint_skin_embeds, vert_skin_embeds, joints_pos, x_skin, x_skl
def encode_skin(self, joints_list: List[torch.Tensor], parents_list: List[torch.Tensor], skin_list: List[torch.Tensor]=None):
joint_skin_embeds, vert_skin_embeds = self.skin_encoder(joints_list, parents_list, skin_list)
return joint_skin_embeds, vert_skin_embeds
class AniGenElasticSLatEncoder(SparseTransformerElasticMixin, AniGenSLatEncoder):
"""
SLat VAE encoder with elastic memory management.
Used for training with low VRAM.
"""