File size: 5,299 Bytes
316a030 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 | import torch
import numpy as np
import matplotlib.pyplot as plt
import matplotlib.animation as animation
import torch.nn.functional as F
from pathlib import Path
from safetensors.torch import load_file
# Import model architecture
from rvq_model import MotionRVQ_VAE
# ==========================================
# 1. Configuration
# ==========================================
BASE_DIR = Path(__file__).resolve().parent
FILE_TO_TEST = BASE_DIR / "000001.npy"
WEIGHTS_PATH = BASE_DIR / "motion_rvq_weights.safetensors"
# ==========================================
# 2. Model Initialization
# ==========================================
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
model = MotionRVQ_VAE().to(device)
try:
state_dict = load_file(str(WEIGHTS_PATH), device=str(device))
model.load_state_dict(state_dict)
print(f"Successfully loaded weights from {WEIGHTS_PATH}!")
except FileNotFoundError:
print(f"ERROR: Could not find file {WEIGHTS_PATH}.")
exit()
model.eval()
# ==========================================
# 3. Forward Pass (With Padding and Normalization)
# ==========================================
original_data = np.load(FILE_TO_TEST)
T_orig = original_data.shape[0]
# Load normalization vectors
mean = np.load(BASE_DIR / "Mean.npy")
std = np.load(BASE_DIR / "Std.npy")
# Padding for stride=4 compression
pad_len = (4 - (T_orig % 4)) % 4
if pad_len > 0:
last_frame = original_data[-1:]
padded_data = np.concatenate([original_data, np.repeat(last_frame, pad_len, axis=0)], axis=0)
else:
padded_data = original_data
padded_data = (padded_data - mean) / std
input_tensor = torch.from_numpy(padded_data).float().unsqueeze(0).permute(0, 2, 1).to(device)
with torch.no_grad():
# 1. Get tokens from all levels
z = model.encoder(input_tensor)
_, tokens, _ = model.rvq(z)
# 2. Function for "partial" decoding
def decode_from_levels(num_levels):
z_q_partial = 0
for i in range(num_levels):
# Get indices only for level "i"
indices = tokens[:, i, :]
quantizer = model.rvq.quantizers[i]
# Convert token id (e.g. 145) into its 1024-codebook vector
level_z_q = F.embedding(indices, quantizer.embedding)
level_z_q = level_z_q.permute(0, 2, 1) # Shape expected by decoder
# Add residual vector to the running latent
z_q_partial = z_q_partial + level_z_q
return model.decoder(z_q_partial)
# Generate motion using only level 1 and all 4 levels
recon_tensor_1_lvl = decode_from_levels(1)
recon_tensor_4_lvl = decode_from_levels(4)
# Back to NumPy
recon_1_lvl = recon_tensor_1_lvl.squeeze(0).permute(1, 0).cpu().numpy()[:T_orig, :]
recon_4_lvl = recon_tensor_4_lvl.squeeze(0).permute(1, 0).cpu().numpy()[:T_orig, :]
# De-normalization
recon_1_lvl = (recon_1_lvl * std) + mean
recon_4_lvl = (recon_4_lvl * std) + mean
# ==========================================
# 4. Skeleton extraction utility
# ==========================================
def get_3d_joints(data_263):
num_frames = data_263.shape[0]
joints = np.zeros((num_frames, 22, 3))
for i in range(num_frames):
root_y = data_263[i, 3]
joints[i, 0] = [0, root_y, 0]
local_positions = data_263[i, 4:67].reshape(21, 3)
joints[i, 1:] = local_positions + [0, root_y, 0]
return joints
joints_orig = get_3d_joints(original_data)
joints_1_lvl = get_3d_joints(recon_1_lvl)
joints_4_lvl = get_3d_joints(recon_4_lvl)
# ==========================================
# 5. Three-panel visualization
# ==========================================
kinematic_tree = [
[0, 1, 4, 7, 10], [0, 2, 5, 8, 11], [0, 3, 6, 9, 12, 15],
[9, 13, 16, 18, 20], [9, 14, 17, 19, 21]
]
fig = plt.figure(figsize=(15, 5))
ax1 = fig.add_subplot(131, projection='3d')
ax2 = fig.add_subplot(132, projection='3d')
ax3 = fig.add_subplot(133, projection='3d')
def update(frame):
for ax in [ax1, ax2, ax3]:
ax.clear()
ax.set_xlim(-1, 1); ax.set_ylim(-1, 1); ax.set_zlim(0, 2)
ax.view_init(elev=10., azim=-90)
ax.axis('off')
ax1.set_title(f"ORIGINAL\n(Frame {frame})")
ax2.set_title("RECONSTRUCTION: 1 LEVEL\n(Coarse tokens only)")
ax3.set_title("RECONSTRUCTION: 4 LEVELS\n(Full RVQ detail)")
for chain in kinematic_tree:
# Original (Blue)
ax1.plot(joints_orig[frame, chain, 0], joints_orig[frame, chain, 2], joints_orig[frame, chain, 1],
linewidth=3, marker='o', markersize=4, color='blue')
# 1 Level (Orange)
ax2.plot(joints_1_lvl[frame, chain, 0], joints_1_lvl[frame, chain, 2], joints_1_lvl[frame, chain, 1],
linewidth=3, marker='o', markersize=4, color='orange')
# 4 Levels (Red)
ax3.plot(joints_4_lvl[frame, chain, 0], joints_4_lvl[frame, chain, 2], joints_4_lvl[frame, chain, 1],
linewidth=3, marker='o', markersize=4, color='red')
ani = animation.FuncAnimation(fig, update, frames=T_orig, interval=50, repeat=True)
plt.tight_layout()
plt.show()
|