# /// script # requires-python = ">=3.10" # dependencies = [ # "torch>=2.0.0", # "torchaudio>=2.0.0", # "transformers>=4.36.0", # "click>=8.0.0", # ] # /// """ Transcription script for ASR-1. Usage: uv run src/transcribe.py --model models/asr-1 --audio audio.wav uv run src/transcribe.py --model models/asr-1 --audio-dir ./recordings/ """ import sys from pathlib import Path from dotenv import load_dotenv load_dotenv() import torch import torchaudio import click from transformers import WhisperProcessor, WhisperForConditionalGeneration def load_audio(audio_path: str, target_sr: int = 16000): """Load and resample audio to target sample rate.""" waveform, sample_rate = torchaudio.load(audio_path) if sample_rate != target_sr: waveform = torchaudio.transforms.Resample(sample_rate, target_sr)(waveform) # Convert to mono if stereo if waveform.shape[0] > 1: waveform = waveform.mean(dim=0, keepdim=True) return waveform.squeeze().numpy() @click.command() @click.option('--model', '-m', default='models/asr-1', help='Model path or HuggingFace model ID') @click.option('--audio', '-a', default=None, help='Path to audio file') @click.option('--audio-dir', '-d', default=None, help='Directory of audio files') @click.option('--output', '-o', default=None, help='Output file for transcriptions') def transcribe(model, audio, audio_dir, output): """Transcribe Vietnamese audio to text.""" device = "cuda" if torch.cuda.is_available() else "cpu" # Load model click.echo(f"Loading model: {model}") processor = WhisperProcessor.from_pretrained(model) asr_model = WhisperForConditionalGeneration.from_pretrained(model).to(device) asr_model.eval() # Collect audio files audio_files = [] if audio: audio_files.append(Path(audio)) elif audio_dir: audio_dir = Path(audio_dir) audio_files = sorted( p for p in audio_dir.iterdir() if p.suffix.lower() in ('.wav', '.mp3', '.flac', '.ogg', '.m4a') ) else: click.echo("Error: Provide --audio or --audio-dir") sys.exit(1) click.echo(f"Transcribing {len(audio_files)} file(s)...\n") results = [] for audio_path in audio_files: # Load audio waveform = load_audio(str(audio_path)) # Extract features input_features = processor.feature_extractor( waveform, sampling_rate=16000, return_tensors="pt", ).input_features.to(device) # Generate with torch.no_grad(): predicted_ids = asr_model.generate( input_features, language="vi", task="transcribe", ) # Decode transcription = processor.tokenizer.batch_decode( predicted_ids, skip_special_tokens=True )[0] click.echo(f"{audio_path.name}: {transcription}") results.append((str(audio_path), transcription)) # Save output if output: with open(output, "w", encoding="utf-8") as f: for path, text in results: f.write(f"{path}\t{text}\n") click.echo(f"\nTranscriptions saved to: {output}") if __name__ == '__main__': transcribe()