File size: 6,485 Bytes
73c22fe
 
 
 
bce4a85
52f0004
20277ed
316c57c
61386ba
20277ed
20dd96d
 
 
e96d8eb
20dd96d
 
61386ba
20dd96d
e96d8eb
73c22fe
 
61386ba
4927bb9
20dd96d
 
 
e96d8eb
 
20dd96d
e96d8eb
 
a6041ec
 
 
 
 
 
73c22fe
 
20dd96d
e96d8eb
 
 
 
73c22fe
e96d8eb
20dd96d
73c22fe
20dd96d
 
 
 
 
 
 
52f0004
6ab727a
20dd96d
 
 
 
 
73c22fe
 
 
 
 
 
 
 
 
 
 
 
20dd96d
 
 
 
 
73c22fe
20dd96d
52f0004
20dd96d
bce4a85
20dd96d
 
 
316c57c
61386ba
316c57c
61386ba
73c22fe
61386ba
52f0004
 
20dd96d
 
316c57c
73c22fe
20dd96d
73c22fe
 
 
 
 
 
 
 
 
 
 
20dd96d
 
 
 
 
 
 
 
 
 
 
 
 
 
 
73c22fe
20dd96d
 
 
 
73c22fe
20dd96d
 
 
 
6ab727a
 
73c22fe
20dd96d
 
 
 
73c22fe
6ab727a
20dd96d
 
7a79027
73c22fe
6ab727a
20dd96d
 
 
 
73c22fe
20dd96d
73c22fe
 
20dd96d
73c22fe
20dd96d
 
 
 
73c22fe
20dd96d
 
73c22fe
 
 
 
 
 
 
 
 
 
 
 
20dd96d
 
 
 
 
 
73c22fe
20dd96d
 
 
 
 
 
 
52f0004
 
 
e96d8eb
73c22fe
 
 
 
 
 
52f0004
bce4a85
52f0004
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
import re
import os
import tempfile

import gradio as gr
import torch
import torchaudio
import requests
from faster_whisper import WhisperModel

# ================================
# CONFIG
# ================================
DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
MODEL_NAME = os.getenv("WHISPER_MODEL", "large-v3")
FAST_MODEL_NAME = os.getenv("FAST_WHISPER_MODEL", "base")
COMPUTE_TYPE = "float16" if torch.cuda.is_available() else "int8"

BAD_WORD_URL = (
    "https://raw.githubusercontent.com/LDNOOBW/"
    "List-of-Dirty-Naughty-Obscene-and-Otherwise-Bad-Words/master/en"
)

# ================================
# BAD WORD LIST
# ================================
def get_bad_words():
    try:
        print("🌐 Fetching bad-word list…")
        r = requests.get(BAD_WORD_URL, timeout=10)
        if r.status_code == 200:
            words = {
                re.sub(r"[^\w]", "", w.lower())
                for line in r.text.splitlines()
                for w in line.split()
                if w.strip()
            }
            # Extra words to always catch
            words.update({"hell", "dam", "damn", "yeah"})
            print(f"βœ… Loaded {len(words)} bad words.")
            return words
    except Exception as e:
        print(f"⚠️ Failed to fetch list: {e}")

    return {"fuck", "shit", "bitch", "ass", "damn", "hell"}  # fallback


BAD_WORDS = get_bad_words()


# ================================
# UTILITY: SAFE AUDIO LOAD
# ================================
def load_audio_safe(path, target_sr=16000):
    wav, sr = torchaudio.load(path)
    if wav.shape[0] > 1:
        wav = wav.mean(dim=0, keepdim=True)
    if sr != target_sr:
        wav = torchaudio.functional.resample(wav, sr, target_sr)
    return wav, target_sr


# ================================
# LOAD MODELS
# ================================
print(f"πŸš€ Loading FAST Whisper: {FAST_MODEL_NAME} ({COMPUTE_TYPE}) on {DEVICE}")
fast_model = WhisperModel(FAST_MODEL_NAME, device=DEVICE, compute_type=COMPUTE_TYPE)

print(f"πŸš€ Loading LARGE Whisper: {MODEL_NAME} ({COMPUTE_TYPE}) on {DEVICE}")
large_model = WhisperModel(MODEL_NAME, device=DEVICE, compute_type=COMPUTE_TYPE)

print("βœ… All models ready!\n")


# ================================
# MAIN TRANSCRIBE FUNCTION
# ================================
def transcribe(file_path):

    # Load + normalize audio
    wav, sr = load_audio_safe(file_path)
    fixed_path = "input_fixed.wav"
    torchaudio.save(fixed_path, wav, sr)

    # =====================================
    # 1) FAST PASS β€” detect explicit words
    # =====================================
    fast_segments, fast_info = fast_model.transcribe(
        fixed_path,
        beam_size=1,
        word_timestamps=True,
        vad_filter=True,
    )

    transcript = []
    sample_rate = getattr(fast_info, "sample_rate", sr)

    for seg in fast_segments:
        if not getattr(seg, "words", None):
            continue
        for w in seg.words:
            # FIX: was incorrectly re-running the bad word set comprehension here
            clean_word = re.sub(r"[^\w]", "", w.word.strip().lower())
            is_explicit = clean_word in BAD_WORDS
            transcript.append({
                "word": w.word.strip(),
                "start": float(w.start),
                "end": float(w.end),
                "explicit": is_explicit,
                "explicit_fast": is_explicit,
            })

    # =====================================
    # EARLY EXIT IF NO EXPLICIT WORDS
    # =====================================
    flagged = [w for w in transcript if w["explicit_fast"]]
    if not flagged:
        print("βœ… No explicit words detected β€” returning fast transcript.")
        return transcript

    # =====================================
    # 2) REFINE PASS β€” only explicit words
    # =====================================
    final = []

    for entry in transcript:
        # Not explicit β€” keep untouched
        if not entry["explicit_fast"]:
            final.append(entry)
            continue

        # Extract audio chunk for just this word
        start_s = entry["start"]
        end_s = entry["end"]
        start_sample = int(start_s * sample_rate)
        end_sample = int(end_s * sample_rate)
        chunk = wav[:, start_sample:end_sample]

        # Safety: collapsed timestamp
        if chunk.numel() == 0:
            final.append(entry)
            continue

        # Save chunk to temp file
        with tempfile.NamedTemporaryFile(delete=False, suffix=".wav") as tmp:
            chunk_path = tmp.name
        torchaudio.save(chunk_path, chunk, sample_rate)

        # Run large model on chunk
        try:
            refined_segs, _ = large_model.transcribe(
                chunk_path,
                beam_size=5,
                word_timestamps=True,
                vad_filter=False,
            )
        except Exception as e:
            print(f"⚠️ Large model failed on chunk: {e} β€” keeping fast result")
            final.append(entry)
            os.remove(chunk_path)
            continue

        os.remove(chunk_path)

        # Extract refined words, offset timestamps back to full-track time
        refined_words = []
        for seg in refined_segs:
            if not getattr(seg, "words", None):
                continue
            for w in seg.words:
                refined_words.append({
                    "word": w.word.strip(),
                    "start": float(w.start) + start_s,
                    "end": float(w.end) + start_s,
                    "explicit": entry["explicit_fast"],
                    "explicit_fast": entry["explicit_fast"],
                })

        # Fallback if large model returned nothing
        if not refined_words:
            final.append(entry)
            continue

        final.extend(refined_words)

    # Sort by timestamp (critical for assembler)
    final.sort(key=lambda x: x["start"])
    return final


# ================================
# GRADIO UI
# ================================
iface = gr.Interface(
    fn=transcribe,
    inputs=gr.Audio(type="filepath", label="Upload Vocals"),
    outputs=gr.JSON(label="Transcript with Explicit Flags"),
    title="CleanSong AI β€” Whisper Transcriber",
    description=(
        "Fast model detects explicit words β†’ "
        "Large model refines only those segments. "
        "Returns word-level timestamps."
    ),
)

if __name__ == "__main__":
    iface.launch()