File size: 5,299 Bytes
316a030
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
import torch
import numpy as np
import matplotlib.pyplot as plt
import matplotlib.animation as animation
import torch.nn.functional as F
from pathlib import Path
from safetensors.torch import load_file

# Import model architecture
from rvq_model import MotionRVQ_VAE

# ==========================================
# 1. Configuration
# ==========================================
BASE_DIR = Path(__file__).resolve().parent
FILE_TO_TEST = BASE_DIR / "000001.npy"
WEIGHTS_PATH = BASE_DIR / "motion_rvq_weights.safetensors"

# ==========================================
# 2. Model Initialization
# ==========================================
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
model = MotionRVQ_VAE().to(device)

try:
    state_dict = load_file(str(WEIGHTS_PATH), device=str(device))
    model.load_state_dict(state_dict)
    print(f"Successfully loaded weights from {WEIGHTS_PATH}!")
except FileNotFoundError:
    print(f"ERROR: Could not find file {WEIGHTS_PATH}.")
    exit()

model.eval()

# ==========================================
# 3. Forward Pass (With Padding and Normalization)
# ==========================================
original_data = np.load(FILE_TO_TEST)
T_orig = original_data.shape[0]

# Load normalization vectors
mean = np.load(BASE_DIR / "Mean.npy")
std = np.load(BASE_DIR / "Std.npy")

# Padding for stride=4 compression
pad_len = (4 - (T_orig % 4)) % 4
if pad_len > 0:
    last_frame = original_data[-1:]
    padded_data = np.concatenate([original_data, np.repeat(last_frame, pad_len, axis=0)], axis=0)
else:
    padded_data = original_data

padded_data = (padded_data - mean) / std
input_tensor = torch.from_numpy(padded_data).float().unsqueeze(0).permute(0, 2, 1).to(device)

with torch.no_grad():
    # 1. Get tokens from all levels
    z = model.encoder(input_tensor)
    _, tokens, _ = model.rvq(z)
    
    # 2. Function for "partial" decoding
    def decode_from_levels(num_levels):
        z_q_partial = 0
        for i in range(num_levels):
            # Get indices only for level "i"
            indices = tokens[:, i, :] 
            quantizer = model.rvq.quantizers[i]
            
            # Convert token id (e.g. 145) into its 1024-codebook vector
            level_z_q = F.embedding(indices, quantizer.embedding)
            level_z_q = level_z_q.permute(0, 2, 1)  # Shape expected by decoder
            
            # Add residual vector to the running latent
            z_q_partial = z_q_partial + level_z_q
            
        return model.decoder(z_q_partial)

    # Generate motion using only level 1 and all 4 levels
    recon_tensor_1_lvl = decode_from_levels(1)
    recon_tensor_4_lvl = decode_from_levels(4)

# Back to NumPy
recon_1_lvl = recon_tensor_1_lvl.squeeze(0).permute(1, 0).cpu().numpy()[:T_orig, :]
recon_4_lvl = recon_tensor_4_lvl.squeeze(0).permute(1, 0).cpu().numpy()[:T_orig, :]

# De-normalization
recon_1_lvl = (recon_1_lvl * std) + mean
recon_4_lvl = (recon_4_lvl * std) + mean

# ==========================================
# 4. Skeleton extraction utility
# ==========================================
def get_3d_joints(data_263):
    num_frames = data_263.shape[0]
    joints = np.zeros((num_frames, 22, 3))
    for i in range(num_frames):
        root_y = data_263[i, 3]
        joints[i, 0] = [0, root_y, 0] 
        local_positions = data_263[i, 4:67].reshape(21, 3)
        joints[i, 1:] = local_positions + [0, root_y, 0]
    return joints

joints_orig = get_3d_joints(original_data)
joints_1_lvl = get_3d_joints(recon_1_lvl)
joints_4_lvl = get_3d_joints(recon_4_lvl)

# ==========================================
# 5. Three-panel visualization
# ==========================================
kinematic_tree = [
    [0, 1, 4, 7, 10], [0, 2, 5, 8, 11], [0, 3, 6, 9, 12, 15], 
    [9, 13, 16, 18, 20], [9, 14, 17, 19, 21]
]

fig = plt.figure(figsize=(15, 5))
ax1 = fig.add_subplot(131, projection='3d')
ax2 = fig.add_subplot(132, projection='3d')
ax3 = fig.add_subplot(133, projection='3d')

def update(frame):
    for ax in [ax1, ax2, ax3]:
        ax.clear()
        ax.set_xlim(-1, 1); ax.set_ylim(-1, 1); ax.set_zlim(0, 2)
        ax.view_init(elev=10., azim=-90)
        ax.axis('off')
        
    ax1.set_title(f"ORIGINAL\n(Frame {frame})")
    ax2.set_title("RECONSTRUCTION: 1 LEVEL\n(Coarse tokens only)")
    ax3.set_title("RECONSTRUCTION: 4 LEVELS\n(Full RVQ detail)")
    
    for chain in kinematic_tree:
        # Original (Blue)
        ax1.plot(joints_orig[frame, chain, 0], joints_orig[frame, chain, 2], joints_orig[frame, chain, 1], 
                 linewidth=3, marker='o', markersize=4, color='blue')
        # 1 Level (Orange)
        ax2.plot(joints_1_lvl[frame, chain, 0], joints_1_lvl[frame, chain, 2], joints_1_lvl[frame, chain, 1], 
                 linewidth=3, marker='o', markersize=4, color='orange')
        # 4 Levels (Red)
        ax3.plot(joints_4_lvl[frame, chain, 0], joints_4_lvl[frame, chain, 2], joints_4_lvl[frame, chain, 1], 
                 linewidth=3, marker='o', markersize=4, color='red')

ani = animation.FuncAnimation(fig, update, frames=T_orig, interval=50, repeat=True)
plt.tight_layout()
plt.show()