Spaces:
Build error
Build error
Refactor skin weight calculations to handle division by zero and ensure valid index access in Exporter and SAMPart3DDataset classes
27fa9cc
| from collections import defaultdict | |
| import torch.distributed | |
| import lightning as L | |
| import os | |
| import torch | |
| import numpy as np | |
| from torch import Tensor, FloatTensor, LongTensor | |
| from typing import Dict, Union, List, Literal | |
| from lightning.pytorch.callbacks import BasePredictionWriter | |
| from numpy import ndarray | |
| from scipy.sparse import csr_matrix | |
| from scipy.spatial import cKDTree | |
| from ..data.order import OrderConfig, get_order | |
| from ..data.raw_data import RawSkin, RawData | |
| from ..data.exporter import Exporter | |
| from ..model.spec import ModelSpec | |
| class SkinSystem(L.LightningModule): | |
| def __init__( | |
| self, | |
| steps_per_epoch: int, | |
| model: ModelSpec, | |
| output_path: Union[str, None]=None, | |
| record_res: Union[bool]=False, | |
| val_interval: Union[int, None]=None, | |
| val_start_from: Union[int, None]=None, | |
| ): | |
| super().__init__() | |
| self.save_hyperparameters(ignore="model") | |
| self.steps_per_epoch = steps_per_epoch | |
| self.model = model | |
| self.output_path = output_path | |
| self.record_res = record_res | |
| self.val_interval = val_interval | |
| self.val_start_from = val_start_from | |
| if self.record_res: | |
| assert self.output_path is not None, "record_res is True, but output_path in skin is None" | |
| def predict_step(self, batch, batch_idx, dataloader_idx=None): | |
| res = self.model.predict_step(batch) | |
| if isinstance(res, list): | |
| return { | |
| 'skin_pred': res, | |
| } | |
| elif isinstance(res, dict): | |
| assert 'skin_pred' in res, f"expect key 'skin_pred' in prediction from {self.model.__class__}, found: {res.keys()}" | |
| return res | |
| else: | |
| assert 0, f"expect type of prediction from {self.model.__class__} to be a list or dict, found: {type(res)}" | |
| class SkinWriter(BasePredictionWriter): | |
| def __init__( | |
| self, | |
| output_dir: Union[str, None], | |
| save_name: str, | |
| order_config: Union[OrderConfig, None]=None, | |
| **kwargs | |
| ): | |
| super().__init__('batch') | |
| self.output_dir = output_dir | |
| self.npz_dir = kwargs.get('npz_dir', None) | |
| self.user_mode = kwargs.get('user_mode', False) | |
| self.output_name = kwargs.get('output_name', None) # for a single name | |
| self.save_name = save_name | |
| self.add_num = kwargs.get('add_num', False) | |
| self.export_npz = kwargs.get('export_npz', True) | |
| self.export_fbx = kwargs.get('export_fbx', False) | |
| if order_config is not None: | |
| self.order = get_order(config=order_config) | |
| else: | |
| self.order = None | |
| self._epoch = 0 | |
| def write_on_batch_end(self, trainer, pl_module: SkinSystem, prediction: List[Dict], batch_indices, batch, batch_idx, dataloader_idx): | |
| assert 'path' in batch | |
| paths: List[str] = batch['path'] | |
| data_names: List[str] = batch['data_name'] | |
| joints: FloatTensor = batch['joints'] | |
| num_bones: LongTensor = batch['num_bones'] | |
| num_faces: LongTensor = batch['num_faces'] | |
| num_points: LongTensor = batch['num_points'] | |
| tails: FloatTensor = batch['tails'] | |
| parents_list: LongTensor = batch['parents'] # -1 represents root | |
| vertices: FloatTensor = batch['origin_vertices'] | |
| sampled_vertices: FloatTensor = batch['vertices'] | |
| faces: LongTensor = batch['origin_faces'] | |
| joints = joints.detach().cpu().numpy() | |
| tails = tails.detach().cpu().numpy() | |
| parents_list = parents_list.detach().cpu().numpy() | |
| num_bones = num_bones.detach().cpu().numpy() | |
| num_faces = num_faces.detach().cpu().numpy() | |
| vertices = vertices.detach().cpu().numpy() | |
| faces = faces.detach().cpu().numpy() | |
| skin_pred_list: List = prediction['skin_pred'] | |
| ret_sampled_vertices = prediction.get('sampled_vertices', None) | |
| if ret_sampled_vertices is not None: | |
| assert isinstance(ret_sampled_vertices, Tensor) | |
| sampled_vertices = ret_sampled_vertices | |
| if isinstance(sampled_vertices, Tensor): | |
| sampled_vertices = sampled_vertices.type(torch.float32).detach().cpu().numpy() | |
| for (id, skin_pred) in enumerate(skin_pred_list): | |
| if isinstance(skin_pred, Tensor): | |
| skin_pred = skin_pred.type(torch.float32).detach().cpu().numpy() | |
| # TODO: add custom post-processing here | |
| # resample | |
| N = num_points[id] | |
| J = num_bones[id] | |
| F = num_faces[id] | |
| o_vertices = vertices[id, :N] | |
| _parents = parents_list[id] | |
| parents = [] | |
| for i in range(J): | |
| if _parents[i] == -1: | |
| parents.append(None) | |
| else: | |
| parents.append(_parents[i]) | |
| skin_resampled = reskin( | |
| sampled_vertices=sampled_vertices[id], | |
| vertices=o_vertices, | |
| parents=parents, | |
| faces=faces[id, :F], | |
| sampled_skin=skin_pred, | |
| sample_method='median', | |
| alpha=2.0, | |
| threshold=0.03, | |
| ) | |
| def make_path(save_name: str, suffix: str, trim: bool=False): | |
| if trim: | |
| path = os.path.relpath(paths[id], self.npz_dir) | |
| else: | |
| path = paths[id] | |
| if self.output_dir is not None: | |
| path = os.path.join(self.output_dir, path) | |
| if self.add_num: | |
| path = os.path.join(path, f"{save_name}_{self._epoch}.{suffix}") | |
| else: | |
| path = os.path.join(path, f"{save_name}.{suffix}") | |
| return path | |
| raw_data = RawSkin(skin=skin_pred, vertices=sampled_vertices[id], joints=joints[id, :J]) | |
| if self.export_npz is not None: | |
| raw_data.save(path=make_path(self.export_npz, 'npz')) | |
| if self.export_fbx is not None: | |
| exporter = Exporter() | |
| names = RawData.load(path=os.path.join(paths[id], data_names[id])).names | |
| if names is None: | |
| names = [f"bone_{i}" for i in range(J)] | |
| if self.user_mode: | |
| if self.output_name is not None: | |
| path = self.output_name | |
| else: | |
| path = make_path(self.save_name, 'fbx', trim=True) | |
| else: | |
| path = make_path(self.export_fbx, 'fbx') | |
| exporter._export_fbx( | |
| path=path, | |
| vertices=o_vertices, | |
| joints=joints[id, :J], | |
| skin=skin_resampled, | |
| parents=parents, | |
| names=names, | |
| faces=faces[id, :F], | |
| group_per_vertex=4, | |
| tails=tails[id, :J], | |
| use_extrude_bone=False, | |
| use_connect_unique_child=False, | |
| # do_not_normalize=True, | |
| ) | |
| def write_on_epoch_end(self, trainer, pl_module, predictions, batch_indices): | |
| self._epoch += 1 | |
| def reskin( | |
| sampled_vertices: ndarray, | |
| vertices: ndarray, | |
| parents: List[Union[None, int]], | |
| faces: ndarray, | |
| sampled_skin: ndarray, | |
| sample_method: Literal['mean', 'median']='mean', | |
| **kwargs, | |
| ) -> ndarray: | |
| nearest_samples = kwargs.get('nearest_samples', 7) | |
| iter_steps = kwargs.get('iter_steps', 1) | |
| threshold = kwargs.get('threshold', 0.01) | |
| alpha = kwargs.get('alpha', 2) | |
| assert sample_method in ['mean', 'median'] | |
| N = vertices.shape[0] | |
| J = sampled_skin.shape[1] | |
| if sample_method == 'mean': | |
| tree = cKDTree(sampled_vertices) | |
| dis, nearest = tree.query(vertices, k=nearest_samples, p=2) | |
| # weighted sum | |
| weights = np.exp(-alpha * dis) # (N, nearest_samples) | |
| weight_sum = weights.sum(axis=1, keepdims=True) | |
| sampled_skin_nearest = sampled_skin[nearest] | |
| skin = (sampled_skin_nearest * weights[..., np.newaxis]).sum(axis=1) / weight_sum | |
| elif sample_method == 'median': | |
| tree = cKDTree(sampled_vertices) | |
| dis, nearest = tree.query(vertices, k=nearest_samples, p=2) | |
| skin = np.median(sampled_skin[nearest], axis=1) | |
| else: | |
| assert 0 | |
| # (from, to) | |
| edges = np.concatenate([faces[:, [0, 1]], faces[:, [1, 2]], faces[:, [2, 0]]], axis=0) | |
| edges = np.concatenate([edges, edges[:, [1, 0]]], axis=0) # (2*F*3, 2) | |
| # diffusion in neighbours | |
| for _ in range(iter_steps): | |
| sum_skin = skin.copy() | |
| for i in reversed(range(J)): | |
| p = parents[i] | |
| if p is None: | |
| continue | |
| sum_skin[:, p] += sum_skin[:, i] | |
| # (2*F*3, J) | |
| # only transfer from hotter to cooler | |
| mask = sum_skin[edges[:, 1]] < sum_skin[edges[:, 0]] | |
| neighbor_skin = np.zeros_like(sum_skin) # (N, J) | |
| neighbor_co = np.zeros((N, J), dtype=np.float32) | |
| dis = np.sqrt(((vertices[edges[:, 1]] - vertices[edges[:, 0]])**2).sum(axis=1, keepdims=True)) | |
| co = np.exp(-dis * alpha) | |
| neighbor_skin[edges[:, 1]] += sum_skin[edges[:, 0]] * co * mask | |
| neighbor_co[edges[:, 1]] += co * mask | |
| sum_skin = (sum_skin + neighbor_skin) / (1. + neighbor_co) | |
| for i in range(J): | |
| p = parents[i] | |
| if p is None: | |
| continue | |
| sum_skin[:, p] -= sum_skin[:, i] | |
| skin = sum_skin / sum_skin.sum(axis=-1, keepdims=True) | |
| # avoid 0-skin | |
| mask = (skin>=threshold).any(axis=-1, keepdims=True) | |
| skin[(skin<threshold)&mask] = 0. | |
| skin = skin / skin.sum(axis=-1, keepdims=True) | |
| return skin |