| import json |
| import os |
| import argparse |
| import torch |
| from tqdm.auto import tqdm |
| from typing import List, Dict, Any |
| import logging |
| import sys |
| from torchaudio.functional import resample |
|
|
| |
| from fairseq import checkpoint_utils |
| import joblib |
| import torchaudio |
|
|
| logging.basicConfig( |
| format="%(asctime)s | %(levelname)s | %(name)s | %(message)s", |
| datefmt="%Y-%m-%d %H:%M:%S", |
| level=os.environ.get("LOGLEVEL", "INFO").upper(), |
| stream=sys.stdout, |
| ) |
| logger = logging.getLogger('generate_pseudo_language') |
|
|
| DEVICE = "cuda" if torch.cuda.is_available() else "cpu" |
|
|
| class FeatureReader: |
| def __init__(self, ckpt_path, layer, max_chunk=1600000, fp16=False, sampling_rate=16000): |
| (model, cfg, task) = checkpoint_utils.load_model_ensemble_and_task([ckpt_path]) |
| self.model = model[0].eval().to(DEVICE) |
| self.task = task |
| self.layer = layer |
| self.max_chunk = max_chunk |
| self.fp16 = fp16 |
| if fp16: |
| self.model.half() |
| self.target_sample_hz = sampling_rate |
| |
| def read_audio(self, path): |
| wav, sr = torchaudio.load(path) |
| if sr != self.target_sample_hz: |
| wav = resample(wav, sr, self.target_sample_hz) |
| return wav |
|
|
| @torch.no_grad() |
| def get_feats(self, waveform): |
| x = waveform |
| if self.fp16: |
| x = x.half().cuda() |
| else: |
| x = x.float().cuda() |
| if self.task.cfg.normalize: |
| x = torch.nn.functional.layer_norm(x, x.shape) |
| x = x.view(1, -1) |
| feat = [] |
| for start in range(0, x.size(1), self.max_chunk): |
| x_chunk = x[:, start: start + self.max_chunk] |
| feat_chunk, _ = self.model.extract_features( |
| source=x_chunk, |
| padding_mask=None, |
| mask=False, |
| output_layer=self.layer, |
| ) |
| feat.append(feat_chunk) |
| return torch.cat(feat, 1).squeeze(0) |
|
|
| class ApplyKmeans: |
| def __init__(self, km_path): |
| self.km_model = joblib.load(km_path) |
| self.C_np = self.km_model.cluster_centers_.transpose() |
| self.Cnorm_np = (self.C_np ** 2).sum(0, keepdims=True) |
| self.C = torch.from_numpy(self.C_np).to(DEVICE) |
| self.Cnorm = torch.from_numpy(self.Cnorm_np).to(DEVICE) |
|
|
| def __call__(self, x): |
| x = x.to(DEVICE) |
| dist = (x.pow(2).sum(1, keepdim=True) - 2 * torch.matmul(x, self.C) + self.Cnorm) |
| return dist.argmin(dim=1).cpu().numpy() |
|
|
| class Speech2Unit(torch.nn.Module): |
| def __init__(self, ckpt_dir, layer=11, max_chunk=1600000, fp16=False, sampling_rate=16000): |
| super().__init__() |
| ckpt_path = os.path.join(ckpt_dir, "mhubert_base_vp_en_es_fr_it3.pt") |
| km_path = os.path.join(ckpt_dir, "mhubert_base_vp_en_es_fr_it3_L11_km1000.bin") |
| self.feature_reader = FeatureReader(ckpt_path, layer, max_chunk, fp16, sampling_rate) |
| self.apply_kmeans = ApplyKmeans(km_path) |
| |
| @staticmethod |
| def merge_duplicates(cluster_ids): |
| dup_cluster_list = [] |
| duration_list = [] |
| count = 1 |
| for i in range(len(cluster_ids)): |
| if i + 1 < len(cluster_ids) and cluster_ids[i] == cluster_ids[i+1]: |
| count += 1 |
| else: |
| dup_cluster_list.append(cluster_ids[i]) |
| duration_list.append(count) |
| count = 1 |
| return dup_cluster_list, duration_list |
| |
| def __call__(self, path, merged=True): |
| waveform = self.feature_reader.read_audio(path).to(DEVICE) |
| feat = self.feature_reader.get_feats(waveform) |
| cluster_ids = self.apply_kmeans(feat).tolist() |
| dup_cluster_list, _ = self.merge_duplicates(cluster_ids) |
| merged_units = "<sosp>" + "".join([f"<{str(x)}>" for x in dup_cluster_list]) + "<eosp>" |
| unmerged_units = "<sosp>" + "".join([f"<{str(x)}>" for x in cluster_ids]) + "<eosp>" |
| return merged_units if merged else unmerged_units |
|
|
| def process_jsonl(input_path: str, output_path: str, ckpt_dir: str): |
| s2u = Speech2Unit(ckpt_dir=ckpt_dir) |
| with open(input_path, 'r', encoding='utf-8') as infile: |
| data_all = json.load(infile) |
| data_tgt_units = [] |
| for data in tqdm(data_all): |
| if "speech_gpt" in data: |
| if os.path.exists(data["speech_gpt"][0]): |
| units = s2u(data["speech_gpt"][0]) |
| data["tgt_units"] = units |
| data_tgt_units.append(data) |
|
|
| with open(output_path, 'w', encoding='utf-8') as f: |
| json.dump(data_tgt_units, f, ensure_ascii=False, indent=2) |
|
|
|
|
| if __name__ == '__main__': |
| parser = argparse.ArgumentParser(description="Process JSONL with Speech2Unit") |
| parser.add_argument("--input_jsonl", type=str, required=True, help="Input JSONL file path") |
| parser.add_argument("--output_jsonl", type=str, required=True, help="Output JSONL file path") |
| parser.add_argument("--ckpt_dir", type=str, required=False, help="Directory of checkpoint and kmeans model", default='models/') |
| args = parser.parse_args() |
|
|
| process_jsonl(args.input_jsonl, args.output_jsonl, args.ckpt_dir) |
|
|