import os # 1. Force PyTorch to allow loading "unsafe" weights (The VAD models require this) os.environ["TORCH_FORCE_NO_WEIGHTS_ONLY_LOAD"] = "1" import gradio as gr import whisperx import json import torch import gc # --- NEW HELPER FUNCTION --- def format_timestamp(ts): """Convert seconds (float) to mm:ss format""" minutes = int(ts // 60) seconds = int(ts % 60) return f"{minutes}:{seconds:02}" # --------------------------- # 2. Global Patch for torch.load (Backup fix for libraries that hardcode parameters) _original_load = torch.load def patched_load(*args, **kwargs): if 'weights_only' in kwargs: kwargs['weights_only'] = False return _original_load(*args, **kwargs) torch.load = patched_load # 1. Setup Device & Config device = "cuda" if torch.cuda.is_available() else "cpu" batch_size = 16 compute_type = "float16" if device == "cuda" else "int8" # int8 is faster on CPU DEFAULT_XLSR = "facebook/wav2vec2-large-xlsr-53" ALIGN_MODEL_MAP = { # ===================== # Built-in WhisperX aligner (no explicit model needed) # ===================== "en": None, # English "tl": None, # Tagalog / Filipino "es": None, # Spanish "fr": None, # French "de": None, # German "it": None, # Italian "pt": None, # Portuguese "ru": None, # Russian "nl": None, # Dutch # ===================== # Explicit wav2vec2 models (needed) # ===================== # Chinese "zh": "jonatasgrosman/wav2vec2-large-xlsr-53-chinese-zh-cn", # Japanese "ja": "jonatasgrosman/wav2vec2-large-xlsr-53-japanese", # Korean "ko": "kresnik/wav2vec2-large-xlsr-korean", # Thai "th": "airesearch/wav2vec2-large-xlsr-53-th", # Vietnamese "vi": "nguyenvulebinh/wav2vec2-large-xlsr-53-vietnamese", # Indonesian / Malay "id": "indonesian-nlp/wav2vec2-large-xlsr-indonesian", "ms": "indonesian-nlp/wav2vec2-large-xlsr-indonesian", # Arabic "ar": "elgeish/wav2vec2-large-xlsr-53-arabic", # Hindi "hi": "theainerd/Wav2Vec2-large-xlsr-hindi", # Turkish "tr": "savasy/wav2vec2-large-xlsr-turkish", # Polish "pl": "jonatasgrosman/wav2vec2-large-xlsr-53-polish", # Ukrainian "uk": "jonatasgrosman/wav2vec2-large-xlsr-53-ukrainian", } print(f"Loading WhisperX model on {device}...") model = whisperx.load_model("medium", device, compute_type=compute_type) def generate_lyrics(audio_file_path): if audio_file_path is None: return {"error": "No audio file provided"} try: audio = whisperx.load_audio(audio_file_path) result = model.transcribe(audio, batch_size=batch_size) lang = result["language"].lower() if lang.startswith("zh"): lang = "zh" align_model_name = ALIGN_MODEL_MAP.get(lang, DEFAULT_XLSR) try: if align_model_name is None: model_a, metadata = whisperx.load_align_model(language_code=lang, device=device) else: model_a, metadata = whisperx.load_align_model(language_code=lang, device=device, model_name=align_model_name) result = whisperx.align(result["segments"], model_a, metadata, audio, device, return_char_alignments=False) del model_a except Exception as align_err: print(f"[WARN] Alignment skipped: {align_err}") # --- UPDATED FORMATTING --- formatted_lyrics = [ { "time": format_timestamp(seg["start"]), # Changed from round(seg["start"], 3) "text": seg["text"].strip(), "chords": [] } for seg in result["segments"] ] # -------------------------- gc.collect() if device == "cuda": torch.cuda.empty_cache() return {"lyrics": formatted_lyrics} except Exception as e: return {"error": str(e)} demo = gr.Interface( fn=generate_lyrics, inputs=gr.Audio(type="filepath", label="Upload Vocals/Audio"), outputs=gr.JSON(label="JSON Result"), title="WhisperX Aligned Lyric Generator", description="Transcribes audio and provides mm:ss timestamps." ) if __name__ == "__main__": demo.launch()