""" Inference script for IndicMOS Author: Sathvik Udupa (sathvikudupa66@gmail.com) """ import warnings warnings.filterwarnings("ignore") import os import torch import argparse import torchaudio import numpy as np import torch.nn as nn from tqdm import tqdm import s3prl.hub as hub from huggingface_hub import hf_hub_download parser = argparse.ArgumentParser(description="IndicMOS Inference") parser.add_argument("--manifest_path", type=str, required=False, help="Path to the manifest file") parser.add_argument("--save_path", type=str, required=False, help="Path to the save file for the scores from the manifest audios") # parser.add_argument("--audio_path", type=str, required=False, help="Path to the audio file") parser.add_argument("--batch_size", type=int, default=32, help="Batch size for the manifest file") parser.add_argument("--use_cer", action="store_true", default=False, help="Enable to use CER as an input feature for MOS prediction") parser.add_argument("--use_langid", action="store_true", default=False, help="Enable to use Language ID as an input feature for MOS prediction") parser.add_argument("--device", default="cpu", help="device to run the model on") REPO_ID = "SYSPIN/IndicMOS" SSL_NAME = "indicw2v_base_pretrained.pt" BASE_PREDICTOR = "joint_indicw2v_base.pt" CER_PREDICTOR = "joint_indicw2v_base_cer.pt" LANG_ID_PREDICTOR = "joint_indicw2v_base_lang.pt" CER_LANG_ID_PREDICTOR = "joint_indicw2v_base_cer_lang.pt" HF_PATH = "hf_inference_models" LANG_ID_MAPPING = { "hi": 0, "te": 1, "mr": 2, "kn": 3, "bn": 4, "en": 5, "ch": 6, "hindi": 0, "telugu": 1, "marathi": 2, "kannada": 3, "bengali": 4, "english": 5, "chhattisgarhi": 6, } class ssl_mospred_model(nn.Module): def __init__( self, ssl_model, dim=768, use_cer=False, use_lang=False, lang_dim=32, cer_hidden_dim=32, cer_final_dim=4, proj_dim=64, num_langs=7 ): super(ssl_mospred_model, self).__init__() self.ssl_model = ssl_model if use_cer: dim = cer_hidden_dim if use_lang: dim += lang_dim self.linear = nn.Linear(dim, 1) self.use_cer = use_cer if use_cer: self.cer_embed = nn.Sequential( nn.Linear(1, cer_hidden_dim), nn.ReLU(), nn.Linear(cer_hidden_dim, cer_final_dim), nn.ReLU(), ) self.feat_proj = nn.Sequential( nn.ReLU(), nn.Linear(dim, proj_dim), ) self.use_lang = use_lang if use_lang: self.lang_embed = nn.Embedding(num_langs, lang_dim) def handle_cer_embed(self, feats, cer): if not self.use_cer: return feats feats = self.feat_proj(feats) cer = self.cer_embed(cer[:, None]) feats = torch.cat([feats, cer], -1) return feats def handle_lang_embed(self, feats, lang): if not self.use_lang: return feats lang = self.lang_embed(lang) feats = torch.cat([feats, lang], -1) return feats def get_padding_mask(self, x, feats, lengths): max_length = feats.shape[1] num_frames = round(x.shape[-1]/feats.shape[1]) ssl_lengths = [int(l/(num_frames)) for l in lengths] ssl_lengths = torch.LongTensor(ssl_lengths) mask = (torch.arange(max_length).expand(len(ssl_lengths), max_length) < ssl_lengths.unsqueeze(1)).float() return mask.to(x.device) def forward(self, x, cer_data=None, lang_data=None, lengths=None, batch_mode=False): feats = self.ssl_model(x)["hidden_states"][-1] if batch_mode: mask = self.get_padding_mask(x, feats, lengths) feats = feats * mask.unsqueeze(-1) feats = feats.sum(1)/mask.sum(-1).unsqueeze(-1) else: feats = feats.sum(1) feats = self.handle_cer_embed(feats, cer_data) feats = self.handle_lang_embed(feats, lang_data) feats = self.linear(feats) return feats.float() def download_model_from_hub(chk_name, download_path): """ Download the model from the model repo """ path = hf_hub_download(repo_id=REPO_ID, repo_type="model", filename=chk_name, cache_dir=download_path) return path def load_custom_model_from_s3prl(path): """ Load the custom model from the local s3prl file """ ssl_model = getattr(hub, "wav2vec2_custom")(ckpt=path) return ssl_model def load_model(use_cer, use_langid, download_path, device): """ Load the model from the hub """ if use_cer and use_langid: chk = CER_LANG_ID_PREDICTOR elif use_cer: chk = CER_PREDICTOR elif use_langid: chk = LANG_ID_PREDICTOR else: chk = BASE_PREDICTOR predictor_path = download_model_from_hub(chk, download_path) ssl_path = download_model_from_hub(SSL_NAME, download_path) ssl_model = load_custom_model_from_s3prl(ssl_path) predictor = torch.load(predictor_path, map_location=device) mos_model = ssl_mospred_model(ssl_model, use_cer=use_cer, use_lang=use_langid) mos_model.linear.weight.data = predictor["linear.weight"] mos_model.linear.bias.data = predictor["linear.bias"] if use_cer: mos_model.cer_embed[0].weight.data = predictor["cer_embed.0.weight"] mos_model.cer_embed[0].bias.data = predictor["cer_embed.0.bias"] mos_model.cer_embed[2].weight.data = predictor["cer_embed.2.weight"] mos_model.cer_embed[2].bias.data = predictor["cer_embed.2.bias"] mos_model.feat_proj[1].weight.data = predictor["feat_proj.1.weight"] mos_model.feat_proj[1].bias.data = predictor["feat_proj.1.bias"] if use_langid: mos_model.lang_embed.weight.data = predictor["lang_embed.weight"] mos_model.to(device) mos_model.eval() return mos_model def preprocess_single(audio_path, cer, langid): """ Preprocess the audio file and metadata """ audio, sr = torchaudio.load(audio_path) assert sr == 16000, "Audio file should be sampled at 16kHz" if cer is not None: cer = torch.tensor([cer]) if langid is not None: if langid not in LANG_ID_MAPPING: raise ValueError("Language ID not supported, please use one of the following: {}".format(LANG_ID_MAPPING.keys())) langid = torch.tensor([LANG_ID_MAPPING[langid]]) return audio, cer, langid class Collate(): def __call__(self, batch): input_lengths, ids_sorted_decreasing = torch.sort(torch.LongTensor([len(x[0]) for x in batch]),dim=0, descending=True) max_input_len = input_lengths[0] audio_padded = torch.FloatTensor(len(batch), max_input_len) audio_padded.zero_() scores, cers, langs, filenames, lengths = [], [], [], [], [] for i in range(len(batch)): audio = batch[i][0] audio_padded[i, :audio.size(0)] = audio cers.append(batch[i][1]) filenames.append(batch[i][3]) lengths.append(audio.size(0)) langs.append(batch[i][2]) lengths = torch.LongTensor(lengths) if langs[0] is not None: langs = torch.stack(langs, dim=0).squeeze() return audio_padded, cers, lengths, langs, filenames class PreProcessBatch(torch.utils.data.Dataset): def __init__(self, manifest_path, use_cer, use_langid): with open(manifest_path, "r") as f: data = f.read().split("\n") delim = "\t" if len(data[0].split("\t")) < 2: delim = " " headers = data[0].strip().split(delim) assert headers[:2] == ["id", "audio_path"], "Manifest file should have first 2 column headers as id, audio_path, instead found {}".format(headers[:2]) self.cer = cer self.langid = langid if cer is not None: assert "cer" in headers, "Manifest file should have cer column" if langid is not None: assert "langid" in headers, "Manifest file should have langid column" self.metadata_dict = {} for line in data[1:]: if line.strip() == "": continue fields = line.strip().split(delim) key, audio_path = fields[:2] self.metadata_dict[key] = {x:fields[idx+1] for idx, x in enumerate(headers[1:])} self.all_keys = list(self.metadata_dict.keys()) def __len__(self): return len(self.all_keys) def __getitem__(self, idx): key = self.all_keys[idx] audio_path = self.metadata_dict[key]["audio_path"] cer, langid = None, None if "cer" in self.metadata_dict[key]: cer = torch.tensor([float(self.metadata_dict[key]["cer"])]) if "langid" in self.metadata_dict[key]: langid = torch.tensor([LANG_ID_MAPPING[self.metadata_dict[key]["langid"]]]) audio, sr = torchaudio.load(audio_path) return audio.squeeze(), cer, langid, key def score(audio_path, cer=None, langid=None, use_cer=False, use_langid=False, download_path=HF_PATH, device="cpu"): """ Single audio mos prediction """ audio, cer, langid = preprocess_single(audio_path, cer, langid) mos_model = load_model(use_cer, use_langid, download_path, device) with torch.no_grad(): score = mos_model(audio, cer_data=cer, lang_data=langid).squeeze().cpu().item() return score def batch_score(manifest_path, save_path, batch_size=32, use_cer=False, use_langid=False, download_path="hf_inference_models", device="cpu"): """ batch audio mos prediction """ dataset = PreProcessBatch(manifest_path, use_cer, use_langid) loader = torch.utils.data.DataLoader(dataset, batch_size=args.batch_size, shuffle=False, collate_fn=Collate()) mos_model = load_model(use_cer, use_langid, download_path, device) results = {} with torch.no_grad(): for eval_data in tqdm(loader): audio, cer, lengths, langid, filenames = eval_data audio = audio.to(device) scores = mos_model(audio, cer_data=cer, lang_data=langid, lengths=lengths, batch_mode=True).squeeze(-1).cpu().numpy() for idx, filename in enumerate(filenames): results[filename] = scores[idx].squeeze() with open(save_path, "w") as f: for key, value in results.items(): f.write("{}\t{}\n".format(key, value)) return score if __name__ == "__main__": args = parser.parse_args() # if args.audio_path is None and args.manifest_path is None: # raise ValueError("Please provide either audio_path - (single file inference) or manifest_path - (batch inference)") if args.manifest_path is None: raise ValueError("Please provide manifest_path for batch inference") cer = None # if cer is not None: # if cer > 1: # print("WARNING: Use raw CER value, not percentage") langid = None # langid = "kn" # if args.audio_path is not None: ###FIX THIS # score = score(audio_path=args.audio_path, cer=cer, langid=langid, use_cer=args.use_cer, use_langid=args.use_langid) # print("predicted MOS", score) # else: assert args.save_path is not None, "Please provide a file path for the batch scores to be saved - save_path" batch_score(manifest_path=args.manifest_path, save_path=args.save_path, batch_size=args.batch_size, use_cer=args.use_cer, use_langid=args.use_langid, device=args.device)