viks66 commited on
Commit
eda8f8c
·
1 Parent(s): 14fe765

add inference code

Browse files
infer_indicmos.py ADDED
@@ -0,0 +1,304 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Inference script for IndicMOS
3
+
4
+ Author: Sathvik Udupa (sathvikudupa66@gmail.com)
5
+ """
6
+
7
+ import warnings
8
+ warnings.filterwarnings("ignore")
9
+
10
+ import os
11
+ import torch
12
+ import argparse
13
+ import torchaudio
14
+ import numpy as np
15
+ import torch.nn as nn
16
+ from tqdm import tqdm
17
+ import s3prl.hub as hub
18
+ from huggingface_hub import hf_hub_download
19
+
20
+ parser = argparse.ArgumentParser(description="IndicMOS Inference")
21
+ parser.add_argument("--manifest_path", type=str, required=False, help="Path to the manifest file")
22
+ parser.add_argument("--save_path", type=str, required=False, help="Path to the save file for the scores from the manifest audios")
23
+ # parser.add_argument("--audio_path", type=str, required=False, help="Path to the audio file")
24
+ parser.add_argument("--batch_size", type=int, default=32, help="Batch size for the manifest file")
25
+ parser.add_argument("--use_cer", action="store_true", default=False, help="Enable to use CER as an input feature for MOS prediction")
26
+ parser.add_argument("--use_langid", action="store_true", default=False, help="Enable to use Language ID as an input feature for MOS prediction")
27
+ parser.add_argument("--device", default="cpu", help="device to run the model on")
28
+
29
+
30
+ REPO_ID = "SYSPIN/IndicMOS"
31
+ SSL_NAME = "indicw2v_base_pretrained.pt"
32
+ BASE_PREDICTOR = "joint_indicw2v_base.pt"
33
+ CER_PREDICTOR = "joint_indicw2v_base_cer.pt"
34
+ LANG_ID_PREDICTOR = "joint_indicw2v_base_lang.pt"
35
+ CER_LANG_ID_PREDICTOR = "joint_indicw2v_base_cer_lang.pt"
36
+
37
+ LANG_ID_MAPPING = {
38
+ "hi": 0,
39
+ "te": 1,
40
+ "mr": 2,
41
+ "kn": 3,
42
+ "bn": 4,
43
+ "en": 5,
44
+ "ch": 6,
45
+ "hindi": 0,
46
+ "telugu": 1,
47
+ "marathi": 2,
48
+ "kannada": 3,
49
+ "bengali": 4,
50
+ "english": 5,
51
+ "chhattisgarhi": 6,
52
+ }
53
+
54
+ class ssl_mospred_model(nn.Module):
55
+ def __init__(
56
+ self,
57
+ ssl_model,
58
+ dim=768,
59
+ use_cer=False,
60
+ use_lang=False,
61
+ lang_dim=32,
62
+ cer_hidden_dim=32,
63
+ cer_final_dim=4,
64
+ proj_dim=64,
65
+ num_langs=7
66
+ ):
67
+ super(ssl_mospred_model, self).__init__()
68
+ self.ssl_model = ssl_model
69
+ if use_cer:
70
+ dim = cer_hidden_dim
71
+ if use_lang:
72
+ dim += lang_dim
73
+
74
+ self.linear = nn.Linear(dim, 1)
75
+ self.use_cer = use_cer
76
+ if use_cer:
77
+ self.cer_embed = nn.Sequential(
78
+ nn.Linear(1, cer_hidden_dim),
79
+ nn.ReLU(),
80
+ nn.Linear(cer_hidden_dim, cer_final_dim),
81
+ nn.ReLU(),
82
+ )
83
+ self.feat_proj = nn.Sequential(
84
+ nn.ReLU(),
85
+ nn.Linear(dim, proj_dim),
86
+ )
87
+ self.use_lang = use_lang
88
+ if use_lang:
89
+ self.lang_embed = nn.Embedding(num_langs, lang_dim)
90
+
91
+ def handle_cer_embed(self, feats, cer):
92
+ if not self.use_cer:
93
+ return feats
94
+ feats = self.feat_proj(feats)
95
+ cer = self.cer_embed(cer[:, None])
96
+ feats = torch.cat([feats, cer], -1)
97
+ return feats
98
+
99
+ def handle_lang_embed(self, feats, lang):
100
+ if not self.use_lang:
101
+ return feats
102
+ lang = self.lang_embed(lang)
103
+ feats = torch.cat([feats, lang], -1)
104
+ return feats
105
+
106
+ def get_padding_mask(self, x, feats, lengths):
107
+ max_length = feats.shape[1]
108
+ num_frames = round(x.shape[-1]/feats.shape[1])
109
+ ssl_lengths = [int(l/(num_frames)) for l in lengths]
110
+ ssl_lengths = torch.LongTensor(ssl_lengths)
111
+ mask = (torch.arange(max_length).expand(len(ssl_lengths), max_length) < ssl_lengths.unsqueeze(1)).float()
112
+ return mask.to(x.device)
113
+
114
+ def forward(self, x, cer_data=None, lang_data=None, lengths=None, batch_mode=False):
115
+ feats = self.ssl_model(x)["hidden_states"][-1]
116
+ if batch_mode:
117
+ mask = self.get_padding_mask(x, feats, lengths)
118
+ feats = feats * mask.unsqueeze(-1)
119
+ feats = feats.sum(1)/mask.sum(-1).unsqueeze(-1)
120
+ else:
121
+ feats = feats.sum(1)
122
+ feats = self.handle_cer_embed(feats, cer_data)
123
+ feats = self.handle_lang_embed(feats, lang_data)
124
+ feats = self.linear(feats)
125
+ return feats.float()
126
+
127
+ def download_model_from_hub(chk_name, download_path):
128
+ """
129
+ Download the model from the model repo
130
+ """
131
+ path = hf_hub_download(repo_id=REPO_ID, repo_type="model", filename=chk_name, cache_dir=download_path)
132
+ return path
133
+
134
+ def load_custom_model_from_s3prl(path):
135
+ """
136
+ Load the custom model from the local s3prl file
137
+ """
138
+ ssl_model = getattr(hub, "wav2vec2_custom")(ckpt=path)
139
+ return ssl_model
140
+
141
+ def load_model(use_cer, use_langid, download_path, device):
142
+ """
143
+ Load the model from the hub
144
+ """
145
+ if use_cer and use_langid:
146
+ chk = CER_LANG_ID_PREDICTOR
147
+ elif use_cer:
148
+ chk = CER_PREDICTOR
149
+ elif use_langid:
150
+ chk = LANG_ID_PREDICTOR
151
+ else:
152
+ chk = BASE_PREDICTOR
153
+ predictor_path = download_model_from_hub(chk, download_path)
154
+ ssl_path = download_model_from_hub(SSL_NAME, download_path)
155
+ ssl_model = load_custom_model_from_s3prl(ssl_path)
156
+ predictor = torch.load(predictor_path, map_location=device)
157
+
158
+ mos_model = ssl_mospred_model(ssl_model, use_cer=use_cer, use_lang=use_langid)
159
+ mos_model.linear.weight.data = predictor["linear.weight"]
160
+ mos_model.linear.bias.data = predictor["linear.bias"]
161
+
162
+ if use_cer:
163
+ mos_model.cer_embed[0].weight.data = predictor["cer_embed.0.weight"]
164
+ mos_model.cer_embed[0].bias.data = predictor["cer_embed.0.bias"]
165
+ mos_model.cer_embed[2].weight.data = predictor["cer_embed.2.weight"]
166
+ mos_model.cer_embed[2].bias.data = predictor["cer_embed.2.bias"]
167
+
168
+ mos_model.feat_proj[1].weight.data = predictor["feat_proj.1.weight"]
169
+ mos_model.feat_proj[1].bias.data = predictor["feat_proj.1.bias"]
170
+
171
+ if use_langid:
172
+ mos_model.lang_embed.weight.data = predictor["lang_embed.weight"]
173
+
174
+ mos_model.to(device)
175
+ mos_model.eval()
176
+ return mos_model
177
+
178
+ def preprocess_single(audio_path, cer, langid):
179
+ """
180
+ Preprocess the audio file and metadata
181
+ """
182
+ audio, sr = torchaudio.load(audio_path)
183
+ assert sr == 16000, "Audio file should be sampled at 16kHz"
184
+ if cer is not None:
185
+ cer = torch.tensor([cer])
186
+ if langid is not None:
187
+ if langid not in LANG_ID_MAPPING:
188
+ raise ValueError("Language ID not supported, please use one of the following: {}".format(LANG_ID_MAPPING.keys()))
189
+ langid = torch.tensor([LANG_ID_MAPPING[langid]])
190
+ return audio, cer, langid
191
+
192
+ class Collate():
193
+ def __call__(self, batch):
194
+ input_lengths, ids_sorted_decreasing = torch.sort(torch.LongTensor([len(x[0]) for x in batch]),dim=0, descending=True)
195
+ max_input_len = input_lengths[0]
196
+ audio_padded = torch.FloatTensor(len(batch), max_input_len)
197
+ audio_padded.zero_()
198
+ scores, cers, langs, filenames, lengths = [], [], [], [], []
199
+ for i in range(len(batch)):
200
+ audio = batch[i][0]
201
+ audio_padded[i, :audio.size(0)] = audio
202
+ cers.append(batch[i][1])
203
+ filenames.append(batch[i][3])
204
+ lengths.append(audio.size(0))
205
+ langs.append(batch[i][2])
206
+ lengths = torch.LongTensor(lengths)
207
+ if langs[0] is not None:
208
+ langs = torch.stack(langs, dim=0).squeeze()
209
+ return audio_padded, cers, lengths, langs, filenames
210
+
211
+ class PreProcessBatch(torch.utils.data.Dataset):
212
+ def __init__(self, manifest_path, cer, langid):
213
+ with open(manifest_path, "r") as f:
214
+ data = f.read().split("\n")
215
+ delim = "\t"
216
+ if len(data[0].split("\t")) < 2:
217
+ delim = " "
218
+ headers = data[0].strip().split(delim)
219
+ assert headers[:2] == ["id", "audio_path"], "Manifest file should have first 2 column headers as id, audio_path, instead found {}".format(headers[:2])
220
+ self.cer = cer
221
+ self.langid = langid
222
+
223
+ if cer is not None:
224
+ assert "cer" in headers, "Manifest file should have cer column"
225
+ if langid is not None:
226
+ assert "langid" in headers, "Manifest file should have langid column"
227
+ self.metadata_dict = {}
228
+ for line in data[1:]:
229
+ if line.strip() == "":
230
+ continue
231
+ fields = line.strip().split(delim)
232
+ key, audio_path = fields[:2]
233
+ self.metadata_dict[key] = {x:fields[idx+1] for idx, x in enumerate(headers[1:])}
234
+ self.all_keys = list(self.metadata_dict.keys())
235
+
236
+ def __len__(self):
237
+ return len(self.all_keys)
238
+
239
+ def __getitem__(self, idx):
240
+ key = self.all_keys[idx]
241
+ audio_path = self.metadata_dict[key]["audio_path"]
242
+ cer, langid = None, None
243
+ if "cer" in self.metadata_dict[key]:
244
+ cer = torch.tensor([float(self.metadata_dict[key]["cer"])])
245
+ if "langid" in self.metadata_dict[key]:
246
+ langid = torch.tensor([LANG_ID_MAPPING[self.metadata_dict[key]["langid"]]])
247
+
248
+ audio, sr = torchaudio.load(audio_path)
249
+ return audio.squeeze(), cer, langid, key
250
+
251
+ def score(audio_path, cer=None, langid=None, use_cer=False, use_langid=False, download_path="hf_inference_models", device="cpu"):
252
+ """
253
+ Single audio mos prediction
254
+ """
255
+ audio, cer, langid = preprocess_single(audio_path, cer, langid)
256
+ mos_model = load_model(use_cer, use_langid, download_path, device)
257
+ with torch.no_grad():
258
+ score = mos_model(audio, cer_data=cer, lang_data=langid).squeeze().cpu().item()
259
+ return score
260
+
261
+ def batch_score(manifest_path, save_path, batch_size=32, cer=None, langid=None, use_cer=False, use_langid=False, download_path="hf_inference_models", device="cpu"):
262
+ """
263
+ batch audio mos prediction
264
+ """
265
+ dataset = PreProcessBatch(manifest_path, cer, langid)
266
+ loader = torch.utils.data.DataLoader(dataset, batch_size=args.batch_size, shuffle=False, collate_fn=Collate())
267
+ mos_model = load_model(use_cer, use_langid, download_path, device)
268
+ results = {}
269
+ with torch.no_grad():
270
+ for eval_data in tqdm(loader):
271
+ audio, cer, lengths, langid, filenames = eval_data
272
+ audio = audio.to(device)
273
+ scores = mos_model(audio, cer_data=cer, lang_data=langid, lengths=lengths, batch_mode=True).squeeze(-1).cpu().numpy()
274
+ for idx, filename in enumerate(filenames):
275
+ results[filename] = scores[idx].squeeze()
276
+ with open(save_path, "w") as f:
277
+ for key, value in results.items():
278
+ f.write("{}\t{}\n".format(key, value))
279
+ return score
280
+
281
+ if __name__ == "__main__":
282
+ args = parser.parse_args()
283
+
284
+ # if args.audio_path is None and args.manifest_path is None:
285
+ # raise ValueError("Please provide either audio_path - (single file inference) or manifest_path - (batch inference)")
286
+
287
+ if args.manifest_path is None:
288
+ raise ValueError("Please provide manifest_path for batch inference")
289
+
290
+ cer = None
291
+ if cer is not None:
292
+ if cer > 1:
293
+ print("WARNING: Use raw CER value, not percentage")
294
+ langid = None
295
+ # langid = "kn"
296
+ if args.audio_path is not None:
297
+ ###FIX THIS
298
+ score = score(audio_path=args.audio_path, cer=cer, langid=langid, use_cer=args.use_cer, use_langid=args.use_langid)
299
+ print("predicted MOS", score)
300
+ else:
301
+ assert args.save_path is not None, "Please provide a file path for the batch scores to be saved - save_path"
302
+ batch_score(manifest_path=args.manifest_path, save_path=args.save_path, batch_size=args.batch_size, cer=cer, langid=langid, use_cer=args.use_cer, use_langid=args.use_langid, device=args.device)
303
+
304
+
sample_manifest/manifest.txt ADDED
@@ -0,0 +1,4 @@
 
 
 
 
 
1
+ id audio_path langid
2
+ 1 ../sample_audio/kn_audio1.wav
3
+ 2 ../sample_audio/hi_audio2.wav
4
+ 4 ../sample_audio/mr_audio3.wav
sample_manifest/manifest_lang.txt ADDED
@@ -0,0 +1,4 @@
 
 
 
 
 
1
+ id audio_path langid
2
+ 1 ../sample_audio/kn_audio1.wav kn
3
+ 2 ../sample_audio/hi_audio2.wav hi
4
+ 4 ../sample_audio/mr_audio3.wav mr