motion-mgvqvae / motion_vqvae_hf.py
khania's picture
Motion VQ-VAE model update
8d54c4c verified
"""
MotionVQVAE - Motion Vector Quantized VAE for HuggingFace
Load and use the MotionVQVAE model for motion tokenization and reconstruction.
Usage:
from motion_vqvae_hf import MotionVQVAE
# Load from HuggingFace Hub
model = MotionVQVAE.from_pretrained("khania/motion-vqvae")
# Encode motion to tokens
tokens = model.encode(motion_array) # (B, T, 272) -> (num_groups, B, T')
# Decode tokens back to motion
motion_recon = model.decode(tokens) # (num_groups, B, T') -> (B, T, 272)
# Full forward pass
motion_recon, tokens = model(motion_array)
"""
import os
import json
import math
import torch
import torch.nn as nn
import torch.nn.functional as F
import numpy as np
from typing import List, Union, Optional, Dict, Any, Tuple
from pathlib import Path
try:
from huggingface_hub import snapshot_download
HF_HUB_AVAILABLE = True
except ImportError:
HF_HUB_AVAILABLE = False
# =============================================================================
# Encoder / Decoder Components (matching original architecture exactly)
# =============================================================================
class ResConv1DBlock(nn.Module):
"""Residual 1D Convolution Block - matches original models/resnet.py exactly."""
def __init__(self, n_in: int, n_state: int, dilation: int = 1,
activation: str = 'relu', norm: str = None, kernel_size: int = 3):
super().__init__()
padding = dilation * (kernel_size - 1) // 2
self.norm = norm
# Norm layers
if norm == "LN":
self.norm1 = nn.LayerNorm(n_in)
self.norm2 = nn.LayerNorm(n_in)
elif norm == "GN":
self.norm1 = nn.GroupNorm(num_groups=32, num_channels=n_in, eps=1e-6, affine=True)
self.norm2 = nn.GroupNorm(num_groups=32, num_channels=n_in, eps=1e-6, affine=True)
elif norm == "BN":
self.norm1 = nn.BatchNorm1d(num_features=n_in, eps=1e-6, affine=True)
self.norm2 = nn.BatchNorm1d(num_features=n_in, eps=1e-6, affine=True)
else:
self.norm1 = nn.Identity()
self.norm2 = nn.Identity()
# Activation layers
if activation == "relu":
self.activation1 = nn.ReLU()
self.activation2 = nn.ReLU()
elif activation == "silu":
self.activation1 = nn.SiLU()
self.activation2 = nn.SiLU()
elif activation == "gelu":
self.activation1 = nn.GELU()
self.activation2 = nn.GELU()
else:
self.activation1 = nn.ReLU()
self.activation2 = nn.ReLU()
# Convolution layers - MUST be named conv1 and conv2 to match checkpoint
self.conv1 = nn.Conv1d(n_in, n_state, kernel_size, 1, padding, dilation)
self.conv2 = nn.Conv1d(n_state, n_in, 1, 1, 0)
def forward(self, x: torch.Tensor) -> torch.Tensor:
x_orig = x
if self.norm == "LN":
x = self.norm1(x.transpose(-2, -1))
x = self.activation1(x.transpose(-2, -1))
else:
x = self.norm1(x)
x = self.activation1(x)
x = self.conv1(x)
if self.norm == "LN":
x = self.norm2(x.transpose(-2, -1))
x = self.activation2(x.transpose(-2, -1))
else:
x = self.norm2(x)
x = self.activation2(x)
x = self.conv2(x)
x = x + x_orig
return x
class Resnet1D(nn.Module):
"""1D Residual Network - matches original models/resnet.py exactly.
Uses self.model = nn.Sequential(*blocks) to match checkpoint key structure.
"""
def __init__(self, n_in: int, n_depth: int, dilation_growth_rate: int = 1,
reverse_dilation: bool = False, activation: str = 'relu',
norm: str = None, kernel_size: int = 3):
super().__init__()
blocks = [
ResConv1DBlock(n_in, n_in, dilation=dilation_growth_rate ** depth,
activation=activation, norm=norm, kernel_size=kernel_size)
for depth in range(n_depth)
]
if reverse_dilation:
blocks = blocks[::-1]
# MUST be named 'model' to match checkpoint keys like 'model.0.conv1.weight'
self.model = nn.Sequential(*blocks)
def forward(self, x: torch.Tensor) -> torch.Tensor:
return self.model(x)
class Encoder(nn.Module):
"""1D CNN Encoder - matches original models/encdec.py exactly.
Uses self.model = nn.Sequential(*blocks) to match checkpoint key structure:
- model.0: Conv1d (input projection)
- model.1: ReLU
- model.2: Sequential(Conv1d, Resnet1D) for first downsample
- model.3: Sequential(Conv1d, Resnet1D) for second downsample
- model.4: Conv1d (output projection)
"""
def __init__(
self,
input_emb_width: int = 272,
output_emb_width: int = 512,
down_t: int = 2,
stride_t: int = 2,
width: int = 512,
depth: int = 3,
dilation_growth_rate: int = 3,
activation: str = 'relu',
norm: str = None,
kernel_size: int = 3
):
super().__init__()
blocks = []
filter_t, pad_t = stride_t * 2, stride_t // 2
# model.0: input conv
blocks.append(nn.Conv1d(input_emb_width, width, kernel_size, 1, (kernel_size - 1) // 2))
# model.1: ReLU
blocks.append(nn.ReLU())
# model.2, model.3, ...: downsample blocks
for i in range(down_t):
input_dim = width
block = nn.Sequential(
nn.Conv1d(input_dim, width, filter_t, stride_t, pad_t),
Resnet1D(width, depth, dilation_growth_rate, activation=activation,
norm=norm, kernel_size=kernel_size),
)
blocks.append(block)
# model.4: output conv
blocks.append(nn.Conv1d(width, output_emb_width, kernel_size, 1, (kernel_size - 1) // 2))
# MUST be named 'model' to match checkpoint keys
self.model = nn.Sequential(*blocks)
def forward(self, x: torch.Tensor) -> torch.Tensor:
return self.model(x)
class Decoder(nn.Module):
"""1D CNN Decoder - matches original models/encdec.py exactly.
Uses self.model = nn.Sequential(*blocks) to match checkpoint key structure:
- model.0: Conv1d (input projection)
- model.1: ReLU
- model.2: Sequential(Resnet1D, Upsample, Conv1d) for first upsample
- model.3: Sequential(Resnet1D, Upsample, Conv1d) for second upsample
- model.4: Conv1d
- model.5: ReLU
- model.6: Conv1d (output projection)
"""
def __init__(
self,
input_emb_width: int = 272,
output_emb_width: int = 512,
down_t: int = 2,
stride_t: int = 2,
width: int = 512,
depth: int = 3,
dilation_growth_rate: int = 3,
activation: str = 'relu',
norm: str = None,
kernel_size: int = 3
):
super().__init__()
blocks = []
filter_t, pad_t = stride_t * 2, stride_t // 2
# model.0: input conv
blocks.append(nn.Conv1d(output_emb_width, width, kernel_size, 1, (kernel_size - 1) // 2))
# model.1: ReLU
blocks.append(nn.ReLU())
# model.2, model.3, ...: upsample blocks
for i in range(down_t):
out_dim = width
block = nn.Sequential(
Resnet1D(width, depth, dilation_growth_rate, reverse_dilation=True,
activation=activation, norm=norm, kernel_size=kernel_size),
nn.Upsample(scale_factor=2, mode='linear', align_corners=False),
nn.Conv1d(width, out_dim, 3, 1, 1)
)
blocks.append(block)
# model.4: conv
blocks.append(nn.Conv1d(width, width, kernel_size, 1, (kernel_size - 1) // 2))
# model.5: ReLU
blocks.append(nn.ReLU())
# model.6: output conv
blocks.append(nn.Conv1d(width, input_emb_width, kernel_size, 1, (kernel_size - 1) // 2))
# MUST be named 'model' to match checkpoint keys
self.model = nn.Sequential(*blocks)
def forward(self, x: torch.Tensor) -> torch.Tensor:
return self.model(x)
# =============================================================================
# Vector Quantizer
# =============================================================================
class VectorQuantizerEMA(nn.Module):
"""Single Vector Quantizer with EMA updates.
Uses self.codebook as nn.Parameter to match checkpoint keys.
"""
def __init__(
self,
num_embeddings: int = 512,
embedding_dim: int = 8, # per-group dimension
decay: float = 0.99,
epsilon: float = 1e-5
):
super().__init__()
self.num_embeddings = num_embeddings
self.embedding_dim = embedding_dim
self.decay = decay
self.epsilon = epsilon
# MUST be named 'codebook' to match checkpoint keys
self.codebook = nn.Parameter(torch.randn(num_embeddings, embedding_dim))
def forward(self, z: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
"""
Args:
z: (B, D, T) latent features for this group
Returns:
z_q: (B, D, T) quantized features
indices: (B, T) codebook indices
"""
B, D, T = z.shape
# Reshape: (B, D, T) -> (B*T, D)
z_flat = z.permute(0, 2, 1).reshape(-1, D)
# Compute distances to codebook
# d(z, e) = ||z||^2 + ||e||^2 - 2*z*e
distances = (
torch.sum(z_flat ** 2, dim=1, keepdim=True)
+ torch.sum(self.codebook ** 2, dim=1)
- 2 * torch.matmul(z_flat, self.codebook.t())
)
# Get nearest codebook entry
indices = torch.argmin(distances, dim=1)
# Quantize
z_q_flat = F.embedding(indices, self.codebook)
# Reshape back: (B*T, D) -> (B, D, T)
z_q = z_q_flat.reshape(B, T, D).permute(0, 2, 1)
# Straight-through estimator
z_q = z + (z_q - z).detach()
# Reshape indices: (B*T,) -> (B, T)
indices = indices.reshape(B, T)
return z_q, indices
def decode_indices(self, indices: torch.Tensor) -> torch.Tensor:
"""
Args:
indices: (B, T) codebook indices
Returns:
z_q: (B, D, T) quantized features
"""
B, T = indices.shape
z_q_flat = F.embedding(indices.reshape(-1), self.codebook)
z_q = z_q_flat.reshape(B, T, -1).permute(0, 2, 1)
return z_q
class MultiGroupVectorQuantizer(nn.Module):
"""Multi-Group Vector Quantizer - splits latent into groups.
Uses self.quantizers = nn.ModuleList to match checkpoint keys like
'quantizer.quantizers.0.codebook', 'quantizer.quantizers.1.codebook', etc.
"""
def __init__(
self,
num_groups: int = 64,
num_embeddings: int = 512,
embedding_dim: int = 512, # total latent dim
decay: float = 0.99,
epsilon: float = 1e-5
):
super().__init__()
self.num_groups = num_groups
self.num_embeddings = num_embeddings
self.embedding_dim = embedding_dim
self.group_dim = embedding_dim // num_groups
# MUST be named 'quantizers' to match checkpoint keys
self.quantizers = nn.ModuleList([
VectorQuantizerEMA(num_embeddings, self.group_dim, decay, epsilon)
for _ in range(num_groups)
])
def forward(self, z: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
"""
Args:
z: (B, D, T) latent features
Returns:
z_q: (B, D, T) quantized features
indices: (num_groups, B, T) codebook indices per group
"""
B, D, T = z.shape
# Split into groups
z_groups = z.chunk(self.num_groups, dim=1) # list of (B, group_dim, T)
z_q_groups = []
indices_list = []
for i, (z_g, quantizer) in enumerate(zip(z_groups, self.quantizers)):
z_q_g, idx_g = quantizer(z_g)
z_q_groups.append(z_q_g)
indices_list.append(idx_g)
# Concatenate quantized groups
z_q = torch.cat(z_q_groups, dim=1)
# Stack indices: list of (B, T) -> (num_groups, B, T)
indices = torch.stack(indices_list, dim=0)
return z_q, indices
def decode_indices(self, indices: torch.Tensor) -> torch.Tensor:
"""
Args:
indices: (num_groups, B, T) codebook indices
Returns:
z_q: (B, D, T) quantized features
"""
z_q_groups = []
for i, quantizer in enumerate(self.quantizers):
z_q_g = quantizer.decode_indices(indices[i])
z_q_groups.append(z_q_g)
z_q = torch.cat(z_q_groups, dim=1)
return z_q
# =============================================================================
# Main Model
# =============================================================================
class MotionVQVAE(nn.Module):
"""Motion Vector Quantized VAE for HuggingFace.
Architecture matches the original training code exactly to ensure
checkpoint compatibility.
"""
def __init__(self, config: Optional[Dict[str, Any]] = None):
super().__init__()
# Default config
if config is None:
config = {}
self.config = config
# Model parameters
# Support both naming conventions for config keys
self.motion_dim = config.get('motion_dim', config.get('input_dim', 272))
self.latent_dim = config.get('latent_dim', config.get('code_dim', 512))
self.num_groups = config.get('num_groups', 64)
self.num_codes = config.get('num_codes', config.get('nb_code', 512))
self.down_t = config.get('down_t', 2)
self.stride_t = config.get('stride_t', 2)
self.width = config.get('width', 512)
self.depth = config.get('depth', 3)
self.dilation_growth_rate = config.get('dilation_growth_rate', 3)
self.activation = config.get('activation', 'relu')
self.kernel_size = config.get('kernel_size', 3)
# Normalization stats (loaded separately)
self.register_buffer('mean', torch.zeros(self.motion_dim))
self.register_buffer('std', torch.ones(self.motion_dim))
# Build model components - names must match checkpoint
self.encoder = Encoder(
input_emb_width=self.motion_dim,
output_emb_width=self.latent_dim,
down_t=self.down_t,
stride_t=self.stride_t,
width=self.width,
depth=self.depth,
dilation_growth_rate=self.dilation_growth_rate,
activation=self.activation,
kernel_size=self.kernel_size
)
self.decoder = Decoder(
input_emb_width=self.motion_dim,
output_emb_width=self.latent_dim,
down_t=self.down_t,
stride_t=self.stride_t,
width=self.width,
depth=self.depth,
dilation_growth_rate=self.dilation_growth_rate,
activation=self.activation,
kernel_size=self.kernel_size
)
self.quantizer = MultiGroupVectorQuantizer(
num_groups=self.num_groups,
num_embeddings=self.num_codes,
embedding_dim=self.latent_dim
)
def normalize(self, motion: torch.Tensor) -> torch.Tensor:
"""Normalize motion data using mean and std."""
# motion: (B, T, D) or (B, D, T)
mean = self.mean.view(1, 1, -1)
std = self.std.view(1, 1, -1)
std_safe = torch.clamp(std, min=0.01)
if motion.shape[-1] != self.motion_dim:
# (B, D, T) format
mean = mean.permute(0, 2, 1)
std_safe = std_safe.permute(0, 2, 1)
normalized = (motion - mean) / std_safe
return torch.clamp(normalized, -20, 20)
def denormalize(self, motion: torch.Tensor) -> torch.Tensor:
"""Denormalize motion data using mean and std."""
mean = self.mean.view(1, 1, -1)
std = self.std.view(1, 1, -1)
std_safe = torch.clamp(std, min=0.01)
if motion.shape[-1] != self.motion_dim:
# (B, D, T) format
mean = mean.permute(0, 2, 1)
std_safe = std_safe.permute(0, 2, 1)
return motion * std_safe + mean
def encode(self, motion: torch.Tensor, normalize: bool = True) -> torch.Tensor:
"""
Encode motion to discrete tokens.
Args:
motion: (B, T, D) motion data where D=272
normalize: whether to normalize input
Returns:
tokens: (num_groups, B, T') discrete tokens where T' = T // 4
"""
# Normalize if needed
if normalize:
motion = self.normalize(motion)
# Convert to (B, D, T) for conv layers
x = motion.permute(0, 2, 1)
# Encode
z = self.encoder(x)
# Quantize
_, indices = self.quantizer(z)
return indices
def decode(self, tokens: torch.Tensor, denormalize: bool = True) -> torch.Tensor:
"""
Decode discrete tokens to motion.
Args:
tokens: (num_groups, B, T') discrete tokens
denormalize: whether to denormalize output
Returns:
motion: (B, T, D) reconstructed motion
"""
# Decode tokens to latent
z_q = self.quantizer.decode_indices(tokens)
# Decode latent to motion
x_recon = self.decoder(z_q)
# Convert to (B, T, D)
motion = x_recon.permute(0, 2, 1)
# Denormalize if needed
if denormalize:
motion = self.denormalize(motion)
return motion
def forward(
self,
motion: torch.Tensor,
normalize: bool = True,
denormalize: bool = True
) -> Tuple[torch.Tensor, torch.Tensor]:
"""
Full forward pass: encode to tokens and decode back.
Args:
motion: (B, T, D) motion data
normalize: whether to normalize input
denormalize: whether to denormalize output
Returns:
motion_recon: (B, T, D) reconstructed motion
tokens: (num_groups, B, T') discrete tokens
"""
# Normalize if needed
if normalize:
motion_normalized = self.normalize(motion)
else:
motion_normalized = motion
# Convert to (B, D, T) for conv layers
x = motion_normalized.permute(0, 2, 1)
# Encode
z = self.encoder(x)
# Quantize
z_q, indices = self.quantizer(z)
# Decode
x_recon = self.decoder(z_q)
# Convert to (B, T, D)
motion_recon = x_recon.permute(0, 2, 1)
# Denormalize if needed
if denormalize:
motion_recon = self.denormalize(motion_recon)
return motion_recon, indices
@classmethod
def from_pretrained(
cls,
pretrained_path: str,
device: Optional[str] = None,
**kwargs
) -> "MotionVQVAE":
"""
Load pretrained model from HuggingFace Hub or local path.
Args:
pretrained_path: HuggingFace repo ID (e.g., "khania/motion-vqvae")
or local directory path
device: Device to load model on ('cuda', 'cpu', or None for auto)
**kwargs: Additional arguments passed to model initialization
Returns:
Loaded MotionVQVAE model
"""
# Determine if path is HF repo or local
if os.path.isdir(pretrained_path):
model_dir = pretrained_path
elif HF_HUB_AVAILABLE:
model_dir = snapshot_download(repo_id=pretrained_path)
else:
raise ValueError(
f"Path {pretrained_path} is not a local directory and "
"huggingface_hub is not installed. Install with: pip install huggingface_hub"
)
# Load config
config_path = os.path.join(model_dir, "config.json")
if os.path.exists(config_path):
with open(config_path, 'r') as f:
config = json.load(f)
else:
config = {}
# Override config with kwargs
config.update(kwargs)
# Create model
model = cls(config)
# Load weights (includes mean/std buffers if converted with updated script)
weights_path = os.path.join(model_dir, "pytorch_model.bin")
if os.path.exists(weights_path):
state_dict = torch.load(weights_path, map_location='cpu')
# Load state dict
missing, unexpected = model.load_state_dict(state_dict, strict=False)
# Filter out expected missing keys (mean/std might be in separate files for older checkpoints)
missing_filtered = [k for k in missing if k not in ['mean', 'std']]
if missing_filtered:
print(f"Warning: Missing keys in state_dict: {missing_filtered}")
if unexpected:
print(f"Warning: Unexpected keys in state_dict: {unexpected}")
else:
raise ValueError(f"Model weights not found at {weights_path}")
# Fallback: Load mean/std from numpy files if not in state_dict
# This supports both old format (separate files) and new format (in weights)
mean_path = os.path.join(model_dir, "mean.npy")
std_path = os.path.join(model_dir, "std.npy")
if 'mean' not in state_dict and os.path.exists(mean_path):
model.mean = torch.from_numpy(np.load(mean_path)).float()
if 'std' not in state_dict and os.path.exists(std_path):
model.std = torch.from_numpy(np.load(std_path)).float()
# Move to device
if device is None:
device = 'cuda' if torch.cuda.is_available() else 'cpu'
model = model.to(device)
model.eval()
return model
def save_pretrained(self, save_dir: str):
"""
Save model to directory in HuggingFace format.
Args:
save_dir: Directory to save model to
"""
os.makedirs(save_dir, exist_ok=True)
# Save config
config_path = os.path.join(save_dir, "config.json")
with open(config_path, 'w') as f:
json.dump(self.config, f, indent=2)
# Save normalization stats
mean_path = os.path.join(save_dir, "mean.npy")
std_path = os.path.join(save_dir, "std.npy")
np.save(mean_path, self.mean.cpu().numpy())
np.save(std_path, self.std.cpu().numpy())
# Save weights
weights_path = os.path.join(save_dir, "pytorch_model.bin")
torch.save(self.state_dict(), weights_path)
print(f"Model saved to {save_dir}")
# =============================================================================
# Utility Functions
# =============================================================================
def load_motion_vqvae(
pretrained_path: str = "khania/motion-vqvae",
device: Optional[str] = None
) -> MotionVQVAE:
"""
Convenience function to load MotionVQVAE.
Args:
pretrained_path: HuggingFace repo ID or local path
device: Device to load on
Returns:
Loaded model
"""
return MotionVQVAE.from_pretrained(pretrained_path, device=device)
if __name__ == "__main__":
# Test model creation and forward pass
print("Testing MotionVQVAE...")
config = {
'motion_dim': 272,
'latent_dim': 512,
'num_groups': 64,
'num_codes': 512,
'down_t': 2,
'stride_t': 2,
'width': 512,
'depth': 3,
'dilation_growth_rate': 3,
'activation': 'relu',
'kernel_size': 3
}
model = MotionVQVAE(config)
print(f"Model created with {sum(p.numel() for p in model.parameters()):,} parameters")
# Print model architecture for debugging
print("\nModel state_dict keys:")
for k in sorted(model.state_dict().keys())[:20]:
print(f" {k}")
print(" ...")
# Test forward pass
batch_size = 2
seq_len = 64
motion = torch.randn(batch_size, seq_len, 272)
model.eval()
with torch.no_grad():
motion_recon, tokens = model(motion, normalize=False, denormalize=False)
print(f"\nInput shape: {motion.shape}")
print(f"Output shape: {motion_recon.shape}")
print(f"Tokens shape: {tokens.shape}")
print(f"MSE (random weights): {F.mse_loss(motion, motion_recon).item():.4f}")