Spaces:
Paused
Paused
| from dataclasses import asdict, dataclass | |
| from omegaconf import OmegaConf | |
| from scipy.spatial import cKDTree # type: ignore | |
| from torch import nn, Tensor | |
| from typing import Dict, List | |
| import math | |
| import numpy as np | |
| import random | |
| import torch | |
| import torch.nn.functional as F | |
| from src.rig_package.info.asset import Asset | |
| from .spec import ModelSpec, ModelInput, VaeInput | |
| from .skin_vae.autoencoders import SkinFSQCVAEModel | |
| try: | |
| from flash_attn_interface import flash_attn_func # type: ignore | |
| except Exception as e: | |
| from flash_attn.flash_attn_interface import flash_attn_func as _flash_attn_func | |
| def flash_attn_func(*args, **kwargs): | |
| res = _flash_attn_func(*args, **kwargs) | |
| return res, None | |
| class Perceiver(nn.Module): | |
| def __init__(self, channels, out_tokens, num_heads=8): | |
| super().__init__() | |
| self.q_vec = nn.Parameter(torch.randn(out_tokens // num_heads, num_heads, channels) * 0.02) | |
| self.num_heads = num_heads | |
| self.head_dim = channels // num_heads | |
| self.k_proj = nn.Linear(channels, channels) | |
| self.v_proj = nn.Linear(channels, channels) | |
| self.out_proj = nn.Linear(channels, channels) | |
| def forward(self, x: Tensor) -> Tensor: | |
| B, N, C = x.shape | |
| k = self.k_proj(x) # [B, N, C] | |
| v = self.v_proj(x) # [B, N, C] | |
| q_repeated = self.q_vec.repeat(B, 1, 1, 1) | |
| q = q_repeated.view(B, -1, self.num_heads, self.head_dim).type(torch.bfloat16) | |
| k = k.view(B, -1, self.num_heads, self.head_dim) | |
| v = v.view(B, -1, self.num_heads, self.head_dim) | |
| hidden_states, _ = flash_attn_func(q, k, v) | |
| hidden_states = hidden_states.view(B, -1, self.num_heads * self.head_dim) # type: ignore | |
| hidden_states = self.out_proj(hidden_states) | |
| return hidden_states | |
| class SkinVAEModel(ModelSpec): | |
| def __init__(self, model_config, transform_config, tokenizer_config=None): | |
| super().__init__(model_config, transform_config, tokenizer_config) | |
| cfg = self.model_config | |
| self.cond_tokens = cfg['sample']['cond_tokens'] | |
| self.compress_tokens = cfg['sample']['compress_tokens'] | |
| self.sample_tokens = cfg['sample']['sample_tokens'] | |
| self.only_dense = cfg['sample'].get('only_dense', False) | |
| self.model_type = cfg.get('type', 'fsqc') | |
| if self.model_type == 'fsqc': | |
| self.model = SkinFSQCVAEModel(**cfg['model'], sample_tokens=self.sample_tokens) | |
| else: | |
| raise NotImplementedError() | |
| if self.sample_tokens != self.compress_tokens: | |
| self.down_perceiver = Perceiver(self.model.latent_channels, self.compress_tokens) | |
| if self.sample_tokens != self.compress_tokens: | |
| self.up_perceiver = Perceiver(self.model.latent_channels, self.sample_tokens) | |
| def compile_model(self): | |
| self.model.compile_model() | |
| def vocab_size(self) -> int: | |
| return self.model.FSQ.codebook_size | |
| def latent_channels(self) -> int: | |
| return self.model.latent_channels | |
| def encode(self, vae_input: VaeInput, num_tokens: int=4, j: int=0, full: bool=False, encode_repeat: int=4, return_cond: bool=True): | |
| raise NotImplementedError() | |
| def decode(self, z: Tensor, sampled_cond: Tensor, cond_tokens: Tensor, full: bool=False, encode_repeat: int=4) -> Tensor: | |
| assert z.shape[0] == sampled_cond.shape[0] == cond_tokens.shape[0] | |
| if full: | |
| l = z.shape[0] | |
| s = [] | |
| for i in range(0, l, encode_repeat): | |
| t = min(l,i+encode_repeat) | |
| if self.sample_tokens != self.compress_tokens: | |
| _z = self.up_perceiver(z[i:t]) | |
| else: | |
| _z = z[i:t] | |
| logits = self.model._decode(z=_z, cond=cond_tokens[i:t], sampled_points=sampled_cond[i:t]) | |
| s.append(logits) | |
| return torch.cat(s, dim=0) | |
| else: | |
| if self.sample_tokens != self.compress_tokens: | |
| z = self.up_perceiver(z) | |
| logits = self.model._decode(z=z, cond=cond_tokens, sampled_points=sampled_cond) | |
| return logits | |
| def get_loss_dict( | |
| self, | |
| skin_pred: Tensor, | |
| skin_gt: Tensor, | |
| ) -> Dict[str, Tensor]: | |
| raise NotImplementedError() | |
| def get_input(self, batch: Dict) -> VaeInput: | |
| vertices: Tensor = batch['vertices'].float() # (B, N, 3) | |
| normals: Tensor = batch['normals'].float() # (B, N, 3) | |
| uniform_skin: List[Tensor] = batch['uniform_skin'] # [(N, J)] | |
| dense_skin: List[Tensor] = batch['dense_skin'] # [(J, skin_samples)] | |
| dense_vertices: List[Tensor] = batch['dense_vertices'] # [(J, skin_samples, 3)] | |
| dense_normals: List[Tensor] = batch['dense_normals'] # [(J, skin_samples, 3)] | |
| dense_indices: List[List[int]] = batch['dense_indices'] # [List[J]] | |
| B = vertices.shape[0] | |
| uniform_cond = torch.cat([vertices, normals], dim=-1).float() | |
| dense_cond = [] | |
| for i in range(B): | |
| dense_cond.append(torch.cat([dense_vertices[i], dense_normals[i]], dim=-1).float()) | |
| uniform_skin = [s.float() for s in uniform_skin] | |
| dense_skin = [s.float() for s in dense_skin] | |
| return VaeInput( | |
| dense_cond=dense_cond, | |
| dense_skin=dense_skin, | |
| dense_indices=dense_indices, | |
| uniform_cond=uniform_cond, | |
| uniform_skin=uniform_skin, | |
| ) | |
| def training_step(self, batch: Dict) -> Dict: | |
| raise NotImplementedError() | |
| def process_fn(self, batch: List[ModelInput], is_train: bool = True) -> List[Dict]: | |
| res = [] | |
| for b in batch: | |
| asset = b.asset | |
| assert asset is not None | |
| assert asset.sampled_vertex_groups is not None | |
| assert 'skin' in asset.sampled_vertex_groups | |
| assert asset.meta is not None | |
| assert 'dense_indices' in asset.meta | |
| assert 'dense_skin' in asset.meta | |
| assert 'dense_vertices' in asset.meta | |
| assert 'dense_normals' in asset.meta | |
| _d = { | |
| 'vertices': asset.sampled_vertices, | |
| 'normals': b.asset.sampled_normals, | |
| 'non': { | |
| 'uniform_skin': asset.sampled_vertex_groups['skin'], | |
| 'num_bones': asset.J, | |
| 'skin_samples': asset.skin_samples, | |
| 'dense_indices': asset.meta['dense_indices'], | |
| 'dense_skin': asset.meta['dense_skin'], | |
| 'dense_vertices': asset.meta['dense_vertices'], | |
| 'dense_normals': asset.meta['dense_normals'], | |
| } | |
| } | |
| res.append(_d) | |
| return res | |
| def forward(self, batch: Dict) -> Dict: | |
| return self.training_step(batch=batch) | |
| def predict_step(self, batch: Dict) -> Dict: | |
| vertices: Tensor = batch['vertices'].float() # (B, N, 3) | |
| num_bones: List[int] = batch['num_bones'] | |
| B = vertices.shape[0] | |
| N = vertices.shape[1] | |
| vae_input = self.get_input(batch=batch) | |
| num_tokens = 4 | |
| z, cond_tokens, indices, _ = self.encode(vae_input=vae_input, num_tokens=num_tokens, full=True, encode_repeat=8) | |
| assert cond_tokens is not None | |
| z = self.model.FSQ.indices_to_codes(indices).reshape(z.shape) | |
| _skin_pred = self.decode(z=z, sampled_cond=vae_input.get_flatten_uniform_cond(), cond_tokens=cond_tokens[vae_input.get_flatten_indices()], full=True, encode_repeat=8) | |
| _skin_pred = _skin_pred.squeeze(-1) | |
| tot = 0 | |
| results = [] | |
| for i in range(B): | |
| asset: Asset = batch['model_input'][i].asset.copy() | |
| skin_pred = torch.zeros((N, num_bones[i]), dtype=vertices.dtype, device=vertices.device) | |
| for j in range(vae_input.get_len(i=i)): | |
| skin_pred[:, vae_input.true_j(i=i, j=j)] = _skin_pred[tot] | |
| tot += 1 | |
| sampled_vertices = vertices[i].detach().float().cpu().numpy() | |
| tree = cKDTree(sampled_vertices) | |
| distances, indices = tree.query(asset.vertices) | |
| sampled_skin = skin_pred.detach().float().cpu().numpy()[indices] | |
| asset.skin = sampled_skin | |
| results.append(asset) | |
| return { | |
| 'results': results, | |
| } |