import os import sys import gc import torch import numpy as np import gradio as gr from pathlib import Path from typing import Tuple, Optional from unittest.mock import patch import huggingface_hub # Needed for the intercept trick # Chatterbox TTS dependencies from chatterbox.tts import ChatterboxTTS, Conditionals from chatterbox.models.t3.modules.cond_enc import T3Cond # ============================================================================== # 1. Configuration & Utilities # ============================================================================== class VoxConfig: """Configuration parameters for the VoxMorph inference engine.""" SAMPLE_RATE = 24000 DOT_THRESHOLD = 0.9995 DEVICE = "cuda" if torch.cuda.is_available() else "cpu" # Path to local weights LOCAL_WEIGHTS_DIR = Path("./Model_Weights") DEFAULT_TEXT = ( "The concept of morphing implies a smooth and seamless transition " "between two distinct states, blending identity and style." ) class MathUtils: """Static utilities for geometric operations on the hypersphere.""" @staticmethod def slerp(v0: np.ndarray, v1: np.ndarray, t: float) -> np.ndarray: try: v0_t, v1_t = torch.from_numpy(v0), torch.from_numpy(v1) if v0_t.numel() == 0 or v1_t.numel() == 0: raise ValueError("Empty vectors") v0_t = v0_t / torch.norm(v0_t) v1_t = v1_t / torch.norm(v1_t) dot = torch.clamp(torch.sum(v0_t * v1_t), -1.0, 1.0) if torch.abs(dot) > VoxConfig.DOT_THRESHOLD: v_lerp = torch.lerp(v0_t, v1_t, t) return (v_lerp / torch.norm(v_lerp)).numpy() theta_0 = torch.acos(torch.abs(dot)) sin_theta_0 = torch.sin(theta_0) if sin_theta_0 < 1e-8: return torch.lerp(v0_t, v1_t, t).numpy() theta_t = theta_0 * t s0 = torch.cos(theta_t) - dot * torch.sin(theta_t) / sin_theta_0 s1 = torch.sin(theta_t) / sin_theta_0 return (s0 * v0_t + s1 * v1_t).numpy() except Exception: return ((1 - t) * v0 + t * v1) # ============================================================================== # 2. Inference Engine # ============================================================================== class VoxMorphEngine: """Core engine for handling model loading and synthesis.""" def __init__(self): self.model = None self._load_model() def _load_model(self): """ Loads ChatterboxTTS. Uses a 'patch' to force the library to use local files from ./Model_Weights/ instead of downloading from Hugging Face. """ if self.model is None: print(f"[VoxMorph] Initializing model on {VoxConfig.DEVICE}...", flush=True) # Store the original downloader function original_download = huggingface_hub.hf_hub_download def local_redirect_download(repo_id, filename, **kwargs): """Redirects download requests to the local folder if file exists.""" local_file = VoxConfig.LOCAL_WEIGHTS_DIR / filename if local_file.exists(): print(f" [Loader] Intercepted request for '{filename}' -> Using local copy.") return str(local_file.absolute()) else: print(f" [Loader] '{filename}' not found locally. Downloading from Hub...") return original_download(repo_id, filename, **kwargs) # Apply the patch only during model initialization try: with patch('huggingface_hub.hf_hub_download', side_effect=local_redirect_download): # We pass ONLY device, because we are tricking the internal downloader self.model = ChatterboxTTS.from_pretrained(device=VoxConfig.DEVICE) print("[VoxMorph] Model loaded successfully.") except Exception as e: print(f"[VoxMorph] CRITICAL ERROR loading model: {e}") # Fallback: try loading normally without local weights if patch fails self.model = ChatterboxTTS.from_pretrained(device=VoxConfig.DEVICE) def _cleanup_memory(self): if torch.cuda.is_available(): torch.cuda.empty_cache() torch.cuda.ipc_collect() gc.collect() def _extract_embeddings(self, audio_path: str): self.model.prepare_conditionals(audio_path) conds = Conditionals(self.model.conds.t3, self.model.conds.gen) emb_t3 = conds.t3.speaker_emb.detach().squeeze(0).cpu().numpy() emb_gen = conds.gen['embedding'].detach().squeeze(0).cpu().numpy() return emb_t3, emb_gen def synthesize(self, path_a, path_b, text, alpha, progress=gr.Progress()): self._cleanup_memory() if not path_a or not path_b: raise ValueError("Both audio sources must be provided.") try: # Stage 1: Extraction progress(0.1, desc="Profiling Source A...") t3_a, gen_a = self._extract_embeddings(path_a) # Save template structure from A conds_template = Conditionals(self.model.conds.t3, self.model.conds.gen) progress(0.3, desc="Profiling Source B...") t3_b, gen_b = self._extract_embeddings(path_b) # Stage 2: Interpolation progress(0.5, desc=f"Interpolating (Alpha: {alpha:.2f})...") if alpha == 0.0: final_t3_emb = torch.from_numpy(t3_a).unsqueeze(0) final_gen_emb = torch.from_numpy(gen_a).unsqueeze(0) elif alpha == 1.0: final_t3_emb = torch.from_numpy(t3_b).unsqueeze(0) final_gen_emb = torch.from_numpy(gen_b).unsqueeze(0) else: final_t3_emb = torch.from_numpy(MathUtils.slerp(t3_a, t3_b, alpha)).unsqueeze(0) final_gen_emb = torch.from_numpy(MathUtils.slerp(gen_a, gen_b, alpha)).unsqueeze(0) # Reconstruct Conditionals final_t3_cond = T3Cond( speaker_emb=final_t3_emb, cond_prompt_speech_tokens=conds_template.t3.cond_prompt_speech_tokens, emotion_adv=conds_template.t3.emotion_adv ) final_gen_cond = conds_template.gen.copy() final_gen_cond['embedding'] = final_gen_emb # Stage 3: Synthesis progress(0.8, desc="Synthesizing...") self.model.conds = Conditionals(final_t3_cond, final_gen_cond).to(self.model.device) wav_tensor = self.model.generate(text) self._cleanup_memory() return VoxConfig.SAMPLE_RATE, wav_tensor.cpu().squeeze().numpy() except Exception as e: self._cleanup_memory() raise RuntimeError(f"Morphing process failed: {str(e)}") # ============================================================================== # 3. User Interface # ============================================================================== def create_interface(): engine = VoxMorphEngine() with gr.Blocks(theme=gr.themes.Soft(primary_hue="slate"), title="VoxMorph") as app: gr.Markdown( """ # 🗣️ VoxMorph: Scalable Zero-Shot Voice Identity Morphing **University of North Texas | ICASSP Accepted Paper** """ ) with gr.Row(): with gr.Column(scale=1): gr.Markdown("### 1. Source Input") with gr.Group(): audio_input_a = gr.Audio(label="Source Identity A", type="filepath") audio_input_b = gr.Audio(label="Target Identity B", type="filepath") text_input = gr.Textbox(label="Linguistic Content", value=VoxConfig.DEFAULT_TEXT, lines=3) with gr.Column(scale=1): gr.Markdown("### 2. Morphing Controls") alpha_slider = gr.Slider(0.0, 1.0, value=0.5, step=0.05, label="Interpolation Factor (α)") process_btn = gr.Button("Perform Zero-Shot Morphing", variant="primary", size="lg") gr.Markdown("### 3. Acoustic Output") audio_output = gr.Audio(label="Synthesized Morph", interactive=False, type="numpy") process_btn.click( fn=engine.synthesize, inputs=[audio_input_a, audio_input_b, text_input, alpha_slider], outputs=[audio_output] ) return app if __name__ == "__main__": demo = create_interface() demo.queue().launch(share=False)