hubert-large-ls960-ft / speech_recognition_example.py
mzbac's picture
Upload folder using huggingface_hub
cb7382b verified
#!/usr/bin/env python3
"""
Example: Using converted HuBERT CTC model for speech recognition
"""
import torch
from transformers import Wav2Vec2Processor, HubertForCTC
from datasets import load_dataset
# Load model and processor
print("Loading model and processor...")
processor = Wav2Vec2Processor.from_pretrained("./converted_ctc_models")
model = HubertForCTC.from_pretrained("./converted_ctc_models")
# For GPU if available
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model.to(device)
model.eval()
# Load sample audio
print("\nLoading sample audio...")
ds = load_dataset("patrickvonplaten/librispeech_asr_dummy", "clean", split="validation", trust_remote_code=True)
sample = ds[0]
# Process audio
audio_input = sample["audio"]["array"]
sampling_rate = sample["audio"]["sampling_rate"]
# Ensure 16kHz
if sampling_rate != 16000:
import librosa
audio_input = librosa.resample(audio_input, orig_sr=sampling_rate, target_sr=16000)
sampling_rate = 16000
# Prepare input
inputs = processor(
audio_input,
return_tensors="pt",
sampling_rate=sampling_rate
).to(device)
# Run inference
print("\nRunning inference...")
with torch.no_grad():
logits = model(**inputs).logits
# Decode
predicted_ids = torch.argmax(logits, dim=-1)
transcription = processor.decode(predicted_ids[0])
print(f"\nTranscription: '{transcription}'")
print(f"Expected: '{sample['text']}'")
# Batch processing example
print("\n\nBatch processing example:")
print("-"*40)
# Process multiple samples
batch_size = 4
audio_samples = [ds[i]["audio"]["array"] for i in range(batch_size)]
# Pad to same length
inputs = processor(
audio_samples,
return_tensors="pt",
padding=True,
sampling_rate=16000
).to(device)
# Batch inference
with torch.no_grad():
logits = model(**inputs).logits
# Decode all
predicted_ids = torch.argmax(logits, dim=-1)
transcriptions = processor.batch_decode(predicted_ids)
for i, transcription in enumerate(transcriptions):
print(f"Sample {i+1}: '{transcription}'")