Autolyrics / app.py
vansh71600's picture
Upload app.py
f974ed0 verified
"""
AutoLyrics β€” Gradio Demo
Fine-tuned Whisper-small + LoRA for lyrics transcription.
Usage:
pip install gradio transformers peft torch torchaudio librosa pyloudnorm jiwer
python app.py
"""
import re
import torch
import torchaudio
import torchaudio.transforms as T
import librosa
import pyloudnorm as pyln
import numpy as np
import gradio as gr
from transformers import WhisperForConditionalGeneration, WhisperProcessor
from peft import PeftModel
# ──────────────────────────────────────────────────────────────────
# CONFIGURATION β€” adjust paths if needed
# ──────────────────────────────────────────────────────────────────
MODEL_NAME = "openai/whisper-small"
LORA_DIR = "./checkpoints/lora_best" # path where you saved the LoRA adapter
DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
MODEL_DTYPE = torch.float16 if DEVICE == "cuda" else torch.float32
TARGET_SR = 16000
MAX_DURATION = 30.0
LANGUAGE = "en"
TASK = "transcribe"
BEAM_SIZE = 3
MAX_NEW_TOKENS = 200
# ──────────────────────────────────────────────────────────────────
# AUDIO PREPROCESSING (mirrors your notebook pipeline)
# ──────────────────────────────────────────────────────────────────
LUFS_TARGET = -23.0
LUFS_HEADROOM = 1.0
SILENCE_TOP_DB = 30
def _remove_dc_offset(waveform: torch.Tensor) -> torch.Tensor:
return waveform - waveform.mean()
def _trim_silence(waveform: torch.Tensor) -> torch.Tensor:
arr = waveform.numpy()
trimmed, _ = librosa.effects.trim(arr, top_db=SILENCE_TOP_DB)
return torch.from_numpy(trimmed)
def _loudness_normalize(waveform: torch.Tensor, sr: int) -> torch.Tensor:
arr = waveform.numpy().astype("float64")
meter = pyln.Meter(sr)
loudness = meter.integrated_loudness(arr)
if not (loudness > -70.0):
peak = arr.max() if arr.max() != 0 else 1.0
arr = arr / peak
else:
arr = pyln.normalize.loudness(arr, loudness, LUFS_TARGET)
limit = 10 ** (-LUFS_HEADROOM / 20.0)
arr = arr.clip(-limit, limit)
return torch.from_numpy(arr.astype("float32"))
def preprocess_audio(waveform: torch.Tensor, sr: int) -> torch.Tensor:
"""Full preprocessing chain: resample β†’ mono β†’ DC β†’ trim β†’ loudness."""
# Convert to mono
if waveform.dim() == 1:
waveform = waveform.unsqueeze(0)
if waveform.shape[0] > 1:
waveform = waveform.mean(dim=0, keepdim=True)
# Resample
if sr != TARGET_SR:
waveform = T.Resample(sr, TARGET_SR)(waveform)
waveform = waveform.squeeze(0)
# Preprocessing chain
waveform = _remove_dc_offset(waveform)
waveform = _trim_silence(waveform)
if len(waveform) == 0:
return waveform
waveform = _loudness_normalize(waveform, TARGET_SR)
return waveform
# ──────────────────────────────────────────────────────────────────
# MODEL LOADING
# ──────────────────────────────────────────────────────────────────
print(f"Loading model on {DEVICE}…")
processor = WhisperProcessor.from_pretrained(LORA_DIR, language=LANGUAGE, task=TASK)
base_model = WhisperForConditionalGeneration.from_pretrained(
MODEL_NAME, torch_dtype=MODEL_DTYPE
).to(DEVICE)
base_model.config.forced_decoder_ids = None
base_model.generation_config.forced_decoder_ids = None
base_model.generation_config.suppress_tokens = []
model = PeftModel.from_pretrained(base_model, LORA_DIR).to(DEVICE)
inner_model = model.base_model.model
inner_model.eval()
print("Model loaded βœ“")
# ──────────────────────────────────────────────────────────────────
# TRANSCRIPTION FUNCTION
# ──────────────────────────────────────────────────────────────────
def transcribe(audio_path: str, beam_size: int, max_new_tokens: int) -> str:
"""Load audio, preprocess, run Whisper LoRA, return transcript."""
if audio_path is None:
return "⚠️ Please upload or record an audio file."
try:
waveform, sr = torchaudio.load(audio_path)
except Exception as e:
return f"❌ Error loading audio: {e}"
waveform = preprocess_audio(waveform, sr)
if len(waveform) == 0:
return "⚠️ Audio appears to be silent after preprocessing."
duration = len(waveform) / TARGET_SR
if duration > MAX_DURATION:
waveform = waveform[: int(MAX_DURATION * TARGET_SR)]
duration = MAX_DURATION
# Pad to 30 s for the feature extractor
max_samples = int(MAX_DURATION * TARGET_SR)
if len(waveform) < max_samples:
waveform = torch.nn.functional.pad(waveform, (0, max_samples - len(waveform)))
features = processor.feature_extractor(
waveform.numpy(), sampling_rate=TARGET_SR, return_tensors="pt"
).input_features.to(DEVICE, dtype=MODEL_DTYPE)
with torch.no_grad():
generated_ids = inner_model.generate(
input_features=features,
num_beams=int(beam_size),
max_new_tokens=int(max_new_tokens),
language=LANGUAGE,
task=TASK,
suppress_tokens=[],
forced_decoder_ids=None,
)
transcript = processor.tokenizer.batch_decode(generated_ids, skip_special_tokens=True)[0].strip()
info = f"_(duration: {duration:.1f}s, device: {DEVICE})_"
return f"{transcript}\n\n{info}"
# ──────────────────────────────────────────────────────────────────
# GRADIO INTERFACE
# ──────────────────────────────────────────────────────────────────
with gr.Blocks(title="AutoLyrics β€” Whisper LoRA", theme=gr.themes.Soft()) as demo:
gr.Markdown(
"""
# 🎡 AutoLyrics β€” Whisper-small + LoRA
**Fine-tuned on `gmenon/slt-lyrics-audio` for music lyrics transcription.**
Upload a song clip (≀ 30 s) or record directly from your microphone, then hit **Transcribe**.
"""
)
with gr.Row():
with gr.Column(scale=1):
audio_input = gr.Audio(
label="🎀 Audio Input",
sources=["upload", "microphone"],
type="filepath",
)
with gr.Accordion("βš™οΈ Advanced settings", open=False):
beam_slider = gr.Slider(
minimum=1, maximum=5, value=BEAM_SIZE, step=1,
label="Beam size (1 = greedy, higher = better but slower)"
)
tokens_slider = gr.Slider(
minimum=50, maximum=448, value=MAX_NEW_TOKENS, step=10,
label="Max new tokens"
)
transcribe_btn = gr.Button("🎢 Transcribe", variant="primary")
with gr.Column(scale=1):
output_text = gr.Textbox(
label="πŸ“ Transcription",
lines=8,
placeholder="Lyrics will appear here…",
show_copy_button=True,
)
transcribe_btn.click(
fn=transcribe,
inputs=[audio_input, beam_slider, tokens_slider],
outputs=output_text,
)
gr.Examples(
examples=[], # add example audio paths here if you have them
inputs=audio_input,
)
gr.Markdown(
"""
---
**Model:** `openai/whisper-small` + LoRA (r=8, Ξ±=16) &nbsp;|&nbsp;
**Dataset:** `gmenon/slt-lyrics-audio` &nbsp;|&nbsp;
**Preprocessing:** EBU R128 loudness normalisation, silence trimming
"""
)
if __name__ == "__main__":
demo.launch(share=True) # share=True gives a public URL; remove for local only