| """ |
| MotionVQVAE - Motion Vector Quantized VAE for HuggingFace |
| |
| Load and use the MotionVQVAE model for motion tokenization and reconstruction. |
| |
| Usage: |
| from motion_vqvae_hf import MotionVQVAE |
| |
| # Load from HuggingFace Hub |
| model = MotionVQVAE.from_pretrained("khania/motion-vqvae") |
| |
| # Encode motion to tokens |
| tokens = model.encode(motion_array) # (B, T, 272) -> (num_groups, B, T') |
| |
| # Decode tokens back to motion |
| motion_recon = model.decode(tokens) # (num_groups, B, T') -> (B, T, 272) |
| |
| # Full forward pass |
| motion_recon, tokens = model(motion_array) |
| """ |
|
|
| import os |
| import json |
| import math |
| import torch |
| import torch.nn as nn |
| import torch.nn.functional as F |
| import numpy as np |
| from typing import List, Union, Optional, Dict, Any, Tuple |
| from pathlib import Path |
|
|
| try: |
| from huggingface_hub import snapshot_download |
| HF_HUB_AVAILABLE = True |
| except ImportError: |
| HF_HUB_AVAILABLE = False |
|
|
|
|
| |
| |
| |
|
|
| class ResConv1DBlock(nn.Module): |
| """Residual 1D Convolution Block - matches original models/resnet.py exactly.""" |
| |
| def __init__(self, n_in: int, n_state: int, dilation: int = 1, |
| activation: str = 'relu', norm: str = None, kernel_size: int = 3): |
| super().__init__() |
| padding = dilation * (kernel_size - 1) // 2 |
| self.norm = norm |
| |
| |
| if norm == "LN": |
| self.norm1 = nn.LayerNorm(n_in) |
| self.norm2 = nn.LayerNorm(n_in) |
| elif norm == "GN": |
| self.norm1 = nn.GroupNorm(num_groups=32, num_channels=n_in, eps=1e-6, affine=True) |
| self.norm2 = nn.GroupNorm(num_groups=32, num_channels=n_in, eps=1e-6, affine=True) |
| elif norm == "BN": |
| self.norm1 = nn.BatchNorm1d(num_features=n_in, eps=1e-6, affine=True) |
| self.norm2 = nn.BatchNorm1d(num_features=n_in, eps=1e-6, affine=True) |
| else: |
| self.norm1 = nn.Identity() |
| self.norm2 = nn.Identity() |
| |
| |
| if activation == "relu": |
| self.activation1 = nn.ReLU() |
| self.activation2 = nn.ReLU() |
| elif activation == "silu": |
| self.activation1 = nn.SiLU() |
| self.activation2 = nn.SiLU() |
| elif activation == "gelu": |
| self.activation1 = nn.GELU() |
| self.activation2 = nn.GELU() |
| else: |
| self.activation1 = nn.ReLU() |
| self.activation2 = nn.ReLU() |
| |
| |
| self.conv1 = nn.Conv1d(n_in, n_state, kernel_size, 1, padding, dilation) |
| self.conv2 = nn.Conv1d(n_state, n_in, 1, 1, 0) |
| |
| def forward(self, x: torch.Tensor) -> torch.Tensor: |
| x_orig = x |
| if self.norm == "LN": |
| x = self.norm1(x.transpose(-2, -1)) |
| x = self.activation1(x.transpose(-2, -1)) |
| else: |
| x = self.norm1(x) |
| x = self.activation1(x) |
| |
| x = self.conv1(x) |
| |
| if self.norm == "LN": |
| x = self.norm2(x.transpose(-2, -1)) |
| x = self.activation2(x.transpose(-2, -1)) |
| else: |
| x = self.norm2(x) |
| x = self.activation2(x) |
| |
| x = self.conv2(x) |
| x = x + x_orig |
| return x |
|
|
|
|
| class Resnet1D(nn.Module): |
| """1D Residual Network - matches original models/resnet.py exactly. |
| |
| Uses self.model = nn.Sequential(*blocks) to match checkpoint key structure. |
| """ |
| |
| def __init__(self, n_in: int, n_depth: int, dilation_growth_rate: int = 1, |
| reverse_dilation: bool = False, activation: str = 'relu', |
| norm: str = None, kernel_size: int = 3): |
| super().__init__() |
| |
| blocks = [ |
| ResConv1DBlock(n_in, n_in, dilation=dilation_growth_rate ** depth, |
| activation=activation, norm=norm, kernel_size=kernel_size) |
| for depth in range(n_depth) |
| ] |
| if reverse_dilation: |
| blocks = blocks[::-1] |
| |
| |
| self.model = nn.Sequential(*blocks) |
| |
| def forward(self, x: torch.Tensor) -> torch.Tensor: |
| return self.model(x) |
|
|
|
|
| class Encoder(nn.Module): |
| """1D CNN Encoder - matches original models/encdec.py exactly. |
| |
| Uses self.model = nn.Sequential(*blocks) to match checkpoint key structure: |
| - model.0: Conv1d (input projection) |
| - model.1: ReLU |
| - model.2: Sequential(Conv1d, Resnet1D) for first downsample |
| - model.3: Sequential(Conv1d, Resnet1D) for second downsample |
| - model.4: Conv1d (output projection) |
| """ |
| |
| def __init__( |
| self, |
| input_emb_width: int = 272, |
| output_emb_width: int = 512, |
| down_t: int = 2, |
| stride_t: int = 2, |
| width: int = 512, |
| depth: int = 3, |
| dilation_growth_rate: int = 3, |
| activation: str = 'relu', |
| norm: str = None, |
| kernel_size: int = 3 |
| ): |
| super().__init__() |
| |
| blocks = [] |
| filter_t, pad_t = stride_t * 2, stride_t // 2 |
| |
| |
| blocks.append(nn.Conv1d(input_emb_width, width, kernel_size, 1, (kernel_size - 1) // 2)) |
| |
| blocks.append(nn.ReLU()) |
| |
| |
| for i in range(down_t): |
| input_dim = width |
| block = nn.Sequential( |
| nn.Conv1d(input_dim, width, filter_t, stride_t, pad_t), |
| Resnet1D(width, depth, dilation_growth_rate, activation=activation, |
| norm=norm, kernel_size=kernel_size), |
| ) |
| blocks.append(block) |
| |
| |
| blocks.append(nn.Conv1d(width, output_emb_width, kernel_size, 1, (kernel_size - 1) // 2)) |
| |
| |
| self.model = nn.Sequential(*blocks) |
| |
| def forward(self, x: torch.Tensor) -> torch.Tensor: |
| return self.model(x) |
|
|
|
|
| class Decoder(nn.Module): |
| """1D CNN Decoder - matches original models/encdec.py exactly. |
| |
| Uses self.model = nn.Sequential(*blocks) to match checkpoint key structure: |
| - model.0: Conv1d (input projection) |
| - model.1: ReLU |
| - model.2: Sequential(Resnet1D, Upsample, Conv1d) for first upsample |
| - model.3: Sequential(Resnet1D, Upsample, Conv1d) for second upsample |
| - model.4: Conv1d |
| - model.5: ReLU |
| - model.6: Conv1d (output projection) |
| """ |
| |
| def __init__( |
| self, |
| input_emb_width: int = 272, |
| output_emb_width: int = 512, |
| down_t: int = 2, |
| stride_t: int = 2, |
| width: int = 512, |
| depth: int = 3, |
| dilation_growth_rate: int = 3, |
| activation: str = 'relu', |
| norm: str = None, |
| kernel_size: int = 3 |
| ): |
| super().__init__() |
| |
| blocks = [] |
| filter_t, pad_t = stride_t * 2, stride_t // 2 |
| |
| |
| blocks.append(nn.Conv1d(output_emb_width, width, kernel_size, 1, (kernel_size - 1) // 2)) |
| |
| blocks.append(nn.ReLU()) |
| |
| |
| for i in range(down_t): |
| out_dim = width |
| block = nn.Sequential( |
| Resnet1D(width, depth, dilation_growth_rate, reverse_dilation=True, |
| activation=activation, norm=norm, kernel_size=kernel_size), |
| nn.Upsample(scale_factor=2, mode='linear', align_corners=False), |
| nn.Conv1d(width, out_dim, 3, 1, 1) |
| ) |
| blocks.append(block) |
| |
| |
| blocks.append(nn.Conv1d(width, width, kernel_size, 1, (kernel_size - 1) // 2)) |
| |
| blocks.append(nn.ReLU()) |
| |
| blocks.append(nn.Conv1d(width, input_emb_width, kernel_size, 1, (kernel_size - 1) // 2)) |
| |
| |
| self.model = nn.Sequential(*blocks) |
| |
| def forward(self, x: torch.Tensor) -> torch.Tensor: |
| return self.model(x) |
|
|
|
|
| |
| |
| |
|
|
| class VectorQuantizerEMA(nn.Module): |
| """Single Vector Quantizer with EMA updates. |
| |
| Uses self.codebook as nn.Parameter to match checkpoint keys. |
| """ |
| |
| def __init__( |
| self, |
| num_embeddings: int = 512, |
| embedding_dim: int = 8, |
| decay: float = 0.99, |
| epsilon: float = 1e-5 |
| ): |
| super().__init__() |
| |
| self.num_embeddings = num_embeddings |
| self.embedding_dim = embedding_dim |
| self.decay = decay |
| self.epsilon = epsilon |
| |
| |
| self.codebook = nn.Parameter(torch.randn(num_embeddings, embedding_dim)) |
| |
| def forward(self, z: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]: |
| """ |
| Args: |
| z: (B, D, T) latent features for this group |
| Returns: |
| z_q: (B, D, T) quantized features |
| indices: (B, T) codebook indices |
| """ |
| B, D, T = z.shape |
| |
| |
| z_flat = z.permute(0, 2, 1).reshape(-1, D) |
| |
| |
| |
| distances = ( |
| torch.sum(z_flat ** 2, dim=1, keepdim=True) |
| + torch.sum(self.codebook ** 2, dim=1) |
| - 2 * torch.matmul(z_flat, self.codebook.t()) |
| ) |
| |
| |
| indices = torch.argmin(distances, dim=1) |
| |
| |
| z_q_flat = F.embedding(indices, self.codebook) |
| |
| |
| z_q = z_q_flat.reshape(B, T, D).permute(0, 2, 1) |
| |
| |
| z_q = z + (z_q - z).detach() |
| |
| |
| indices = indices.reshape(B, T) |
| |
| return z_q, indices |
| |
| def decode_indices(self, indices: torch.Tensor) -> torch.Tensor: |
| """ |
| Args: |
| indices: (B, T) codebook indices |
| Returns: |
| z_q: (B, D, T) quantized features |
| """ |
| B, T = indices.shape |
| z_q_flat = F.embedding(indices.reshape(-1), self.codebook) |
| z_q = z_q_flat.reshape(B, T, -1).permute(0, 2, 1) |
| return z_q |
|
|
|
|
| class MultiGroupVectorQuantizer(nn.Module): |
| """Multi-Group Vector Quantizer - splits latent into groups. |
| |
| Uses self.quantizers = nn.ModuleList to match checkpoint keys like |
| 'quantizer.quantizers.0.codebook', 'quantizer.quantizers.1.codebook', etc. |
| """ |
| |
| def __init__( |
| self, |
| num_groups: int = 64, |
| num_embeddings: int = 512, |
| embedding_dim: int = 512, |
| decay: float = 0.99, |
| epsilon: float = 1e-5 |
| ): |
| super().__init__() |
| |
| self.num_groups = num_groups |
| self.num_embeddings = num_embeddings |
| self.embedding_dim = embedding_dim |
| self.group_dim = embedding_dim // num_groups |
| |
| |
| self.quantizers = nn.ModuleList([ |
| VectorQuantizerEMA(num_embeddings, self.group_dim, decay, epsilon) |
| for _ in range(num_groups) |
| ]) |
| |
| def forward(self, z: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]: |
| """ |
| Args: |
| z: (B, D, T) latent features |
| Returns: |
| z_q: (B, D, T) quantized features |
| indices: (num_groups, B, T) codebook indices per group |
| """ |
| B, D, T = z.shape |
| |
| |
| z_groups = z.chunk(self.num_groups, dim=1) |
| |
| z_q_groups = [] |
| indices_list = [] |
| |
| for i, (z_g, quantizer) in enumerate(zip(z_groups, self.quantizers)): |
| z_q_g, idx_g = quantizer(z_g) |
| z_q_groups.append(z_q_g) |
| indices_list.append(idx_g) |
| |
| |
| z_q = torch.cat(z_q_groups, dim=1) |
| |
| |
| indices = torch.stack(indices_list, dim=0) |
| |
| return z_q, indices |
| |
| def decode_indices(self, indices: torch.Tensor) -> torch.Tensor: |
| """ |
| Args: |
| indices: (num_groups, B, T) codebook indices |
| Returns: |
| z_q: (B, D, T) quantized features |
| """ |
| z_q_groups = [] |
| for i, quantizer in enumerate(self.quantizers): |
| z_q_g = quantizer.decode_indices(indices[i]) |
| z_q_groups.append(z_q_g) |
| |
| z_q = torch.cat(z_q_groups, dim=1) |
| return z_q |
|
|
|
|
| |
| |
| |
|
|
| class MotionVQVAE(nn.Module): |
| """Motion Vector Quantized VAE for HuggingFace. |
| |
| Architecture matches the original training code exactly to ensure |
| checkpoint compatibility. |
| """ |
| |
| def __init__(self, config: Optional[Dict[str, Any]] = None): |
| super().__init__() |
| |
| |
| if config is None: |
| config = {} |
| |
| self.config = config |
| |
| |
| |
| self.motion_dim = config.get('motion_dim', config.get('input_dim', 272)) |
| self.latent_dim = config.get('latent_dim', config.get('code_dim', 512)) |
| self.num_groups = config.get('num_groups', 64) |
| self.num_codes = config.get('num_codes', config.get('nb_code', 512)) |
| self.down_t = config.get('down_t', 2) |
| self.stride_t = config.get('stride_t', 2) |
| self.width = config.get('width', 512) |
| self.depth = config.get('depth', 3) |
| self.dilation_growth_rate = config.get('dilation_growth_rate', 3) |
| self.activation = config.get('activation', 'relu') |
| self.kernel_size = config.get('kernel_size', 3) |
| |
| |
| self.register_buffer('mean', torch.zeros(self.motion_dim)) |
| self.register_buffer('std', torch.ones(self.motion_dim)) |
| |
| |
| self.encoder = Encoder( |
| input_emb_width=self.motion_dim, |
| output_emb_width=self.latent_dim, |
| down_t=self.down_t, |
| stride_t=self.stride_t, |
| width=self.width, |
| depth=self.depth, |
| dilation_growth_rate=self.dilation_growth_rate, |
| activation=self.activation, |
| kernel_size=self.kernel_size |
| ) |
| |
| self.decoder = Decoder( |
| input_emb_width=self.motion_dim, |
| output_emb_width=self.latent_dim, |
| down_t=self.down_t, |
| stride_t=self.stride_t, |
| width=self.width, |
| depth=self.depth, |
| dilation_growth_rate=self.dilation_growth_rate, |
| activation=self.activation, |
| kernel_size=self.kernel_size |
| ) |
| |
| self.quantizer = MultiGroupVectorQuantizer( |
| num_groups=self.num_groups, |
| num_embeddings=self.num_codes, |
| embedding_dim=self.latent_dim |
| ) |
| |
| def normalize(self, motion: torch.Tensor) -> torch.Tensor: |
| """Normalize motion data using mean and std.""" |
| |
| mean = self.mean.view(1, 1, -1) |
| std = self.std.view(1, 1, -1) |
| std_safe = torch.clamp(std, min=0.01) |
| |
| if motion.shape[-1] != self.motion_dim: |
| |
| mean = mean.permute(0, 2, 1) |
| std_safe = std_safe.permute(0, 2, 1) |
| |
| normalized = (motion - mean) / std_safe |
| return torch.clamp(normalized, -20, 20) |
| |
| def denormalize(self, motion: torch.Tensor) -> torch.Tensor: |
| """Denormalize motion data using mean and std.""" |
| mean = self.mean.view(1, 1, -1) |
| std = self.std.view(1, 1, -1) |
| std_safe = torch.clamp(std, min=0.01) |
| |
| if motion.shape[-1] != self.motion_dim: |
| |
| mean = mean.permute(0, 2, 1) |
| std_safe = std_safe.permute(0, 2, 1) |
| |
| return motion * std_safe + mean |
| |
| def encode(self, motion: torch.Tensor, normalize: bool = True) -> torch.Tensor: |
| """ |
| Encode motion to discrete tokens. |
| |
| Args: |
| motion: (B, T, D) motion data where D=272 |
| normalize: whether to normalize input |
| |
| Returns: |
| tokens: (num_groups, B, T') discrete tokens where T' = T // 4 |
| """ |
| |
| if normalize: |
| motion = self.normalize(motion) |
| |
| |
| x = motion.permute(0, 2, 1) |
| |
| |
| z = self.encoder(x) |
| |
| |
| _, indices = self.quantizer(z) |
| |
| return indices |
| |
| def decode(self, tokens: torch.Tensor, denormalize: bool = True) -> torch.Tensor: |
| """ |
| Decode discrete tokens to motion. |
| |
| Args: |
| tokens: (num_groups, B, T') discrete tokens |
| denormalize: whether to denormalize output |
| |
| Returns: |
| motion: (B, T, D) reconstructed motion |
| """ |
| |
| z_q = self.quantizer.decode_indices(tokens) |
| |
| |
| x_recon = self.decoder(z_q) |
| |
| |
| motion = x_recon.permute(0, 2, 1) |
| |
| |
| if denormalize: |
| motion = self.denormalize(motion) |
| |
| return motion |
| |
| def forward( |
| self, |
| motion: torch.Tensor, |
| normalize: bool = True, |
| denormalize: bool = True |
| ) -> Tuple[torch.Tensor, torch.Tensor]: |
| """ |
| Full forward pass: encode to tokens and decode back. |
| |
| Args: |
| motion: (B, T, D) motion data |
| normalize: whether to normalize input |
| denormalize: whether to denormalize output |
| |
| Returns: |
| motion_recon: (B, T, D) reconstructed motion |
| tokens: (num_groups, B, T') discrete tokens |
| """ |
| |
| if normalize: |
| motion_normalized = self.normalize(motion) |
| else: |
| motion_normalized = motion |
| |
| |
| x = motion_normalized.permute(0, 2, 1) |
| |
| |
| z = self.encoder(x) |
| |
| |
| z_q, indices = self.quantizer(z) |
| |
| |
| x_recon = self.decoder(z_q) |
| |
| |
| motion_recon = x_recon.permute(0, 2, 1) |
| |
| |
| if denormalize: |
| motion_recon = self.denormalize(motion_recon) |
| |
| return motion_recon, indices |
| |
| @classmethod |
| def from_pretrained( |
| cls, |
| pretrained_path: str, |
| device: Optional[str] = None, |
| **kwargs |
| ) -> "MotionVQVAE": |
| """ |
| Load pretrained model from HuggingFace Hub or local path. |
| |
| Args: |
| pretrained_path: HuggingFace repo ID (e.g., "khania/motion-vqvae") |
| or local directory path |
| device: Device to load model on ('cuda', 'cpu', or None for auto) |
| **kwargs: Additional arguments passed to model initialization |
| |
| Returns: |
| Loaded MotionVQVAE model |
| """ |
| |
| if os.path.isdir(pretrained_path): |
| model_dir = pretrained_path |
| elif HF_HUB_AVAILABLE: |
| model_dir = snapshot_download(repo_id=pretrained_path) |
| else: |
| raise ValueError( |
| f"Path {pretrained_path} is not a local directory and " |
| "huggingface_hub is not installed. Install with: pip install huggingface_hub" |
| ) |
| |
| |
| config_path = os.path.join(model_dir, "config.json") |
| if os.path.exists(config_path): |
| with open(config_path, 'r') as f: |
| config = json.load(f) |
| else: |
| config = {} |
| |
| |
| config.update(kwargs) |
| |
| |
| model = cls(config) |
| |
| |
| weights_path = os.path.join(model_dir, "pytorch_model.bin") |
| if os.path.exists(weights_path): |
| state_dict = torch.load(weights_path, map_location='cpu') |
| |
| |
| missing, unexpected = model.load_state_dict(state_dict, strict=False) |
| |
| |
| missing_filtered = [k for k in missing if k not in ['mean', 'std']] |
| |
| if missing_filtered: |
| print(f"Warning: Missing keys in state_dict: {missing_filtered}") |
| if unexpected: |
| print(f"Warning: Unexpected keys in state_dict: {unexpected}") |
| else: |
| raise ValueError(f"Model weights not found at {weights_path}") |
| |
| |
| |
| mean_path = os.path.join(model_dir, "mean.npy") |
| std_path = os.path.join(model_dir, "std.npy") |
| |
| if 'mean' not in state_dict and os.path.exists(mean_path): |
| model.mean = torch.from_numpy(np.load(mean_path)).float() |
| if 'std' not in state_dict and os.path.exists(std_path): |
| model.std = torch.from_numpy(np.load(std_path)).float() |
| |
| |
| if device is None: |
| device = 'cuda' if torch.cuda.is_available() else 'cpu' |
| model = model.to(device) |
| model.eval() |
| |
| return model |
| |
| def save_pretrained(self, save_dir: str): |
| """ |
| Save model to directory in HuggingFace format. |
| |
| Args: |
| save_dir: Directory to save model to |
| """ |
| os.makedirs(save_dir, exist_ok=True) |
| |
| |
| config_path = os.path.join(save_dir, "config.json") |
| with open(config_path, 'w') as f: |
| json.dump(self.config, f, indent=2) |
| |
| |
| mean_path = os.path.join(save_dir, "mean.npy") |
| std_path = os.path.join(save_dir, "std.npy") |
| np.save(mean_path, self.mean.cpu().numpy()) |
| np.save(std_path, self.std.cpu().numpy()) |
| |
| |
| weights_path = os.path.join(save_dir, "pytorch_model.bin") |
| torch.save(self.state_dict(), weights_path) |
| |
| print(f"Model saved to {save_dir}") |
|
|
|
|
| |
| |
| |
|
|
| def load_motion_vqvae( |
| pretrained_path: str = "khania/motion-vqvae", |
| device: Optional[str] = None |
| ) -> MotionVQVAE: |
| """ |
| Convenience function to load MotionVQVAE. |
| |
| Args: |
| pretrained_path: HuggingFace repo ID or local path |
| device: Device to load on |
| |
| Returns: |
| Loaded model |
| """ |
| return MotionVQVAE.from_pretrained(pretrained_path, device=device) |
|
|
|
|
| if __name__ == "__main__": |
| |
| print("Testing MotionVQVAE...") |
| |
| config = { |
| 'motion_dim': 272, |
| 'latent_dim': 512, |
| 'num_groups': 64, |
| 'num_codes': 512, |
| 'down_t': 2, |
| 'stride_t': 2, |
| 'width': 512, |
| 'depth': 3, |
| 'dilation_growth_rate': 3, |
| 'activation': 'relu', |
| 'kernel_size': 3 |
| } |
| |
| model = MotionVQVAE(config) |
| print(f"Model created with {sum(p.numel() for p in model.parameters()):,} parameters") |
| |
| |
| print("\nModel state_dict keys:") |
| for k in sorted(model.state_dict().keys())[:20]: |
| print(f" {k}") |
| print(" ...") |
| |
| |
| batch_size = 2 |
| seq_len = 64 |
| motion = torch.randn(batch_size, seq_len, 272) |
| |
| model.eval() |
| with torch.no_grad(): |
| motion_recon, tokens = model(motion, normalize=False, denormalize=False) |
| |
| print(f"\nInput shape: {motion.shape}") |
| print(f"Output shape: {motion_recon.shape}") |
| print(f"Tokens shape: {tokens.shape}") |
| print(f"MSE (random weights): {F.mse_loss(motion, motion_recon).item():.4f}") |
|
|