File size: 6,331 Bytes
4af101a
1e63917
 
 
 
 
4af101a
1e63917
 
90d9071
 
 
 
5a9b5a5
7b4a7c6
 
 
 
 
 
1e63917
 
 
 
 
 
7b4a7c6
90d9071
5a9b5a5
1e63917
881c51d
 
1e63917
90d9071
1e63917
 
 
 
 
 
881c51d
 
1e63917
 
881c51d
1e63917
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
90d9071
1e63917
 
 
 
 
4af101a
1e63917
 
90d9071
 
 
1e63917
 
 
 
4af101a
1e63917
 
5a9b5a5
1e63917
 
4af101a
1e63917
 
8132a8c
e20cadb
1e63917
 
4af101a
1e63917
 
abf2936
90d9071
1e63917
90d9071
1e63917
 
90d9071
1e63917
1b42d2d
1e63917
1b42d2d
1e63917
90d9071
 
 
1e63917
 
 
 
 
 
 
5a9b5a5
1e63917
 
5a9b5a5
1e63917
 
 
 
 
 
 
 
90d9071
1e63917
 
1b42d2d
7b4a7c6
90d9071
4af101a
1e63917
 
 
 
 
90d9071
 
abf2936
 
1e63917
 
 
90d9071
5a9b5a5
90d9071
5a9b5a5
1e63917
 
 
 
90d9071
1e63917
 
 
 
abf2936
1e63917
 
90d9071
1e63917
 
90d9071
1e63917
90d9071
1e63917
 
 
90d9071
 
4af101a
 
 
7b4a7c6
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
import os
import sys
import logging
import tempfile
import numpy as np
import torch
import soundfile as sf
import gradio as gr
from pathlib import Path
import librosa
from transformers import pipeline
from demucs.pretrained import get_model
from demucs.apply import apply_model

logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s')
logger = logging.getLogger(__name__)

os.environ["COQUI_TOS_AGREED"] = "1"
os.environ["CUDA_MODULE_LOADING"] = "LAZY"

try:
    from TTS.api import TTS
    from TTS.config.shared_configs import BaseDatasetConfig
    torch.serialization.add_safe_globals([BaseDatasetConfig])
except ImportError:
    pass
except Exception as e:
    logger.warning(f"{e}")

class ProcessingManager:
    def __init__(self):
        self.device = "cuda" if torch.cuda.is_available() else "cpu"
        self.models = {}
        self.temp_dir = Path(tempfile.gettempdir()) / "voice_mask_pro"
        self.temp_dir.mkdir(exist_ok=True)

    def get_whisper(self, model_size="large-v3"):
        key = f"whisper_{model_size}"
        if key not in self.models:
            self.models[key] = pipeline(
                "automatic-speech-recognition",
                model=f"openai/whisper-{model_size}",
                device=self.device,
                torch_dtype=torch.float16 if self.device == "cuda" else torch.float32
            )
        return self.models[key]

    def get_demucs(self):
        if "demucs" not in self.models:
            self.models["demucs"] = get_model("htdemucs")
            self.models["demucs"].to(self.device)
        return self.models["demucs"]

    def get_tts(self):
        if "tts" not in self.models:
            self.models["tts"] = TTS("tts_models/multilingual/multi-dataset/xtts_v2").to(self.device)
        return self.models["tts"]

manager = ProcessingManager()

def process_audio_pipeline(
    audio_path, 
    language, 
    speaker_ref_path, 
    voice_cleanup_slider, 
    pitch_shift,
    progress=gr.Progress()
):
    try:
        if not audio_path:
            raise ValueError("No audio file provided")
        
        if not speaker_ref_path:
            raise ValueError("Reference voice (MP3) is required")

        progress(0.1, desc="Separating Vocals...")
        demucs_model = manager.get_demucs()
        wav, sr = librosa.load(audio_path, sr=44100, mono=False)
        
        if len(wav.shape) == 1:
            wav = np.stack([wav, wav])
        
        ref = torch.tensor(wav).to(manager.device)
        sources = apply_model(demucs_model, ref[None], shifts=1, split=True, overlap=0.25, progress=False)[0]
        
        sources = sources.cpu().numpy()
        vocals = sources[3]
        instrumental = sources[0] + sources[1] + sources[2]
        
        vocal_path = manager.temp_dir / "vocals.wav"
        inst_path = manager.temp_dir / "instrumental.wav"
        
        sf.write(vocal_path, vocals.T, 44100)
        sf.write(inst_path, instrumental.T, 44100)

        progress(0.4, desc="Transcribing...")
        whisper = manager.get_whisper()
        transcription = whisper(str(vocal_path), generate_kwargs={"task": "transcribe", "language": language})
        original_text = transcription["text"]

        progress(0.6, desc="Synthesizing with Reference Voice...")
        tts_model = manager.get_tts()
        
        output_tts_path = manager.temp_dir / "tts_output.wav"
        
        tts_model.tts_to_file(
            text=original_text,
            speaker_wav=speaker_ref_path,
            language=language,
            file_path=str(output_tts_path),
            split_sentences=True
        )

        progress(0.9, desc="Mixing...")
        tts_wav, _ = librosa.load(str(output_tts_path), sr=44100)
        inst_wav, _ = librosa.load(str(inst_path), sr=44100)
        
        min_len = min(len(tts_wav), len(inst_wav))
        mixed = tts_wav[:min_len] * 1.0 + inst_wav[:min_len] * 0.8
        
        final_path = manager.temp_dir / "final_mix.wav"
        sf.write(final_path, mixed, 44100)

        return (
            final_path, 
            str(vocal_path), 
            str(inst_path), 
            str(output_tts_path),
            original_text
        )

    except Exception as e:
        logger.error(f"Pipeline failed: {str(e)}", exc_info=True)
        return None, None, None, None, f"Error: {str(e)}"

custom_css = """
.container { max_width: 900px; margin: auto; }
.gr-box { border-radius: 10px !important; border: 1px solid #e0e0e0; box-shadow: 0 4px 6px rgba(0,0,0,0.05); }
"""

with gr.Blocks(title="AI Voice Masker") as demo:
    gr.Markdown("# 🎤 AI Voice Masker")
    
    with gr.Row():
        with gr.Column(scale=1, variant="panel"):
            gr.Markdown("### 1. Input & Settings")
            input_audio = gr.Audio(label="Source Song", type="filepath")
            ref_audio = gr.Audio(label="Reference Voice (MP3 Required)", type="filepath")
            
            language = gr.Dropdown(["en", "es", "fr", "it", "de", "pt", "ja"], value="es", label="Song Language")
            
            with gr.Accordion("Advanced Audio", open=False):
                cleanup = gr.Slider(0, 1, value=0.5, label="Voice Cleanup")
                pitch = gr.Slider(-12, 12, value=0, step=1, label="Pitch Shift")

            btn_process = gr.Button("🚀 Start Masking", variant="primary", size="lg")

        with gr.Column(scale=1, variant="panel"):
            gr.Markdown("### 2. Output Results")
            final_output = gr.Audio(label="Final Mixed Song")
            
            with gr.Tabs():
                with gr.Tab("Lyrics"):
                    orig_txt = gr.Textbox(label="Transcribed Lyrics", lines=8, interactive=False)
                
                with gr.Tab("Stems"):
                    voc_out = gr.Audio(label="Original Vocals")
                    inst_out = gr.Audio(label="Instrumental")
                    tts_out = gr.Audio(label="Generated Vocals (Raw)")

    btn_process.click(
        fn=process_audio_pipeline,
        inputs=[input_audio, language, ref_audio, cleanup, pitch],
        outputs=[final_output, voc_out, inst_out, tts_out, orig_txt]
    )

if __name__ == "__main__":
    demo.launch(
        server_name="0.0.0.0", 
        server_port=7860,
        theme=gr.themes.Soft(),
        css=custom_css
    )