import gradio as gr import spaces import torch import numpy as np import os import tarfile from pathlib import Path from typing import Optional from huggingface_hub import hf_hub_download import sentencepiece # Configuration HF_REPO = "nvidia/personaplex-7b-v1" DEVICE = "cuda" SAMPLE_RATE = 24000 # Available voices in PersonaPlex ALL_VOICES = [ "NATF0", "NATF1", "NATF2", "NATF3", # Natural Female "NATM0", "NATM1", "NATM2", "NATM3", # Natural Male "VARF0", "VARF1", "VARF2", "VARF3", "VARF4", # Variety Female "VARM0", "VARM1", "VARM2", "VARM3", "VARM4", # Variety Male ] # Example persona prompts from PersonaPlex paper EXAMPLE_PERSONAS = [ "You are a wise and friendly teacher. Answer questions or provide advice in a clear and engaging way.", "You enjoy having a good conversation.", "You work for CitySan Services which is a waste management company and your name is Ayelen Lucero.", "You enjoy having a good conversation. Have a technical discussion about fixing a reactor core on a spaceship to Mars. You are an astronaut on a Mars mission. Your name is Alex.", ] # Import moshi after spaces to allow interception from moshi.models import loaders, LMGen from moshi.models.lm import load_audio, _iterate_audio, encode_from_sphn # Pre-download model weights at startup (cached by huggingface_hub) print("Downloading model weights...") MIMI_WEIGHT = hf_hub_download(HF_REPO, loaders.MIMI_NAME) MOSHI_WEIGHT = hf_hub_download(HF_REPO, loaders.MOSHI_NAME) TOKENIZER_PATH = hf_hub_download(HF_REPO, loaders.TEXT_TOKENIZER_NAME) VOICES_TGZ = hf_hub_download(HF_REPO, "voices.tgz") # Extract voices archive VOICES_DIR = Path(VOICES_TGZ).parent / "voices" if not VOICES_DIR.exists(): print("Extracting voice embeddings...") with tarfile.open(VOICES_TGZ, "r:gz") as tar: tar.extractall(path=Path(VOICES_TGZ).parent) print("Model weights ready.") # Load text tokenizer (CPU only, no CUDA needed) text_tokenizer = sentencepiece.SentencePieceProcessor(TOKENIZER_PATH) # Global model cache - models loaded lazily inside @spaces.GPU _model_cache = {} def get_models(): """Lazy load models on first GPU call.""" global _model_cache if "initialized" not in _model_cache: print("Loading models to GPU...") # Load Mimi encoder/decoder mimi = loaders.get_mimi(MIMI_WEIGHT, DEVICE) other_mimi = loaders.get_mimi(MIMI_WEIGHT, DEVICE) # Load Moshi LM lm = loaders.get_moshi_lm(MOSHI_WEIGHT, device=DEVICE) lm.eval() # Create LMGen wrapper frame_size = int(mimi.sample_rate / mimi.frame_rate) lm_gen = LMGen( lm, audio_silence_frame_cnt=int(0.5 * mimi.frame_rate), sample_rate=mimi.sample_rate, device=DEVICE, frame_rate=mimi.frame_rate, temp=0.8, temp_text=0.7, top_k=250, top_k_text=25, ) # Enable streaming mode mimi.streaming_forever(1) other_mimi.streaming_forever(1) lm_gen.streaming_forever(1) # Run warmup to initialize CUDA graphs (improves performance) print("Running warmup...") _warmup_models(mimi, other_mimi, lm_gen, frame_size) print("Warmup complete.") _model_cache.update({ "mimi": mimi, "other_mimi": other_mimi, "lm_gen": lm_gen, "frame_size": frame_size, "initialized": True, }) print("Models loaded successfully.") return _model_cache def _warmup_models(mimi, other_mimi, lm_gen, frame_size): """Run warmup passes to initialize CUDA graphs.""" for _ in range(4): chunk = torch.zeros(1, 1, frame_size, dtype=torch.float32, device=DEVICE) codes = mimi.encode(chunk) _ = other_mimi.encode(chunk) for c in range(codes.shape[-1]): tokens = lm_gen.step(codes[:, :, c:c+1]) if tokens is not None: _ = mimi.decode(tokens[:, 1:9]) _ = other_mimi.decode(tokens[:, 1:9]) torch.cuda.synchronize() # Reset after warmup mimi.reset_streaming() other_mimi.reset_streaming() lm_gen.reset_streaming() def wrap_with_system_tags(text: str) -> str: """Add system tags as PersonaPlex expects.""" text = text.strip() if text.startswith("") and text.endswith(""): return text return f" {text} " def decode_tokens_to_pcm(mimi, other_mimi, tokens: torch.Tensor) -> np.ndarray: """Decode audio tokens to PCM waveform.""" # tokens shape: [B, num_codebooks, 1] # Agent audio is in codebooks 1:9 agent_audio_tokens = tokens[:, 1:9, :] pcm = other_mimi.decode(agent_audio_tokens) return pcm[0, 0].detach().cpu().numpy() @spaces.GPU(duration=120) def generate_response(audio_input, persona: str, voice: str): """Process audio input and generate PersonaPlex response.""" if audio_input is None: return None, "Please record audio first." # Get lazily loaded models models = get_models() mimi = models["mimi"] other_mimi = models["other_mimi"] lm_gen = models["lm_gen"] frame_size = models["frame_size"] # Process input audio sr, audio = audio_input audio = audio.astype(np.float32) # Convert to mono if stereo if audio.ndim > 1: audio = audio.mean(axis=1) # Normalize to [-1, 1] if audio.max() > 1.0 or audio.min() < -1.0: audio = audio / 32768.0 if audio.dtype == np.int16 else audio / np.abs(audio).max() # Resample to model's sample rate if needed if sr != mimi.sample_rate: import sphn audio = sphn.resample(audio, sr, mimi.sample_rate) # PREPEND SILENCE: Let model say its default greeting during this time (we'll discard this output) prepend_silence_duration = 2 # seconds prepend_silence = np.zeros(int(prepend_silence_duration * mimi.sample_rate), dtype=np.float32) # APPEND SILENCE: Give model time to complete its response after user finishes speaking append_silence_duration = 8 # seconds append_silence = np.zeros(int(append_silence_duration * mimi.sample_rate), dtype=np.float32) # Final audio: [prepend_silence] + [user_audio] + [append_silence] audio = np.concatenate([prepend_silence, audio, append_silence]) # Calculate how many output frames to skip (corresponds to prepend silence) # frame_rate is 12.5 Hz, so frames_to_skip = prepend_silence_duration * frame_rate frames_to_skip = int(prepend_silence_duration * 12.5) # Add channel dimension: (T,) -> (1, T) if audio.ndim == 1: audio = audio[None, :] # Load voice prompt voice_path = str(VOICES_DIR / f"{voice}.pt") if not os.path.exists(voice_path): return None, f"Voice '{voice}' not found." lm_gen.load_voice_prompt_embeddings(voice_path) # Set text prompt if persona.strip(): lm_gen.text_prompt_tokens = text_tokenizer.encode(wrap_with_system_tags(persona)) else: lm_gen.text_prompt_tokens = None # Run system prompts (voice + text conditioning) with lm_gen.streaming(1): # Reset streaming state inside the context mimi.reset_streaming() other_mimi.reset_streaming() lm_gen.reset_streaming() lm_gen.step_system_prompts(mimi) mimi.reset_streaming() # Process user audio frames generated_frames = [] generated_text = [] frame_count = 0 # Track frame index to skip prepend silence output for user_encoded in encode_from_sphn( mimi, _iterate_audio(audio, sample_interval_size=frame_size, pad=True), max_batch=1, ): for c in range(user_encoded.shape[-1]): step_in = user_encoded[:, :, c:c+1] tokens = lm_gen.step(step_in) frame_count += 1 if tokens is None: continue # Skip frames generated during prepend silence (model's default greeting) if frame_count <= frames_to_skip: continue # Decode agent audio pcm = decode_tokens_to_pcm(mimi, other_mimi, tokens) generated_frames.append(pcm) # Decode text token text_token = tokens[0, 0, 0].item() if text_token not in (0, 3): # Skip special tokens text_piece = text_tokenizer.id_to_piece(text_token).replace("▁", " ") generated_text.append(text_piece) if not generated_frames: return None, "No audio generated. Try speaking more clearly." # Concatenate output audio output_audio = np.concatenate(generated_frames, axis=-1) output_text = "".join(generated_text).strip() return (mimi.sample_rate, output_audio), output_text # Build Gradio interface with gr.Blocks(title="PersonaPlex Demo", theme=gr.themes.Soft()) as demo: gr.Markdown( """ # 🎭 PersonaPlex **Voice and Role Control for Full Duplex Conversational Speech Models** [Paper](https://arxiv.org/abs/2503.04721) | [GitHub](https://github.com/NVIDIA/personaplex) | [Model](https://huggingface.co/nvidia/personaplex-7b-v1) --- Record your message, and PersonaPlex will respond with the configured persona and voice. """ ) with gr.Row(): with gr.Column(scale=1): persona = gr.Textbox( label="Persona Description", placeholder="Describe the assistant's persona...", value=EXAMPLE_PERSONAS[0], lines=4, ) voice = gr.Dropdown( choices=ALL_VOICES, value="NATF2", label="Voice" ) gr.Examples( examples=[[p] for p in EXAMPLE_PERSONAS], inputs=[persona], label="Example Personas" ) with gr.Column(scale=2): audio_input = gr.Audio( label="🎤 Record your message", sources=["microphone", "upload"], type="numpy", ) generate_btn = gr.Button("Generate Response", variant="primary", size="lg") audio_output = gr.Audio( label="🔊 PersonaPlex Response", type="numpy", autoplay=True, ) text_output = gr.Textbox( label="📝 Response Text", interactive=False, ) generate_btn.click( fn=generate_response, inputs=[audio_input, persona, voice], outputs=[audio_output, text_output], ) if __name__ == "__main__": demo.launch()