File size: 4,265 Bytes
1716dc6
 
 
 
c2256c1
 
 
1716dc6
 
81d4369
e7566d4
 
 
 
 
 
 
 
1716dc6
 
 
 
 
 
 
c2256c1
1716dc6
 
 
 
a32061d
2deee0c
bb78346
 
 
 
 
 
 
 
 
 
 
 
2deee0c
bb78346
 
 
 
 
 
 
 
 
 
 
 
 
 
2deee0c
 
bb78346
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
2deee0c
c2256c1
1716dc6
e2e70f5
c2256c1
1716dc6
 
 
c2256c1
1716dc6
 
e7566d4
2deee0c
e7566d4
 
 
bb78346
2deee0c
 
 
e7566d4
2deee0c
e7566d4
2deee0c
e7566d4
2deee0c
 
 
 
e7566d4
2deee0c
 
e7566d4
2deee0c
 
 
 
 
e7566d4
2deee0c
1716dc6
 
 
c2256c1
1716dc6
c2256c1
1716dc6
 
c2256c1
1716dc6
 
 
 
 
e7566d4
1716dc6
c2256c1
1716dc6
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
import os
# 1. Force PyTorch to allow loading "unsafe" weights (The VAD models require this)
os.environ["TORCH_FORCE_NO_WEIGHTS_ONLY_LOAD"] = "1"

import gradio as gr
import whisperx
import json
import torch
import gc

# --- NEW HELPER FUNCTION ---
def format_timestamp(ts):
    """Convert seconds (float) to mm:ss format"""
    minutes = int(ts // 60)
    seconds = int(ts % 60)
    return f"{minutes}:{seconds:02}"
# ---------------------------

# 2. Global Patch for torch.load (Backup fix for libraries that hardcode parameters)
_original_load = torch.load
def patched_load(*args, **kwargs):
    if 'weights_only' in kwargs:
        kwargs['weights_only'] = False
    return _original_load(*args, **kwargs)
torch.load = patched_load

# 1. Setup Device & Config
device = "cuda" if torch.cuda.is_available() else "cpu"
batch_size = 16 
compute_type = "float16" if device == "cuda" else "int8" # int8 is faster on CPU
DEFAULT_XLSR = "facebook/wav2vec2-large-xlsr-53"
ALIGN_MODEL_MAP = {
    # =====================
    # Built-in WhisperX aligner (no explicit model needed)
    # =====================
    "en": None,     # English
    "tl": None,     # Tagalog / Filipino
    "es": None,     # Spanish
    "fr": None,     # French
    "de": None,     # German
    "it": None,     # Italian
    "pt": None,     # Portuguese
    "ru": None,     # Russian
    "nl": None,     # Dutch

    # =====================
    # Explicit wav2vec2 models (needed)
    # =====================

    # Chinese
    "zh": "jonatasgrosman/wav2vec2-large-xlsr-53-chinese-zh-cn",

    # Japanese
    "ja": "jonatasgrosman/wav2vec2-large-xlsr-53-japanese",

    # Korean
    "ko": "kresnik/wav2vec2-large-xlsr-korean",

    # Thai
    "th": "airesearch/wav2vec2-large-xlsr-53-th",

    # Vietnamese
    "vi": "nguyenvulebinh/wav2vec2-large-xlsr-53-vietnamese",

    # Indonesian / Malay
    "id": "indonesian-nlp/wav2vec2-large-xlsr-indonesian",
    "ms": "indonesian-nlp/wav2vec2-large-xlsr-indonesian",

    # Arabic
    "ar": "elgeish/wav2vec2-large-xlsr-53-arabic",

    # Hindi
    "hi": "theainerd/Wav2Vec2-large-xlsr-hindi",

    # Turkish
    "tr": "savasy/wav2vec2-large-xlsr-turkish",

    # Polish
    "pl": "jonatasgrosman/wav2vec2-large-xlsr-53-polish",

    # Ukrainian
    "uk": "jonatasgrosman/wav2vec2-large-xlsr-53-ukrainian",

}

print(f"Loading WhisperX model on {device}...")
model = whisperx.load_model("medium", device, compute_type=compute_type)

def generate_lyrics(audio_file_path):
    if audio_file_path is None:
        return {"error": "No audio file provided"}

    try:
        audio = whisperx.load_audio(audio_file_path)
        result = model.transcribe(audio, batch_size=batch_size)

        lang = result["language"].lower()
        if lang.startswith("zh"): lang = "zh"
        
        align_model_name = ALIGN_MODEL_MAP.get(lang, DEFAULT_XLSR)

        try:
            if align_model_name is None:
                model_a, metadata = whisperx.load_align_model(language_code=lang, device=device)
            else:
                model_a, metadata = whisperx.load_align_model(language_code=lang, device=device, model_name=align_model_name)

            result = whisperx.align(result["segments"], model_a, metadata, audio, device, return_char_alignments=False)
            del model_a
        except Exception as align_err:
            print(f"[WARN] Alignment skipped: {align_err}")

        # --- UPDATED FORMATTING ---
        formatted_lyrics = [
            {
                "time": format_timestamp(seg["start"]), # Changed from round(seg["start"], 3)
                "text": seg["text"].strip(),
                "chords": []
            }
            for seg in result["segments"]
        ]
        # --------------------------

        gc.collect()
        if device == "cuda":
            torch.cuda.empty_cache()

        return {"lyrics": formatted_lyrics}

    except Exception as e:
        return {"error": str(e)}

demo = gr.Interface(
    fn=generate_lyrics,
    inputs=gr.Audio(type="filepath", label="Upload Vocals/Audio"),
    outputs=gr.JSON(label="JSON Result"),
    title="WhisperX Aligned Lyric Generator",
    description="Transcribes audio and provides mm:ss timestamps."
)

if __name__ == "__main__":
    demo.launch()