VoxMorph-Models / app.py
BharathK333's picture
Upload folder using huggingface_hub
9b88428 verified
import os
import sys
import gc
import torch
import numpy as np
import gradio as gr
from pathlib import Path
from typing import Tuple, Optional
from unittest.mock import patch
import huggingface_hub # Needed for the intercept trick
# Chatterbox TTS dependencies
from chatterbox.tts import ChatterboxTTS, Conditionals
from chatterbox.models.t3.modules.cond_enc import T3Cond
# ==============================================================================
# 1. Configuration & Utilities
# ==============================================================================
class VoxConfig:
"""Configuration parameters for the VoxMorph inference engine."""
SAMPLE_RATE = 24000
DOT_THRESHOLD = 0.9995
DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
# Path to local weights
LOCAL_WEIGHTS_DIR = Path("./Model_Weights")
DEFAULT_TEXT = (
"The concept of morphing implies a smooth and seamless transition "
"between two distinct states, blending identity and style."
)
class MathUtils:
"""Static utilities for geometric operations on the hypersphere."""
@staticmethod
def slerp(v0: np.ndarray, v1: np.ndarray, t: float) -> np.ndarray:
try:
v0_t, v1_t = torch.from_numpy(v0), torch.from_numpy(v1)
if v0_t.numel() == 0 or v1_t.numel() == 0: raise ValueError("Empty vectors")
v0_t = v0_t / torch.norm(v0_t)
v1_t = v1_t / torch.norm(v1_t)
dot = torch.clamp(torch.sum(v0_t * v1_t), -1.0, 1.0)
if torch.abs(dot) > VoxConfig.DOT_THRESHOLD:
v_lerp = torch.lerp(v0_t, v1_t, t)
return (v_lerp / torch.norm(v_lerp)).numpy()
theta_0 = torch.acos(torch.abs(dot))
sin_theta_0 = torch.sin(theta_0)
if sin_theta_0 < 1e-8: return torch.lerp(v0_t, v1_t, t).numpy()
theta_t = theta_0 * t
s0 = torch.cos(theta_t) - dot * torch.sin(theta_t) / sin_theta_0
s1 = torch.sin(theta_t) / sin_theta_0
return (s0 * v0_t + s1 * v1_t).numpy()
except Exception:
return ((1 - t) * v0 + t * v1)
# ==============================================================================
# 2. Inference Engine
# ==============================================================================
class VoxMorphEngine:
"""Core engine for handling model loading and synthesis."""
def __init__(self):
self.model = None
self._load_model()
def _load_model(self):
"""
Loads ChatterboxTTS.
Uses a 'patch' to force the library to use local files from ./Model_Weights/
instead of downloading from Hugging Face.
"""
if self.model is None:
print(f"[VoxMorph] Initializing model on {VoxConfig.DEVICE}...", flush=True)
# Store the original downloader function
original_download = huggingface_hub.hf_hub_download
def local_redirect_download(repo_id, filename, **kwargs):
"""Redirects download requests to the local folder if file exists."""
local_file = VoxConfig.LOCAL_WEIGHTS_DIR / filename
if local_file.exists():
print(f" [Loader] Intercepted request for '{filename}' -> Using local copy.")
return str(local_file.absolute())
else:
print(f" [Loader] '{filename}' not found locally. Downloading from Hub...")
return original_download(repo_id, filename, **kwargs)
# Apply the patch only during model initialization
try:
with patch('huggingface_hub.hf_hub_download', side_effect=local_redirect_download):
# We pass ONLY device, because we are tricking the internal downloader
self.model = ChatterboxTTS.from_pretrained(device=VoxConfig.DEVICE)
print("[VoxMorph] Model loaded successfully.")
except Exception as e:
print(f"[VoxMorph] CRITICAL ERROR loading model: {e}")
# Fallback: try loading normally without local weights if patch fails
self.model = ChatterboxTTS.from_pretrained(device=VoxConfig.DEVICE)
def _cleanup_memory(self):
if torch.cuda.is_available():
torch.cuda.empty_cache()
torch.cuda.ipc_collect()
gc.collect()
def _extract_embeddings(self, audio_path: str):
self.model.prepare_conditionals(audio_path)
conds = Conditionals(self.model.conds.t3, self.model.conds.gen)
emb_t3 = conds.t3.speaker_emb.detach().squeeze(0).cpu().numpy()
emb_gen = conds.gen['embedding'].detach().squeeze(0).cpu().numpy()
return emb_t3, emb_gen
def synthesize(self, path_a, path_b, text, alpha, progress=gr.Progress()):
self._cleanup_memory()
if not path_a or not path_b: raise ValueError("Both audio sources must be provided.")
try:
# Stage 1: Extraction
progress(0.1, desc="Profiling Source A...")
t3_a, gen_a = self._extract_embeddings(path_a)
# Save template structure from A
conds_template = Conditionals(self.model.conds.t3, self.model.conds.gen)
progress(0.3, desc="Profiling Source B...")
t3_b, gen_b = self._extract_embeddings(path_b)
# Stage 2: Interpolation
progress(0.5, desc=f"Interpolating (Alpha: {alpha:.2f})...")
if alpha == 0.0:
final_t3_emb = torch.from_numpy(t3_a).unsqueeze(0)
final_gen_emb = torch.from_numpy(gen_a).unsqueeze(0)
elif alpha == 1.0:
final_t3_emb = torch.from_numpy(t3_b).unsqueeze(0)
final_gen_emb = torch.from_numpy(gen_b).unsqueeze(0)
else:
final_t3_emb = torch.from_numpy(MathUtils.slerp(t3_a, t3_b, alpha)).unsqueeze(0)
final_gen_emb = torch.from_numpy(MathUtils.slerp(gen_a, gen_b, alpha)).unsqueeze(0)
# Reconstruct Conditionals
final_t3_cond = T3Cond(
speaker_emb=final_t3_emb,
cond_prompt_speech_tokens=conds_template.t3.cond_prompt_speech_tokens,
emotion_adv=conds_template.t3.emotion_adv
)
final_gen_cond = conds_template.gen.copy()
final_gen_cond['embedding'] = final_gen_emb
# Stage 3: Synthesis
progress(0.8, desc="Synthesizing...")
self.model.conds = Conditionals(final_t3_cond, final_gen_cond).to(self.model.device)
wav_tensor = self.model.generate(text)
self._cleanup_memory()
return VoxConfig.SAMPLE_RATE, wav_tensor.cpu().squeeze().numpy()
except Exception as e:
self._cleanup_memory()
raise RuntimeError(f"Morphing process failed: {str(e)}")
# ==============================================================================
# 3. User Interface
# ==============================================================================
def create_interface():
engine = VoxMorphEngine()
with gr.Blocks(theme=gr.themes.Soft(primary_hue="slate"), title="VoxMorph") as app:
gr.Markdown(
"""
# 🗣️ VoxMorph: Scalable Zero-Shot Voice Identity Morphing
**University of North Texas | ICASSP Accepted Paper**
"""
)
with gr.Row():
with gr.Column(scale=1):
gr.Markdown("### 1. Source Input")
with gr.Group():
audio_input_a = gr.Audio(label="Source Identity A", type="filepath")
audio_input_b = gr.Audio(label="Target Identity B", type="filepath")
text_input = gr.Textbox(label="Linguistic Content", value=VoxConfig.DEFAULT_TEXT, lines=3)
with gr.Column(scale=1):
gr.Markdown("### 2. Morphing Controls")
alpha_slider = gr.Slider(0.0, 1.0, value=0.5, step=0.05, label="Interpolation Factor (α)")
process_btn = gr.Button("Perform Zero-Shot Morphing", variant="primary", size="lg")
gr.Markdown("### 3. Acoustic Output")
audio_output = gr.Audio(label="Synthesized Morph", interactive=False, type="numpy")
process_btn.click(
fn=engine.synthesize,
inputs=[audio_input_a, audio_input_b, text_input, alpha_slider],
outputs=[audio_output]
)
return app
if __name__ == "__main__":
demo = create_interface()
demo.queue().launch(share=False)