from typing import List, Union import logging import os import sys import joblib import fire import fairseq from fairseq import checkpoint_utils import soundfile as sf import torch import torch.nn as nn import torch.nn.functional as F from tqdm.auto import tqdm from einops import rearrange import re import numpy as np from functools import partial import torch.multiprocessing as mp import torchaudio import glob import tqdm import argparse from torchaudio.functional import resample logging.basicConfig( format="%(asctime)s | %(levelname)s | %(name)s | %(message)s", datefmt="%Y-%m-%d %H:%M:%S", level=os.environ.get("LOGLEVEL", "INFO").upper(), stream=sys.stdout, ) logger = logging.getLogger('generate_pseudo_language') DEVICE = "cuda" if torch.cuda.is_available() else "cpu" class FeatureReader(object): def __init__(self, ckpt_path, layer, max_chunk=1600000, fp16=False, sampling_rate=16000): ( model, cfg, task, ) = checkpoint_utils.load_model_ensemble_and_task([ckpt_path]) self.model = model[0].eval().to(DEVICE) self.task = task self.layer = layer self.max_chunk = max_chunk self.fp16 = fp16 if fp16: self.model.half() self.layer_shift = 0 self.target_sample_hz = sampling_rate logger.info(f"TASK CONFIG:\n{self.task.cfg}") def read_audio(self, path): wav, sr = torchaudio.load(path) if sr != self.target_sample_hz: wav = resample(wav, sr, self.target_sample_hz) return wav @torch.no_grad() def get_feats(self, waveform): x = waveform with torch.no_grad(): if self.fp16: x = x.half().cuda() else: x = x.float().cuda() if self.task.cfg.normalize: x = F.layer_norm(x, x.shape) x = x.view(1, -1) feat = [] for start in range(0, x.size(1), self.max_chunk): x_chunk = x[:, start: start + self.max_chunk] feat_chunk, _ = self.model.extract_features( source=x_chunk, padding_mask=None, mask=False, output_layer=self.layer + self.layer_shift, ) feat.append(feat_chunk) if len(feat) == 0: return torch.zeros(0, 0) return torch.cat(feat, 1).squeeze(0) class ApplyKmeans(object): def __init__(self, km_path): self.km_model = joblib.load(km_path) self.C_np = self.km_model.cluster_centers_.transpose() self.Cnorm_np = (self.C_np ** 2).sum(0, keepdims=True) self.C = torch.from_numpy(self.C_np) self.Cnorm = torch.from_numpy(self.Cnorm_np) if torch.cuda.is_available(): self.C = self.C.cuda() self.Cnorm = self.Cnorm.cuda() def __call__(self, x): if isinstance(x, torch.Tensor): self.C = self.C.to(x) self.Cnorm = self.Cnorm.to(x) dist = ( x.pow(2).sum(1, keepdim=True) - 2 * torch.matmul(x, self.C) + self.Cnorm ) return dist.argmin(dim=1).cpu().numpy() else: dist = ( (x ** 2).sum(1, keepdims=True) - 2 * np.matmul(x, self.C_np) + self.Cnorm_np ) return np.argmin(dist, axis=1) class Speech2Unit(torch.nn.Module): def __init__( self, ckpt_dir, layer=11, max_chunk=1600000, fp16=False, sampling_rate=16000, ): """ Args: ckpt_dir(str): path to hubert model dir(e.g. hubert_base_ls960.pt) layer(int): feat from which layer of hubert models defauly by 9 max_chunk(int): default by 1600000 fp16(bool): default by False sampling_rate(int): sampling_rate default by 16000 """ super().__init__() ckpt_path = os.path.join(ckpt_dir, "mhubert_base_vp_en_es_fr_it3.pt") km_path = os.path.join(ckpt_dir, "mhubert_base_vp_en_es_fr_it3_L11_km1000.bin") self.feature_reader = FeatureReader(ckpt_path, layer, max_chunk, fp16, sampling_rate) self.apply_kmeans = ApplyKmeans(km_path) @staticmethod def merge_duplicates(cluster_ids): dup_cluster_list = [] duration_list = [] count = 1 for i in range(0, len(cluster_ids)): if i + 1 < len(cluster_ids) and cluster_ids[i] == cluster_ids[i+1]: count += 1 else: dup_cluster_list.append(cluster_ids[i]) duration_list.append(count) count = 1 return dup_cluster_list, duration_list def __call__(self, path, merged=True): waveform = self.feature_reader.read_audio(path).to(DEVICE) feat = self.feature_reader.get_feats(waveform) cluster_ids = self.apply_kmeans(feat).tolist() dup_cluster_list, duration_list = self.merge_duplicates(cluster_ids) merged_units = "" + "".join([f"<{str(x)}>" for x in dup_cluster_list]) + "" unmerged_units = "" + "".join([f"<{str(x)}>" for x in cluster_ids]) + "" if merged: return merged_units else: return unmerged_units # return {"continuous":feat, "units":dup_cluster_list, "duration":duration_list, "unmerged_units":cluster_ids} if __name__ == '__main__': parser = argparse.ArgumentParser() parser.add_argument("--wav", type=str) args = parser.parse_args() ckpt_dir = "/data1/speech/anhnmt2/Speech2Speech/LLaMA-Omni/utils/speech2unit/models" s2u = Speech2Unit( ckpt_dir=ckpt_dir ) units = s2u(args.wav) print(units)