File size: 6,485 Bytes
73c22fe bce4a85 52f0004 20277ed 316c57c 61386ba 20277ed 20dd96d e96d8eb 20dd96d 61386ba 20dd96d e96d8eb 73c22fe 61386ba 4927bb9 20dd96d e96d8eb 20dd96d e96d8eb a6041ec 73c22fe 20dd96d e96d8eb 73c22fe e96d8eb 20dd96d 73c22fe 20dd96d 52f0004 6ab727a 20dd96d 73c22fe 20dd96d 73c22fe 20dd96d 52f0004 20dd96d bce4a85 20dd96d 316c57c 61386ba 316c57c 61386ba 73c22fe 61386ba 52f0004 20dd96d 316c57c 73c22fe 20dd96d 73c22fe 20dd96d 73c22fe 20dd96d 73c22fe 20dd96d 6ab727a 73c22fe 20dd96d 73c22fe 6ab727a 20dd96d 7a79027 73c22fe 6ab727a 20dd96d 73c22fe 20dd96d 73c22fe 20dd96d 73c22fe 20dd96d 73c22fe 20dd96d 73c22fe 20dd96d 73c22fe 20dd96d 52f0004 e96d8eb 73c22fe 52f0004 bce4a85 52f0004 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 | import re
import os
import tempfile
import gradio as gr
import torch
import torchaudio
import requests
from faster_whisper import WhisperModel
# ================================
# CONFIG
# ================================
DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
MODEL_NAME = os.getenv("WHISPER_MODEL", "large-v3")
FAST_MODEL_NAME = os.getenv("FAST_WHISPER_MODEL", "base")
COMPUTE_TYPE = "float16" if torch.cuda.is_available() else "int8"
BAD_WORD_URL = (
"https://raw.githubusercontent.com/LDNOOBW/"
"List-of-Dirty-Naughty-Obscene-and-Otherwise-Bad-Words/master/en"
)
# ================================
# BAD WORD LIST
# ================================
def get_bad_words():
try:
print("π Fetching bad-word listβ¦")
r = requests.get(BAD_WORD_URL, timeout=10)
if r.status_code == 200:
words = {
re.sub(r"[^\w]", "", w.lower())
for line in r.text.splitlines()
for w in line.split()
if w.strip()
}
# Extra words to always catch
words.update({"hell", "dam", "damn", "yeah"})
print(f"β
Loaded {len(words)} bad words.")
return words
except Exception as e:
print(f"β οΈ Failed to fetch list: {e}")
return {"fuck", "shit", "bitch", "ass", "damn", "hell"} # fallback
BAD_WORDS = get_bad_words()
# ================================
# UTILITY: SAFE AUDIO LOAD
# ================================
def load_audio_safe(path, target_sr=16000):
wav, sr = torchaudio.load(path)
if wav.shape[0] > 1:
wav = wav.mean(dim=0, keepdim=True)
if sr != target_sr:
wav = torchaudio.functional.resample(wav, sr, target_sr)
return wav, target_sr
# ================================
# LOAD MODELS
# ================================
print(f"π Loading FAST Whisper: {FAST_MODEL_NAME} ({COMPUTE_TYPE}) on {DEVICE}")
fast_model = WhisperModel(FAST_MODEL_NAME, device=DEVICE, compute_type=COMPUTE_TYPE)
print(f"π Loading LARGE Whisper: {MODEL_NAME} ({COMPUTE_TYPE}) on {DEVICE}")
large_model = WhisperModel(MODEL_NAME, device=DEVICE, compute_type=COMPUTE_TYPE)
print("β
All models ready!\n")
# ================================
# MAIN TRANSCRIBE FUNCTION
# ================================
def transcribe(file_path):
# Load + normalize audio
wav, sr = load_audio_safe(file_path)
fixed_path = "input_fixed.wav"
torchaudio.save(fixed_path, wav, sr)
# =====================================
# 1) FAST PASS β detect explicit words
# =====================================
fast_segments, fast_info = fast_model.transcribe(
fixed_path,
beam_size=1,
word_timestamps=True,
vad_filter=True,
)
transcript = []
sample_rate = getattr(fast_info, "sample_rate", sr)
for seg in fast_segments:
if not getattr(seg, "words", None):
continue
for w in seg.words:
# FIX: was incorrectly re-running the bad word set comprehension here
clean_word = re.sub(r"[^\w]", "", w.word.strip().lower())
is_explicit = clean_word in BAD_WORDS
transcript.append({
"word": w.word.strip(),
"start": float(w.start),
"end": float(w.end),
"explicit": is_explicit,
"explicit_fast": is_explicit,
})
# =====================================
# EARLY EXIT IF NO EXPLICIT WORDS
# =====================================
flagged = [w for w in transcript if w["explicit_fast"]]
if not flagged:
print("β
No explicit words detected β returning fast transcript.")
return transcript
# =====================================
# 2) REFINE PASS β only explicit words
# =====================================
final = []
for entry in transcript:
# Not explicit β keep untouched
if not entry["explicit_fast"]:
final.append(entry)
continue
# Extract audio chunk for just this word
start_s = entry["start"]
end_s = entry["end"]
start_sample = int(start_s * sample_rate)
end_sample = int(end_s * sample_rate)
chunk = wav[:, start_sample:end_sample]
# Safety: collapsed timestamp
if chunk.numel() == 0:
final.append(entry)
continue
# Save chunk to temp file
with tempfile.NamedTemporaryFile(delete=False, suffix=".wav") as tmp:
chunk_path = tmp.name
torchaudio.save(chunk_path, chunk, sample_rate)
# Run large model on chunk
try:
refined_segs, _ = large_model.transcribe(
chunk_path,
beam_size=5,
word_timestamps=True,
vad_filter=False,
)
except Exception as e:
print(f"β οΈ Large model failed on chunk: {e} β keeping fast result")
final.append(entry)
os.remove(chunk_path)
continue
os.remove(chunk_path)
# Extract refined words, offset timestamps back to full-track time
refined_words = []
for seg in refined_segs:
if not getattr(seg, "words", None):
continue
for w in seg.words:
refined_words.append({
"word": w.word.strip(),
"start": float(w.start) + start_s,
"end": float(w.end) + start_s,
"explicit": entry["explicit_fast"],
"explicit_fast": entry["explicit_fast"],
})
# Fallback if large model returned nothing
if not refined_words:
final.append(entry)
continue
final.extend(refined_words)
# Sort by timestamp (critical for assembler)
final.sort(key=lambda x: x["start"])
return final
# ================================
# GRADIO UI
# ================================
iface = gr.Interface(
fn=transcribe,
inputs=gr.Audio(type="filepath", label="Upload Vocals"),
outputs=gr.JSON(label="Transcript with Explicit Flags"),
title="CleanSong AI β Whisper Transcriber",
description=(
"Fast model detects explicit words β "
"Large model refines only those segments. "
"Returns word-level timestamps."
),
)
if __name__ == "__main__":
iface.launch()
|