Spaces:
Running on Zero
Running on Zero
| #!/usr/bin/env python3 | |
| """ | |
| HuggingFace Space Demo for TextSyncMimi | |
| Speech Editing with Token-Level Embedding Swapping | |
| This demo loads the model from HuggingFace Hub and allows: | |
| - Generating speech with different voices using Kokoro TTS | |
| - Swapping speech embeddings at specific token positions | |
| - Real-time speech editing | |
| Prerequisites: | |
| - Model will be loaded from HuggingFace Hub | |
| """ | |
| import os | |
| # Import spaces for GPU support BEFORE torch or other CUDA packages | |
| try: | |
| import spaces | |
| GPU_AVAILABLE = True | |
| except ImportError: | |
| GPU_AVAILABLE = False | |
| # Create dummy decorator if spaces not available | |
| class spaces: | |
| def GPU(func): | |
| return func | |
| import json | |
| import tempfile | |
| import argparse | |
| from typing import List, Tuple, Optional | |
| from pathlib import Path | |
| import numpy as np | |
| import torch | |
| import torch.nn as nn | |
| import soundfile as sf | |
| import gradio as gr | |
| from kokoro import KPipeline | |
| from transformers import ( | |
| AutoModel, | |
| AutoFeatureExtractor, | |
| AutoTokenizer, | |
| MimiModel, | |
| ) | |
| # Constants | |
| SAMPLE_RATE = 24000 | |
| FRAME_RATE = 12.5 | |
| TTS_VOICES = [ | |
| "af_heart", "af_alloy", "af_aoede", "af_bella", "af_jessica", "af_kore", "af_nicole", "af_nova", "af_river", "af_sky", | |
| "am_adam", "am_echo", "am_eric", "am_fenrir", "am_liam", "am_michael", "am_onyx", "am_puck", "am_santa" | |
| ] | |
| MAX_Z_TOKENS = 50 | |
| END_TOKEN_THRESHOLD = 0.5 | |
| # Global variables | |
| model = None | |
| mimi_model = None | |
| tokenizer = None | |
| feature_extractor = None | |
| device = None | |
| kokoro_pipeline = None | |
| def load_audio_to_inputs(feature_extractor, audio_path: str, sample_rate: int) -> torch.Tensor: | |
| """Load audio file and convert to model inputs.""" | |
| import librosa | |
| audio, sr = librosa.load(audio_path, sr=sample_rate, mono=True) | |
| audio_inputs = feature_extractor(raw_audio=audio, return_tensors="pt", sampling_rate=sample_rate) | |
| return audio_inputs.input_values | |
| def initialize_models(model_id: str, tokenizer_id: str = "meta-llama/Llama-3.1-8B-Instruct", hf_token: Optional[str] = None): | |
| """Initialize all models from HuggingFace Hub.""" | |
| global model, mimi_model, tokenizer, feature_extractor, device, kokoro_pipeline | |
| device = "cuda" if torch.cuda.is_available() else "cpu" | |
| print(f"Using device: {device}") | |
| print(f"Loading TextSyncMimi model from {model_id}...") | |
| model = AutoModel.from_pretrained( | |
| model_id, | |
| trust_remote_code=True, | |
| token=hf_token | |
| ) | |
| model.to(device) | |
| model.eval() | |
| # Get mimi_model_id from config | |
| mimi_model_id = model.config.mimi_model_id if hasattr(model.config, 'mimi_model_id') else "kyutai/mimi" | |
| print("Loading Mimi model...") | |
| mimi_model = MimiModel.from_pretrained(mimi_model_id, token=hf_token) | |
| mimi_model.to(device) | |
| mimi_model.eval() | |
| print(f"Loading tokenizer from {tokenizer_id}...") | |
| tokenizer = AutoTokenizer.from_pretrained(tokenizer_id, token=hf_token) | |
| if tokenizer.pad_token is None: | |
| tokenizer.pad_token = tokenizer.eos_token | |
| print("Loading feature extractor...") | |
| feature_extractor = AutoFeatureExtractor.from_pretrained(mimi_model_id, token=hf_token) | |
| print("Initializing Kokoro pipeline...") | |
| kokoro_pipeline = KPipeline(lang_code='a') | |
| print("β All models loaded successfully!") | |
| def compute_cross_attention_s( | |
| model, | |
| text_embeddings: torch.Tensor, | |
| input_values: torch.Tensor, | |
| device: str | |
| ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: | |
| """Compute projected text embeddings and cross-attended speech embeddings.""" | |
| audio_attention_mask = torch.ones(1, input_values.shape[-1], dtype=torch.bool, device=device) | |
| text_attention_mask = torch.ones(1, text_embeddings.shape[1], dtype=torch.bool, device=device) | |
| # Encode speech | |
| speech_embeddings = model.encode_audio_to_representation( | |
| input_values.to(device), | |
| audio_attention_mask=audio_attention_mask, | |
| ).transpose(1, 2) | |
| # Project text | |
| text_proj = model.text_proj(text_embeddings.to(device)) | |
| # Build attention masks | |
| batch_size, text_seq_len = text_proj.shape[:2] | |
| causal_mask = torch.tril(torch.ones(text_seq_len, text_seq_len, device=device, dtype=text_proj.dtype)) | |
| causal_mask = causal_mask.view(1, 1, text_seq_len, text_seq_len).expand(batch_size, -1, -1, -1) | |
| pad_mask = text_attention_mask.view(batch_size, 1, 1, text_seq_len) | |
| formatted_text_attention_mask = torch.where((causal_mask * pad_mask).bool(), 0.0, float("-inf")) | |
| speech_seq_len = speech_embeddings.shape[1] | |
| speech_mask = torch.ones(batch_size, speech_seq_len, dtype=torch.bool, device=device) | |
| formatted_speech_attention_mask = torch.where( | |
| speech_mask.view(batch_size, 1, 1, speech_seq_len), 0.0, float("-inf") | |
| ) | |
| # Cross attention | |
| cross_out = model.cross_attention_transformer( | |
| hidden_states=text_proj, | |
| encoder_hidden_states=speech_embeddings, | |
| attention_mask=formatted_text_attention_mask, | |
| encoder_attention_mask=formatted_speech_attention_mask, | |
| alignment_chunk_sizes=None, | |
| ).last_hidden_state | |
| return text_proj, cross_out, text_attention_mask | |
| def ar_generate_and_decode( | |
| model, | |
| mimi_model, | |
| text_proj: torch.Tensor, | |
| s_tokens: torch.Tensor, | |
| text_attention_mask: torch.Tensor, | |
| max_z_tokens: int, | |
| end_token_threshold: float, | |
| device: str | |
| ) -> np.ndarray: | |
| """Generate audio autoregressively and decode to waveform.""" | |
| batch_size, text_seq_len = text_proj.shape[:2] | |
| text_speech_latent_emb = model.text_speech_latent_embed(torch.zeros(1, dtype=torch.long, device=device)) | |
| time_speech_start_emb = model.time_speech_start_embed(torch.zeros(1, dtype=torch.long, device=device)) | |
| time_speech_end_emb = model.time_speech_end_embed(torch.zeros(1, dtype=torch.long, device=device)) | |
| generated_z_tokens: List[torch.Tensor] = [] | |
| for b in range(batch_size): | |
| if text_attention_mask is not None: | |
| valid_text_len = int(text_attention_mask[b].sum().item()) | |
| else: | |
| valid_text_len = text_seq_len | |
| sequence: List[torch.Tensor] = [text_speech_latent_emb] | |
| for i in range(valid_text_len): | |
| t_i = text_proj[b, i:i+1] | |
| s_i = s_tokens[b, i:i+1] | |
| sequence.extend([t_i, s_i]) | |
| sequence.append(time_speech_start_emb) | |
| z_count = 0 | |
| while z_count < max_z_tokens: | |
| current_sequence = torch.cat(sequence, dim=0).unsqueeze(0) | |
| ar_attention_mask = torch.ones(1, current_sequence.shape[1], dtype=torch.bool, device=device) | |
| ar_outputs = model.ar_transformer( | |
| hidden_states=current_sequence, | |
| attention_mask=ar_attention_mask, | |
| ) | |
| last_prediction = ar_outputs.last_hidden_state[0, -1:, :] | |
| end_token_logit = model.end_token_classifier(last_prediction).squeeze(-1) | |
| end_token_prob = torch.sigmoid(end_token_logit).item() | |
| if end_token_prob >= end_token_threshold: | |
| break | |
| sequence.append(last_prediction) | |
| generated_z_tokens.append(last_prediction.squeeze(0)) | |
| z_count += 1 | |
| sequence.append(time_speech_end_emb) | |
| # Decode z tokens to audio | |
| if len(generated_z_tokens) == 0: | |
| audio_tensor = torch.zeros(1, 1, 1000, device=device) | |
| else: | |
| z_tokens_batch = torch.stack(generated_z_tokens, dim=0).unsqueeze(0) | |
| embeddings_bct = z_tokens_batch.transpose(1, 2) | |
| embeddings_upsampled = mimi_model.upsample(embeddings_bct) | |
| decoder_outputs = mimi_model.decoder_transformer(embeddings_upsampled.transpose(1, 2), return_dict=True) | |
| embeddings_after_dec = decoder_outputs.last_hidden_state.transpose(1, 2) | |
| audio_tensor = mimi_model.decoder(embeddings_after_dec) | |
| audio_numpy = audio_tensor.squeeze().detach().cpu().numpy() | |
| if np.isnan(audio_numpy).any() or np.isinf(audio_numpy).any(): | |
| audio_numpy = np.nan_to_num(audio_numpy) | |
| if audio_numpy.ndim > 1: | |
| audio_numpy = audio_numpy.flatten() | |
| return audio_numpy | |
| def generate_tts_audio(text: str, voice: str) -> str: | |
| """Generate TTS audio using Kokoro and return the file path.""" | |
| if not kokoro_pipeline: | |
| raise RuntimeError("Kokoro pipeline not initialized") | |
| # Kokoro uses KPipeline output which is an iterator of (graphemes, phonemes, audio) | |
| generator = kokoro_pipeline(text, voice=voice, speed=1.0, split_pattern=r'\n+') | |
| audio_chunks = [] | |
| for gs, ps, audio in generator: | |
| if audio is not None: | |
| audio_chunks.append(audio) | |
| if len(audio_chunks) > 0: | |
| audio_np = np.concatenate(audio_chunks) | |
| else: | |
| # Fallback empty audio | |
| audio_np = np.zeros(24000) | |
| with tempfile.NamedTemporaryFile(suffix=".wav", delete=False) as temp_file: | |
| sf.write(temp_file.name, audio_np, 24000) | |
| return temp_file.name | |
| def process_inputs(transcript_text: str, voice1: str, voice2: str): | |
| """Process inputs and generate audio.""" | |
| if not all([model, mimi_model, tokenizer, feature_extractor, kokoro_pipeline]): | |
| return "Please initialize models first!", None, None, None, None, None, None, None | |
| if not transcript_text.strip(): | |
| return "Please provide a transcript!", None, None, None, None, None, None, None | |
| if not voice1 or not voice2: | |
| return "Please select voices for both audio samples!", None, None, None, None, None, None, None | |
| # Tokenize | |
| tokens = tokenizer(transcript_text.strip(), return_tensors="pt", add_special_tokens=False) | |
| text_token_ids_cpu = tokens.input_ids.squeeze(0).tolist() | |
| text_token_strs = tokenizer.convert_ids_to_tokens(text_token_ids_cpu) | |
| text_token_ids = tokens.input_ids.to(device) | |
| token_display = "" | |
| for i, tok in enumerate(text_token_strs): | |
| token_display += f"Token {i}: {tok}\n" | |
| # Generate TTS audio | |
| print(f"Generating TTS audio with voice '{voice1}'...") | |
| audio1_path = generate_tts_audio(transcript_text.strip(), voice1) | |
| print(f"Generating TTS audio with voice '{voice2}'...") | |
| audio2_path = generate_tts_audio(transcript_text.strip(), voice2) | |
| # Load audio | |
| input_values_utt1 = load_audio_to_inputs(feature_extractor, audio1_path, SAMPLE_RATE) | |
| input_values_utt2 = load_audio_to_inputs(feature_extractor, audio2_path, SAMPLE_RATE) | |
| # Get text embeddings using model's built-in text_token_embedding | |
| with torch.no_grad(): | |
| text_embeddings = model.text_token_embedding(text_token_ids) | |
| # Compute cross-attention embeddings | |
| t1_proj, s1_cross, text_attention_mask = compute_cross_attention_s( | |
| model, text_embeddings, input_values_utt1, device | |
| ) | |
| _, s2_cross, _ = compute_cross_attention_s( | |
| model, text_embeddings, input_values_utt2, device | |
| ) | |
| # Generate baseline audio | |
| baseline_audio = ar_generate_and_decode( | |
| model=model, | |
| mimi_model=mimi_model, | |
| text_proj=t1_proj, | |
| s_tokens=s1_cross, | |
| text_attention_mask=text_attention_mask, | |
| max_z_tokens=MAX_Z_TOKENS, | |
| end_token_threshold=END_TOKEN_THRESHOLD, | |
| device=device, | |
| ) | |
| with tempfile.NamedTemporaryFile(suffix=".wav", delete=False) as f: | |
| sf.write(f.name, baseline_audio, SAMPLE_RATE) | |
| baseline_path = f.name | |
| return ( | |
| "Processing completed successfully!", | |
| token_display, | |
| audio1_path, | |
| audio2_path, | |
| baseline_path, | |
| json.dumps({ | |
| "t1_proj": t1_proj.cpu().numpy().tolist(), | |
| "s1_cross": s1_cross.cpu().numpy().tolist(), | |
| "s2_cross": s2_cross.cpu().numpy().tolist(), | |
| "text_attention_mask": text_attention_mask.cpu().numpy().tolist(), | |
| "num_tokens": len(text_token_strs) | |
| }), | |
| audio1_path, | |
| audio2_path | |
| ) | |
| def swap_embeddings(embeddings_json: str, swap_indices: str): | |
| """Perform embedding swap at specified token indices.""" | |
| if not embeddings_json: | |
| return "Please process inputs first!", None | |
| if not swap_indices.strip(): | |
| return "Please specify token indices to swap (e.g., 0,2,5)!", None | |
| # Parse stored embeddings | |
| embeddings_data = json.loads(embeddings_json) | |
| t1_proj = torch.tensor(embeddings_data["t1_proj"]).to(device) | |
| s1_cross = torch.tensor(embeddings_data["s1_cross"]).to(device) | |
| s2_cross = torch.tensor(embeddings_data["s2_cross"]).to(device) | |
| text_attention_mask = torch.tensor(embeddings_data["text_attention_mask"]).to(device) | |
| num_tokens = embeddings_data["num_tokens"] | |
| # Parse indices | |
| parts = [p.strip() for p in swap_indices.split(",")] | |
| parsed = [int(p) for p in parts if p.isdigit()] | |
| if len(parsed) == 0: | |
| return "No valid indices provided! Use format: 0,2,5", None | |
| valid_indices = [i for i in parsed if 0 <= i < num_tokens] | |
| if len(valid_indices) == 0: | |
| return f"All indices out of range! Valid range: 0-{num_tokens-1}", None | |
| # Perform swap | |
| s_swapped = s1_cross.clone() | |
| for idx in valid_indices: | |
| s_swapped[:, idx:idx+1, :] = s2_cross[:, idx:idx+1, :] | |
| # Generate swapped audio | |
| swapped_audio = ar_generate_and_decode( | |
| model=model, | |
| mimi_model=mimi_model, | |
| text_proj=t1_proj, | |
| s_tokens=s_swapped, | |
| text_attention_mask=text_attention_mask, | |
| max_z_tokens=MAX_Z_TOKENS, | |
| end_token_threshold=END_TOKEN_THRESHOLD, | |
| device=device, | |
| ) | |
| with tempfile.NamedTemporaryFile(suffix=".wav", delete=False) as f: | |
| sf.write(f.name, swapped_audio, SAMPLE_RATE) | |
| swapped_path = f.name | |
| return f"Successfully swapped embeddings at token indices: {valid_indices}", swapped_path | |
| def create_gradio_interface(): | |
| """Create the Gradio interface.""" | |
| with gr.Blocks( | |
| title="Speech Editing with TextSyncMimi", | |
| theme=gr.themes.Soft(primary_hue="red"), | |
| ) as interface: | |
| gr.Markdown("# Speech Editing with TextSyncMimi") | |
| gr.Markdown("π **Blog Post**: [From Time to Text: Investigating Text-Synchronous Speech Representations](https://potsawee.github.io/textsync-speech-embed)") | |
| gr.Markdown("π€ **Model Checkpoint**: https://huggingface.co/potsawee/TextSyncMimi-v1") | |
| gr.Markdown("*Demo for Speech Editing Ability with Text-Synchronous Speech Representations*: Generate two voice renditions using Kokoro TTS, then swap speech embeddings at token positions.") | |
| with gr.Row(): | |
| with gr.Column(): | |
| gr.Markdown("## TTS Configuration") | |
| transcript_text = gr.Textbox( | |
| label="Text", | |
| placeholder="Enter text to synthesize...", | |
| lines=3 | |
| ) | |
| with gr.Row(): | |
| voice1 = gr.Dropdown( | |
| choices=TTS_VOICES, | |
| label="Voice 1", | |
| value="af_heart" | |
| ) | |
| voice2 = gr.Dropdown( | |
| choices=TTS_VOICES, | |
| label="Voice 2", | |
| value="am_adam" | |
| ) | |
| process_btn = gr.Button("Generate & Process", variant="primary") | |
| process_status = gr.Textbox(label="Status", interactive=False) | |
| with gr.Column(): | |
| gr.Markdown("## Tokenization") | |
| tokens_display = gr.Textbox( | |
| label="Tokens", | |
| lines=16, | |
| interactive=False | |
| ) | |
| with gr.Row(): | |
| with gr.Column(): | |
| gr.Markdown("## Kokoro Synthesized Audio") | |
| gr.Markdown("### Audio 1 and Audio 2 (to be swapped)") | |
| generated_audio1 = gr.Audio(label="Generated Audio 1") | |
| generated_audio2 = gr.Audio(label="Generated Audio 2") | |
| with gr.Column(): | |
| gr.Markdown("## Audio1 Reconstruction") | |
| gr.Markdown("### Reconstruction using TextSync-Mimi without embedding swap") | |
| baseline_audio = gr.Audio(label="Baseline Reconstruction") | |
| gr.Markdown("### Embedding Position(s) to be Swapped (Audio2 β Audio1)") | |
| swap_indices_input = gr.Textbox( | |
| label="Token Indices to Swap (See token IDs above)", | |
| placeholder="e.g., 0,2,5" | |
| ) | |
| swap_btn = gr.Button("Perform Swap", variant="primary") | |
| swap_status = gr.Textbox(label="Swap Status", interactive=False) | |
| swapped_audio = gr.Audio(label="Swapped Result") | |
| # Hidden states | |
| embeddings_state = gr.State() | |
| audio1_state = gr.State() | |
| audio2_state = gr.State() | |
| # Event handlers | |
| process_btn.click( | |
| fn=process_inputs, | |
| inputs=[transcript_text, voice1, voice2], | |
| outputs=[process_status, tokens_display, generated_audio1, generated_audio2, | |
| baseline_audio, embeddings_state, audio1_state, audio2_state] | |
| ) | |
| swap_btn.click( | |
| fn=swap_embeddings, | |
| inputs=[embeddings_state, swap_indices_input], | |
| outputs=[swap_status, swapped_audio] | |
| ) | |
| return interface | |
| def main(): | |
| """Main function.""" | |
| parser = argparse.ArgumentParser(description="HuggingFace Space Demo for TextSyncMimi") | |
| parser.add_argument( | |
| "--model_id", | |
| type=str, | |
| default="potsawee/TextSyncMimi-v1", | |
| help="HuggingFace model ID" | |
| ) | |
| parser.add_argument( | |
| "--tokenizer_id", | |
| type=str, | |
| default="meta-llama/Llama-3.1-8B-Instruct", | |
| help="HuggingFace tokenizer ID" | |
| ) | |
| parser.add_argument( | |
| "--hf_token", | |
| type=str, | |
| default=None, | |
| help="Hugging Face token (or set HF_TOKEN env var)" | |
| ) | |
| parser.add_argument( | |
| "--port", | |
| type=int, | |
| default=7860, | |
| help="Port for Gradio app" | |
| ) | |
| parser.add_argument( | |
| "--share", | |
| action="store_true", | |
| help="Create public share link" | |
| ) | |
| args = parser.parse_args() | |
| # Get HF token | |
| hf_token = args.hf_token or os.getenv("HF_TOKEN") | |
| # Initialize models | |
| print(f"π Initializing TextSyncMimi from HuggingFace Hub: {args.model_id}...") | |
| initialize_models(args.model_id, args.tokenizer_id, hf_token) | |
| print("π Launching Gradio interface...") | |
| # Launch | |
| interface = create_gradio_interface() | |
| interface.launch(server_name="0.0.0.0", server_port=args.port, share=args.share) | |
| if __name__ == "__main__": | |
| main() |