import os import sys import logging import tempfile import numpy as np import torch import soundfile as sf import gradio as gr from pathlib import Path import librosa from transformers import pipeline from demucs.pretrained import get_model from demucs.apply import apply_model logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s') logger = logging.getLogger(__name__) os.environ["COQUI_TOS_AGREED"] = "1" os.environ["CUDA_MODULE_LOADING"] = "LAZY" try: from TTS.api import TTS from TTS.config.shared_configs import BaseDatasetConfig torch.serialization.add_safe_globals([BaseDatasetConfig]) except ImportError: pass except Exception as e: logger.warning(f"{e}") class ProcessingManager: def __init__(self): self.device = "cuda" if torch.cuda.is_available() else "cpu" self.models = {} self.temp_dir = Path(tempfile.gettempdir()) / "voice_mask_pro" self.temp_dir.mkdir(exist_ok=True) def get_whisper(self, model_size="large-v3"): key = f"whisper_{model_size}" if key not in self.models: self.models[key] = pipeline( "automatic-speech-recognition", model=f"openai/whisper-{model_size}", device=self.device, torch_dtype=torch.float16 if self.device == "cuda" else torch.float32 ) return self.models[key] def get_demucs(self): if "demucs" not in self.models: self.models["demucs"] = get_model("htdemucs") self.models["demucs"].to(self.device) return self.models["demucs"] def get_tts(self): if "tts" not in self.models: self.models["tts"] = TTS("tts_models/multilingual/multi-dataset/xtts_v2").to(self.device) return self.models["tts"] manager = ProcessingManager() def process_audio_pipeline( audio_path, language, speaker_ref_path, voice_cleanup_slider, pitch_shift, progress=gr.Progress() ): try: if not audio_path: raise ValueError("No audio file provided") if not speaker_ref_path: raise ValueError("Reference voice (MP3) is required") progress(0.1, desc="Separating Vocals...") demucs_model = manager.get_demucs() wav, sr = librosa.load(audio_path, sr=44100, mono=False) if len(wav.shape) == 1: wav = np.stack([wav, wav]) ref = torch.tensor(wav).to(manager.device) sources = apply_model(demucs_model, ref[None], shifts=1, split=True, overlap=0.25, progress=False)[0] sources = sources.cpu().numpy() vocals = sources[3] instrumental = sources[0] + sources[1] + sources[2] vocal_path = manager.temp_dir / "vocals.wav" inst_path = manager.temp_dir / "instrumental.wav" sf.write(vocal_path, vocals.T, 44100) sf.write(inst_path, instrumental.T, 44100) progress(0.4, desc="Transcribing...") whisper = manager.get_whisper() transcription = whisper(str(vocal_path), generate_kwargs={"task": "transcribe", "language": language}) original_text = transcription["text"] progress(0.6, desc="Synthesizing with Reference Voice...") tts_model = manager.get_tts() output_tts_path = manager.temp_dir / "tts_output.wav" tts_model.tts_to_file( text=original_text, speaker_wav=speaker_ref_path, language=language, file_path=str(output_tts_path), split_sentences=True ) progress(0.9, desc="Mixing...") tts_wav, _ = librosa.load(str(output_tts_path), sr=44100) inst_wav, _ = librosa.load(str(inst_path), sr=44100) min_len = min(len(tts_wav), len(inst_wav)) mixed = tts_wav[:min_len] * 1.0 + inst_wav[:min_len] * 0.8 final_path = manager.temp_dir / "final_mix.wav" sf.write(final_path, mixed, 44100) return ( final_path, str(vocal_path), str(inst_path), str(output_tts_path), original_text ) except Exception as e: logger.error(f"Pipeline failed: {str(e)}", exc_info=True) return None, None, None, None, f"Error: {str(e)}" custom_css = """ .container { max_width: 900px; margin: auto; } .gr-box { border-radius: 10px !important; border: 1px solid #e0e0e0; box-shadow: 0 4px 6px rgba(0,0,0,0.05); } """ with gr.Blocks(title="AI Voice Masker") as demo: gr.Markdown("# 🎤 AI Voice Masker") with gr.Row(): with gr.Column(scale=1, variant="panel"): gr.Markdown("### 1. Input & Settings") input_audio = gr.Audio(label="Source Song", type="filepath") ref_audio = gr.Audio(label="Reference Voice (MP3 Required)", type="filepath") language = gr.Dropdown(["en", "es", "fr", "it", "de", "pt", "ja"], value="es", label="Song Language") with gr.Accordion("Advanced Audio", open=False): cleanup = gr.Slider(0, 1, value=0.5, label="Voice Cleanup") pitch = gr.Slider(-12, 12, value=0, step=1, label="Pitch Shift") btn_process = gr.Button("🚀 Start Masking", variant="primary", size="lg") with gr.Column(scale=1, variant="panel"): gr.Markdown("### 2. Output Results") final_output = gr.Audio(label="Final Mixed Song") with gr.Tabs(): with gr.Tab("Lyrics"): orig_txt = gr.Textbox(label="Transcribed Lyrics", lines=8, interactive=False) with gr.Tab("Stems"): voc_out = gr.Audio(label="Original Vocals") inst_out = gr.Audio(label="Instrumental") tts_out = gr.Audio(label="Generated Vocals (Raw)") btn_process.click( fn=process_audio_pipeline, inputs=[input_audio, language, ref_audio, cleanup, pitch], outputs=[final_output, voc_out, inst_out, tts_out, orig_txt] ) if __name__ == "__main__": demo.launch( server_name="0.0.0.0", server_port=7860, theme=gr.themes.Soft(), css=custom_css )