| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | """ |
| | Transcription script for ASR-1. |
| | |
| | Usage: |
| | uv run src/transcribe.py --model models/asr-1 --audio audio.wav |
| | uv run src/transcribe.py --model models/asr-1 --audio-dir ./recordings/ |
| | """ |
| |
|
| | import sys |
| | from pathlib import Path |
| |
|
| | from dotenv import load_dotenv |
| | load_dotenv() |
| |
|
| | import torch |
| | import torchaudio |
| | import click |
| | from transformers import WhisperProcessor, WhisperForConditionalGeneration |
| |
|
| |
|
| | def load_audio(audio_path: str, target_sr: int = 16000): |
| | """Load and resample audio to target sample rate.""" |
| | waveform, sample_rate = torchaudio.load(audio_path) |
| | if sample_rate != target_sr: |
| | waveform = torchaudio.transforms.Resample(sample_rate, target_sr)(waveform) |
| | |
| | if waveform.shape[0] > 1: |
| | waveform = waveform.mean(dim=0, keepdim=True) |
| | return waveform.squeeze().numpy() |
| |
|
| |
|
| | @click.command() |
| | @click.option('--model', '-m', default='models/asr-1', help='Model path or HuggingFace model ID') |
| | @click.option('--audio', '-a', default=None, help='Path to audio file') |
| | @click.option('--audio-dir', '-d', default=None, help='Directory of audio files') |
| | @click.option('--output', '-o', default=None, help='Output file for transcriptions') |
| | def transcribe(model, audio, audio_dir, output): |
| | """Transcribe Vietnamese audio to text.""" |
| |
|
| | device = "cuda" if torch.cuda.is_available() else "cpu" |
| |
|
| | |
| | click.echo(f"Loading model: {model}") |
| | processor = WhisperProcessor.from_pretrained(model) |
| | asr_model = WhisperForConditionalGeneration.from_pretrained(model).to(device) |
| | asr_model.eval() |
| |
|
| | |
| | audio_files = [] |
| | if audio: |
| | audio_files.append(Path(audio)) |
| | elif audio_dir: |
| | audio_dir = Path(audio_dir) |
| | audio_files = sorted( |
| | p for p in audio_dir.iterdir() |
| | if p.suffix.lower() in ('.wav', '.mp3', '.flac', '.ogg', '.m4a') |
| | ) |
| | else: |
| | click.echo("Error: Provide --audio or --audio-dir") |
| | sys.exit(1) |
| |
|
| | click.echo(f"Transcribing {len(audio_files)} file(s)...\n") |
| |
|
| | results = [] |
| | for audio_path in audio_files: |
| | |
| | waveform = load_audio(str(audio_path)) |
| |
|
| | |
| | input_features = processor.feature_extractor( |
| | waveform, |
| | sampling_rate=16000, |
| | return_tensors="pt", |
| | ).input_features.to(device) |
| |
|
| | |
| | with torch.no_grad(): |
| | predicted_ids = asr_model.generate( |
| | input_features, |
| | language="vi", |
| | task="transcribe", |
| | ) |
| |
|
| | |
| | transcription = processor.tokenizer.batch_decode( |
| | predicted_ids, skip_special_tokens=True |
| | )[0] |
| |
|
| | click.echo(f"{audio_path.name}: {transcription}") |
| | results.append((str(audio_path), transcription)) |
| |
|
| | |
| | if output: |
| | with open(output, "w", encoding="utf-8") as f: |
| | for path, text in results: |
| | f.write(f"{path}\t{text}\n") |
| | click.echo(f"\nTranscriptions saved to: {output}") |
| |
|
| |
|
| | if __name__ == '__main__': |
| | transcribe() |
| |
|