| import os |
| from glob import glob |
| from pathlib import Path |
| import torch |
| import logging |
| import argparse |
| import torch |
| import numpy as np |
| from sklearn.cluster import KMeans, MiniBatchKMeans |
| import tqdm |
| logging.basicConfig(level=logging.INFO) |
| logger = logging.getLogger(__name__) |
| import time |
| import random |
|
|
| def train_cluster(in_dir, n_clusters, use_minibatch=True, verbose=False): |
|
|
| logger.info(f"Loading features from {in_dir}") |
| features = [] |
| nums = 0 |
| for path in tqdm.tqdm(in_dir.glob("*.soft.pt")): |
| features.append(torch.load(path).squeeze(0).numpy().T) |
| |
| features = np.concatenate(features, axis=0) |
| print(nums, features.nbytes/ 1024**2, "MB , shape:",features.shape, features.dtype) |
| features = features.astype(np.float32) |
| logger.info(f"Clustering features of shape: {features.shape}") |
| t = time.time() |
| if use_minibatch: |
| kmeans = MiniBatchKMeans(n_clusters=n_clusters,verbose=verbose, batch_size=4096, max_iter=80).fit(features) |
| else: |
| kmeans = KMeans(n_clusters=n_clusters,verbose=verbose).fit(features) |
| print(time.time()-t, "s") |
|
|
| x = { |
| "n_features_in_": kmeans.n_features_in_, |
| "_n_threads": kmeans._n_threads, |
| "cluster_centers_": kmeans.cluster_centers_, |
| } |
| print("end") |
|
|
| return x |
|
|
|
|
| if __name__ == "__main__": |
|
|
| parser = argparse.ArgumentParser() |
| parser.add_argument('--dataset', type=Path, default="./dataset/44k", |
| help='path of training data directory') |
| parser.add_argument('--output', type=Path, default="logs/44k", |
| help='path of model output directory') |
|
|
| args = parser.parse_args() |
|
|
| checkpoint_dir = args.output |
| dataset = args.dataset |
| n_clusters = 10000 |
|
|
| ckpt = {} |
| for spk in os.listdir(dataset): |
| if os.path.isdir(dataset/spk): |
| print(f"train kmeans for {spk}...") |
| in_dir = dataset/spk |
| x = train_cluster(in_dir, n_clusters, verbose=False) |
| ckpt[spk] = x |
|
|
| checkpoint_path = checkpoint_dir / f"kmeans_{n_clusters}.pt" |
| checkpoint_path.parent.mkdir(exist_ok=True, parents=True) |
| torch.save( |
| ckpt, |
| checkpoint_path, |
| ) |
|
|
|
|
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
|
|
|
|
|
|