|
|
|
|
|
import sys |
|
|
import torch |
|
|
|
|
|
from transformers import AutoModelForCTC, AutoProcessor |
|
|
from datasets import load_dataset |
|
|
import torchaudio.functional as F |
|
|
|
|
|
device = "cuda" if torch.cuda.is_available() else "cpu" |
|
|
|
|
|
model_id = sys.argv[1] |
|
|
lang = sys.argv[2] |
|
|
|
|
|
ds = load_dataset("common_voice", lang, split="test", streaming=True) |
|
|
|
|
|
sample = next(iter(ds)) |
|
|
|
|
|
resampled_audio = F.resample(torch.tensor(sample["audio"]["array"]), 48_000, 16_000).numpy() |
|
|
|
|
|
model = AutoModelForCTC.from_pretrained(model_id).to(device) |
|
|
processor = AutoProcessor.from_pretrained(model_id) |
|
|
|
|
|
input_values = processor(resampled_audio, return_tensors="pt").input_values |
|
|
|
|
|
with torch.no_grad(): |
|
|
logits = model(input_values.to(device)).logits |
|
|
|
|
|
prediction_ids = torch.argmax(logits, dim=-1) |
|
|
transcription = processor.batch_decode(prediction_ids) |
|
|
|
|
|
print(f"Correct: {sample['sentence']}") |
|
|
print(f"Predict: {transcription}") |
|
|
|