import json import os import argparse import torch from tqdm.auto import tqdm from typing import List, Dict, Any import logging import sys from torchaudio.functional import resample # Include previous imports for Speech2Unit from fairseq import checkpoint_utils import joblib import torchaudio logging.basicConfig( format="%(asctime)s | %(levelname)s | %(name)s | %(message)s", datefmt="%Y-%m-%d %H:%M:%S", level=os.environ.get("LOGLEVEL", "INFO").upper(), stream=sys.stdout, ) logger = logging.getLogger('generate_pseudo_language') DEVICE = "cuda" if torch.cuda.is_available() else "cpu" class FeatureReader: def __init__(self, ckpt_path, layer, max_chunk=1600000, fp16=False, sampling_rate=16000): (model, cfg, task) = checkpoint_utils.load_model_ensemble_and_task([ckpt_path]) self.model = model[0].eval().to(DEVICE) self.task = task self.layer = layer self.max_chunk = max_chunk self.fp16 = fp16 if fp16: self.model.half() self.target_sample_hz = sampling_rate def read_audio(self, path): wav, sr = torchaudio.load(path) if sr != self.target_sample_hz: wav = resample(wav, sr, self.target_sample_hz) return wav @torch.no_grad() def get_feats(self, waveform): x = waveform if self.fp16: x = x.half().cuda() else: x = x.float().cuda() if self.task.cfg.normalize: x = torch.nn.functional.layer_norm(x, x.shape) x = x.view(1, -1) feat = [] for start in range(0, x.size(1), self.max_chunk): x_chunk = x[:, start: start + self.max_chunk] feat_chunk, _ = self.model.extract_features( source=x_chunk, padding_mask=None, mask=False, output_layer=self.layer, ) feat.append(feat_chunk) return torch.cat(feat, 1).squeeze(0) class ApplyKmeans: def __init__(self, km_path): self.km_model = joblib.load(km_path) self.C_np = self.km_model.cluster_centers_.transpose() self.Cnorm_np = (self.C_np ** 2).sum(0, keepdims=True) self.C = torch.from_numpy(self.C_np).to(DEVICE) self.Cnorm = torch.from_numpy(self.Cnorm_np).to(DEVICE) def __call__(self, x): x = x.to(DEVICE) dist = (x.pow(2).sum(1, keepdim=True) - 2 * torch.matmul(x, self.C) + self.Cnorm) return dist.argmin(dim=1).cpu().numpy() class Speech2Unit(torch.nn.Module): def __init__(self, ckpt_dir, layer=11, max_chunk=1600000, fp16=False, sampling_rate=16000): super().__init__() ckpt_path = os.path.join(ckpt_dir, "mhubert_base_vp_en_es_fr_it3.pt") km_path = os.path.join(ckpt_dir, "mhubert_base_vp_en_es_fr_it3_L11_km1000.bin") self.feature_reader = FeatureReader(ckpt_path, layer, max_chunk, fp16, sampling_rate) self.apply_kmeans = ApplyKmeans(km_path) @staticmethod def merge_duplicates(cluster_ids): dup_cluster_list = [] duration_list = [] count = 1 for i in range(len(cluster_ids)): if i + 1 < len(cluster_ids) and cluster_ids[i] == cluster_ids[i+1]: count += 1 else: dup_cluster_list.append(cluster_ids[i]) duration_list.append(count) count = 1 return dup_cluster_list, duration_list def __call__(self, path, merged=True): waveform = self.feature_reader.read_audio(path).to(DEVICE) feat = self.feature_reader.get_feats(waveform) cluster_ids = self.apply_kmeans(feat).tolist() dup_cluster_list, _ = self.merge_duplicates(cluster_ids) merged_units = "" + "".join([f"<{str(x)}>" for x in dup_cluster_list]) + "" unmerged_units = "" + "".join([f"<{str(x)}>" for x in cluster_ids]) + "" return merged_units if merged else unmerged_units def process_jsonl(input_path: str, output_path: str, ckpt_dir: str): s2u = Speech2Unit(ckpt_dir=ckpt_dir) with open(input_path, 'r', encoding='utf-8') as infile, open(output_path, 'w', encoding='utf-8') as outfile: for line in tqdm(infile, desc="Processing JSONL"): data = json.loads(line) if "path_ans" in data: if os.path.exists(data["path_ans"]): units = s2u(data["path_ans"]) data["tgt_units"] = units outfile.write(json.dumps(data, ensure_ascii=False) + '\n') if __name__ == '__main__': parser = argparse.ArgumentParser(description="Process JSONL with Speech2Unit") parser.add_argument("--input_jsonl", type=str, required=True, help="Input JSONL file path") parser.add_argument("--output_jsonl", type=str, required=True, help="Output JSONL file path") parser.add_argument("--ckpt_dir", type=str, required=False, help="Directory of checkpoint and kmeans model", default='models/') args = parser.parse_args() process_jsonl(args.input_jsonl, args.output_jsonl, args.ckpt_dir)