""" AutoLyrics — Gradio Demo Fine-tuned Whisper-small + LoRA for lyrics transcription. Usage: pip install gradio transformers peft torch torchaudio librosa pyloudnorm jiwer python app.py """ import re import torch import torchaudio import torchaudio.transforms as T import librosa import pyloudnorm as pyln import numpy as np import gradio as gr from transformers import WhisperForConditionalGeneration, WhisperProcessor from peft import PeftModel # ────────────────────────────────────────────────────────────────── # CONFIGURATION — adjust paths if needed # ────────────────────────────────────────────────────────────────── MODEL_NAME = "openai/whisper-small" LORA_DIR = "./checkpoints/lora_best" # path where you saved the LoRA adapter DEVICE = "cuda" if torch.cuda.is_available() else "cpu" MODEL_DTYPE = torch.float16 if DEVICE == "cuda" else torch.float32 TARGET_SR = 16000 MAX_DURATION = 30.0 LANGUAGE = "en" TASK = "transcribe" BEAM_SIZE = 3 MAX_NEW_TOKENS = 200 # ────────────────────────────────────────────────────────────────── # AUDIO PREPROCESSING (mirrors your notebook pipeline) # ────────────────────────────────────────────────────────────────── LUFS_TARGET = -23.0 LUFS_HEADROOM = 1.0 SILENCE_TOP_DB = 30 def _remove_dc_offset(waveform: torch.Tensor) -> torch.Tensor: return waveform - waveform.mean() def _trim_silence(waveform: torch.Tensor) -> torch.Tensor: arr = waveform.numpy() trimmed, _ = librosa.effects.trim(arr, top_db=SILENCE_TOP_DB) return torch.from_numpy(trimmed) def _loudness_normalize(waveform: torch.Tensor, sr: int) -> torch.Tensor: arr = waveform.numpy().astype("float64") meter = pyln.Meter(sr) loudness = meter.integrated_loudness(arr) if not (loudness > -70.0): peak = arr.max() if arr.max() != 0 else 1.0 arr = arr / peak else: arr = pyln.normalize.loudness(arr, loudness, LUFS_TARGET) limit = 10 ** (-LUFS_HEADROOM / 20.0) arr = arr.clip(-limit, limit) return torch.from_numpy(arr.astype("float32")) def preprocess_audio(waveform: torch.Tensor, sr: int) -> torch.Tensor: """Full preprocessing chain: resample → mono → DC → trim → loudness.""" # Convert to mono if waveform.dim() == 1: waveform = waveform.unsqueeze(0) if waveform.shape[0] > 1: waveform = waveform.mean(dim=0, keepdim=True) # Resample if sr != TARGET_SR: waveform = T.Resample(sr, TARGET_SR)(waveform) waveform = waveform.squeeze(0) # Preprocessing chain waveform = _remove_dc_offset(waveform) waveform = _trim_silence(waveform) if len(waveform) == 0: return waveform waveform = _loudness_normalize(waveform, TARGET_SR) return waveform # ────────────────────────────────────────────────────────────────── # MODEL LOADING # ────────────────────────────────────────────────────────────────── print(f"Loading model on {DEVICE}…") processor = WhisperProcessor.from_pretrained(LORA_DIR, language=LANGUAGE, task=TASK) base_model = WhisperForConditionalGeneration.from_pretrained( MODEL_NAME, torch_dtype=MODEL_DTYPE ).to(DEVICE) base_model.config.forced_decoder_ids = None base_model.generation_config.forced_decoder_ids = None base_model.generation_config.suppress_tokens = [] model = PeftModel.from_pretrained(base_model, LORA_DIR).to(DEVICE) inner_model = model.base_model.model inner_model.eval() print("Model loaded ✓") # ────────────────────────────────────────────────────────────────── # TRANSCRIPTION FUNCTION # ────────────────────────────────────────────────────────────────── def transcribe(audio_path: str, beam_size: int, max_new_tokens: int) -> str: """Load audio, preprocess, run Whisper LoRA, return transcript.""" if audio_path is None: return "⚠️ Please upload or record an audio file." try: waveform, sr = torchaudio.load(audio_path) except Exception as e: return f"❌ Error loading audio: {e}" waveform = preprocess_audio(waveform, sr) if len(waveform) == 0: return "⚠️ Audio appears to be silent after preprocessing." duration = len(waveform) / TARGET_SR if duration > MAX_DURATION: waveform = waveform[: int(MAX_DURATION * TARGET_SR)] duration = MAX_DURATION # Pad to 30 s for the feature extractor max_samples = int(MAX_DURATION * TARGET_SR) if len(waveform) < max_samples: waveform = torch.nn.functional.pad(waveform, (0, max_samples - len(waveform))) features = processor.feature_extractor( waveform.numpy(), sampling_rate=TARGET_SR, return_tensors="pt" ).input_features.to(DEVICE, dtype=MODEL_DTYPE) with torch.no_grad(): generated_ids = inner_model.generate( input_features=features, num_beams=int(beam_size), max_new_tokens=int(max_new_tokens), language=LANGUAGE, task=TASK, suppress_tokens=[], forced_decoder_ids=None, ) transcript = processor.tokenizer.batch_decode(generated_ids, skip_special_tokens=True)[0].strip() info = f"_(duration: {duration:.1f}s, device: {DEVICE})_" return f"{transcript}\n\n{info}" # ────────────────────────────────────────────────────────────────── # GRADIO INTERFACE # ────────────────────────────────────────────────────────────────── with gr.Blocks(title="AutoLyrics — Whisper LoRA", theme=gr.themes.Soft()) as demo: gr.Markdown( """ # 🎵 AutoLyrics — Whisper-small + LoRA **Fine-tuned on `gmenon/slt-lyrics-audio` for music lyrics transcription.** Upload a song clip (≤ 30 s) or record directly from your microphone, then hit **Transcribe**. """ ) with gr.Row(): with gr.Column(scale=1): audio_input = gr.Audio( label="🎤 Audio Input", sources=["upload", "microphone"], type="filepath", ) with gr.Accordion("⚙️ Advanced settings", open=False): beam_slider = gr.Slider( minimum=1, maximum=5, value=BEAM_SIZE, step=1, label="Beam size (1 = greedy, higher = better but slower)" ) tokens_slider = gr.Slider( minimum=50, maximum=448, value=MAX_NEW_TOKENS, step=10, label="Max new tokens" ) transcribe_btn = gr.Button("🎶 Transcribe", variant="primary") with gr.Column(scale=1): output_text = gr.Textbox( label="📝 Transcription", lines=8, placeholder="Lyrics will appear here…", show_copy_button=True, ) transcribe_btn.click( fn=transcribe, inputs=[audio_input, beam_slider, tokens_slider], outputs=output_text, ) gr.Examples( examples=[], # add example audio paths here if you have them inputs=audio_input, ) gr.Markdown( """ --- **Model:** `openai/whisper-small` + LoRA (r=8, α=16)  |  **Dataset:** `gmenon/slt-lyrics-audio`  |  **Preprocessing:** EBU R128 loudness normalisation, silence trimming """ ) if __name__ == "__main__": demo.launch(share=True) # share=True gives a public URL; remove for local only