Spaces:
Build error
Build error
Refactor skin weight calculations to handle division by zero and ensure valid index access in Exporter and SAMPart3DDataset classes
27fa9cc
| import torch | |
| from torch import nn, FloatTensor, LongTensor, Tensor | |
| import torch.nn.functional as F | |
| import numpy as np | |
| from torch.nn.functional import pad | |
| from typing import Dict, List | |
| from transformers import AutoModelForCausalLM, AutoConfig | |
| import math | |
| import torch_scatter | |
| from flash_attn.modules.mha import MHA | |
| from .spec import ModelSpec, ModelInput | |
| from .parse_encoder import MAP_MESH_ENCODER, get_mesh_encoder | |
| from ..data.utils import linear_blend_skinning | |
| class FrequencyPositionalEmbedding(nn.Module): | |
| """The sin/cosine positional embedding. Given an input tensor `x` of shape [n_batch, ..., c_dim], it converts | |
| each feature dimension of `x[..., i]` into: | |
| [ | |
| sin(x[..., i]), | |
| sin(f_1*x[..., i]), | |
| sin(f_2*x[..., i]), | |
| ... | |
| sin(f_N * x[..., i]), | |
| cos(x[..., i]), | |
| cos(f_1*x[..., i]), | |
| cos(f_2*x[..., i]), | |
| ... | |
| cos(f_N * x[..., i]), | |
| x[..., i] # only present if include_input is True. | |
| ], here f_i is the frequency. | |
| Denote the space is [0 / num_freqs, 1 / num_freqs, 2 / num_freqs, 3 / num_freqs, ..., (num_freqs - 1) / num_freqs]. | |
| If logspace is True, then the frequency f_i is [2^(0 / num_freqs), ..., 2^(i / num_freqs), ...]; | |
| Otherwise, the frequencies are linearly spaced between [1.0, 2^(num_freqs - 1)]. | |
| Args: | |
| num_freqs (int): the number of frequencies, default is 6; | |
| logspace (bool): If logspace is True, then the frequency f_i is [..., 2^(i / num_freqs), ...], | |
| otherwise, the frequencies are linearly spaced between [1.0, 2^(num_freqs - 1)]; | |
| input_dim (int): the input dimension, default is 3; | |
| include_input (bool): include the input tensor or not, default is True. | |
| Attributes: | |
| frequencies (torch.Tensor): If logspace is True, then the frequency f_i is [..., 2^(i / num_freqs), ...], | |
| otherwise, the frequencies are linearly spaced between [1.0, 2^(num_freqs - 1); | |
| out_dim (int): the embedding size, if include_input is True, it is input_dim * (num_freqs * 2 + 1), | |
| otherwise, it is input_dim * num_freqs * 2. | |
| """ | |
| def __init__( | |
| self, | |
| num_freqs: int = 6, | |
| logspace: bool = True, | |
| input_dim: int = 3, | |
| include_input: bool = True, | |
| include_pi: bool = True, | |
| ) -> None: | |
| """The initialization""" | |
| super().__init__() | |
| if logspace: | |
| frequencies = 2.0 ** torch.arange(num_freqs, dtype=torch.float32) | |
| else: | |
| frequencies = torch.linspace( | |
| 1.0, 2.0 ** (num_freqs - 1), num_freqs, dtype=torch.float32 | |
| ) | |
| if include_pi: | |
| frequencies *= torch.pi | |
| self.register_buffer("frequencies", frequencies, persistent=False) | |
| self.include_input = include_input | |
| self.num_freqs = num_freqs | |
| self.out_dim = self._get_dims(input_dim) | |
| def _get_dims(self, input_dim): | |
| temp = 1 if self.include_input or self.num_freqs == 0 else 0 | |
| out_dim = input_dim * (self.num_freqs * 2 + temp) | |
| return out_dim | |
| def forward(self, x: torch.Tensor) -> torch.Tensor: | |
| """Forward process. | |
| Args: | |
| x: tensor of shape [..., dim] | |
| Returns: | |
| embedding: an embedding of `x` of shape [..., dim * (num_freqs * 2 + temp)] | |
| where temp is 1 if include_input is True and 0 otherwise. | |
| """ | |
| if self.num_freqs > 0: | |
| embed = (x[..., None].contiguous() * self.frequencies.to(device=x.device)).view( | |
| *x.shape[:-1], -1 | |
| ) | |
| if self.include_input: | |
| return torch.cat((x, embed.sin(), embed.cos()), dim=-1) | |
| else: | |
| return torch.cat((embed.sin(), embed.cos()), dim=-1) | |
| else: | |
| return x | |
| class ResidualCrossAttn(nn.Module): | |
| def __init__(self, feat_dim: int, num_heads: int): | |
| super().__init__() | |
| assert feat_dim % num_heads == 0, "feat_dim must be divisible by num_heads" | |
| self.norm1 = nn.LayerNorm(feat_dim) | |
| self.norm2 = nn.LayerNorm(feat_dim) | |
| # self.attention = nn.MultiheadAttention(embed_dim=feat_dim, num_heads=num_heads, batch_first=True) | |
| self.attention = MHA(embed_dim=feat_dim, num_heads=num_heads, cross_attn=True) | |
| self.ffn = nn.Sequential( | |
| nn.Linear(feat_dim, feat_dim * 4), | |
| nn.GELU(), | |
| nn.Linear(feat_dim * 4, feat_dim), | |
| ) | |
| def forward(self, q, kv): | |
| residual = q | |
| attn_output = self.attention(q, x_kv=kv) | |
| x = self.norm1(residual + attn_output) | |
| x = self.norm2(x + self.ffn(x)) | |
| return x | |
| class BoneEncoder(nn.Module): | |
| def __init__( | |
| self, | |
| feat_bone_dim: int, | |
| feat_dim: int, | |
| embed_dim: int, | |
| num_heads: int, | |
| num_attn: int, | |
| ): | |
| super().__init__() | |
| self.feat_bone_dim = feat_bone_dim | |
| self.feat_dim = feat_dim | |
| self.num_heads = num_heads | |
| self.num_attn = num_attn | |
| self.position_embed = FrequencyPositionalEmbedding(input_dim=self.feat_bone_dim) | |
| self.bone_encoder = nn.Sequential( | |
| self.position_embed, | |
| nn.Linear(self.position_embed.out_dim, embed_dim), | |
| nn.LayerNorm(embed_dim), | |
| nn.GELU(), | |
| nn.Linear(embed_dim, embed_dim * 4), | |
| nn.LayerNorm(embed_dim * 4), | |
| nn.GELU(), | |
| nn.Linear(embed_dim * 4, feat_dim), | |
| nn.LayerNorm(feat_dim), | |
| nn.GELU(), | |
| ) | |
| self.attn = nn.ModuleList() | |
| for _ in range(self.num_attn): | |
| self.attn.append(ResidualCrossAttn(feat_dim, self.num_heads)) | |
| def forward( | |
| self, | |
| base_bone: FloatTensor, | |
| num_bones: LongTensor, | |
| parents: LongTensor, | |
| min_coord: FloatTensor, | |
| global_latents: FloatTensor, | |
| ): | |
| # base_bone: (B, J, C) | |
| B = base_bone.shape[0] | |
| J = base_bone.shape[1] | |
| x = self.bone_encoder((base_bone-min_coord[:, None, :]).reshape(-1, base_bone.shape[-1])).reshape(B, J, -1) | |
| latents = torch.cat([x, global_latents], dim=1) | |
| for (i, attn) in enumerate(self.attn): | |
| x = attn(x, latents) | |
| return x | |
| class SkinweightPred(nn.Module): | |
| def __init__(self, in_dim, mlp_dim): | |
| super().__init__() | |
| self.net = nn.Sequential( | |
| nn.Linear(in_dim, mlp_dim), | |
| nn.LayerNorm(mlp_dim), | |
| nn.GELU(), | |
| nn.Linear(mlp_dim, mlp_dim), | |
| nn.LayerNorm(mlp_dim), | |
| nn.GELU(), | |
| nn.Linear(mlp_dim, mlp_dim), | |
| nn.LayerNorm(mlp_dim), | |
| nn.GELU(), | |
| nn.Linear(mlp_dim, mlp_dim), | |
| nn.LayerNorm(mlp_dim), | |
| nn.GELU(), | |
| nn.Linear(mlp_dim, 1), | |
| ) | |
| def forward(self, x): | |
| return self.net(x) | |
| class UniRigSkin(ModelSpec): | |
| def process_fn(self, batch: List[ModelInput]) -> List[Dict]: | |
| max_bones = 0 | |
| for b in batch: | |
| max_bones = max(max_bones, b.asset.J) | |
| res = [] | |
| current_offset = 0 | |
| for b in batch: | |
| vertex_groups = b.asset.sampled_vertex_groups | |
| current_offset += b.vertices.shape[0] | |
| # (N, J) | |
| voxel_skin = vertex_groups['voxel_skin'] | |
| voxel_skin = np.pad(voxel_skin, ((0, 0), (0, max_bones-b.asset.J)), 'constant', constant_values=0.0) | |
| # (J, 4, 4) | |
| res.append({ | |
| 'voxel_skin': voxel_skin, | |
| 'offset': current_offset, | |
| }) | |
| return res | |
| def __init__(self, mesh_encoder, global_encoder, **kwargs): | |
| super().__init__() | |
| self.num_train_vertex = kwargs['num_train_vertex'] | |
| self.feat_dim = kwargs['feat_dim'] | |
| self.num_heads = kwargs['num_heads'] | |
| self.grid_size = kwargs['grid_size'] | |
| self.mlp_dim = kwargs['mlp_dim'] | |
| self.num_bone_attn = kwargs['num_bone_attn'] | |
| self.num_mesh_bone_attn = kwargs['num_mesh_bone_attn'] | |
| self.bone_embed_dim = kwargs['bone_embed_dim'] | |
| self.voxel_mask = kwargs.get('voxel_mask', 2) | |
| self.mesh_encoder = get_mesh_encoder(**mesh_encoder) | |
| self.global_encoder = get_mesh_encoder(**global_encoder) | |
| if isinstance(self.mesh_encoder, MAP_MESH_ENCODER.ptv3obj): | |
| self.feat_map = nn.Sequential( | |
| nn.Linear(mesh_encoder['enc_channels'][-1], self.feat_dim), | |
| nn.LayerNorm(self.feat_dim), | |
| nn.GELU(), | |
| ) | |
| else: | |
| raise NotImplementedError() | |
| if isinstance(self.global_encoder, MAP_MESH_ENCODER.michelangelo_encoder): | |
| self.out_proj = nn.Sequential( | |
| nn.Linear(self.global_encoder.width, self.feat_dim), | |
| nn.LayerNorm(self.feat_dim), | |
| nn.GELU(), | |
| ) | |
| else: | |
| raise NotImplementedError() | |
| self.bone_encoder = BoneEncoder( | |
| feat_bone_dim=3, | |
| feat_dim=self.feat_dim, | |
| embed_dim=self.bone_embed_dim, | |
| num_heads=self.num_heads, | |
| num_attn=self.num_bone_attn, | |
| ) | |
| self.downscale = nn.Sequential( | |
| nn.Linear(2 * self.num_heads, self.num_heads), | |
| nn.LayerNorm(self.num_heads), | |
| nn.GELU(), | |
| ) | |
| self.skinweight_pred = SkinweightPred( | |
| self.num_heads, | |
| self.mlp_dim, | |
| ) | |
| self.mesh_bone_attn = nn.ModuleList() | |
| self.mesh_bone_attn.extend([ | |
| ResidualCrossAttn(self.feat_dim, self.num_heads) for _ in range(self.num_mesh_bone_attn) | |
| ]) | |
| self.qmesh = nn.Linear(self.feat_dim, self.feat_dim * self.num_heads) | |
| self.kmesh = nn.Linear(self.feat_dim, self.feat_dim * self.num_heads) | |
| self.voxel_skin_embed = nn.Linear(1, self.num_heads) | |
| self.voxel_skin_norm = nn.LayerNorm(self.num_heads) | |
| self.attn_skin_norm = nn.LayerNorm(self.num_heads) | |
| def encode_mesh_cond(self, vertices: FloatTensor, normals: FloatTensor) -> FloatTensor: | |
| assert not torch.isnan(vertices).any() | |
| assert not torch.isnan(normals).any() | |
| if isinstance(self.global_encoder, MAP_MESH_ENCODER.michelangelo_encoder): | |
| if (len(vertices.shape) == 3): | |
| shape_embed, latents, token_num, pre_pc = self.global_encoder.encode_latents(pc=vertices, feats=normals) | |
| else: | |
| shape_embed, latents, token_num, pre_pc = self.global_encoder.encode_latents(pc=vertices.unsqueeze(0), feats=normals.unsqueeze(0)) | |
| latents = self.out_proj(latents) | |
| return latents | |
| else: | |
| raise NotImplementedError() | |
| def _get_predict(self, batch: Dict) -> FloatTensor: | |
| ''' | |
| Return predicted skin. | |
| ''' | |
| num_bones: Tensor = batch['num_bones'] | |
| vertices: FloatTensor = batch['vertices'] # (B, N, 3) | |
| normals: FloatTensor = batch['normals'] | |
| joints: FloatTensor = batch['joints'] | |
| tails: FloatTensor = batch['tails'] | |
| voxel_skin: FloatTensor = batch['voxel_skin'] | |
| parents: LongTensor = batch['parents'] | |
| # turn inputs' dtype into model's dtype | |
| dtype = next(self.parameters()).dtype | |
| vertices = vertices.type(dtype) | |
| normals = normals.type(dtype) | |
| joints = joints.type(dtype) | |
| tails = tails.type(dtype) | |
| voxel_skin = voxel_skin.type(dtype) | |
| B = vertices.shape[0] | |
| N = vertices.shape[1] | |
| J = joints.shape[1] | |
| assert vertices.dim() == 3 | |
| assert normals.dim() == 3 | |
| part_offset = torch.tensor([(i+1)*N for i in range(B)], dtype=torch.int64, device=vertices.device) | |
| idx_ptr = torch.nn.functional.pad(part_offset, (1, 0), value=0) | |
| min_coord = torch_scatter.segment_csr(vertices.reshape(-1, 3), idx_ptr, reduce="min") | |
| pack = [] | |
| if self.training: | |
| train_indices = torch.randperm(N)[:self.num_train_vertex] | |
| pack.append(train_indices) | |
| else: | |
| for i in range((N + self.num_train_vertex - 1) // self.num_train_vertex): | |
| pack.append(torch.arange(i*self.num_train_vertex, min((i+1)*self.num_train_vertex, N))) | |
| # (B, seq_len, feat_dim) | |
| global_latents = self.encode_mesh_cond(vertices, normals) | |
| bone_feat = self.bone_encoder( | |
| base_bone=joints, | |
| num_bones=num_bones, | |
| parents=parents, | |
| min_coord=min_coord, | |
| global_latents=global_latents, | |
| ) | |
| if isinstance(self.mesh_encoder, MAP_MESH_ENCODER.ptv3obj): | |
| feat = torch.cat([vertices, normals, torch.zeros_like(vertices)], dim=-1) | |
| ptv3_input = { | |
| 'coord': vertices.reshape(-1, 3), | |
| 'feat': feat.reshape(-1, 9), | |
| 'offset': batch['offset'].detach().clone(), | |
| 'grid_size': self.grid_size, | |
| } | |
| if not self.training: | |
| # must cast to float32 to avoid sparse-conv precision bugs | |
| with torch.autocast(device_type='cuda', dtype=torch.float32): | |
| mesh_feat = self.mesh_encoder(ptv3_input).feat | |
| mesh_feat = self.feat_map(mesh_feat).view(B, N, self.feat_dim) | |
| else: | |
| mesh_feat = self.mesh_encoder(ptv3_input).feat | |
| mesh_feat = self.feat_map(mesh_feat).view(B, N, self.feat_dim) | |
| mesh_feat = mesh_feat.type(dtype) | |
| else: | |
| raise NotImplementedError() | |
| # (B, J + seq_len, feat_dim) | |
| latents = torch.cat([bone_feat, global_latents], dim=1) | |
| # (B, N, feat_dim) | |
| for block in self.mesh_bone_attn: | |
| mesh_feat = block( | |
| q=mesh_feat, | |
| kv=latents, | |
| ) | |
| # trans to (B, num_heads, J, feat_dim) | |
| bone_feat = self.kmesh(bone_feat).view(B, J, self.num_heads, self.feat_dim).transpose(1, 2) | |
| skin_pred_list = [] | |
| if not self.training: | |
| skin_mask = voxel_skin.clone() | |
| for b in range(B): | |
| num = num_bones[b] | |
| for i in range(num): | |
| p = parents[b, i] | |
| if p < 0: | |
| continue | |
| skin_mask[b, :, p] += skin_mask[b, :, i] | |
| for indices in pack: | |
| cur_N = len(indices) | |
| # trans to (B, num_heads, N, feat_dim) | |
| cur_mesh_feat = self.qmesh(mesh_feat[:, indices]).view(B, cur_N, self.num_heads, self.feat_dim).transpose(1, 2) | |
| # attn_weight shape : (B, num_heads, N, J) | |
| attn_weight = F.softmax(torch.bmm( | |
| cur_mesh_feat.reshape(B * self.num_heads, cur_N, -1), | |
| bone_feat.transpose(-2, -1).reshape(B * self.num_heads, -1, J) | |
| ) / math.sqrt(self.feat_dim), dim=-1, dtype=dtype) | |
| # (B, num_heads, N, J) -> (B, N, J, num_heads) | |
| attn_weight = attn_weight.reshape(B, self.num_heads, cur_N, J).permute(0, 2, 3, 1) | |
| attn_weight = self.attn_skin_norm(attn_weight) | |
| embed_voxel_skin = self.voxel_skin_embed(voxel_skin[:, indices].reshape(B, cur_N, J, 1)) | |
| embed_voxel_skin = self.voxel_skin_norm(embed_voxel_skin) | |
| attn_weight = torch.cat([attn_weight, embed_voxel_skin], dim=-1) | |
| attn_weight = self.downscale(attn_weight) | |
| # (B, N, J, num_heads * (1+c)) -> (B, N, J) | |
| skin_pred = torch.zeros(B, cur_N, J).to(attn_weight.device, dtype) | |
| for i in range(B): | |
| # (N*J, C) | |
| input_features = attn_weight[i, :, :num_bones[i], :].reshape(-1, attn_weight.shape[-1]) | |
| pred = self.skinweight_pred(input_features).reshape(cur_N, num_bones[i]) | |
| skin_pred[i, :, :num_bones[i]] = F.softmax(pred, dim=-1) | |
| skin_pred_list.append(skin_pred) | |
| skin_pred_list = torch.cat(skin_pred_list, dim=1) | |
| for i in range(B): | |
| n = num_bones[i] | |
| skin_pred_list[i, :, :n] = skin_pred_list[i, :, :n] * torch.pow(skin_mask[i, :, :n], self.voxel_mask) | |
| skin_pred_list[i, :, :n] = skin_pred_list[i, :, :n] / skin_pred_list[i, :, :n].sum(dim=-1, keepdim=True) | |
| return skin_pred_list, torch.cat(pack, dim=0) | |
| def predict_step(self, batch: Dict): | |
| with torch.no_grad(): | |
| num_bones: Tensor = batch['num_bones'] | |
| skin_pred, _ = self._get_predict(batch=batch) | |
| outputs = [] | |
| for i in range(skin_pred.shape[0]): | |
| outputs.append(skin_pred[i, :, :num_bones[i]]) | |
| return outputs |