Finnish-ASR-Canary-v2 / inference_example.py
RASMUS's picture
Upload inference_example.py with huggingface_hub
bf41f30 verified
from nemo.collections.asr.models import EncDecMultiTaskModel
from omegaconf import OmegaConf
import os
import argparse
def main():
parser = argparse.ArgumentParser(description="Finnish ASR Inference Example")
parser.add_argument("--audio", type=str, required=True, help="Path to the audio file (.wav)")
parser.add_argument("--model", type=str, default="models/canary-finnish.nemo", help="Path to the finetuned .nemo model")
parser.add_argument("--kenlm", type=str, default="models/kenlm_5M.nemo", help="Path to the KenLM model")
parser.add_argument("--beam_size", type=int, default=4, help="Beam size for decoding")
parser.add_argument("--pnc", type=str, default="yes", help="Enable Punctuation and Capitalization (yes/no)")
args = parser.parse_args()
# 1. Load Model and KenLM Bundle
if not os.path.exists(args.model):
print(f"Error: Model not found at {args.model}")
return
print(f"Loading model from {args.model}...")
model = EncDecMultiTaskModel.restore_from(args.model)
# Configure KenLM if provided
if args.kenlm and os.path.exists(args.kenlm):
print(f"Configuring decoding strategy with KenLM from {args.kenlm}...")
model.change_decoding_strategy(
decoding_cfg=OmegaConf.create({
'strategy': 'beam',
'beam': {
'beam_size': args.beam_size,
'ngram_lm_model': args.kenlm,
'ngram_lm_alpha': 0.2,
},
'batch_size': 1
})
)
else:
print("Using greedy decoding (no KenLM found or specified).")
# 2. Transcribe with Finnish Prompts
if not os.path.exists(args.audio):
print(f"Error: Audio sample not found at {args.audio}")
return
print(f"Transcribing {args.audio}...")
transcription = model.transcribe(
audio=[args.audio],
taskname="asr",
source_lang="fi",
target_lang="fi",
pnc=args.pnc
)
print("-" * 30)
print(f"Result: {transcription[0]}")
print("-" * 30)
if __name__ == "__main__":
main()