| | import json |
| | import os |
| | import argparse |
| | import torch |
| | from tqdm.auto import tqdm |
| | from typing import List, Dict, Any |
| | import logging |
| | import sys |
| | from torchaudio.functional import resample |
| |
|
| | |
| | from fairseq import checkpoint_utils |
| | import joblib |
| | import torchaudio |
| |
|
| | logging.basicConfig( |
| | format="%(asctime)s | %(levelname)s | %(name)s | %(message)s", |
| | datefmt="%Y-%m-%d %H:%M:%S", |
| | level=os.environ.get("LOGLEVEL", "INFO").upper(), |
| | stream=sys.stdout, |
| | ) |
| | logger = logging.getLogger('generate_pseudo_language') |
| |
|
| | DEVICE = "cuda" if torch.cuda.is_available() else "cpu" |
| |
|
| | class FeatureReader: |
| | def __init__(self, ckpt_path, layer, max_chunk=1600000, fp16=False, sampling_rate=16000): |
| | (model, cfg, task) = checkpoint_utils.load_model_ensemble_and_task([ckpt_path]) |
| | self.model = model[0].eval().to(DEVICE) |
| | self.task = task |
| | self.layer = layer |
| | self.max_chunk = max_chunk |
| | self.fp16 = fp16 |
| | if fp16: |
| | self.model.half() |
| | self.target_sample_hz = sampling_rate |
| | |
| | def read_audio(self, path): |
| | wav, sr = torchaudio.load(path) |
| | if sr != self.target_sample_hz: |
| | wav = resample(wav, sr, self.target_sample_hz) |
| | return wav |
| |
|
| | @torch.no_grad() |
| | def get_feats(self, waveform): |
| | x = waveform |
| | if self.fp16: |
| | x = x.half().cuda() |
| | else: |
| | x = x.float().cuda() |
| | if self.task.cfg.normalize: |
| | x = torch.nn.functional.layer_norm(x, x.shape) |
| | x = x.view(1, -1) |
| | feat = [] |
| | for start in range(0, x.size(1), self.max_chunk): |
| | x_chunk = x[:, start: start + self.max_chunk] |
| | feat_chunk, _ = self.model.extract_features( |
| | source=x_chunk, |
| | padding_mask=None, |
| | mask=False, |
| | output_layer=self.layer, |
| | ) |
| | feat.append(feat_chunk) |
| | return torch.cat(feat, 1).squeeze(0) |
| |
|
| | class ApplyKmeans: |
| | def __init__(self, km_path): |
| | self.km_model = joblib.load(km_path) |
| | self.C_np = self.km_model.cluster_centers_.transpose() |
| | self.Cnorm_np = (self.C_np ** 2).sum(0, keepdims=True) |
| | self.C = torch.from_numpy(self.C_np).to(DEVICE) |
| | self.Cnorm = torch.from_numpy(self.Cnorm_np).to(DEVICE) |
| |
|
| | def __call__(self, x): |
| | x = x.to(DEVICE) |
| | dist = (x.pow(2).sum(1, keepdim=True) - 2 * torch.matmul(x, self.C) + self.Cnorm) |
| | return dist.argmin(dim=1).cpu().numpy() |
| |
|
| | class Speech2Unit(torch.nn.Module): |
| | def __init__(self, ckpt_dir, layer=11, max_chunk=1600000, fp16=False, sampling_rate=16000): |
| | super().__init__() |
| | ckpt_path = os.path.join(ckpt_dir, "mhubert_base_vp_en_es_fr_it3.pt") |
| | km_path = os.path.join(ckpt_dir, "mhubert_base_vp_en_es_fr_it3_L11_km1000.bin") |
| | self.feature_reader = FeatureReader(ckpt_path, layer, max_chunk, fp16, sampling_rate) |
| | self.apply_kmeans = ApplyKmeans(km_path) |
| | |
| | @staticmethod |
| | def merge_duplicates(cluster_ids): |
| | dup_cluster_list = [] |
| | duration_list = [] |
| | count = 1 |
| | for i in range(len(cluster_ids)): |
| | if i + 1 < len(cluster_ids) and cluster_ids[i] == cluster_ids[i+1]: |
| | count += 1 |
| | else: |
| | dup_cluster_list.append(cluster_ids[i]) |
| | duration_list.append(count) |
| | count = 1 |
| | return dup_cluster_list, duration_list |
| | |
| | def __call__(self, path, merged=True): |
| | waveform = self.feature_reader.read_audio(path).to(DEVICE) |
| | feat = self.feature_reader.get_feats(waveform) |
| | cluster_ids = self.apply_kmeans(feat).tolist() |
| | dup_cluster_list, _ = self.merge_duplicates(cluster_ids) |
| | merged_units = "<sosp>" + "".join([f"<{str(x)}>" for x in dup_cluster_list]) + "<eosp>" |
| | unmerged_units = "<sosp>" + "".join([f"<{str(x)}>" for x in cluster_ids]) + "<eosp>" |
| | return merged_units if merged else unmerged_units |
| |
|
| | def process_jsonl(input_path: str, output_path: str, ckpt_dir: str): |
| | s2u = Speech2Unit(ckpt_dir=ckpt_dir) |
| | with open(input_path, 'r', encoding='utf-8') as infile, open(output_path, 'w', encoding='utf-8') as outfile: |
| | for line in tqdm(infile, desc="Processing JSONL"): |
| | data = json.loads(line) |
| | if "path_ans" in data: |
| | if os.path.exists(data["path_ans"]): |
| | units = s2u(data["path_ans"]) |
| | data["tgt_units"] = units |
| | outfile.write(json.dumps(data, ensure_ascii=False) + '\n') |
| |
|
| | if __name__ == '__main__': |
| | parser = argparse.ArgumentParser(description="Process JSONL with Speech2Unit") |
| | parser.add_argument("--input_jsonl", type=str, required=True, help="Input JSONL file path") |
| | parser.add_argument("--output_jsonl", type=str, required=True, help="Output JSONL file path") |
| | parser.add_argument("--ckpt_dir", type=str, required=False, help="Directory of checkpoint and kmeans model", default='models/') |
| | args = parser.parse_args() |
| |
|
| | process_jsonl(args.input_jsonl, args.output_jsonl, args.ckpt_dir) |
| |
|