File size: 8,947 Bytes
9b88428 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 | import os
import sys
import gc
import torch
import numpy as np
import gradio as gr
from pathlib import Path
from typing import Tuple, Optional
from unittest.mock import patch
import huggingface_hub # Needed for the intercept trick
# Chatterbox TTS dependencies
from chatterbox.tts import ChatterboxTTS, Conditionals
from chatterbox.models.t3.modules.cond_enc import T3Cond
# ==============================================================================
# 1. Configuration & Utilities
# ==============================================================================
class VoxConfig:
"""Configuration parameters for the VoxMorph inference engine."""
SAMPLE_RATE = 24000
DOT_THRESHOLD = 0.9995
DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
# Path to local weights
LOCAL_WEIGHTS_DIR = Path("./Model_Weights")
DEFAULT_TEXT = (
"The concept of morphing implies a smooth and seamless transition "
"between two distinct states, blending identity and style."
)
class MathUtils:
"""Static utilities for geometric operations on the hypersphere."""
@staticmethod
def slerp(v0: np.ndarray, v1: np.ndarray, t: float) -> np.ndarray:
try:
v0_t, v1_t = torch.from_numpy(v0), torch.from_numpy(v1)
if v0_t.numel() == 0 or v1_t.numel() == 0: raise ValueError("Empty vectors")
v0_t = v0_t / torch.norm(v0_t)
v1_t = v1_t / torch.norm(v1_t)
dot = torch.clamp(torch.sum(v0_t * v1_t), -1.0, 1.0)
if torch.abs(dot) > VoxConfig.DOT_THRESHOLD:
v_lerp = torch.lerp(v0_t, v1_t, t)
return (v_lerp / torch.norm(v_lerp)).numpy()
theta_0 = torch.acos(torch.abs(dot))
sin_theta_0 = torch.sin(theta_0)
if sin_theta_0 < 1e-8: return torch.lerp(v0_t, v1_t, t).numpy()
theta_t = theta_0 * t
s0 = torch.cos(theta_t) - dot * torch.sin(theta_t) / sin_theta_0
s1 = torch.sin(theta_t) / sin_theta_0
return (s0 * v0_t + s1 * v1_t).numpy()
except Exception:
return ((1 - t) * v0 + t * v1)
# ==============================================================================
# 2. Inference Engine
# ==============================================================================
class VoxMorphEngine:
"""Core engine for handling model loading and synthesis."""
def __init__(self):
self.model = None
self._load_model()
def _load_model(self):
"""
Loads ChatterboxTTS.
Uses a 'patch' to force the library to use local files from ./Model_Weights/
instead of downloading from Hugging Face.
"""
if self.model is None:
print(f"[VoxMorph] Initializing model on {VoxConfig.DEVICE}...", flush=True)
# Store the original downloader function
original_download = huggingface_hub.hf_hub_download
def local_redirect_download(repo_id, filename, **kwargs):
"""Redirects download requests to the local folder if file exists."""
local_file = VoxConfig.LOCAL_WEIGHTS_DIR / filename
if local_file.exists():
print(f" [Loader] Intercepted request for '{filename}' -> Using local copy.")
return str(local_file.absolute())
else:
print(f" [Loader] '{filename}' not found locally. Downloading from Hub...")
return original_download(repo_id, filename, **kwargs)
# Apply the patch only during model initialization
try:
with patch('huggingface_hub.hf_hub_download', side_effect=local_redirect_download):
# We pass ONLY device, because we are tricking the internal downloader
self.model = ChatterboxTTS.from_pretrained(device=VoxConfig.DEVICE)
print("[VoxMorph] Model loaded successfully.")
except Exception as e:
print(f"[VoxMorph] CRITICAL ERROR loading model: {e}")
# Fallback: try loading normally without local weights if patch fails
self.model = ChatterboxTTS.from_pretrained(device=VoxConfig.DEVICE)
def _cleanup_memory(self):
if torch.cuda.is_available():
torch.cuda.empty_cache()
torch.cuda.ipc_collect()
gc.collect()
def _extract_embeddings(self, audio_path: str):
self.model.prepare_conditionals(audio_path)
conds = Conditionals(self.model.conds.t3, self.model.conds.gen)
emb_t3 = conds.t3.speaker_emb.detach().squeeze(0).cpu().numpy()
emb_gen = conds.gen['embedding'].detach().squeeze(0).cpu().numpy()
return emb_t3, emb_gen
def synthesize(self, path_a, path_b, text, alpha, progress=gr.Progress()):
self._cleanup_memory()
if not path_a or not path_b: raise ValueError("Both audio sources must be provided.")
try:
# Stage 1: Extraction
progress(0.1, desc="Profiling Source A...")
t3_a, gen_a = self._extract_embeddings(path_a)
# Save template structure from A
conds_template = Conditionals(self.model.conds.t3, self.model.conds.gen)
progress(0.3, desc="Profiling Source B...")
t3_b, gen_b = self._extract_embeddings(path_b)
# Stage 2: Interpolation
progress(0.5, desc=f"Interpolating (Alpha: {alpha:.2f})...")
if alpha == 0.0:
final_t3_emb = torch.from_numpy(t3_a).unsqueeze(0)
final_gen_emb = torch.from_numpy(gen_a).unsqueeze(0)
elif alpha == 1.0:
final_t3_emb = torch.from_numpy(t3_b).unsqueeze(0)
final_gen_emb = torch.from_numpy(gen_b).unsqueeze(0)
else:
final_t3_emb = torch.from_numpy(MathUtils.slerp(t3_a, t3_b, alpha)).unsqueeze(0)
final_gen_emb = torch.from_numpy(MathUtils.slerp(gen_a, gen_b, alpha)).unsqueeze(0)
# Reconstruct Conditionals
final_t3_cond = T3Cond(
speaker_emb=final_t3_emb,
cond_prompt_speech_tokens=conds_template.t3.cond_prompt_speech_tokens,
emotion_adv=conds_template.t3.emotion_adv
)
final_gen_cond = conds_template.gen.copy()
final_gen_cond['embedding'] = final_gen_emb
# Stage 3: Synthesis
progress(0.8, desc="Synthesizing...")
self.model.conds = Conditionals(final_t3_cond, final_gen_cond).to(self.model.device)
wav_tensor = self.model.generate(text)
self._cleanup_memory()
return VoxConfig.SAMPLE_RATE, wav_tensor.cpu().squeeze().numpy()
except Exception as e:
self._cleanup_memory()
raise RuntimeError(f"Morphing process failed: {str(e)}")
# ==============================================================================
# 3. User Interface
# ==============================================================================
def create_interface():
engine = VoxMorphEngine()
with gr.Blocks(theme=gr.themes.Soft(primary_hue="slate"), title="VoxMorph") as app:
gr.Markdown(
"""
# 🗣️ VoxMorph: Scalable Zero-Shot Voice Identity Morphing
**University of North Texas | ICASSP Accepted Paper**
"""
)
with gr.Row():
with gr.Column(scale=1):
gr.Markdown("### 1. Source Input")
with gr.Group():
audio_input_a = gr.Audio(label="Source Identity A", type="filepath")
audio_input_b = gr.Audio(label="Target Identity B", type="filepath")
text_input = gr.Textbox(label="Linguistic Content", value=VoxConfig.DEFAULT_TEXT, lines=3)
with gr.Column(scale=1):
gr.Markdown("### 2. Morphing Controls")
alpha_slider = gr.Slider(0.0, 1.0, value=0.5, step=0.05, label="Interpolation Factor (α)")
process_btn = gr.Button("Perform Zero-Shot Morphing", variant="primary", size="lg")
gr.Markdown("### 3. Acoustic Output")
audio_output = gr.Audio(label="Synthesized Morph", interactive=False, type="numpy")
process_btn.click(
fn=engine.synthesize,
inputs=[audio_input_a, audio_input_b, text_input, alpha_slider],
outputs=[audio_output]
)
return app
if __name__ == "__main__":
demo = create_interface()
demo.queue().launch(share=False) |