|
|
""" |
|
|
Inference script for IndicMOS |
|
|
|
|
|
Author: Sathvik Udupa (sathvikudupa66@gmail.com) |
|
|
""" |
|
|
|
|
|
import warnings |
|
|
warnings.filterwarnings("ignore") |
|
|
|
|
|
import os |
|
|
import torch |
|
|
import argparse |
|
|
import torchaudio |
|
|
import numpy as np |
|
|
import torch.nn as nn |
|
|
from tqdm import tqdm |
|
|
import s3prl.hub as hub |
|
|
from huggingface_hub import hf_hub_download |
|
|
|
|
|
parser = argparse.ArgumentParser(description="IndicMOS Inference") |
|
|
parser.add_argument("--manifest_path", type=str, required=False, help="Path to the manifest file") |
|
|
parser.add_argument("--save_path", type=str, required=False, help="Path to the save file for the scores from the manifest audios") |
|
|
|
|
|
parser.add_argument("--batch_size", type=int, default=32, help="Batch size for the manifest file") |
|
|
parser.add_argument("--use_cer", action="store_true", default=False, help="Enable to use CER as an input feature for MOS prediction") |
|
|
parser.add_argument("--use_langid", action="store_true", default=False, help="Enable to use Language ID as an input feature for MOS prediction") |
|
|
parser.add_argument("--device", default="cpu", help="device to run the model on") |
|
|
|
|
|
|
|
|
REPO_ID = "SYSPIN/IndicMOS" |
|
|
SSL_NAME = "indicw2v_base_pretrained.pt" |
|
|
BASE_PREDICTOR = "joint_indicw2v_base.pt" |
|
|
CER_PREDICTOR = "joint_indicw2v_base_cer.pt" |
|
|
LANG_ID_PREDICTOR = "joint_indicw2v_base_lang.pt" |
|
|
CER_LANG_ID_PREDICTOR = "joint_indicw2v_base_cer_lang.pt" |
|
|
HF_PATH = "hf_inference_models" |
|
|
|
|
|
LANG_ID_MAPPING = { |
|
|
"hi": 0, |
|
|
"te": 1, |
|
|
"mr": 2, |
|
|
"kn": 3, |
|
|
"bn": 4, |
|
|
"en": 5, |
|
|
"ch": 6, |
|
|
"hindi": 0, |
|
|
"telugu": 1, |
|
|
"marathi": 2, |
|
|
"kannada": 3, |
|
|
"bengali": 4, |
|
|
"english": 5, |
|
|
"chhattisgarhi": 6, |
|
|
} |
|
|
|
|
|
class ssl_mospred_model(nn.Module): |
|
|
def __init__( |
|
|
self, |
|
|
ssl_model, |
|
|
dim=768, |
|
|
use_cer=False, |
|
|
use_lang=False, |
|
|
lang_dim=32, |
|
|
cer_hidden_dim=32, |
|
|
cer_final_dim=4, |
|
|
proj_dim=64, |
|
|
num_langs=7 |
|
|
): |
|
|
super(ssl_mospred_model, self).__init__() |
|
|
self.ssl_model = ssl_model |
|
|
if use_cer: |
|
|
dim = cer_hidden_dim |
|
|
if use_lang: |
|
|
dim += lang_dim |
|
|
|
|
|
self.linear = nn.Linear(dim, 1) |
|
|
self.use_cer = use_cer |
|
|
if use_cer: |
|
|
self.cer_embed = nn.Sequential( |
|
|
nn.Linear(1, cer_hidden_dim), |
|
|
nn.ReLU(), |
|
|
nn.Linear(cer_hidden_dim, cer_final_dim), |
|
|
nn.ReLU(), |
|
|
) |
|
|
self.feat_proj = nn.Sequential( |
|
|
nn.ReLU(), |
|
|
nn.Linear(dim, proj_dim), |
|
|
) |
|
|
self.use_lang = use_lang |
|
|
if use_lang: |
|
|
self.lang_embed = nn.Embedding(num_langs, lang_dim) |
|
|
|
|
|
def handle_cer_embed(self, feats, cer): |
|
|
if not self.use_cer: |
|
|
return feats |
|
|
feats = self.feat_proj(feats) |
|
|
cer = self.cer_embed(cer[:, None]) |
|
|
feats = torch.cat([feats, cer], -1) |
|
|
return feats |
|
|
|
|
|
def handle_lang_embed(self, feats, lang): |
|
|
if not self.use_lang: |
|
|
return feats |
|
|
lang = self.lang_embed(lang) |
|
|
feats = torch.cat([feats, lang], -1) |
|
|
return feats |
|
|
|
|
|
def get_padding_mask(self, x, feats, lengths): |
|
|
max_length = feats.shape[1] |
|
|
num_frames = round(x.shape[-1]/feats.shape[1]) |
|
|
ssl_lengths = [int(l/(num_frames)) for l in lengths] |
|
|
ssl_lengths = torch.LongTensor(ssl_lengths) |
|
|
mask = (torch.arange(max_length).expand(len(ssl_lengths), max_length) < ssl_lengths.unsqueeze(1)).float() |
|
|
return mask.to(x.device) |
|
|
|
|
|
def forward(self, x, cer_data=None, lang_data=None, lengths=None, batch_mode=False): |
|
|
feats = self.ssl_model(x)["hidden_states"][-1] |
|
|
if batch_mode: |
|
|
mask = self.get_padding_mask(x, feats, lengths) |
|
|
feats = feats * mask.unsqueeze(-1) |
|
|
feats = feats.sum(1)/mask.sum(-1).unsqueeze(-1) |
|
|
else: |
|
|
feats = feats.sum(1) |
|
|
feats = self.handle_cer_embed(feats, cer_data) |
|
|
feats = self.handle_lang_embed(feats, lang_data) |
|
|
feats = self.linear(feats) |
|
|
return feats.float() |
|
|
|
|
|
def download_model_from_hub(chk_name, download_path): |
|
|
""" |
|
|
Download the model from the model repo |
|
|
""" |
|
|
path = hf_hub_download(repo_id=REPO_ID, repo_type="model", filename=chk_name, cache_dir=download_path) |
|
|
return path |
|
|
|
|
|
def load_custom_model_from_s3prl(path): |
|
|
""" |
|
|
Load the custom model from the local s3prl file |
|
|
""" |
|
|
ssl_model = getattr(hub, "wav2vec2_custom")(ckpt=path) |
|
|
return ssl_model |
|
|
|
|
|
def load_model(use_cer, use_langid, download_path, device): |
|
|
""" |
|
|
Load the model from the hub |
|
|
""" |
|
|
if use_cer and use_langid: |
|
|
chk = CER_LANG_ID_PREDICTOR |
|
|
elif use_cer: |
|
|
chk = CER_PREDICTOR |
|
|
elif use_langid: |
|
|
chk = LANG_ID_PREDICTOR |
|
|
else: |
|
|
chk = BASE_PREDICTOR |
|
|
predictor_path = download_model_from_hub(chk, download_path) |
|
|
ssl_path = download_model_from_hub(SSL_NAME, download_path) |
|
|
ssl_model = load_custom_model_from_s3prl(ssl_path) |
|
|
predictor = torch.load(predictor_path, map_location=device) |
|
|
|
|
|
mos_model = ssl_mospred_model(ssl_model, use_cer=use_cer, use_lang=use_langid) |
|
|
mos_model.linear.weight.data = predictor["linear.weight"] |
|
|
mos_model.linear.bias.data = predictor["linear.bias"] |
|
|
|
|
|
if use_cer: |
|
|
mos_model.cer_embed[0].weight.data = predictor["cer_embed.0.weight"] |
|
|
mos_model.cer_embed[0].bias.data = predictor["cer_embed.0.bias"] |
|
|
mos_model.cer_embed[2].weight.data = predictor["cer_embed.2.weight"] |
|
|
mos_model.cer_embed[2].bias.data = predictor["cer_embed.2.bias"] |
|
|
|
|
|
mos_model.feat_proj[1].weight.data = predictor["feat_proj.1.weight"] |
|
|
mos_model.feat_proj[1].bias.data = predictor["feat_proj.1.bias"] |
|
|
|
|
|
if use_langid: |
|
|
mos_model.lang_embed.weight.data = predictor["lang_embed.weight"] |
|
|
|
|
|
mos_model.to(device) |
|
|
mos_model.eval() |
|
|
return mos_model |
|
|
|
|
|
def preprocess_single(audio_path, cer, langid): |
|
|
""" |
|
|
Preprocess the audio file and metadata |
|
|
""" |
|
|
audio, sr = torchaudio.load(audio_path) |
|
|
assert sr == 16000, "Audio file should be sampled at 16kHz" |
|
|
if cer is not None: |
|
|
cer = torch.tensor([cer]) |
|
|
if langid is not None: |
|
|
if langid not in LANG_ID_MAPPING: |
|
|
raise ValueError("Language ID not supported, please use one of the following: {}".format(LANG_ID_MAPPING.keys())) |
|
|
langid = torch.tensor([LANG_ID_MAPPING[langid]]) |
|
|
return audio, cer, langid |
|
|
|
|
|
class Collate(): |
|
|
def __call__(self, batch): |
|
|
input_lengths, ids_sorted_decreasing = torch.sort(torch.LongTensor([len(x[0]) for x in batch]),dim=0, descending=True) |
|
|
max_input_len = input_lengths[0] |
|
|
audio_padded = torch.FloatTensor(len(batch), max_input_len) |
|
|
audio_padded.zero_() |
|
|
scores, cers, langs, filenames, lengths = [], [], [], [], [] |
|
|
for i in range(len(batch)): |
|
|
audio = batch[i][0] |
|
|
audio_padded[i, :audio.size(0)] = audio |
|
|
cers.append(batch[i][1]) |
|
|
filenames.append(batch[i][3]) |
|
|
lengths.append(audio.size(0)) |
|
|
langs.append(batch[i][2]) |
|
|
lengths = torch.LongTensor(lengths) |
|
|
if langs[0] is not None: |
|
|
langs = torch.stack(langs, dim=0).squeeze() |
|
|
return audio_padded, cers, lengths, langs, filenames |
|
|
|
|
|
class PreProcessBatch(torch.utils.data.Dataset): |
|
|
def __init__(self, manifest_path, use_cer, use_langid): |
|
|
with open(manifest_path, "r") as f: |
|
|
data = f.read().split("\n") |
|
|
delim = "\t" |
|
|
if len(data[0].split("\t")) < 2: |
|
|
delim = " " |
|
|
headers = data[0].strip().split(delim) |
|
|
assert headers[:2] == ["id", "audio_path"], "Manifest file should have first 2 column headers as id, audio_path, instead found {}".format(headers[:2]) |
|
|
self.cer = cer |
|
|
self.langid = langid |
|
|
|
|
|
if cer is not None: |
|
|
assert "cer" in headers, "Manifest file should have cer column" |
|
|
if langid is not None: |
|
|
assert "langid" in headers, "Manifest file should have langid column" |
|
|
self.metadata_dict = {} |
|
|
for line in data[1:]: |
|
|
if line.strip() == "": |
|
|
continue |
|
|
fields = line.strip().split(delim) |
|
|
key, audio_path = fields[:2] |
|
|
self.metadata_dict[key] = {x:fields[idx+1] for idx, x in enumerate(headers[1:])} |
|
|
self.all_keys = list(self.metadata_dict.keys()) |
|
|
|
|
|
def __len__(self): |
|
|
return len(self.all_keys) |
|
|
|
|
|
def __getitem__(self, idx): |
|
|
key = self.all_keys[idx] |
|
|
audio_path = self.metadata_dict[key]["audio_path"] |
|
|
cer, langid = None, None |
|
|
if "cer" in self.metadata_dict[key]: |
|
|
cer = torch.tensor([float(self.metadata_dict[key]["cer"])]) |
|
|
if "langid" in self.metadata_dict[key]: |
|
|
langid = torch.tensor([LANG_ID_MAPPING[self.metadata_dict[key]["langid"]]]) |
|
|
|
|
|
audio, sr = torchaudio.load(audio_path) |
|
|
return audio.squeeze(), cer, langid, key |
|
|
|
|
|
def score(audio_path, cer=None, langid=None, use_cer=False, use_langid=False, download_path=HF_PATH, device="cpu"): |
|
|
""" |
|
|
Single audio mos prediction |
|
|
""" |
|
|
audio, cer, langid = preprocess_single(audio_path, cer, langid) |
|
|
mos_model = load_model(use_cer, use_langid, download_path, device) |
|
|
with torch.no_grad(): |
|
|
score = mos_model(audio, cer_data=cer, lang_data=langid).squeeze().cpu().item() |
|
|
return score |
|
|
|
|
|
def batch_score(manifest_path, save_path, batch_size=32, use_cer=False, use_langid=False, download_path="hf_inference_models", device="cpu"): |
|
|
""" |
|
|
batch audio mos prediction |
|
|
""" |
|
|
dataset = PreProcessBatch(manifest_path, use_cer, use_langid) |
|
|
loader = torch.utils.data.DataLoader(dataset, batch_size=args.batch_size, shuffle=False, collate_fn=Collate()) |
|
|
mos_model = load_model(use_cer, use_langid, download_path, device) |
|
|
results = {} |
|
|
with torch.no_grad(): |
|
|
for eval_data in tqdm(loader): |
|
|
audio, cer, lengths, langid, filenames = eval_data |
|
|
audio = audio.to(device) |
|
|
scores = mos_model(audio, cer_data=cer, lang_data=langid, lengths=lengths, batch_mode=True).squeeze(-1).cpu().numpy() |
|
|
for idx, filename in enumerate(filenames): |
|
|
results[filename] = scores[idx].squeeze() |
|
|
with open(save_path, "w") as f: |
|
|
for key, value in results.items(): |
|
|
f.write("{}\t{}\n".format(key, value)) |
|
|
return score |
|
|
|
|
|
if __name__ == "__main__": |
|
|
args = parser.parse_args() |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
if args.manifest_path is None: |
|
|
raise ValueError("Please provide manifest_path for batch inference") |
|
|
|
|
|
cer = None |
|
|
|
|
|
|
|
|
|
|
|
langid = None |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
assert args.save_path is not None, "Please provide a file path for the batch scores to be saved - save_path" |
|
|
batch_score(manifest_path=args.manifest_path, save_path=args.save_path, batch_size=args.batch_size, use_cer=args.use_cer, use_langid=args.use_langid, device=args.device) |
|
|
|
|
|
|