| import sys |
| import h5py |
| import soundfile as sf |
| import fire |
| import math |
| import yamlargparse |
| import numpy as np |
| from torch.utils.data import DataLoader |
| import torch |
| from utils.utils import parse_config_or_kwargs |
| from utils.dataset import DiarizationDataset |
| from models.models import TransformerDiarization |
| from scipy.signal import medfilt |
| from sklearn.cluster import AgglomerativeClustering |
| from scipy.spatial import distance |
| from utils.kaldi_data import KaldiData |
|
|
|
|
| def get_cl_sil(args, acti, cls_num): |
| n_chunks = len(acti) |
| mean_acti = np.array([np.mean(acti[i], axis=0) |
| for i in range(n_chunks)]).flatten() |
| n = args.num_speakers |
| sil_spk_th = args.sil_spk_th |
|
|
| cl_lst = [] |
| sil_lst = [] |
| for chunk_idx in range(n_chunks): |
| if cls_num is not None: |
| if args.num_speakers > cls_num: |
| mean_acti_bi = np.array([mean_acti[n * chunk_idx + s_loc_idx] |
| for s_loc_idx in range(n)]) |
| min_idx = np.argmin(mean_acti_bi) |
| mean_acti[n * chunk_idx + min_idx] = 0.0 |
|
|
| for s_loc_idx in range(n): |
| a = n * chunk_idx + (s_loc_idx + 0) % n |
| b = n * chunk_idx + (s_loc_idx + 1) % n |
| if mean_acti[a] > sil_spk_th and mean_acti[b] > sil_spk_th: |
| cl_lst.append((a, b)) |
| else: |
| if mean_acti[a] <= sil_spk_th: |
| sil_lst.append(a) |
|
|
| return cl_lst, sil_lst |
|
|
|
|
| def clustering(args, svec, cls_num, ahc_dis_th, cl_lst, sil_lst): |
| org_svec_len = len(svec) |
| svec = np.delete(svec, sil_lst, 0) |
|
|
| |
| _tbl = [i - sum(sil < i for sil in sil_lst) for i in range(org_svec_len)] |
| cl_lst = [(_tbl[_cl[0]], _tbl[_cl[1]]) for _cl in cl_lst] |
|
|
| distMat = distance.cdist(svec, svec, metric='euclidean') |
| for cl in cl_lst: |
| distMat[cl[0], cl[1]] = args.clink_dis |
| distMat[cl[1], cl[0]] = args.clink_dis |
|
|
| clusterer = AgglomerativeClustering( |
| n_clusters=cls_num, |
| affinity='precomputed', |
| linkage='average', |
| distance_threshold=ahc_dis_th) |
| clusterer.fit(distMat) |
|
|
| if cls_num is not None: |
| print("oracle n_clusters is known") |
| else: |
| print("oracle n_clusters is unknown") |
| print("estimated n_clusters by constraind AHC: {}" |
| .format(len(np.unique(clusterer.labels_)))) |
| cls_num = len(np.unique(clusterer.labels_)) |
|
|
| sil_lab = cls_num |
| insert_sil_lab = [sil_lab for i in range(len(sil_lst))] |
| insert_sil_lab_idx = [sil_lst[i] - i for i in range(len(sil_lst))] |
| print("insert_sil_lab : {}".format(insert_sil_lab)) |
| print("insert_sil_lab_idx : {}".format(insert_sil_lab_idx)) |
| clslab = np.insert(clusterer.labels_, |
| insert_sil_lab_idx, |
| insert_sil_lab).reshape(-1, args.num_speakers) |
| print("clslab : {}".format(clslab)) |
|
|
| return clslab, cls_num |
|
|
|
|
| def merge_act_max(act, i, j): |
| for k in range(len(act)): |
| act[k, i] = max(act[k, i], act[k, j]) |
| act[k, j] = 0.0 |
| return act |
|
|
|
|
| def merge_acti_clslab(args, acti, clslab, cls_num): |
| sil_lab = cls_num |
| for i in range(len(clslab)): |
| _lab = clslab[i].reshape(-1, 1) |
| distM = distance.cdist(_lab, _lab, metric='euclidean').astype(np.int64) |
| for j in range(len(distM)): |
| distM[j][:j] = -1 |
| idx_lst = np.where(np.count_nonzero(distM == 0, axis=1) > 1) |
| merge_done = [] |
| for j in idx_lst[0]: |
| for k in (np.where(distM[j] == 0))[0]: |
| if j != k and clslab[i, j] != sil_lab and k not in merge_done: |
| print("merge : (i, j, k) == ({}, {}, {})".format(i, j, k)) |
| acti[i] = merge_act_max(acti[i], j, k) |
| clslab[i, k] = sil_lab |
| merge_done.append(j) |
|
|
| return acti, clslab |
|
|
|
|
| def stitching(args, acti, clslab, cls_num): |
| n_chunks = len(acti) |
| s_loc = args.num_speakers |
| sil_lab = cls_num |
| s_tot = max(cls_num, s_loc-1) |
|
|
| |
| add_acti = [] |
| for chunk_idx in range(n_chunks): |
| zeros = np.zeros((len(acti[chunk_idx]), s_tot+1)) |
| if s_tot+1 > s_loc: |
| zeros[:, :-(s_tot+1-s_loc)] = acti[chunk_idx] |
| else: |
| zeros = acti[chunk_idx] |
| add_acti.append(zeros) |
| acti = np.array(add_acti) |
|
|
| out_chunks = [] |
| for chunk_idx in range(n_chunks): |
| |
| |
| |
| cls_set = set() |
| for s_loc_idx in range(s_tot+1): |
| cls_set.add(s_loc_idx) |
|
|
| sloci2lab_dct = {} |
| for s_loc_idx in range(s_tot+1): |
| if s_loc_idx < s_loc: |
| sloci2lab_dct[s_loc_idx] = clslab[chunk_idx][s_loc_idx] |
| if clslab[chunk_idx][s_loc_idx] in cls_set: |
| cls_set.remove(clslab[chunk_idx][s_loc_idx]) |
| else: |
| if clslab[chunk_idx][s_loc_idx] != sil_lab: |
| raise ValueError |
| else: |
| sloci2lab_dct[s_loc_idx] = list(cls_set)[s_loc_idx-s_loc] |
|
|
| |
| sloci2lab_lst = sorted(sloci2lab_dct.items(), key=lambda x: x[1]) |
|
|
| |
| sil_lab_idx = None |
| for idx_lab in sloci2lab_lst: |
| if idx_lab[1] == sil_lab: |
| sil_lab_idx = idx_lab[0] |
| break |
| if sil_lab_idx is None: |
| raise ValueError |
|
|
| |
| |
| swap_idx = [sil_lab_idx for j in range(s_tot+1)] |
| for lab in range(s_tot+1): |
| for idx_lab in sloci2lab_lst: |
| if lab == idx_lab[1]: |
| swap_idx[lab] = idx_lab[0] |
|
|
| print("swap_idx {}".format(swap_idx)) |
| swap_acti = acti[chunk_idx][:, swap_idx] |
| swap_acti = np.delete(swap_acti, sil_lab, 1) |
| out_chunks.append(swap_acti) |
|
|
| return out_chunks |
|
|
|
|
| def prediction(num_speakers, net, wav_list, chunk_len_list): |
| acti_lst = [] |
| svec_lst = [] |
| len_list = [] |
|
|
| with torch.no_grad(): |
| for wav, chunk_len in zip(wav_list, chunk_len_list): |
| wav = wav.to('cuda') |
| outputs = net.batch_estimate(torch.unsqueeze(wav, 0)) |
| ys = outputs[0] |
|
|
| for i in range(num_speakers): |
| spkivecs = outputs[i+1] |
| svec_lst.append(spkivecs[0].cpu().detach().numpy()) |
|
|
| acti = ys[0][-chunk_len:].cpu().detach().numpy() |
| acti_lst.append(acti) |
| len_list.append(chunk_len) |
|
|
| acti_arr = np.concatenate(acti_lst, axis=0) |
| svec_arr = np.stack(svec_lst) |
| len_arr = np.array(len_list) |
|
|
| return acti_arr, svec_arr, len_arr |
|
|
| def cluster(args, conf, acti_arr, svec_arr, len_arr): |
|
|
| acti_list = [] |
| n_chunks = len_arr.shape[0] |
| start = 0 |
| for i in range(n_chunks): |
| chunk_len = len_arr[i] |
| acti_list.append(acti_arr[start: start+chunk_len]) |
| start += chunk_len |
| acti = np.array(acti_list) |
| svec = svec_arr |
|
|
| |
| cls_num = None |
| ahc_dis_th = args.ahc_dis_th |
| |
| cl_lst, sil_lst = get_cl_sil(args, acti, cls_num) |
|
|
| n_samples = n_chunks * args.num_speakers - len(sil_lst) |
| min_n_samples = 2 |
| if cls_num is not None: |
| min_n_samples = cls_num |
|
|
| if n_samples >= min_n_samples: |
| |
| clslab, cls_num =\ |
| clustering(args, svec, cls_num, ahc_dis_th, cl_lst, sil_lst) |
| |
| acti, clslab = merge_acti_clslab(args, acti, clslab, cls_num) |
| |
| out_chunks = stitching(args, acti, clslab, cls_num) |
| else: |
| out_chunks = acti |
|
|
| outdata = np.vstack(out_chunks) |
| |
| return outdata |
|
|
| def make_rttm(args, conf, cluster_data): |
| args.frame_shift = conf['model']['frame_shift'] |
| args.subsampling = conf['model']['subsampling'] |
| args.sampling_rate = conf['dataset']['sampling_rate'] |
|
|
| with open(args.out_rttm_file, 'w') as wf: |
| a = np.where(cluster_data > args.threshold, 1, 0) |
| if args.median > 1: |
| a = medfilt(a, (args.median, 1)) |
| for spkid, frames in enumerate(a.T): |
| frames = np.pad(frames, (1, 1), 'constant') |
| changes, = np.where(np.diff(frames, axis=0) != 0) |
| fmt = "SPEAKER {:s} 1 {:7.2f} {:7.2f} <NA> <NA> {:s} <NA>" |
| for s, e in zip(changes[::2], changes[1::2]): |
| print(fmt.format( |
| args.session, |
| s * args.frame_shift * args.subsampling / args.sampling_rate, |
| (e - s) * args.frame_shift * args.subsampling / args.sampling_rate, |
| args.session + "_" + str(spkid)), file=wf) |
|
|
| def main(args): |
| conf = parse_config_or_kwargs(args.config_path) |
| num_speakers = conf['dataset']['num_speakers'] |
| args.num_speakers = num_speakers |
|
|
| |
| model_parameter_dict = torch.load(args.model_init)['model'] |
| model_all_n_speakers = model_parameter_dict["embed.weight"].shape[0] |
| conf['model']['all_n_speakers'] = model_all_n_speakers |
| net = TransformerDiarization(**conf['model']) |
| net.load_state_dict(model_parameter_dict, strict=False) |
| net.eval() |
| net = net.to("cuda") |
|
|
| audio, sr = sf.read(args.wav_path, dtype="float32") |
| audio_len = audio.shape[0] |
| chunk_size, frame_shift, subsampling = conf['dataset']['chunk_size'], conf['model']['frame_shift'], conf['model']['subsampling'] |
| scale_ratio = int(frame_shift * subsampling) |
| chunk_audio_size = chunk_size * scale_ratio |
| wav_list, chunk_len_list = [], [] |
| for i in range(0, math.ceil(1.0 * audio_len / chunk_audio_size)): |
| start, end = i*chunk_audio_size, (i+1)*chunk_audio_size |
| if end > audio_len: |
| chunk_len_list.append(int((audio_len-start) / scale_ratio)) |
| end = audio_len |
| start = max(0, audio_len - chunk_audio_size) |
| else: |
| chunk_len_list.append(chunk_size) |
| wav_list.append(audio[start:end]) |
| wav_list = [torch.from_numpy(wav).float() for wav in wav_list] |
|
|
| acti_arr, svec_arr, len_arr = prediction(num_speakers, net, wav_list, chunk_len_list) |
| cluster_data = cluster(args, conf, acti_arr, svec_arr, len_arr) |
| make_rttm(args, conf, cluster_data) |
|
|
| if __name__ == '__main__': |
| parser = yamlargparse.ArgumentParser(description='decoding') |
| parser.add_argument('--wav_path', |
| help='the input wav path', |
| default="tmp/mix_0000496.wav") |
| parser.add_argument('--config_path', |
| help='config file path', |
| default="config/infer_est_nspk1.yaml") |
| parser.add_argument('--model_init', |
| help='model initialize path', |
| default="") |
| parser.add_argument('--sil_spk_th', default=0.05, type=float) |
| parser.add_argument('--ahc_dis_th', default=1.0, type=float) |
| parser.add_argument('--clink_dis', default=1.0e+4, type=float) |
| parser.add_argument('--session', default='Anonymous', help='the name of the output speaker') |
| parser.add_argument('--out_rttm_file', default='out.rttm', help='the output rttm file') |
| parser.add_argument('--threshold', default=0.4, type=float) |
| parser.add_argument('--median', default=25, type=int) |
|
|
|
|
| args = parser.parse_args() |
| main(args) |
|
|