#!/usr/bin/env python3 """ HuggingFace Space Demo for TextSyncMimi Speech Editing with Token-Level Embedding Swapping This demo loads the model from HuggingFace Hub and allows: - Generating speech with different voices using Kokoro TTS - Swapping speech embeddings at specific token positions - Real-time speech editing Prerequisites: - Model will be loaded from HuggingFace Hub """ import os # Import spaces for GPU support BEFORE torch or other CUDA packages try: import spaces GPU_AVAILABLE = True except ImportError: GPU_AVAILABLE = False # Create dummy decorator if spaces not available class spaces: @staticmethod def GPU(func): return func import json import tempfile import argparse from typing import List, Tuple, Optional from pathlib import Path import numpy as np import torch import torch.nn as nn import soundfile as sf import gradio as gr from kokoro import KPipeline from transformers import ( AutoModel, AutoFeatureExtractor, AutoTokenizer, MimiModel, ) # Constants SAMPLE_RATE = 24000 FRAME_RATE = 12.5 TTS_VOICES = [ "af_heart", "af_alloy", "af_aoede", "af_bella", "af_jessica", "af_kore", "af_nicole", "af_nova", "af_river", "af_sky", "am_adam", "am_echo", "am_eric", "am_fenrir", "am_liam", "am_michael", "am_onyx", "am_puck", "am_santa" ] MAX_Z_TOKENS = 50 END_TOKEN_THRESHOLD = 0.5 # Global variables model = None mimi_model = None tokenizer = None feature_extractor = None device = None kokoro_pipeline = None def load_audio_to_inputs(feature_extractor, audio_path: str, sample_rate: int) -> torch.Tensor: """Load audio file and convert to model inputs.""" import librosa audio, sr = librosa.load(audio_path, sr=sample_rate, mono=True) audio_inputs = feature_extractor(raw_audio=audio, return_tensors="pt", sampling_rate=sample_rate) return audio_inputs.input_values def initialize_models(model_id: str, tokenizer_id: str = "meta-llama/Llama-3.1-8B-Instruct", hf_token: Optional[str] = None): """Initialize all models from HuggingFace Hub.""" global model, mimi_model, tokenizer, feature_extractor, device, kokoro_pipeline device = "cuda" if torch.cuda.is_available() else "cpu" print(f"Using device: {device}") print(f"Loading TextSyncMimi model from {model_id}...") model = AutoModel.from_pretrained( model_id, trust_remote_code=True, token=hf_token ) model.to(device) model.eval() # Get mimi_model_id from config mimi_model_id = model.config.mimi_model_id if hasattr(model.config, 'mimi_model_id') else "kyutai/mimi" print("Loading Mimi model...") mimi_model = MimiModel.from_pretrained(mimi_model_id, token=hf_token) mimi_model.to(device) mimi_model.eval() print(f"Loading tokenizer from {tokenizer_id}...") tokenizer = AutoTokenizer.from_pretrained(tokenizer_id, token=hf_token) if tokenizer.pad_token is None: tokenizer.pad_token = tokenizer.eos_token print("Loading feature extractor...") feature_extractor = AutoFeatureExtractor.from_pretrained(mimi_model_id, token=hf_token) print("Initializing Kokoro pipeline...") kokoro_pipeline = KPipeline(lang_code='a') print("✅ All models loaded successfully!") @torch.no_grad() def compute_cross_attention_s( model, text_embeddings: torch.Tensor, input_values: torch.Tensor, device: str ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: """Compute projected text embeddings and cross-attended speech embeddings.""" audio_attention_mask = torch.ones(1, input_values.shape[-1], dtype=torch.bool, device=device) text_attention_mask = torch.ones(1, text_embeddings.shape[1], dtype=torch.bool, device=device) # Encode speech speech_embeddings = model.encode_audio_to_representation( input_values.to(device), audio_attention_mask=audio_attention_mask, ).transpose(1, 2) # Project text text_proj = model.text_proj(text_embeddings.to(device)) # Build attention masks batch_size, text_seq_len = text_proj.shape[:2] causal_mask = torch.tril(torch.ones(text_seq_len, text_seq_len, device=device, dtype=text_proj.dtype)) causal_mask = causal_mask.view(1, 1, text_seq_len, text_seq_len).expand(batch_size, -1, -1, -1) pad_mask = text_attention_mask.view(batch_size, 1, 1, text_seq_len) formatted_text_attention_mask = torch.where((causal_mask * pad_mask).bool(), 0.0, float("-inf")) speech_seq_len = speech_embeddings.shape[1] speech_mask = torch.ones(batch_size, speech_seq_len, dtype=torch.bool, device=device) formatted_speech_attention_mask = torch.where( speech_mask.view(batch_size, 1, 1, speech_seq_len), 0.0, float("-inf") ) # Cross attention cross_out = model.cross_attention_transformer( hidden_states=text_proj, encoder_hidden_states=speech_embeddings, attention_mask=formatted_text_attention_mask, encoder_attention_mask=formatted_speech_attention_mask, alignment_chunk_sizes=None, ).last_hidden_state return text_proj, cross_out, text_attention_mask @torch.no_grad() def ar_generate_and_decode( model, mimi_model, text_proj: torch.Tensor, s_tokens: torch.Tensor, text_attention_mask: torch.Tensor, max_z_tokens: int, end_token_threshold: float, device: str ) -> np.ndarray: """Generate audio autoregressively and decode to waveform.""" batch_size, text_seq_len = text_proj.shape[:2] text_speech_latent_emb = model.text_speech_latent_embed(torch.zeros(1, dtype=torch.long, device=device)) time_speech_start_emb = model.time_speech_start_embed(torch.zeros(1, dtype=torch.long, device=device)) time_speech_end_emb = model.time_speech_end_embed(torch.zeros(1, dtype=torch.long, device=device)) generated_z_tokens: List[torch.Tensor] = [] for b in range(batch_size): if text_attention_mask is not None: valid_text_len = int(text_attention_mask[b].sum().item()) else: valid_text_len = text_seq_len sequence: List[torch.Tensor] = [text_speech_latent_emb] for i in range(valid_text_len): t_i = text_proj[b, i:i+1] s_i = s_tokens[b, i:i+1] sequence.extend([t_i, s_i]) sequence.append(time_speech_start_emb) z_count = 0 while z_count < max_z_tokens: current_sequence = torch.cat(sequence, dim=0).unsqueeze(0) ar_attention_mask = torch.ones(1, current_sequence.shape[1], dtype=torch.bool, device=device) ar_outputs = model.ar_transformer( hidden_states=current_sequence, attention_mask=ar_attention_mask, ) last_prediction = ar_outputs.last_hidden_state[0, -1:, :] end_token_logit = model.end_token_classifier(last_prediction).squeeze(-1) end_token_prob = torch.sigmoid(end_token_logit).item() if end_token_prob >= end_token_threshold: break sequence.append(last_prediction) generated_z_tokens.append(last_prediction.squeeze(0)) z_count += 1 sequence.append(time_speech_end_emb) # Decode z tokens to audio if len(generated_z_tokens) == 0: audio_tensor = torch.zeros(1, 1, 1000, device=device) else: z_tokens_batch = torch.stack(generated_z_tokens, dim=0).unsqueeze(0) embeddings_bct = z_tokens_batch.transpose(1, 2) embeddings_upsampled = mimi_model.upsample(embeddings_bct) decoder_outputs = mimi_model.decoder_transformer(embeddings_upsampled.transpose(1, 2), return_dict=True) embeddings_after_dec = decoder_outputs.last_hidden_state.transpose(1, 2) audio_tensor = mimi_model.decoder(embeddings_after_dec) audio_numpy = audio_tensor.squeeze().detach().cpu().numpy() if np.isnan(audio_numpy).any() or np.isinf(audio_numpy).any(): audio_numpy = np.nan_to_num(audio_numpy) if audio_numpy.ndim > 1: audio_numpy = audio_numpy.flatten() return audio_numpy def generate_tts_audio(text: str, voice: str) -> str: """Generate TTS audio using Kokoro and return the file path.""" if not kokoro_pipeline: raise RuntimeError("Kokoro pipeline not initialized") # Kokoro uses KPipeline output which is an iterator of (graphemes, phonemes, audio) generator = kokoro_pipeline(text, voice=voice, speed=1.0, split_pattern=r'\n+') audio_chunks = [] for gs, ps, audio in generator: if audio is not None: audio_chunks.append(audio) if len(audio_chunks) > 0: audio_np = np.concatenate(audio_chunks) else: # Fallback empty audio audio_np = np.zeros(24000) with tempfile.NamedTemporaryFile(suffix=".wav", delete=False) as temp_file: sf.write(temp_file.name, audio_np, 24000) return temp_file.name @spaces.GPU def process_inputs(transcript_text: str, voice1: str, voice2: str): """Process inputs and generate audio.""" if not all([model, mimi_model, tokenizer, feature_extractor, kokoro_pipeline]): return "Please initialize models first!", None, None, None, None, None, None, None if not transcript_text.strip(): return "Please provide a transcript!", None, None, None, None, None, None, None if not voice1 or not voice2: return "Please select voices for both audio samples!", None, None, None, None, None, None, None # Tokenize tokens = tokenizer(transcript_text.strip(), return_tensors="pt", add_special_tokens=False) text_token_ids_cpu = tokens.input_ids.squeeze(0).tolist() text_token_strs = tokenizer.convert_ids_to_tokens(text_token_ids_cpu) text_token_ids = tokens.input_ids.to(device) token_display = "" for i, tok in enumerate(text_token_strs): token_display += f"Token {i}: {tok}\n" # Generate TTS audio print(f"Generating TTS audio with voice '{voice1}'...") audio1_path = generate_tts_audio(transcript_text.strip(), voice1) print(f"Generating TTS audio with voice '{voice2}'...") audio2_path = generate_tts_audio(transcript_text.strip(), voice2) # Load audio input_values_utt1 = load_audio_to_inputs(feature_extractor, audio1_path, SAMPLE_RATE) input_values_utt2 = load_audio_to_inputs(feature_extractor, audio2_path, SAMPLE_RATE) # Get text embeddings using model's built-in text_token_embedding with torch.no_grad(): text_embeddings = model.text_token_embedding(text_token_ids) # Compute cross-attention embeddings t1_proj, s1_cross, text_attention_mask = compute_cross_attention_s( model, text_embeddings, input_values_utt1, device ) _, s2_cross, _ = compute_cross_attention_s( model, text_embeddings, input_values_utt2, device ) # Generate baseline audio baseline_audio = ar_generate_and_decode( model=model, mimi_model=mimi_model, text_proj=t1_proj, s_tokens=s1_cross, text_attention_mask=text_attention_mask, max_z_tokens=MAX_Z_TOKENS, end_token_threshold=END_TOKEN_THRESHOLD, device=device, ) with tempfile.NamedTemporaryFile(suffix=".wav", delete=False) as f: sf.write(f.name, baseline_audio, SAMPLE_RATE) baseline_path = f.name return ( "Processing completed successfully!", token_display, audio1_path, audio2_path, baseline_path, json.dumps({ "t1_proj": t1_proj.cpu().numpy().tolist(), "s1_cross": s1_cross.cpu().numpy().tolist(), "s2_cross": s2_cross.cpu().numpy().tolist(), "text_attention_mask": text_attention_mask.cpu().numpy().tolist(), "num_tokens": len(text_token_strs) }), audio1_path, audio2_path ) @spaces.GPU def swap_embeddings(embeddings_json: str, swap_indices: str): """Perform embedding swap at specified token indices.""" if not embeddings_json: return "Please process inputs first!", None if not swap_indices.strip(): return "Please specify token indices to swap (e.g., 0,2,5)!", None # Parse stored embeddings embeddings_data = json.loads(embeddings_json) t1_proj = torch.tensor(embeddings_data["t1_proj"]).to(device) s1_cross = torch.tensor(embeddings_data["s1_cross"]).to(device) s2_cross = torch.tensor(embeddings_data["s2_cross"]).to(device) text_attention_mask = torch.tensor(embeddings_data["text_attention_mask"]).to(device) num_tokens = embeddings_data["num_tokens"] # Parse indices parts = [p.strip() for p in swap_indices.split(",")] parsed = [int(p) for p in parts if p.isdigit()] if len(parsed) == 0: return "No valid indices provided! Use format: 0,2,5", None valid_indices = [i for i in parsed if 0 <= i < num_tokens] if len(valid_indices) == 0: return f"All indices out of range! Valid range: 0-{num_tokens-1}", None # Perform swap s_swapped = s1_cross.clone() for idx in valid_indices: s_swapped[:, idx:idx+1, :] = s2_cross[:, idx:idx+1, :] # Generate swapped audio swapped_audio = ar_generate_and_decode( model=model, mimi_model=mimi_model, text_proj=t1_proj, s_tokens=s_swapped, text_attention_mask=text_attention_mask, max_z_tokens=MAX_Z_TOKENS, end_token_threshold=END_TOKEN_THRESHOLD, device=device, ) with tempfile.NamedTemporaryFile(suffix=".wav", delete=False) as f: sf.write(f.name, swapped_audio, SAMPLE_RATE) swapped_path = f.name return f"Successfully swapped embeddings at token indices: {valid_indices}", swapped_path def create_gradio_interface(): """Create the Gradio interface.""" with gr.Blocks( title="Speech Editing with TextSyncMimi", theme=gr.themes.Soft(primary_hue="red"), ) as interface: gr.Markdown("# Speech Editing with TextSyncMimi") gr.Markdown("📝 **Blog Post**: [From Time to Text: Investigating Text-Synchronous Speech Representations](https://potsawee.github.io/textsync-speech-embed)") gr.Markdown("🤗 **Model Checkpoint**: https://huggingface.co/potsawee/TextSyncMimi-v1") gr.Markdown("*Demo for Speech Editing Ability with Text-Synchronous Speech Representations*: Generate two voice renditions using Kokoro TTS, then swap speech embeddings at token positions.") with gr.Row(): with gr.Column(): gr.Markdown("## TTS Configuration") transcript_text = gr.Textbox( label="Text", placeholder="Enter text to synthesize...", lines=3 ) with gr.Row(): voice1 = gr.Dropdown( choices=TTS_VOICES, label="Voice 1", value="af_heart" ) voice2 = gr.Dropdown( choices=TTS_VOICES, label="Voice 2", value="am_adam" ) process_btn = gr.Button("Generate & Process", variant="primary") process_status = gr.Textbox(label="Status", interactive=False) with gr.Column(): gr.Markdown("## Tokenization") tokens_display = gr.Textbox( label="Tokens", lines=16, interactive=False ) with gr.Row(): with gr.Column(): gr.Markdown("## Kokoro Synthesized Audio") gr.Markdown("### Audio 1 and Audio 2 (to be swapped)") generated_audio1 = gr.Audio(label="Generated Audio 1") generated_audio2 = gr.Audio(label="Generated Audio 2") with gr.Column(): gr.Markdown("## Audio1 Reconstruction") gr.Markdown("### Reconstruction using TextSync-Mimi without embedding swap") baseline_audio = gr.Audio(label="Baseline Reconstruction") gr.Markdown("### Embedding Position(s) to be Swapped (Audio2 → Audio1)") swap_indices_input = gr.Textbox( label="Token Indices to Swap (See token IDs above)", placeholder="e.g., 0,2,5" ) swap_btn = gr.Button("Perform Swap", variant="primary") swap_status = gr.Textbox(label="Swap Status", interactive=False) swapped_audio = gr.Audio(label="Swapped Result") # Hidden states embeddings_state = gr.State() audio1_state = gr.State() audio2_state = gr.State() # Event handlers process_btn.click( fn=process_inputs, inputs=[transcript_text, voice1, voice2], outputs=[process_status, tokens_display, generated_audio1, generated_audio2, baseline_audio, embeddings_state, audio1_state, audio2_state] ) swap_btn.click( fn=swap_embeddings, inputs=[embeddings_state, swap_indices_input], outputs=[swap_status, swapped_audio] ) return interface def main(): """Main function.""" parser = argparse.ArgumentParser(description="HuggingFace Space Demo for TextSyncMimi") parser.add_argument( "--model_id", type=str, default="potsawee/TextSyncMimi-v1", help="HuggingFace model ID" ) parser.add_argument( "--tokenizer_id", type=str, default="meta-llama/Llama-3.1-8B-Instruct", help="HuggingFace tokenizer ID" ) parser.add_argument( "--hf_token", type=str, default=None, help="Hugging Face token (or set HF_TOKEN env var)" ) parser.add_argument( "--port", type=int, default=7860, help="Port for Gradio app" ) parser.add_argument( "--share", action="store_true", help="Create public share link" ) args = parser.parse_args() # Get HF token hf_token = args.hf_token or os.getenv("HF_TOKEN") # Initialize models print(f"🚀 Initializing TextSyncMimi from HuggingFace Hub: {args.model_id}...") initialize_models(args.model_id, args.tokenizer_id, hf_token) print("🌐 Launching Gradio interface...") # Launch interface = create_gradio_interface() interface.launch(server_name="0.0.0.0", server_port=args.port, share=args.share) if __name__ == "__main__": main()