SkinTokens / src /model /skin_vae_model.py
pookiefoof's picture
Public release: SkinTokens 路 TokenRig demo
9d7cf7f
from dataclasses import asdict, dataclass
from omegaconf import OmegaConf
from scipy.spatial import cKDTree # type: ignore
from torch import nn, Tensor
from typing import Dict, List
import math
import numpy as np
import random
import torch
import torch.nn.functional as F
from src.rig_package.info.asset import Asset
from .spec import ModelSpec, ModelInput, VaeInput
from .skin_vae.autoencoders import SkinFSQCVAEModel
try:
from flash_attn_interface import flash_attn_func # type: ignore
except Exception as e:
from flash_attn.flash_attn_interface import flash_attn_func as _flash_attn_func
def flash_attn_func(*args, **kwargs):
res = _flash_attn_func(*args, **kwargs)
return res, None
class Perceiver(nn.Module):
def __init__(self, channels, out_tokens, num_heads=8):
super().__init__()
self.q_vec = nn.Parameter(torch.randn(out_tokens // num_heads, num_heads, channels) * 0.02)
self.num_heads = num_heads
self.head_dim = channels // num_heads
self.k_proj = nn.Linear(channels, channels)
self.v_proj = nn.Linear(channels, channels)
self.out_proj = nn.Linear(channels, channels)
def forward(self, x: Tensor) -> Tensor:
B, N, C = x.shape
k = self.k_proj(x) # [B, N, C]
v = self.v_proj(x) # [B, N, C]
q_repeated = self.q_vec.repeat(B, 1, 1, 1)
q = q_repeated.view(B, -1, self.num_heads, self.head_dim).type(torch.bfloat16)
k = k.view(B, -1, self.num_heads, self.head_dim)
v = v.view(B, -1, self.num_heads, self.head_dim)
hidden_states, _ = flash_attn_func(q, k, v)
hidden_states = hidden_states.view(B, -1, self.num_heads * self.head_dim) # type: ignore
hidden_states = self.out_proj(hidden_states)
return hidden_states
class SkinVAEModel(ModelSpec):
def __init__(self, model_config, transform_config, tokenizer_config=None):
super().__init__(model_config, transform_config, tokenizer_config)
cfg = self.model_config
self.cond_tokens = cfg['sample']['cond_tokens']
self.compress_tokens = cfg['sample']['compress_tokens']
self.sample_tokens = cfg['sample']['sample_tokens']
self.only_dense = cfg['sample'].get('only_dense', False)
self.model_type = cfg.get('type', 'fsqc')
if self.model_type == 'fsqc':
self.model = SkinFSQCVAEModel(**cfg['model'], sample_tokens=self.sample_tokens)
else:
raise NotImplementedError()
if self.sample_tokens != self.compress_tokens:
self.down_perceiver = Perceiver(self.model.latent_channels, self.compress_tokens)
if self.sample_tokens != self.compress_tokens:
self.up_perceiver = Perceiver(self.model.latent_channels, self.sample_tokens)
def compile_model(self):
self.model.compile_model()
@property
def vocab_size(self) -> int:
return self.model.FSQ.codebook_size
@property
def latent_channels(self) -> int:
return self.model.latent_channels
def encode(self, vae_input: VaeInput, num_tokens: int=4, j: int=0, full: bool=False, encode_repeat: int=4, return_cond: bool=True):
raise NotImplementedError()
def decode(self, z: Tensor, sampled_cond: Tensor, cond_tokens: Tensor, full: bool=False, encode_repeat: int=4) -> Tensor:
assert z.shape[0] == sampled_cond.shape[0] == cond_tokens.shape[0]
if full:
l = z.shape[0]
s = []
for i in range(0, l, encode_repeat):
t = min(l,i+encode_repeat)
if self.sample_tokens != self.compress_tokens:
_z = self.up_perceiver(z[i:t])
else:
_z = z[i:t]
logits = self.model._decode(z=_z, cond=cond_tokens[i:t], sampled_points=sampled_cond[i:t])
s.append(logits)
return torch.cat(s, dim=0)
else:
if self.sample_tokens != self.compress_tokens:
z = self.up_perceiver(z)
logits = self.model._decode(z=z, cond=cond_tokens, sampled_points=sampled_cond)
return logits
def get_loss_dict(
self,
skin_pred: Tensor,
skin_gt: Tensor,
) -> Dict[str, Tensor]:
raise NotImplementedError()
def get_input(self, batch: Dict) -> VaeInput:
vertices: Tensor = batch['vertices'].float() # (B, N, 3)
normals: Tensor = batch['normals'].float() # (B, N, 3)
uniform_skin: List[Tensor] = batch['uniform_skin'] # [(N, J)]
dense_skin: List[Tensor] = batch['dense_skin'] # [(J, skin_samples)]
dense_vertices: List[Tensor] = batch['dense_vertices'] # [(J, skin_samples, 3)]
dense_normals: List[Tensor] = batch['dense_normals'] # [(J, skin_samples, 3)]
dense_indices: List[List[int]] = batch['dense_indices'] # [List[J]]
B = vertices.shape[0]
uniform_cond = torch.cat([vertices, normals], dim=-1).float()
dense_cond = []
for i in range(B):
dense_cond.append(torch.cat([dense_vertices[i], dense_normals[i]], dim=-1).float())
uniform_skin = [s.float() for s in uniform_skin]
dense_skin = [s.float() for s in dense_skin]
return VaeInput(
dense_cond=dense_cond,
dense_skin=dense_skin,
dense_indices=dense_indices,
uniform_cond=uniform_cond,
uniform_skin=uniform_skin,
)
@torch.autocast(device_type='cuda', dtype=torch.bfloat16)
def training_step(self, batch: Dict) -> Dict:
raise NotImplementedError()
def process_fn(self, batch: List[ModelInput], is_train: bool = True) -> List[Dict]:
res = []
for b in batch:
asset = b.asset
assert asset is not None
assert asset.sampled_vertex_groups is not None
assert 'skin' in asset.sampled_vertex_groups
assert asset.meta is not None
assert 'dense_indices' in asset.meta
assert 'dense_skin' in asset.meta
assert 'dense_vertices' in asset.meta
assert 'dense_normals' in asset.meta
_d = {
'vertices': asset.sampled_vertices,
'normals': b.asset.sampled_normals,
'non': {
'uniform_skin': asset.sampled_vertex_groups['skin'],
'num_bones': asset.J,
'skin_samples': asset.skin_samples,
'dense_indices': asset.meta['dense_indices'],
'dense_skin': asset.meta['dense_skin'],
'dense_vertices': asset.meta['dense_vertices'],
'dense_normals': asset.meta['dense_normals'],
}
}
res.append(_d)
return res
def forward(self, batch: Dict) -> Dict:
return self.training_step(batch=batch)
@torch.autocast('cuda', dtype=torch.bfloat16)
def predict_step(self, batch: Dict) -> Dict:
vertices: Tensor = batch['vertices'].float() # (B, N, 3)
num_bones: List[int] = batch['num_bones']
B = vertices.shape[0]
N = vertices.shape[1]
vae_input = self.get_input(batch=batch)
num_tokens = 4
z, cond_tokens, indices, _ = self.encode(vae_input=vae_input, num_tokens=num_tokens, full=True, encode_repeat=8)
assert cond_tokens is not None
z = self.model.FSQ.indices_to_codes(indices).reshape(z.shape)
_skin_pred = self.decode(z=z, sampled_cond=vae_input.get_flatten_uniform_cond(), cond_tokens=cond_tokens[vae_input.get_flatten_indices()], full=True, encode_repeat=8)
_skin_pred = _skin_pred.squeeze(-1)
tot = 0
results = []
for i in range(B):
asset: Asset = batch['model_input'][i].asset.copy()
skin_pred = torch.zeros((N, num_bones[i]), dtype=vertices.dtype, device=vertices.device)
for j in range(vae_input.get_len(i=i)):
skin_pred[:, vae_input.true_j(i=i, j=j)] = _skin_pred[tot]
tot += 1
sampled_vertices = vertices[i].detach().float().cpu().numpy()
tree = cKDTree(sampled_vertices)
distances, indices = tree.query(asset.vertices)
sampled_skin = skin_pred.detach().float().cpu().numpy()[indices]
asset.skin = sampled_skin
results.append(asset)
return {
'results': results,
}