run_ctc_common_voice.py / run_ctc_model.py
patrickvonplaten's picture
up
1763424
raw
history blame
905 Bytes
#!/usr/bin/env python3
import sys
import torch
from transformers import AutoModelForCTC, AutoProcessor
from datasets import load_dataset
import torchaudio.functional as F
device = "cuda" if torch.cuda.is_available() else "cpu"
model_id = sys.argv[1]
lang = sys.argv[2]
ds = load_dataset("common_voice", lang, split="test", streaming=True)
sample = next(iter(ds))
resampled_audio = F.resample(torch.tensor(sample["audio"]["array"]), 48_000, 16_000).numpy()
model = AutoModelForCTC.from_pretrained(model_id).to(device)
processor = AutoProcessor.from_pretrained(model_id)
input_values = processor(resampled_audio, return_tensors="pt").input_values
with torch.no_grad():
logits = model(input_values.to(device)).logits
prediction_ids = torch.argmax(logits, dim=-1)
transcription = processor.batch_decode(prediction_ids)
print(f"Correct: {sample['sentence']}")
print(f"Predict: {transcription}")