update inference script
Browse files- infer_indicmos.py +14 -13
- sample_manifest/manifest.txt +2 -2
- sample_manifest/manifest_lang.txt +1 -1
infer_indicmos.py
CHANGED
|
@@ -33,6 +33,7 @@ BASE_PREDICTOR = "joint_indicw2v_base.pt"
|
|
| 33 |
CER_PREDICTOR = "joint_indicw2v_base_cer.pt"
|
| 34 |
LANG_ID_PREDICTOR = "joint_indicw2v_base_lang.pt"
|
| 35 |
CER_LANG_ID_PREDICTOR = "joint_indicw2v_base_cer_lang.pt"
|
|
|
|
| 36 |
|
| 37 |
LANG_ID_MAPPING = {
|
| 38 |
"hi": 0,
|
|
@@ -209,7 +210,7 @@ class Collate():
|
|
| 209 |
return audio_padded, cers, lengths, langs, filenames
|
| 210 |
|
| 211 |
class PreProcessBatch(torch.utils.data.Dataset):
|
| 212 |
-
def __init__(self, manifest_path,
|
| 213 |
with open(manifest_path, "r") as f:
|
| 214 |
data = f.read().split("\n")
|
| 215 |
delim = "\t"
|
|
@@ -248,7 +249,7 @@ class PreProcessBatch(torch.utils.data.Dataset):
|
|
| 248 |
audio, sr = torchaudio.load(audio_path)
|
| 249 |
return audio.squeeze(), cer, langid, key
|
| 250 |
|
| 251 |
-
def score(audio_path, cer=None, langid=None, use_cer=False, use_langid=False, download_path=
|
| 252 |
"""
|
| 253 |
Single audio mos prediction
|
| 254 |
"""
|
|
@@ -258,11 +259,11 @@ def score(audio_path, cer=None, langid=None, use_cer=False, use_langid=False, do
|
|
| 258 |
score = mos_model(audio, cer_data=cer, lang_data=langid).squeeze().cpu().item()
|
| 259 |
return score
|
| 260 |
|
| 261 |
-
def batch_score(manifest_path, save_path, batch_size=32,
|
| 262 |
"""
|
| 263 |
batch audio mos prediction
|
| 264 |
"""
|
| 265 |
-
dataset = PreProcessBatch(manifest_path,
|
| 266 |
loader = torch.utils.data.DataLoader(dataset, batch_size=args.batch_size, shuffle=False, collate_fn=Collate())
|
| 267 |
mos_model = load_model(use_cer, use_langid, download_path, device)
|
| 268 |
results = {}
|
|
@@ -288,17 +289,17 @@ if __name__ == "__main__":
|
|
| 288 |
raise ValueError("Please provide manifest_path for batch inference")
|
| 289 |
|
| 290 |
cer = None
|
| 291 |
-
if cer is not None:
|
| 292 |
-
if cer > 1:
|
| 293 |
-
print("WARNING: Use raw CER value, not percentage")
|
| 294 |
langid = None
|
| 295 |
# langid = "kn"
|
| 296 |
-
if args.audio_path is not None:
|
| 297 |
###FIX THIS
|
| 298 |
-
score = score(audio_path=args.audio_path, cer=cer, langid=langid, use_cer=args.use_cer, use_langid=args.use_langid)
|
| 299 |
-
print("predicted MOS", score)
|
| 300 |
-
else:
|
| 301 |
-
|
| 302 |
-
|
| 303 |
|
| 304 |
|
|
|
|
| 33 |
CER_PREDICTOR = "joint_indicw2v_base_cer.pt"
|
| 34 |
LANG_ID_PREDICTOR = "joint_indicw2v_base_lang.pt"
|
| 35 |
CER_LANG_ID_PREDICTOR = "joint_indicw2v_base_cer_lang.pt"
|
| 36 |
+
HF_PATH = "hf_inference_models"
|
| 37 |
|
| 38 |
LANG_ID_MAPPING = {
|
| 39 |
"hi": 0,
|
|
|
|
| 210 |
return audio_padded, cers, lengths, langs, filenames
|
| 211 |
|
| 212 |
class PreProcessBatch(torch.utils.data.Dataset):
|
| 213 |
+
def __init__(self, manifest_path, use_cer, use_langid):
|
| 214 |
with open(manifest_path, "r") as f:
|
| 215 |
data = f.read().split("\n")
|
| 216 |
delim = "\t"
|
|
|
|
| 249 |
audio, sr = torchaudio.load(audio_path)
|
| 250 |
return audio.squeeze(), cer, langid, key
|
| 251 |
|
| 252 |
+
def score(audio_path, cer=None, langid=None, use_cer=False, use_langid=False, download_path=HF_PATH, device="cpu"):
|
| 253 |
"""
|
| 254 |
Single audio mos prediction
|
| 255 |
"""
|
|
|
|
| 259 |
score = mos_model(audio, cer_data=cer, lang_data=langid).squeeze().cpu().item()
|
| 260 |
return score
|
| 261 |
|
| 262 |
+
def batch_score(manifest_path, save_path, batch_size=32, use_cer=False, use_langid=False, download_path="hf_inference_models", device="cpu"):
|
| 263 |
"""
|
| 264 |
batch audio mos prediction
|
| 265 |
"""
|
| 266 |
+
dataset = PreProcessBatch(manifest_path, use_cer, use_langid)
|
| 267 |
loader = torch.utils.data.DataLoader(dataset, batch_size=args.batch_size, shuffle=False, collate_fn=Collate())
|
| 268 |
mos_model = load_model(use_cer, use_langid, download_path, device)
|
| 269 |
results = {}
|
|
|
|
| 289 |
raise ValueError("Please provide manifest_path for batch inference")
|
| 290 |
|
| 291 |
cer = None
|
| 292 |
+
# if cer is not None:
|
| 293 |
+
# if cer > 1:
|
| 294 |
+
# print("WARNING: Use raw CER value, not percentage")
|
| 295 |
langid = None
|
| 296 |
# langid = "kn"
|
| 297 |
+
# if args.audio_path is not None:
|
| 298 |
###FIX THIS
|
| 299 |
+
# score = score(audio_path=args.audio_path, cer=cer, langid=langid, use_cer=args.use_cer, use_langid=args.use_langid)
|
| 300 |
+
# print("predicted MOS", score)
|
| 301 |
+
# else:
|
| 302 |
+
assert args.save_path is not None, "Please provide a file path for the batch scores to be saved - save_path"
|
| 303 |
+
batch_score(manifest_path=args.manifest_path, save_path=args.save_path, batch_size=args.batch_size, use_cer=args.use_cer, use_langid=args.use_langid, device=args.device)
|
| 304 |
|
| 305 |
|
sample_manifest/manifest.txt
CHANGED
|
@@ -1,4 +1,4 @@
|
|
| 1 |
-
id audio_path
|
| 2 |
1 ../sample_audio/kn_audio1.wav
|
| 3 |
2 ../sample_audio/hi_audio2.wav
|
| 4 |
-
|
|
|
|
| 1 |
+
id audio_path
|
| 2 |
1 ../sample_audio/kn_audio1.wav
|
| 3 |
2 ../sample_audio/hi_audio2.wav
|
| 4 |
+
3 ../sample_audio/mr_audio3.wav
|
sample_manifest/manifest_lang.txt
CHANGED
|
@@ -1,4 +1,4 @@
|
|
| 1 |
id audio_path langid
|
| 2 |
1 ../sample_audio/kn_audio1.wav kn
|
| 3 |
2 ../sample_audio/hi_audio2.wav hi
|
| 4 |
-
|
|
|
|
| 1 |
id audio_path langid
|
| 2 |
1 ../sample_audio/kn_audio1.wav kn
|
| 3 |
2 ../sample_audio/hi_audio2.wav hi
|
| 4 |
+
3 ../sample_audio/mr_audio3.wav mr
|