NMCxyz's picture
Add files using upload-large-folder tool
652c91e verified
import json
import os
import argparse
import torch
from tqdm.auto import tqdm
from typing import List, Dict, Any
import logging
import sys
from torchaudio.functional import resample
# Include previous imports for Speech2Unit
from fairseq import checkpoint_utils
import joblib
import torchaudio
logging.basicConfig(
format="%(asctime)s | %(levelname)s | %(name)s | %(message)s",
datefmt="%Y-%m-%d %H:%M:%S",
level=os.environ.get("LOGLEVEL", "INFO").upper(),
stream=sys.stdout,
)
logger = logging.getLogger('generate_pseudo_language')
DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
class FeatureReader:
def __init__(self, ckpt_path, layer, max_chunk=1600000, fp16=False, sampling_rate=16000):
(model, cfg, task) = checkpoint_utils.load_model_ensemble_and_task([ckpt_path])
self.model = model[0].eval().to(DEVICE)
self.task = task
self.layer = layer
self.max_chunk = max_chunk
self.fp16 = fp16
if fp16:
self.model.half()
self.target_sample_hz = sampling_rate
def read_audio(self, path):
wav, sr = torchaudio.load(path)
if sr != self.target_sample_hz:
wav = resample(wav, sr, self.target_sample_hz)
return wav
@torch.no_grad()
def get_feats(self, waveform):
x = waveform
if self.fp16:
x = x.half().cuda()
else:
x = x.float().cuda()
if self.task.cfg.normalize:
x = torch.nn.functional.layer_norm(x, x.shape)
x = x.view(1, -1)
feat = []
for start in range(0, x.size(1), self.max_chunk):
x_chunk = x[:, start: start + self.max_chunk]
feat_chunk, _ = self.model.extract_features(
source=x_chunk,
padding_mask=None,
mask=False,
output_layer=self.layer,
)
feat.append(feat_chunk)
return torch.cat(feat, 1).squeeze(0)
class ApplyKmeans:
def __init__(self, km_path):
self.km_model = joblib.load(km_path)
self.C_np = self.km_model.cluster_centers_.transpose()
self.Cnorm_np = (self.C_np ** 2).sum(0, keepdims=True)
self.C = torch.from_numpy(self.C_np).to(DEVICE)
self.Cnorm = torch.from_numpy(self.Cnorm_np).to(DEVICE)
def __call__(self, x):
x = x.to(DEVICE)
dist = (x.pow(2).sum(1, keepdim=True) - 2 * torch.matmul(x, self.C) + self.Cnorm)
return dist.argmin(dim=1).cpu().numpy()
class Speech2Unit(torch.nn.Module):
def __init__(self, ckpt_dir, layer=11, max_chunk=1600000, fp16=False, sampling_rate=16000):
super().__init__()
ckpt_path = os.path.join(ckpt_dir, "mhubert_base_vp_en_es_fr_it3.pt")
km_path = os.path.join(ckpt_dir, "mhubert_base_vp_en_es_fr_it3_L11_km1000.bin")
self.feature_reader = FeatureReader(ckpt_path, layer, max_chunk, fp16, sampling_rate)
self.apply_kmeans = ApplyKmeans(km_path)
@staticmethod
def merge_duplicates(cluster_ids):
dup_cluster_list = []
duration_list = []
count = 1
for i in range(len(cluster_ids)):
if i + 1 < len(cluster_ids) and cluster_ids[i] == cluster_ids[i+1]:
count += 1
else:
dup_cluster_list.append(cluster_ids[i])
duration_list.append(count)
count = 1
return dup_cluster_list, duration_list
def __call__(self, path, merged=True):
waveform = self.feature_reader.read_audio(path).to(DEVICE)
feat = self.feature_reader.get_feats(waveform)
cluster_ids = self.apply_kmeans(feat).tolist()
dup_cluster_list, _ = self.merge_duplicates(cluster_ids)
merged_units = "<sosp>" + "".join([f"<{str(x)}>" for x in dup_cluster_list]) + "<eosp>"
unmerged_units = "<sosp>" + "".join([f"<{str(x)}>" for x in cluster_ids]) + "<eosp>"
return merged_units if merged else unmerged_units
def process_jsonl(input_path: str, output_path: str, ckpt_dir: str):
s2u = Speech2Unit(ckpt_dir=ckpt_dir)
with open(input_path, 'r', encoding='utf-8') as infile, open(output_path, 'w', encoding='utf-8') as outfile:
for line in tqdm(infile, desc="Processing JSONL"):
data = json.loads(line)
if "path_ans" in data:
if os.path.exists(data["path_ans"]):
units = s2u(data["path_ans"])
data["tgt_units"] = units
outfile.write(json.dumps(data, ensure_ascii=False) + '\n')
if __name__ == '__main__':
parser = argparse.ArgumentParser(description="Process JSONL with Speech2Unit")
parser.add_argument("--input_jsonl", type=str, required=True, help="Input JSONL file path")
parser.add_argument("--output_jsonl", type=str, required=True, help="Output JSONL file path")
parser.add_argument("--ckpt_dir", type=str, required=False, help="Directory of checkpoint and kmeans model", default='models/')
args = parser.parse_args()
process_jsonl(args.input_jsonl, args.output_jsonl, args.ckpt_dir)