Automatic Speech Recognition
Transformers
Vietnamese
vietnamese
whisper
speech-to-text
asr-1 / src /transcribe.py
rain1024's picture
Initial commit: ASR-1 Vietnamese speech recognition model
5763d9e
# /// script
# requires-python = ">=3.10"
# dependencies = [
# "torch>=2.0.0",
# "torchaudio>=2.0.0",
# "transformers>=4.36.0",
# "click>=8.0.0",
# ]
# ///
"""
Transcription script for ASR-1.
Usage:
uv run src/transcribe.py --model models/asr-1 --audio audio.wav
uv run src/transcribe.py --model models/asr-1 --audio-dir ./recordings/
"""
import sys
from pathlib import Path
from dotenv import load_dotenv
load_dotenv()
import torch
import torchaudio
import click
from transformers import WhisperProcessor, WhisperForConditionalGeneration
def load_audio(audio_path: str, target_sr: int = 16000):
"""Load and resample audio to target sample rate."""
waveform, sample_rate = torchaudio.load(audio_path)
if sample_rate != target_sr:
waveform = torchaudio.transforms.Resample(sample_rate, target_sr)(waveform)
# Convert to mono if stereo
if waveform.shape[0] > 1:
waveform = waveform.mean(dim=0, keepdim=True)
return waveform.squeeze().numpy()
@click.command()
@click.option('--model', '-m', default='models/asr-1', help='Model path or HuggingFace model ID')
@click.option('--audio', '-a', default=None, help='Path to audio file')
@click.option('--audio-dir', '-d', default=None, help='Directory of audio files')
@click.option('--output', '-o', default=None, help='Output file for transcriptions')
def transcribe(model, audio, audio_dir, output):
"""Transcribe Vietnamese audio to text."""
device = "cuda" if torch.cuda.is_available() else "cpu"
# Load model
click.echo(f"Loading model: {model}")
processor = WhisperProcessor.from_pretrained(model)
asr_model = WhisperForConditionalGeneration.from_pretrained(model).to(device)
asr_model.eval()
# Collect audio files
audio_files = []
if audio:
audio_files.append(Path(audio))
elif audio_dir:
audio_dir = Path(audio_dir)
audio_files = sorted(
p for p in audio_dir.iterdir()
if p.suffix.lower() in ('.wav', '.mp3', '.flac', '.ogg', '.m4a')
)
else:
click.echo("Error: Provide --audio or --audio-dir")
sys.exit(1)
click.echo(f"Transcribing {len(audio_files)} file(s)...\n")
results = []
for audio_path in audio_files:
# Load audio
waveform = load_audio(str(audio_path))
# Extract features
input_features = processor.feature_extractor(
waveform,
sampling_rate=16000,
return_tensors="pt",
).input_features.to(device)
# Generate
with torch.no_grad():
predicted_ids = asr_model.generate(
input_features,
language="vi",
task="transcribe",
)
# Decode
transcription = processor.tokenizer.batch_decode(
predicted_ids, skip_special_tokens=True
)[0]
click.echo(f"{audio_path.name}: {transcription}")
results.append((str(audio_path), transcription))
# Save output
if output:
with open(output, "w", encoding="utf-8") as f:
for path, text in results:
f.write(f"{path}\t{text}\n")
click.echo(f"\nTranscriptions saved to: {output}")
if __name__ == '__main__':
transcribe()