"""WhisperX wrapper for lyrics extraction with word-level timestamps.""" import json from pathlib import Path from typing import Optional import whisperx def extract_lyrics( vocal_path: str | Path, model_name: str = "large-v2", device: str = "cpu", language: str = "en", output_dir: Optional[str | Path] = None, ) -> list[dict]: """Extract timestamped lyrics from an isolated vocal stem. Args: vocal_path: Path to the isolated vocal audio file (data//stems/vocals.wav). model_name: Whisper model size. Default "large-v2" (best for lyrics). device: Device to run on ("cpu", "cuda"). language: Language code for transcription. output_dir: Directory to save lyrics.json. Defaults to data//. Returns: List of word dicts with keys: "word", "start", "end". Example: [{"word": "hello", "start": 0.5, "end": 0.8}, ...] """ vocal_path = str(vocal_path) # Load audio audio = whisperx.load_audio(vocal_path) # Transcribe model = whisperx.load_model(model_name, device, compute_type="int8", language=language) result = model.transcribe(audio, batch_size=4) del model # free Whisper model before loading alignment model # Forced alignment for word-level timestamps model_a, metadata = whisperx.load_align_model(language_code=language, device=device) result = whisperx.align(result["segments"], model_a, metadata, audio, device) del model_a, metadata # free alignment model # Flatten to word list words = [] for segment in result["segments"]: for word in segment.get("words", []): if "start" in word and "end" in word: words.append({ "word": word["word"].strip(), "start": word["start"], "end": word["end"], }) # Save to JSON in the song directory (stems/ parent = data//) if output_dir is None: output_dir = Path(vocal_path).parent.parent output_dir = Path(output_dir) output_path = output_dir / "lyrics.json" with open(output_path, "w") as f: json.dump(words, f, indent=2) import gc gc.collect() return words if __name__ == "__main__": import sys if len(sys.argv) < 2: print("Usage: python -m src.lyrics_extractor ") sys.exit(1) words = extract_lyrics(sys.argv[1]) for w in words: print(f"{w['start']:6.2f} - {w['end']:6.2f}: {w['word']}") output_path = Path(sys.argv[1]).parent.parent / "lyrics.json" print(f"\nSaved to {output_path}")