import re import os import tempfile import gradio as gr import torch import torchaudio import requests from faster_whisper import WhisperModel # ================================ # CONFIG # ================================ DEVICE = "cuda" if torch.cuda.is_available() else "cpu" MODEL_NAME = os.getenv("WHISPER_MODEL", "large-v3") FAST_MODEL_NAME = os.getenv("FAST_WHISPER_MODEL", "base") COMPUTE_TYPE = "float16" if torch.cuda.is_available() else "int8" BAD_WORD_URL = ( "https://raw.githubusercontent.com/LDNOOBW/" "List-of-Dirty-Naughty-Obscene-and-Otherwise-Bad-Words/master/en" ) # ================================ # BAD WORD LIST # ================================ def get_bad_words(): try: print("🌐 Fetching bad-word list…") r = requests.get(BAD_WORD_URL, timeout=10) if r.status_code == 200: words = { re.sub(r"[^\w]", "", w.lower()) for line in r.text.splitlines() for w in line.split() if w.strip() } # Extra words to always catch words.update({"hell", "dam", "damn", "yeah"}) print(f"βœ… Loaded {len(words)} bad words.") return words except Exception as e: print(f"⚠️ Failed to fetch list: {e}") return {"fuck", "shit", "bitch", "ass", "damn", "hell"} # fallback BAD_WORDS = get_bad_words() # ================================ # UTILITY: SAFE AUDIO LOAD # ================================ def load_audio_safe(path, target_sr=16000): wav, sr = torchaudio.load(path) if wav.shape[0] > 1: wav = wav.mean(dim=0, keepdim=True) if sr != target_sr: wav = torchaudio.functional.resample(wav, sr, target_sr) return wav, target_sr # ================================ # LOAD MODELS # ================================ print(f"πŸš€ Loading FAST Whisper: {FAST_MODEL_NAME} ({COMPUTE_TYPE}) on {DEVICE}") fast_model = WhisperModel(FAST_MODEL_NAME, device=DEVICE, compute_type=COMPUTE_TYPE) print(f"πŸš€ Loading LARGE Whisper: {MODEL_NAME} ({COMPUTE_TYPE}) on {DEVICE}") large_model = WhisperModel(MODEL_NAME, device=DEVICE, compute_type=COMPUTE_TYPE) print("βœ… All models ready!\n") # ================================ # MAIN TRANSCRIBE FUNCTION # ================================ def transcribe(file_path): # Load + normalize audio wav, sr = load_audio_safe(file_path) fixed_path = "input_fixed.wav" torchaudio.save(fixed_path, wav, sr) # ===================================== # 1) FAST PASS β€” detect explicit words # ===================================== fast_segments, fast_info = fast_model.transcribe( fixed_path, beam_size=1, word_timestamps=True, vad_filter=True, ) transcript = [] sample_rate = getattr(fast_info, "sample_rate", sr) for seg in fast_segments: if not getattr(seg, "words", None): continue for w in seg.words: # FIX: was incorrectly re-running the bad word set comprehension here clean_word = re.sub(r"[^\w]", "", w.word.strip().lower()) is_explicit = clean_word in BAD_WORDS transcript.append({ "word": w.word.strip(), "start": float(w.start), "end": float(w.end), "explicit": is_explicit, "explicit_fast": is_explicit, }) # ===================================== # EARLY EXIT IF NO EXPLICIT WORDS # ===================================== flagged = [w for w in transcript if w["explicit_fast"]] if not flagged: print("βœ… No explicit words detected β€” returning fast transcript.") return transcript # ===================================== # 2) REFINE PASS β€” only explicit words # ===================================== final = [] for entry in transcript: # Not explicit β€” keep untouched if not entry["explicit_fast"]: final.append(entry) continue # Extract audio chunk for just this word start_s = entry["start"] end_s = entry["end"] start_sample = int(start_s * sample_rate) end_sample = int(end_s * sample_rate) chunk = wav[:, start_sample:end_sample] # Safety: collapsed timestamp if chunk.numel() == 0: final.append(entry) continue # Save chunk to temp file with tempfile.NamedTemporaryFile(delete=False, suffix=".wav") as tmp: chunk_path = tmp.name torchaudio.save(chunk_path, chunk, sample_rate) # Run large model on chunk try: refined_segs, _ = large_model.transcribe( chunk_path, beam_size=5, word_timestamps=True, vad_filter=False, ) except Exception as e: print(f"⚠️ Large model failed on chunk: {e} β€” keeping fast result") final.append(entry) os.remove(chunk_path) continue os.remove(chunk_path) # Extract refined words, offset timestamps back to full-track time refined_words = [] for seg in refined_segs: if not getattr(seg, "words", None): continue for w in seg.words: refined_words.append({ "word": w.word.strip(), "start": float(w.start) + start_s, "end": float(w.end) + start_s, "explicit": entry["explicit_fast"], "explicit_fast": entry["explicit_fast"], }) # Fallback if large model returned nothing if not refined_words: final.append(entry) continue final.extend(refined_words) # Sort by timestamp (critical for assembler) final.sort(key=lambda x: x["start"]) return final # ================================ # GRADIO UI # ================================ iface = gr.Interface( fn=transcribe, inputs=gr.Audio(type="filepath", label="Upload Vocals"), outputs=gr.JSON(label="Transcript with Explicit Flags"), title="CleanSong AI β€” Whisper Transcriber", description=( "Fast model detects explicit words β†’ " "Large model refines only those segments. " "Returns word-level timestamps." ), ) if __name__ == "__main__": iface.launch()