get-lyrics / app.py
RayPac006's picture
Update app.py
e7566d4 verified
import os
# 1. Force PyTorch to allow loading "unsafe" weights (The VAD models require this)
os.environ["TORCH_FORCE_NO_WEIGHTS_ONLY_LOAD"] = "1"
import gradio as gr
import whisperx
import json
import torch
import gc
# --- NEW HELPER FUNCTION ---
def format_timestamp(ts):
"""Convert seconds (float) to mm:ss format"""
minutes = int(ts // 60)
seconds = int(ts % 60)
return f"{minutes}:{seconds:02}"
# ---------------------------
# 2. Global Patch for torch.load (Backup fix for libraries that hardcode parameters)
_original_load = torch.load
def patched_load(*args, **kwargs):
if 'weights_only' in kwargs:
kwargs['weights_only'] = False
return _original_load(*args, **kwargs)
torch.load = patched_load
# 1. Setup Device & Config
device = "cuda" if torch.cuda.is_available() else "cpu"
batch_size = 16
compute_type = "float16" if device == "cuda" else "int8" # int8 is faster on CPU
DEFAULT_XLSR = "facebook/wav2vec2-large-xlsr-53"
ALIGN_MODEL_MAP = {
# =====================
# Built-in WhisperX aligner (no explicit model needed)
# =====================
"en": None, # English
"tl": None, # Tagalog / Filipino
"es": None, # Spanish
"fr": None, # French
"de": None, # German
"it": None, # Italian
"pt": None, # Portuguese
"ru": None, # Russian
"nl": None, # Dutch
# =====================
# Explicit wav2vec2 models (needed)
# =====================
# Chinese
"zh": "jonatasgrosman/wav2vec2-large-xlsr-53-chinese-zh-cn",
# Japanese
"ja": "jonatasgrosman/wav2vec2-large-xlsr-53-japanese",
# Korean
"ko": "kresnik/wav2vec2-large-xlsr-korean",
# Thai
"th": "airesearch/wav2vec2-large-xlsr-53-th",
# Vietnamese
"vi": "nguyenvulebinh/wav2vec2-large-xlsr-53-vietnamese",
# Indonesian / Malay
"id": "indonesian-nlp/wav2vec2-large-xlsr-indonesian",
"ms": "indonesian-nlp/wav2vec2-large-xlsr-indonesian",
# Arabic
"ar": "elgeish/wav2vec2-large-xlsr-53-arabic",
# Hindi
"hi": "theainerd/Wav2Vec2-large-xlsr-hindi",
# Turkish
"tr": "savasy/wav2vec2-large-xlsr-turkish",
# Polish
"pl": "jonatasgrosman/wav2vec2-large-xlsr-53-polish",
# Ukrainian
"uk": "jonatasgrosman/wav2vec2-large-xlsr-53-ukrainian",
}
print(f"Loading WhisperX model on {device}...")
model = whisperx.load_model("medium", device, compute_type=compute_type)
def generate_lyrics(audio_file_path):
if audio_file_path is None:
return {"error": "No audio file provided"}
try:
audio = whisperx.load_audio(audio_file_path)
result = model.transcribe(audio, batch_size=batch_size)
lang = result["language"].lower()
if lang.startswith("zh"): lang = "zh"
align_model_name = ALIGN_MODEL_MAP.get(lang, DEFAULT_XLSR)
try:
if align_model_name is None:
model_a, metadata = whisperx.load_align_model(language_code=lang, device=device)
else:
model_a, metadata = whisperx.load_align_model(language_code=lang, device=device, model_name=align_model_name)
result = whisperx.align(result["segments"], model_a, metadata, audio, device, return_char_alignments=False)
del model_a
except Exception as align_err:
print(f"[WARN] Alignment skipped: {align_err}")
# --- UPDATED FORMATTING ---
formatted_lyrics = [
{
"time": format_timestamp(seg["start"]), # Changed from round(seg["start"], 3)
"text": seg["text"].strip(),
"chords": []
}
for seg in result["segments"]
]
# --------------------------
gc.collect()
if device == "cuda":
torch.cuda.empty_cache()
return {"lyrics": formatted_lyrics}
except Exception as e:
return {"error": str(e)}
demo = gr.Interface(
fn=generate_lyrics,
inputs=gr.Audio(type="filepath", label="Upload Vocals/Audio"),
outputs=gr.JSON(label="JSON Result"),
title="WhisperX Aligned Lyric Generator",
description="Transcribes audio and provides mm:ss timestamps."
)
if __name__ == "__main__":
demo.launch()