| from sklearn.cluster import * |
| from sklearn import metrics |
| from sklearn.mixture import GaussianMixture |
| import os |
| import numpy as np |
| import config |
| import yaml |
| |
|
|
| if __name__ == "__main__": |
| |
| config.parser.add_argument("-a","--algorithm", default="k",help="choose algorithm",type=str) |
| config.parser.add_argument("-n","--num_clusters", default=3,help="number of clusters",type=int) |
| config.parser.add_argument("-r","--range", default=4,help="number of files in a class",type=int) |
| args = config.parser.parse_args() |
| filelist_dict={} |
| yml_result={} |
| from config import config |
| with open(config.preprocess_text_config.cleaned_path, mode="r", encoding="utf-8") as f: |
| for line in f: |
| speaker=line.split("|")[1] |
| if speaker not in filelist_dict: |
| filelist_dict[speaker]=[] |
| yml_result[speaker]={} |
| filelist_dict[speaker].append(line.split("|")[0]) |
| |
| |
| for speaker in filelist_dict: |
| embs = [] |
| wavnames = [] |
| print("\nspeaker: "+speaker) |
| for file in filelist_dict[speaker]: |
| try: |
| embs.append(np.expand_dims(np.load(f"{os.path.splitext(file)[0]}.emo.npy"), axis=0)) |
| wavnames.append(os.path.basename(file)) |
| except Exception as e: |
| print(e) |
| x = np.concatenate(embs,axis=0) |
| x = np.squeeze(x) |
| |
| n_clusters = args.num_clusters |
| if args.algorithm=="b": |
| model = Birch(n_clusters= n_clusters, threshold= 0.2) |
| elif args.algorithm=="s": |
| model = SpectralClustering(n_clusters=n_clusters) |
| elif args.algorithm=="a": |
| model = AgglomerativeClustering(n_clusters= n_clusters) |
| else: |
| model = KMeans(n_clusters=n_clusters, random_state=10) |
| |
| y_predict = model.fit_predict(x) |
| classes=[[] for i in range(y_predict.max()+1)] |
|
|
| for idx, wavname in enumerate(wavnames): |
| classes[y_predict[idx]].append(wavname) |
|
|
| for i in range(y_predict.max()+1): |
| class_length=len(classes[i]) |
| print("类别:", i, "本类中样本数量:", class_length) |
| yml_result[speaker][f"class{i}"]=[] |
| for j in range(args.range): |
| if j >=class_length: |
| break |
| print(classes[i][j]) |
| yml_result[speaker][f"class{i}"].append(classes[i][j]) |
|
|
| with open(os.path.join(config.dataset_path,'emo_clustering.yml'), 'w', encoding='utf-8') as f: |
| yaml.dump(yml_result, f) |