CleanSong's picture
Update app.py
73c22fe verified
import re
import os
import tempfile
import gradio as gr
import torch
import torchaudio
import requests
from faster_whisper import WhisperModel
# ================================
# CONFIG
# ================================
DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
MODEL_NAME = os.getenv("WHISPER_MODEL", "large-v3")
FAST_MODEL_NAME = os.getenv("FAST_WHISPER_MODEL", "base")
COMPUTE_TYPE = "float16" if torch.cuda.is_available() else "int8"
BAD_WORD_URL = (
"https://raw.githubusercontent.com/LDNOOBW/"
"List-of-Dirty-Naughty-Obscene-and-Otherwise-Bad-Words/master/en"
)
# ================================
# BAD WORD LIST
# ================================
def get_bad_words():
try:
print("🌐 Fetching bad-word list…")
r = requests.get(BAD_WORD_URL, timeout=10)
if r.status_code == 200:
words = {
re.sub(r"[^\w]", "", w.lower())
for line in r.text.splitlines()
for w in line.split()
if w.strip()
}
# Extra words to always catch
words.update({"hell", "dam", "damn", "yeah"})
print(f"βœ… Loaded {len(words)} bad words.")
return words
except Exception as e:
print(f"⚠️ Failed to fetch list: {e}")
return {"fuck", "shit", "bitch", "ass", "damn", "hell"} # fallback
BAD_WORDS = get_bad_words()
# ================================
# UTILITY: SAFE AUDIO LOAD
# ================================
def load_audio_safe(path, target_sr=16000):
wav, sr = torchaudio.load(path)
if wav.shape[0] > 1:
wav = wav.mean(dim=0, keepdim=True)
if sr != target_sr:
wav = torchaudio.functional.resample(wav, sr, target_sr)
return wav, target_sr
# ================================
# LOAD MODELS
# ================================
print(f"πŸš€ Loading FAST Whisper: {FAST_MODEL_NAME} ({COMPUTE_TYPE}) on {DEVICE}")
fast_model = WhisperModel(FAST_MODEL_NAME, device=DEVICE, compute_type=COMPUTE_TYPE)
print(f"πŸš€ Loading LARGE Whisper: {MODEL_NAME} ({COMPUTE_TYPE}) on {DEVICE}")
large_model = WhisperModel(MODEL_NAME, device=DEVICE, compute_type=COMPUTE_TYPE)
print("βœ… All models ready!\n")
# ================================
# MAIN TRANSCRIBE FUNCTION
# ================================
def transcribe(file_path):
# Load + normalize audio
wav, sr = load_audio_safe(file_path)
fixed_path = "input_fixed.wav"
torchaudio.save(fixed_path, wav, sr)
# =====================================
# 1) FAST PASS β€” detect explicit words
# =====================================
fast_segments, fast_info = fast_model.transcribe(
fixed_path,
beam_size=1,
word_timestamps=True,
vad_filter=True,
)
transcript = []
sample_rate = getattr(fast_info, "sample_rate", sr)
for seg in fast_segments:
if not getattr(seg, "words", None):
continue
for w in seg.words:
# FIX: was incorrectly re-running the bad word set comprehension here
clean_word = re.sub(r"[^\w]", "", w.word.strip().lower())
is_explicit = clean_word in BAD_WORDS
transcript.append({
"word": w.word.strip(),
"start": float(w.start),
"end": float(w.end),
"explicit": is_explicit,
"explicit_fast": is_explicit,
})
# =====================================
# EARLY EXIT IF NO EXPLICIT WORDS
# =====================================
flagged = [w for w in transcript if w["explicit_fast"]]
if not flagged:
print("βœ… No explicit words detected β€” returning fast transcript.")
return transcript
# =====================================
# 2) REFINE PASS β€” only explicit words
# =====================================
final = []
for entry in transcript:
# Not explicit β€” keep untouched
if not entry["explicit_fast"]:
final.append(entry)
continue
# Extract audio chunk for just this word
start_s = entry["start"]
end_s = entry["end"]
start_sample = int(start_s * sample_rate)
end_sample = int(end_s * sample_rate)
chunk = wav[:, start_sample:end_sample]
# Safety: collapsed timestamp
if chunk.numel() == 0:
final.append(entry)
continue
# Save chunk to temp file
with tempfile.NamedTemporaryFile(delete=False, suffix=".wav") as tmp:
chunk_path = tmp.name
torchaudio.save(chunk_path, chunk, sample_rate)
# Run large model on chunk
try:
refined_segs, _ = large_model.transcribe(
chunk_path,
beam_size=5,
word_timestamps=True,
vad_filter=False,
)
except Exception as e:
print(f"⚠️ Large model failed on chunk: {e} β€” keeping fast result")
final.append(entry)
os.remove(chunk_path)
continue
os.remove(chunk_path)
# Extract refined words, offset timestamps back to full-track time
refined_words = []
for seg in refined_segs:
if not getattr(seg, "words", None):
continue
for w in seg.words:
refined_words.append({
"word": w.word.strip(),
"start": float(w.start) + start_s,
"end": float(w.end) + start_s,
"explicit": entry["explicit_fast"],
"explicit_fast": entry["explicit_fast"],
})
# Fallback if large model returned nothing
if not refined_words:
final.append(entry)
continue
final.extend(refined_words)
# Sort by timestamp (critical for assembler)
final.sort(key=lambda x: x["start"])
return final
# ================================
# GRADIO UI
# ================================
iface = gr.Interface(
fn=transcribe,
inputs=gr.Audio(type="filepath", label="Upload Vocals"),
outputs=gr.JSON(label="Transcript with Explicit Flags"),
title="CleanSong AI β€” Whisper Transcriber",
description=(
"Fast model detects explicit words β†’ "
"Large model refines only those segments. "
"Returns word-level timestamps."
),
)
if __name__ == "__main__":
iface.launch()