import sys sys.path.append('/mnt/shenzhen2cephfs/capybarali/codes/neobot') import numpy as np from tqdm import tqdm import torch, os, random from mmengine.config import Config from mmengine.registry import MODELS import src import joblib device = torch.device('cuda') tmr_cfg_path = 'ckpts/mm_tmr_400e/tmr.py' tmr_cfg_ckpt_path = 'ckpts/mm_tmr_400e/epoch_400.pth' mean_path = 'data/motionmillion/motion_272_mean.npy' std_path = 'data/motionmillion/motion_272_std.npy' class MotionDataset(torch.utils.data.Dataset): def __init__(self, data_root, data_path_file, mean_path, std_path): self.data_list = [] self.data_ids = [] data_paths = joblib.load(data_path_file) for path in data_paths: data_id = path.replace('/mnt/shenzhen2cephfs/capybarali/codes/neobot/data/omnihumo/2_gmr_retarget', '') path = data_root + data_id.replace('.npz', '') self.data_list.append(path) self.data_ids.append(data_id) self.mean = torch.from_numpy(np.load(mean_path)) self.std = torch.from_numpy(np.load(std_path)) self.unit_length = 4 def __len__(self): return len(self.data_list) def __getitem__(self, idx): data_path = self.data_list[idx] motion = np.load(data_path) T = motion.shape[0] motion_length = (T // self.unit_length) * self.unit_length idx = random.randint(0, T - motion_length) motion = torch.from_numpy(motion[idx:idx+motion_length]) motion = (motion - self.mean) / self.std return motion, motion_length, self.data_ids[idx] def collate_fn(batch): motion = torch.nn.utils.rnn.pad_sequence([x[0] for x in batch], batch_first=True).float() motion_length = torch.tensor([x[1] for x in batch]) data_id = [x[2] for x in batch] return motion, motion_length, data_id def l2_normalize(x, axis=1, eps=1e-8): return x / (np.linalg.norm(x, axis=axis, keepdims=True) + eps) def kmeans_cosine(X, K, max_iters=100, tol=1e-4, seed=0): np.random.seed(seed) # 1. L2 normalize data X = l2_normalize(X) T, C = X.shape # 2. random init centroids indices = np.random.choice(T, K, replace=False) centroids = X[indices] for it in range(max_iters): # 3. cosine similarity = dot product (since normalized) sim = X @ centroids.T # (T, K) labels = np.argmax(sim, axis=1) # assign by max cosine # 4. update centroids new_centroids = np.zeros_like(centroids) for k in range(K): members = X[labels == k] if len(members) == 0: # empty cluster → reinit new_centroids[k] = X[np.random.randint(T)] else: new_centroids[k] = members.mean(axis=0) # 5. normalize centroids new_centroids = l2_normalize(new_centroids) # 6. convergence check shift = np.linalg.norm(centroids - new_centroids) centroids = new_centroids # ===== 新增:类内相似度统计 ===== cluster_avg_sim = np.zeros(K) cluster_sizes = np.zeros(K, dtype=int) for k in range(K): members = X[labels == k] cluster_sizes[k] = len(members) if len(members) == 0: cluster_avg_sim[k] = np.nan else: cluster_avg_sim[k] = (members @ centroids[k]).mean() print(cluster_avg_sim.min(), cluster_avg_sim.max(), cluster_avg_sim.mean()) print(shift) if shift < tol: break return labels, centroids def main(): # build model tmr_cfg = Config.fromfile(tmr_cfg_path) tmr_ckpt = torch.load(tmr_cfg_ckpt_path, weights_only=False, map_location='cpu') tmr = MODELS.build(tmr_cfg.model) tmr.load_state_dict(tmr_ckpt['state_dict']) tmr.to(device).eval() # build dataset dataset = MotionDataset(data_root='/mnt/shenzhen2cephfs/capybarali/omnihumo/processed_results/omnihumo_272', data_path_file='data/omnihumo/filtered_gmr_retarget_path.pkl', mean_path=mean_path, std_path=std_path) dataloader = torch.utils.data.DataLoader(dataset, batch_size=256, shuffle=False, collate_fn=collate_fn, num_workers=4, drop_last=False) all_m_letent = [] for data in tqdm(dataloader): motion, motion_length, data_id = data max_len = motion_length.max() mask = torch.arange(max_len, device=motion.device).expand( motion_length.shape[0], max_len) < motion_length.unsqueeze(1) with torch.no_grad(): m_latents = tmr.motion_encoder(motion.to(device), mask=mask.to(device))[:, 0] all_m_letent.append(m_latents) all_m_letent = torch.concat(all_m_letent, dim=0) # T, C labels, centroids = kmeans_cosine(all_m_letent.cpu().numpy(), K=128, max_iters=1000, tol=1e-4, seed=0) joblib.dump(labels, 'labels.pkl') if __name__ == '__main__': main()