viks66 commited on
Commit
3efb944
·
1 Parent(s): eda8f8c

update inference script

Browse files
infer_indicmos.py CHANGED
@@ -33,6 +33,7 @@ BASE_PREDICTOR = "joint_indicw2v_base.pt"
33
  CER_PREDICTOR = "joint_indicw2v_base_cer.pt"
34
  LANG_ID_PREDICTOR = "joint_indicw2v_base_lang.pt"
35
  CER_LANG_ID_PREDICTOR = "joint_indicw2v_base_cer_lang.pt"
 
36
 
37
  LANG_ID_MAPPING = {
38
  "hi": 0,
@@ -209,7 +210,7 @@ class Collate():
209
  return audio_padded, cers, lengths, langs, filenames
210
 
211
  class PreProcessBatch(torch.utils.data.Dataset):
212
- def __init__(self, manifest_path, cer, langid):
213
  with open(manifest_path, "r") as f:
214
  data = f.read().split("\n")
215
  delim = "\t"
@@ -248,7 +249,7 @@ class PreProcessBatch(torch.utils.data.Dataset):
248
  audio, sr = torchaudio.load(audio_path)
249
  return audio.squeeze(), cer, langid, key
250
 
251
- def score(audio_path, cer=None, langid=None, use_cer=False, use_langid=False, download_path="hf_inference_models", device="cpu"):
252
  """
253
  Single audio mos prediction
254
  """
@@ -258,11 +259,11 @@ def score(audio_path, cer=None, langid=None, use_cer=False, use_langid=False, do
258
  score = mos_model(audio, cer_data=cer, lang_data=langid).squeeze().cpu().item()
259
  return score
260
 
261
- def batch_score(manifest_path, save_path, batch_size=32, cer=None, langid=None, use_cer=False, use_langid=False, download_path="hf_inference_models", device="cpu"):
262
  """
263
  batch audio mos prediction
264
  """
265
- dataset = PreProcessBatch(manifest_path, cer, langid)
266
  loader = torch.utils.data.DataLoader(dataset, batch_size=args.batch_size, shuffle=False, collate_fn=Collate())
267
  mos_model = load_model(use_cer, use_langid, download_path, device)
268
  results = {}
@@ -288,17 +289,17 @@ if __name__ == "__main__":
288
  raise ValueError("Please provide manifest_path for batch inference")
289
 
290
  cer = None
291
- if cer is not None:
292
- if cer > 1:
293
- print("WARNING: Use raw CER value, not percentage")
294
  langid = None
295
  # langid = "kn"
296
- if args.audio_path is not None:
297
  ###FIX THIS
298
- score = score(audio_path=args.audio_path, cer=cer, langid=langid, use_cer=args.use_cer, use_langid=args.use_langid)
299
- print("predicted MOS", score)
300
- else:
301
- assert args.save_path is not None, "Please provide a file path for the batch scores to be saved - save_path"
302
- batch_score(manifest_path=args.manifest_path, save_path=args.save_path, batch_size=args.batch_size, cer=cer, langid=langid, use_cer=args.use_cer, use_langid=args.use_langid, device=args.device)
303
 
304
 
 
33
  CER_PREDICTOR = "joint_indicw2v_base_cer.pt"
34
  LANG_ID_PREDICTOR = "joint_indicw2v_base_lang.pt"
35
  CER_LANG_ID_PREDICTOR = "joint_indicw2v_base_cer_lang.pt"
36
+ HF_PATH = "hf_inference_models"
37
 
38
  LANG_ID_MAPPING = {
39
  "hi": 0,
 
210
  return audio_padded, cers, lengths, langs, filenames
211
 
212
  class PreProcessBatch(torch.utils.data.Dataset):
213
+ def __init__(self, manifest_path, use_cer, use_langid):
214
  with open(manifest_path, "r") as f:
215
  data = f.read().split("\n")
216
  delim = "\t"
 
249
  audio, sr = torchaudio.load(audio_path)
250
  return audio.squeeze(), cer, langid, key
251
 
252
+ def score(audio_path, cer=None, langid=None, use_cer=False, use_langid=False, download_path=HF_PATH, device="cpu"):
253
  """
254
  Single audio mos prediction
255
  """
 
259
  score = mos_model(audio, cer_data=cer, lang_data=langid).squeeze().cpu().item()
260
  return score
261
 
262
+ def batch_score(manifest_path, save_path, batch_size=32, use_cer=False, use_langid=False, download_path="hf_inference_models", device="cpu"):
263
  """
264
  batch audio mos prediction
265
  """
266
+ dataset = PreProcessBatch(manifest_path, use_cer, use_langid)
267
  loader = torch.utils.data.DataLoader(dataset, batch_size=args.batch_size, shuffle=False, collate_fn=Collate())
268
  mos_model = load_model(use_cer, use_langid, download_path, device)
269
  results = {}
 
289
  raise ValueError("Please provide manifest_path for batch inference")
290
 
291
  cer = None
292
+ # if cer is not None:
293
+ # if cer > 1:
294
+ # print("WARNING: Use raw CER value, not percentage")
295
  langid = None
296
  # langid = "kn"
297
+ # if args.audio_path is not None:
298
  ###FIX THIS
299
+ # score = score(audio_path=args.audio_path, cer=cer, langid=langid, use_cer=args.use_cer, use_langid=args.use_langid)
300
+ # print("predicted MOS", score)
301
+ # else:
302
+ assert args.save_path is not None, "Please provide a file path for the batch scores to be saved - save_path"
303
+ batch_score(manifest_path=args.manifest_path, save_path=args.save_path, batch_size=args.batch_size, use_cer=args.use_cer, use_langid=args.use_langid, device=args.device)
304
 
305
 
sample_manifest/manifest.txt CHANGED
@@ -1,4 +1,4 @@
1
- id audio_path langid
2
  1 ../sample_audio/kn_audio1.wav
3
  2 ../sample_audio/hi_audio2.wav
4
- 4 ../sample_audio/mr_audio3.wav
 
1
+ id audio_path
2
  1 ../sample_audio/kn_audio1.wav
3
  2 ../sample_audio/hi_audio2.wav
4
+ 3 ../sample_audio/mr_audio3.wav
sample_manifest/manifest_lang.txt CHANGED
@@ -1,4 +1,4 @@
1
  id audio_path langid
2
  1 ../sample_audio/kn_audio1.wav kn
3
  2 ../sample_audio/hi_audio2.wav hi
4
- 4 ../sample_audio/mr_audio3.wav mr
 
1
  id audio_path langid
2
  1 ../sample_audio/kn_audio1.wav kn
3
  2 ../sample_audio/hi_audio2.wav hi
4
+ 3 ../sample_audio/mr_audio3.wav mr