Spaces:
Sleeping
Sleeping
| import os | |
| # 1. Force PyTorch to allow loading "unsafe" weights (The VAD models require this) | |
| os.environ["TORCH_FORCE_NO_WEIGHTS_ONLY_LOAD"] = "1" | |
| import gradio as gr | |
| import whisperx | |
| import json | |
| import torch | |
| import gc | |
| # --- NEW HELPER FUNCTION --- | |
| def format_timestamp(ts): | |
| """Convert seconds (float) to mm:ss format""" | |
| minutes = int(ts // 60) | |
| seconds = int(ts % 60) | |
| return f"{minutes}:{seconds:02}" | |
| # --------------------------- | |
| # 2. Global Patch for torch.load (Backup fix for libraries that hardcode parameters) | |
| _original_load = torch.load | |
| def patched_load(*args, **kwargs): | |
| if 'weights_only' in kwargs: | |
| kwargs['weights_only'] = False | |
| return _original_load(*args, **kwargs) | |
| torch.load = patched_load | |
| # 1. Setup Device & Config | |
| device = "cuda" if torch.cuda.is_available() else "cpu" | |
| batch_size = 16 | |
| compute_type = "float16" if device == "cuda" else "int8" # int8 is faster on CPU | |
| DEFAULT_XLSR = "facebook/wav2vec2-large-xlsr-53" | |
| ALIGN_MODEL_MAP = { | |
| # ===================== | |
| # Built-in WhisperX aligner (no explicit model needed) | |
| # ===================== | |
| "en": None, # English | |
| "tl": None, # Tagalog / Filipino | |
| "es": None, # Spanish | |
| "fr": None, # French | |
| "de": None, # German | |
| "it": None, # Italian | |
| "pt": None, # Portuguese | |
| "ru": None, # Russian | |
| "nl": None, # Dutch | |
| # ===================== | |
| # Explicit wav2vec2 models (needed) | |
| # ===================== | |
| # Chinese | |
| "zh": "jonatasgrosman/wav2vec2-large-xlsr-53-chinese-zh-cn", | |
| # Japanese | |
| "ja": "jonatasgrosman/wav2vec2-large-xlsr-53-japanese", | |
| # Korean | |
| "ko": "kresnik/wav2vec2-large-xlsr-korean", | |
| # Thai | |
| "th": "airesearch/wav2vec2-large-xlsr-53-th", | |
| # Vietnamese | |
| "vi": "nguyenvulebinh/wav2vec2-large-xlsr-53-vietnamese", | |
| # Indonesian / Malay | |
| "id": "indonesian-nlp/wav2vec2-large-xlsr-indonesian", | |
| "ms": "indonesian-nlp/wav2vec2-large-xlsr-indonesian", | |
| # Arabic | |
| "ar": "elgeish/wav2vec2-large-xlsr-53-arabic", | |
| # Hindi | |
| "hi": "theainerd/Wav2Vec2-large-xlsr-hindi", | |
| # Turkish | |
| "tr": "savasy/wav2vec2-large-xlsr-turkish", | |
| # Polish | |
| "pl": "jonatasgrosman/wav2vec2-large-xlsr-53-polish", | |
| # Ukrainian | |
| "uk": "jonatasgrosman/wav2vec2-large-xlsr-53-ukrainian", | |
| } | |
| print(f"Loading WhisperX model on {device}...") | |
| model = whisperx.load_model("medium", device, compute_type=compute_type) | |
| def generate_lyrics(audio_file_path): | |
| if audio_file_path is None: | |
| return {"error": "No audio file provided"} | |
| try: | |
| audio = whisperx.load_audio(audio_file_path) | |
| result = model.transcribe(audio, batch_size=batch_size) | |
| lang = result["language"].lower() | |
| if lang.startswith("zh"): lang = "zh" | |
| align_model_name = ALIGN_MODEL_MAP.get(lang, DEFAULT_XLSR) | |
| try: | |
| if align_model_name is None: | |
| model_a, metadata = whisperx.load_align_model(language_code=lang, device=device) | |
| else: | |
| model_a, metadata = whisperx.load_align_model(language_code=lang, device=device, model_name=align_model_name) | |
| result = whisperx.align(result["segments"], model_a, metadata, audio, device, return_char_alignments=False) | |
| del model_a | |
| except Exception as align_err: | |
| print(f"[WARN] Alignment skipped: {align_err}") | |
| # --- UPDATED FORMATTING --- | |
| formatted_lyrics = [ | |
| { | |
| "time": format_timestamp(seg["start"]), # Changed from round(seg["start"], 3) | |
| "text": seg["text"].strip(), | |
| "chords": [] | |
| } | |
| for seg in result["segments"] | |
| ] | |
| # -------------------------- | |
| gc.collect() | |
| if device == "cuda": | |
| torch.cuda.empty_cache() | |
| return {"lyrics": formatted_lyrics} | |
| except Exception as e: | |
| return {"error": str(e)} | |
| demo = gr.Interface( | |
| fn=generate_lyrics, | |
| inputs=gr.Audio(type="filepath", label="Upload Vocals/Audio"), | |
| outputs=gr.JSON(label="JSON Result"), | |
| title="WhisperX Aligned Lyric Generator", | |
| description="Transcribes audio and provides mm:ss timestamps." | |
| ) | |
| if __name__ == "__main__": | |
| demo.launch() | |