Spaces:
Running
Running
File size: 4,265 Bytes
1716dc6 c2256c1 1716dc6 81d4369 e7566d4 1716dc6 c2256c1 1716dc6 a32061d 2deee0c bb78346 2deee0c bb78346 2deee0c bb78346 2deee0c c2256c1 1716dc6 e2e70f5 c2256c1 1716dc6 c2256c1 1716dc6 e7566d4 2deee0c e7566d4 bb78346 2deee0c e7566d4 2deee0c e7566d4 2deee0c e7566d4 2deee0c e7566d4 2deee0c e7566d4 2deee0c e7566d4 2deee0c 1716dc6 c2256c1 1716dc6 c2256c1 1716dc6 c2256c1 1716dc6 e7566d4 1716dc6 c2256c1 1716dc6 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 |
import os
# 1. Force PyTorch to allow loading "unsafe" weights (The VAD models require this)
os.environ["TORCH_FORCE_NO_WEIGHTS_ONLY_LOAD"] = "1"
import gradio as gr
import whisperx
import json
import torch
import gc
# --- NEW HELPER FUNCTION ---
def format_timestamp(ts):
"""Convert seconds (float) to mm:ss format"""
minutes = int(ts // 60)
seconds = int(ts % 60)
return f"{minutes}:{seconds:02}"
# ---------------------------
# 2. Global Patch for torch.load (Backup fix for libraries that hardcode parameters)
_original_load = torch.load
def patched_load(*args, **kwargs):
if 'weights_only' in kwargs:
kwargs['weights_only'] = False
return _original_load(*args, **kwargs)
torch.load = patched_load
# 1. Setup Device & Config
device = "cuda" if torch.cuda.is_available() else "cpu"
batch_size = 16
compute_type = "float16" if device == "cuda" else "int8" # int8 is faster on CPU
DEFAULT_XLSR = "facebook/wav2vec2-large-xlsr-53"
ALIGN_MODEL_MAP = {
# =====================
# Built-in WhisperX aligner (no explicit model needed)
# =====================
"en": None, # English
"tl": None, # Tagalog / Filipino
"es": None, # Spanish
"fr": None, # French
"de": None, # German
"it": None, # Italian
"pt": None, # Portuguese
"ru": None, # Russian
"nl": None, # Dutch
# =====================
# Explicit wav2vec2 models (needed)
# =====================
# Chinese
"zh": "jonatasgrosman/wav2vec2-large-xlsr-53-chinese-zh-cn",
# Japanese
"ja": "jonatasgrosman/wav2vec2-large-xlsr-53-japanese",
# Korean
"ko": "kresnik/wav2vec2-large-xlsr-korean",
# Thai
"th": "airesearch/wav2vec2-large-xlsr-53-th",
# Vietnamese
"vi": "nguyenvulebinh/wav2vec2-large-xlsr-53-vietnamese",
# Indonesian / Malay
"id": "indonesian-nlp/wav2vec2-large-xlsr-indonesian",
"ms": "indonesian-nlp/wav2vec2-large-xlsr-indonesian",
# Arabic
"ar": "elgeish/wav2vec2-large-xlsr-53-arabic",
# Hindi
"hi": "theainerd/Wav2Vec2-large-xlsr-hindi",
# Turkish
"tr": "savasy/wav2vec2-large-xlsr-turkish",
# Polish
"pl": "jonatasgrosman/wav2vec2-large-xlsr-53-polish",
# Ukrainian
"uk": "jonatasgrosman/wav2vec2-large-xlsr-53-ukrainian",
}
print(f"Loading WhisperX model on {device}...")
model = whisperx.load_model("medium", device, compute_type=compute_type)
def generate_lyrics(audio_file_path):
if audio_file_path is None:
return {"error": "No audio file provided"}
try:
audio = whisperx.load_audio(audio_file_path)
result = model.transcribe(audio, batch_size=batch_size)
lang = result["language"].lower()
if lang.startswith("zh"): lang = "zh"
align_model_name = ALIGN_MODEL_MAP.get(lang, DEFAULT_XLSR)
try:
if align_model_name is None:
model_a, metadata = whisperx.load_align_model(language_code=lang, device=device)
else:
model_a, metadata = whisperx.load_align_model(language_code=lang, device=device, model_name=align_model_name)
result = whisperx.align(result["segments"], model_a, metadata, audio, device, return_char_alignments=False)
del model_a
except Exception as align_err:
print(f"[WARN] Alignment skipped: {align_err}")
# --- UPDATED FORMATTING ---
formatted_lyrics = [
{
"time": format_timestamp(seg["start"]), # Changed from round(seg["start"], 3)
"text": seg["text"].strip(),
"chords": []
}
for seg in result["segments"]
]
# --------------------------
gc.collect()
if device == "cuda":
torch.cuda.empty_cache()
return {"lyrics": formatted_lyrics}
except Exception as e:
return {"error": str(e)}
demo = gr.Interface(
fn=generate_lyrics,
inputs=gr.Audio(type="filepath", label="Upload Vocals/Audio"),
outputs=gr.JSON(label="JSON Result"),
title="WhisperX Aligned Lyric Generator",
description="Transcribes audio and provides mm:ss timestamps."
)
if __name__ == "__main__":
demo.launch()
|