NMR / tools /data_process /motionmillion /gen_motion_embedding.py
Xxx999's picture
upload
45950ff
import sys
sys.path.append('/mnt/shenzhen2cephfs/capybarali/codes/neobot')
import numpy as np
from tqdm import tqdm
import torch, os, random
from mmengine.config import Config
from mmengine.registry import MODELS
import src
import joblib
device = torch.device('cuda')
tmr_cfg_path = 'ckpts/mm_tmr_400e/tmr.py'
tmr_cfg_ckpt_path = 'ckpts/mm_tmr_400e/epoch_400.pth'
mean_path = 'data/motionmillion/motion_272_mean.npy'
std_path = 'data/motionmillion/motion_272_std.npy'
class MotionDataset(torch.utils.data.Dataset):
def __init__(self, data_root, data_path_file, mean_path, std_path):
self.data_list = []
self.data_ids = []
data_paths = joblib.load(data_path_file)
for path in data_paths:
data_id = path.replace('/mnt/shenzhen2cephfs/capybarali/codes/neobot/data/omnihumo/2_gmr_retarget', '')
path = data_root + data_id.replace('.npz', '')
self.data_list.append(path)
self.data_ids.append(data_id)
self.mean = torch.from_numpy(np.load(mean_path))
self.std = torch.from_numpy(np.load(std_path))
self.unit_length = 4
def __len__(self):
return len(self.data_list)
def __getitem__(self, idx):
data_path = self.data_list[idx]
motion = np.load(data_path)
T = motion.shape[0]
motion_length = (T // self.unit_length) * self.unit_length
idx = random.randint(0, T - motion_length)
motion = torch.from_numpy(motion[idx:idx+motion_length])
motion = (motion - self.mean) / self.std
return motion, motion_length, self.data_ids[idx]
def collate_fn(batch):
motion = torch.nn.utils.rnn.pad_sequence([x[0] for x in batch], batch_first=True).float()
motion_length = torch.tensor([x[1] for x in batch])
data_id = [x[2] for x in batch]
return motion, motion_length, data_id
def l2_normalize(x, axis=1, eps=1e-8):
return x / (np.linalg.norm(x, axis=axis, keepdims=True) + eps)
def kmeans_cosine(X, K, max_iters=100, tol=1e-4, seed=0):
np.random.seed(seed)
# 1. L2 normalize data
X = l2_normalize(X)
T, C = X.shape
# 2. random init centroids
indices = np.random.choice(T, K, replace=False)
centroids = X[indices]
for it in range(max_iters):
# 3. cosine similarity = dot product (since normalized)
sim = X @ centroids.T # (T, K)
labels = np.argmax(sim, axis=1) # assign by max cosine
# 4. update centroids
new_centroids = np.zeros_like(centroids)
for k in range(K):
members = X[labels == k]
if len(members) == 0:
# empty cluster → reinit
new_centroids[k] = X[np.random.randint(T)]
else:
new_centroids[k] = members.mean(axis=0)
# 5. normalize centroids
new_centroids = l2_normalize(new_centroids)
# 6. convergence check
shift = np.linalg.norm(centroids - new_centroids)
centroids = new_centroids
# ===== 新增:类内相似度统计 =====
cluster_avg_sim = np.zeros(K)
cluster_sizes = np.zeros(K, dtype=int)
for k in range(K):
members = X[labels == k]
cluster_sizes[k] = len(members)
if len(members) == 0:
cluster_avg_sim[k] = np.nan
else:
cluster_avg_sim[k] = (members @ centroids[k]).mean()
print(cluster_avg_sim.min(), cluster_avg_sim.max(), cluster_avg_sim.mean())
print(shift)
if shift < tol:
break
return labels, centroids
def main():
# build model
tmr_cfg = Config.fromfile(tmr_cfg_path)
tmr_ckpt = torch.load(tmr_cfg_ckpt_path, weights_only=False, map_location='cpu')
tmr = MODELS.build(tmr_cfg.model)
tmr.load_state_dict(tmr_ckpt['state_dict'])
tmr.to(device).eval()
# build dataset
dataset = MotionDataset(data_root='/mnt/shenzhen2cephfs/capybarali/omnihumo/processed_results/omnihumo_272',
data_path_file='data/omnihumo/filtered_gmr_retarget_path.pkl',
mean_path=mean_path, std_path=std_path)
dataloader = torch.utils.data.DataLoader(dataset, batch_size=256, shuffle=False, collate_fn=collate_fn,
num_workers=4, drop_last=False)
all_m_letent = []
for data in tqdm(dataloader):
motion, motion_length, data_id = data
max_len = motion_length.max()
mask = torch.arange(max_len, device=motion.device).expand(
motion_length.shape[0], max_len) < motion_length.unsqueeze(1)
with torch.no_grad():
m_latents = tmr.motion_encoder(motion.to(device), mask=mask.to(device))[:, 0]
all_m_letent.append(m_latents)
all_m_letent = torch.concat(all_m_letent, dim=0) # T, C
labels, centroids = kmeans_cosine(all_m_letent.cpu().numpy(), K=128, max_iters=1000, tol=1e-4, seed=0)
joblib.dump(labels, 'labels.pkl')
if __name__ == '__main__':
main()