|
|
|
|
|
""" |
|
|
Example: Using converted HuBERT CTC model for speech recognition |
|
|
""" |
|
|
|
|
|
import torch |
|
|
from transformers import Wav2Vec2Processor, HubertForCTC |
|
|
from datasets import load_dataset |
|
|
|
|
|
|
|
|
print("Loading model and processor...") |
|
|
processor = Wav2Vec2Processor.from_pretrained("./converted_ctc_models") |
|
|
model = HubertForCTC.from_pretrained("./converted_ctc_models") |
|
|
|
|
|
|
|
|
device = torch.device("cuda" if torch.cuda.is_available() else "cpu") |
|
|
model.to(device) |
|
|
model.eval() |
|
|
|
|
|
|
|
|
print("\nLoading sample audio...") |
|
|
ds = load_dataset("patrickvonplaten/librispeech_asr_dummy", "clean", split="validation", trust_remote_code=True) |
|
|
sample = ds[0] |
|
|
|
|
|
|
|
|
audio_input = sample["audio"]["array"] |
|
|
sampling_rate = sample["audio"]["sampling_rate"] |
|
|
|
|
|
|
|
|
if sampling_rate != 16000: |
|
|
import librosa |
|
|
audio_input = librosa.resample(audio_input, orig_sr=sampling_rate, target_sr=16000) |
|
|
sampling_rate = 16000 |
|
|
|
|
|
|
|
|
inputs = processor( |
|
|
audio_input, |
|
|
return_tensors="pt", |
|
|
sampling_rate=sampling_rate |
|
|
).to(device) |
|
|
|
|
|
|
|
|
print("\nRunning inference...") |
|
|
with torch.no_grad(): |
|
|
logits = model(**inputs).logits |
|
|
|
|
|
|
|
|
predicted_ids = torch.argmax(logits, dim=-1) |
|
|
transcription = processor.decode(predicted_ids[0]) |
|
|
|
|
|
print(f"\nTranscription: '{transcription}'") |
|
|
print(f"Expected: '{sample['text']}'") |
|
|
|
|
|
|
|
|
print("\n\nBatch processing example:") |
|
|
print("-"*40) |
|
|
|
|
|
|
|
|
batch_size = 4 |
|
|
audio_samples = [ds[i]["audio"]["array"] for i in range(batch_size)] |
|
|
|
|
|
|
|
|
inputs = processor( |
|
|
audio_samples, |
|
|
return_tensors="pt", |
|
|
padding=True, |
|
|
sampling_rate=16000 |
|
|
).to(device) |
|
|
|
|
|
|
|
|
with torch.no_grad(): |
|
|
logits = model(**inputs).logits |
|
|
|
|
|
|
|
|
predicted_ids = torch.argmax(logits, dim=-1) |
|
|
transcriptions = processor.batch_decode(predicted_ids) |
|
|
|
|
|
for i, transcription in enumerate(transcriptions): |
|
|
print(f"Sample {i+1}: '{transcription}'") |
|
|
|