| import sys |
| sys.path.append('/mnt/shenzhen2cephfs/capybarali/codes/neobot') |
|
|
| import numpy as np |
| from tqdm import tqdm |
| import torch, os, random |
| from mmengine.config import Config |
| from mmengine.registry import MODELS |
| import src |
| import joblib |
|
|
| device = torch.device('cuda') |
| tmr_cfg_path = 'ckpts/mm_tmr_400e/tmr.py' |
| tmr_cfg_ckpt_path = 'ckpts/mm_tmr_400e/epoch_400.pth' |
| mean_path = 'data/motionmillion/motion_272_mean.npy' |
| std_path = 'data/motionmillion/motion_272_std.npy' |
|
|
| class MotionDataset(torch.utils.data.Dataset): |
| def __init__(self, data_root, data_path_file, mean_path, std_path): |
| self.data_list = [] |
| self.data_ids = [] |
| data_paths = joblib.load(data_path_file) |
| for path in data_paths: |
| data_id = path.replace('/mnt/shenzhen2cephfs/capybarali/codes/neobot/data/omnihumo/2_gmr_retarget', '') |
| path = data_root + data_id.replace('.npz', '') |
| self.data_list.append(path) |
| self.data_ids.append(data_id) |
|
|
| self.mean = torch.from_numpy(np.load(mean_path)) |
| self.std = torch.from_numpy(np.load(std_path)) |
| self.unit_length = 4 |
| |
| def __len__(self): |
| return len(self.data_list) |
| |
| def __getitem__(self, idx): |
| data_path = self.data_list[idx] |
| motion = np.load(data_path) |
|
|
| T = motion.shape[0] |
|
|
| motion_length = (T // self.unit_length) * self.unit_length |
| idx = random.randint(0, T - motion_length) |
| motion = torch.from_numpy(motion[idx:idx+motion_length]) |
| |
| motion = (motion - self.mean) / self.std |
|
|
| return motion, motion_length, self.data_ids[idx] |
|
|
| def collate_fn(batch): |
| motion = torch.nn.utils.rnn.pad_sequence([x[0] for x in batch], batch_first=True).float() |
| motion_length = torch.tensor([x[1] for x in batch]) |
| data_id = [x[2] for x in batch] |
| return motion, motion_length, data_id |
|
|
| def l2_normalize(x, axis=1, eps=1e-8): |
| return x / (np.linalg.norm(x, axis=axis, keepdims=True) + eps) |
|
|
| def kmeans_cosine(X, K, max_iters=100, tol=1e-4, seed=0): |
| np.random.seed(seed) |
|
|
| |
| X = l2_normalize(X) |
|
|
| T, C = X.shape |
|
|
| |
| indices = np.random.choice(T, K, replace=False) |
| centroids = X[indices] |
|
|
| for it in range(max_iters): |
| |
| sim = X @ centroids.T |
| labels = np.argmax(sim, axis=1) |
|
|
| |
| new_centroids = np.zeros_like(centroids) |
| for k in range(K): |
| members = X[labels == k] |
| if len(members) == 0: |
| |
| new_centroids[k] = X[np.random.randint(T)] |
| else: |
| new_centroids[k] = members.mean(axis=0) |
|
|
| |
| new_centroids = l2_normalize(new_centroids) |
|
|
| |
| shift = np.linalg.norm(centroids - new_centroids) |
| centroids = new_centroids |
|
|
| |
| cluster_avg_sim = np.zeros(K) |
| cluster_sizes = np.zeros(K, dtype=int) |
| for k in range(K): |
| members = X[labels == k] |
| cluster_sizes[k] = len(members) |
| if len(members) == 0: |
| cluster_avg_sim[k] = np.nan |
| else: |
| cluster_avg_sim[k] = (members @ centroids[k]).mean() |
| print(cluster_avg_sim.min(), cluster_avg_sim.max(), cluster_avg_sim.mean()) |
| print(shift) |
|
|
| if shift < tol: |
| break |
|
|
| return labels, centroids |
|
|
| def main(): |
| |
| tmr_cfg = Config.fromfile(tmr_cfg_path) |
| tmr_ckpt = torch.load(tmr_cfg_ckpt_path, weights_only=False, map_location='cpu') |
| tmr = MODELS.build(tmr_cfg.model) |
| tmr.load_state_dict(tmr_ckpt['state_dict']) |
| tmr.to(device).eval() |
|
|
| |
| dataset = MotionDataset(data_root='/mnt/shenzhen2cephfs/capybarali/omnihumo/processed_results/omnihumo_272', |
| data_path_file='data/omnihumo/filtered_gmr_retarget_path.pkl', |
| mean_path=mean_path, std_path=std_path) |
| dataloader = torch.utils.data.DataLoader(dataset, batch_size=256, shuffle=False, collate_fn=collate_fn, |
| num_workers=4, drop_last=False) |
|
|
| all_m_letent = [] |
| for data in tqdm(dataloader): |
| motion, motion_length, data_id = data |
|
|
| max_len = motion_length.max() |
| mask = torch.arange(max_len, device=motion.device).expand( |
| motion_length.shape[0], max_len) < motion_length.unsqueeze(1) |
| |
| with torch.no_grad(): |
| m_latents = tmr.motion_encoder(motion.to(device), mask=mask.to(device))[:, 0] |
| all_m_letent.append(m_latents) |
| all_m_letent = torch.concat(all_m_letent, dim=0) |
| labels, centroids = kmeans_cosine(all_m_letent.cpu().numpy(), K=128, max_iters=1000, tol=1e-4, seed=0) |
| joblib.dump(labels, 'labels.pkl') |
|
|
| if __name__ == '__main__': |
| main() |
|
|