File size: 8,947 Bytes
9b88428
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
import os
import sys
import gc
import torch
import numpy as np
import gradio as gr
from pathlib import Path
from typing import Tuple, Optional
from unittest.mock import patch
import huggingface_hub  # Needed for the intercept trick

# Chatterbox TTS dependencies
from chatterbox.tts import ChatterboxTTS, Conditionals
from chatterbox.models.t3.modules.cond_enc import T3Cond

# ==============================================================================
# 1. Configuration & Utilities
# ==============================================================================

class VoxConfig:
    """Configuration parameters for the VoxMorph inference engine."""
    SAMPLE_RATE = 24000
    DOT_THRESHOLD = 0.9995
    DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
    
    # Path to local weights
    LOCAL_WEIGHTS_DIR = Path("./Model_Weights")
    
    DEFAULT_TEXT = (
        "The concept of morphing implies a smooth and seamless transition "
        "between two distinct states, blending identity and style."
    )

class MathUtils:
    """Static utilities for geometric operations on the hypersphere."""
    
    @staticmethod
    def slerp(v0: np.ndarray, v1: np.ndarray, t: float) -> np.ndarray:
        try:
            v0_t, v1_t = torch.from_numpy(v0), torch.from_numpy(v1)
            if v0_t.numel() == 0 or v1_t.numel() == 0: raise ValueError("Empty vectors")
            
            v0_t = v0_t / torch.norm(v0_t)
            v1_t = v1_t / torch.norm(v1_t)
            
            dot = torch.clamp(torch.sum(v0_t * v1_t), -1.0, 1.0)
            
            if torch.abs(dot) > VoxConfig.DOT_THRESHOLD:
                v_lerp = torch.lerp(v0_t, v1_t, t)
                return (v_lerp / torch.norm(v_lerp)).numpy()

            theta_0 = torch.acos(torch.abs(dot))
            sin_theta_0 = torch.sin(theta_0)
            if sin_theta_0 < 1e-8: return torch.lerp(v0_t, v1_t, t).numpy()

            theta_t = theta_0 * t
            s0 = torch.cos(theta_t) - dot * torch.sin(theta_t) / sin_theta_0
            s1 = torch.sin(theta_t) / sin_theta_0
            
            return (s0 * v0_t + s1 * v1_t).numpy()
        except Exception:
            return ((1 - t) * v0 + t * v1)

# ==============================================================================
# 2. Inference Engine
# ==============================================================================

class VoxMorphEngine:
    """Core engine for handling model loading and synthesis."""
    
    def __init__(self):
        self.model = None
        self._load_model()

    def _load_model(self):
        """

        Loads ChatterboxTTS.

        Uses a 'patch' to force the library to use local files from ./Model_Weights/

        instead of downloading from Hugging Face.

        """
        if self.model is None:
            print(f"[VoxMorph] Initializing model on {VoxConfig.DEVICE}...", flush=True)

            # Store the original downloader function
            original_download = huggingface_hub.hf_hub_download

            def local_redirect_download(repo_id, filename, **kwargs):
                """Redirects download requests to the local folder if file exists."""
                local_file = VoxConfig.LOCAL_WEIGHTS_DIR / filename
                if local_file.exists():
                    print(f"  [Loader] Intercepted request for '{filename}' -> Using local copy.")
                    return str(local_file.absolute())
                else:
                    print(f"  [Loader] '{filename}' not found locally. Downloading from Hub...")
                    return original_download(repo_id, filename, **kwargs)

            # Apply the patch only during model initialization
            try:
                with patch('huggingface_hub.hf_hub_download', side_effect=local_redirect_download):
                    # We pass ONLY device, because we are tricking the internal downloader
                    self.model = ChatterboxTTS.from_pretrained(device=VoxConfig.DEVICE)
                print("[VoxMorph] Model loaded successfully.")
            except Exception as e:
                print(f"[VoxMorph] CRITICAL ERROR loading model: {e}")
                # Fallback: try loading normally without local weights if patch fails
                self.model = ChatterboxTTS.from_pretrained(device=VoxConfig.DEVICE)

    def _cleanup_memory(self):
        if torch.cuda.is_available():
            torch.cuda.empty_cache()
            torch.cuda.ipc_collect()
        gc.collect()

    def _extract_embeddings(self, audio_path: str):
        self.model.prepare_conditionals(audio_path)
        conds = Conditionals(self.model.conds.t3, self.model.conds.gen)
        emb_t3 = conds.t3.speaker_emb.detach().squeeze(0).cpu().numpy()
        emb_gen = conds.gen['embedding'].detach().squeeze(0).cpu().numpy()
        return emb_t3, emb_gen

    def synthesize(self, path_a, path_b, text, alpha, progress=gr.Progress()):
        self._cleanup_memory()
        if not path_a or not path_b: raise ValueError("Both audio sources must be provided.")

        try:
            # Stage 1: Extraction
            progress(0.1, desc="Profiling Source A...")
            t3_a, gen_a = self._extract_embeddings(path_a)
            # Save template structure from A
            conds_template = Conditionals(self.model.conds.t3, self.model.conds.gen)

            progress(0.3, desc="Profiling Source B...")
            t3_b, gen_b = self._extract_embeddings(path_b)

            # Stage 2: Interpolation
            progress(0.5, desc=f"Interpolating (Alpha: {alpha:.2f})...")
            
            if alpha == 0.0:
                final_t3_emb = torch.from_numpy(t3_a).unsqueeze(0)
                final_gen_emb = torch.from_numpy(gen_a).unsqueeze(0)
            elif alpha == 1.0:
                final_t3_emb = torch.from_numpy(t3_b).unsqueeze(0)
                final_gen_emb = torch.from_numpy(gen_b).unsqueeze(0)
            else:
                final_t3_emb = torch.from_numpy(MathUtils.slerp(t3_a, t3_b, alpha)).unsqueeze(0)
                final_gen_emb = torch.from_numpy(MathUtils.slerp(gen_a, gen_b, alpha)).unsqueeze(0)

            # Reconstruct Conditionals
            final_t3_cond = T3Cond(
                speaker_emb=final_t3_emb,
                cond_prompt_speech_tokens=conds_template.t3.cond_prompt_speech_tokens,
                emotion_adv=conds_template.t3.emotion_adv
            )
            final_gen_cond = conds_template.gen.copy()
            final_gen_cond['embedding'] = final_gen_emb
            
            # Stage 3: Synthesis
            progress(0.8, desc="Synthesizing...")
            self.model.conds = Conditionals(final_t3_cond, final_gen_cond).to(self.model.device)
            wav_tensor = self.model.generate(text)
            
            self._cleanup_memory()
            return VoxConfig.SAMPLE_RATE, wav_tensor.cpu().squeeze().numpy()

        except Exception as e:
            self._cleanup_memory()
            raise RuntimeError(f"Morphing process failed: {str(e)}")

# ==============================================================================
# 3. User Interface
# ==============================================================================

def create_interface():
    engine = VoxMorphEngine()

    with gr.Blocks(theme=gr.themes.Soft(primary_hue="slate"), title="VoxMorph") as app:
        gr.Markdown(
            """

            # 🗣️ VoxMorph: Scalable Zero-Shot Voice Identity Morphing

            **University of North Texas | ICASSP Accepted Paper**

            """
        )
        
        with gr.Row():
            with gr.Column(scale=1):
                gr.Markdown("### 1. Source Input")
                with gr.Group():
                    audio_input_a = gr.Audio(label="Source Identity A", type="filepath")
                    audio_input_b = gr.Audio(label="Target Identity B", type="filepath")
                text_input = gr.Textbox(label="Linguistic Content", value=VoxConfig.DEFAULT_TEXT, lines=3)

            with gr.Column(scale=1):
                gr.Markdown("### 2. Morphing Controls")
                alpha_slider = gr.Slider(0.0, 1.0, value=0.5, step=0.05, label="Interpolation Factor (α)")
                process_btn = gr.Button("Perform Zero-Shot Morphing", variant="primary", size="lg")
                gr.Markdown("### 3. Acoustic Output")
                audio_output = gr.Audio(label="Synthesized Morph", interactive=False, type="numpy")

        process_btn.click(
            fn=engine.synthesize,
            inputs=[audio_input_a, audio_input_b, text_input, alpha_slider],
            outputs=[audio_output]
        )

    return app

if __name__ == "__main__":
    demo = create_interface()
    demo.queue().launch(share=False)