Spaces:
Runtime error
Runtime error
| """ | |
| PersonaPlex HuggingFace Space — Speech-to-speech with 16 voices and persona control. | |
| Uses ZeroGPU (@spaces.GPU) for dynamic H200 allocation. | |
| Models are loaded on CPU at startup, moved to CUDA inside the GPU-decorated function. | |
| """ | |
| import sys | |
| import os | |
| import random | |
| import tarfile | |
| import json | |
| from pathlib import Path | |
| from typing import Optional | |
| sys.path.insert(0, ".") | |
| import spaces | |
| import gradio as gr | |
| import torch | |
| import numpy as np | |
| import sentencepiece | |
| import sphn | |
| from huggingface_hub import hf_hub_download | |
| from moshi.models import loaders, LMGen, MimiModel | |
| from moshi.models.lm import ( | |
| load_audio as lm_load_audio, | |
| _iterate_audio as lm_iterate_audio, | |
| encode_from_sphn as lm_encode_from_sphn, | |
| ) | |
| # --------------------------------------------------------------------------- | |
| # Constants | |
| # --------------------------------------------------------------------------- | |
| HF_REPO = "nvidia/personaplex-7b-v1" | |
| VOICES = { | |
| "Natural Female 1 (NATF0)": "NATF0.pt", | |
| "Natural Female 2 (NATF1)": "NATF1.pt", | |
| "Natural Female 3 (NATF2)": "NATF2.pt", | |
| "Natural Female 4 (NATF3)": "NATF3.pt", | |
| "Natural Male 1 (NATM0)": "NATM0.pt", | |
| "Natural Male 2 (NATM1)": "NATM1.pt", | |
| "Natural Male 3 (NATM2)": "NATM2.pt", | |
| "Natural Male 4 (NATM3)": "NATM3.pt", | |
| "Variety Female 1 (VARF0)": "VARF0.pt", | |
| "Variety Female 2 (VARF1)": "VARF1.pt", | |
| "Variety Female 3 (VARF2)": "VARF2.pt", | |
| "Variety Female 4 (VARF3)": "VARF3.pt", | |
| "Variety Female 5 (VARF4)": "VARF4.pt", | |
| "Variety Male 1 (VARM0)": "VARM0.pt", | |
| "Variety Male 2 (VARM1)": "VARM1.pt", | |
| "Variety Male 3 (VARM2)": "VARM2.pt", | |
| "Variety Male 4 (VARM3)": "VARM3.pt", | |
| "Variety Male 5 (VARM4)": "VARM4.pt", | |
| } | |
| PERSONAS = { | |
| "Assistant": "You are a wise and friendly teacher. Answer questions or provide advice in a clear and engaging way.", | |
| "Mars Astronaut": "You enjoy having a good conversation. Have a technical discussion about fixing a reactor core on a spaceship to Mars. You are an astronaut on a Mars mission. Your name is Alex.", | |
| "Restaurant": "You work for Jerusalem Shakshuka which is a restaurant and your name is Owen Foster. Information: There are two shakshuka options: Classic (poached eggs, $9.50) and Spicy (scrambled eggs with jalapenos, $10.25).", | |
| "Casual Chat": "You enjoy having a good conversation.", | |
| "Custom": "", | |
| } | |
| # --------------------------------------------------------------------------- | |
| # Model globals (loaded on CPU at startup) | |
| # --------------------------------------------------------------------------- | |
| _mimi_weight_path: Optional[str] = None | |
| _moshi_weight_path: Optional[str] = None | |
| _tokenizer_path: Optional[str] = None | |
| _voice_prompt_dir: Optional[str] = None | |
| _text_tokenizer: Optional[sentencepiece.SentencePieceProcessor] = None | |
| def _download_assets(): | |
| """Download all model weights and voice prompts from HuggingFace Hub.""" | |
| global _mimi_weight_path, _moshi_weight_path, _tokenizer_path | |
| global _voice_prompt_dir, _text_tokenizer | |
| print("[Init] Downloading config.json (download counter)...") | |
| hf_hub_download(HF_REPO, "config.json") | |
| print("[Init] Downloading Mimi weights...") | |
| _mimi_weight_path = hf_hub_download(HF_REPO, loaders.MIMI_NAME) | |
| print("[Init] Downloading Moshi LM weights...") | |
| _moshi_weight_path = hf_hub_download(HF_REPO, loaders.MOSHI_NAME) | |
| print("[Init] Downloading tokenizer...") | |
| _tokenizer_path = hf_hub_download(HF_REPO, loaders.TEXT_TOKENIZER_NAME) | |
| _text_tokenizer = sentencepiece.SentencePieceProcessor(_tokenizer_path) | |
| print("[Init] Downloading voice prompts...") | |
| voices_tgz = hf_hub_download(HF_REPO, "voices.tgz") | |
| voices_tgz = Path(voices_tgz) | |
| voices_dir = voices_tgz.parent / "voices" | |
| if not voices_dir.exists(): | |
| print(f"[Init] Extracting {voices_tgz} -> {voices_dir}") | |
| with tarfile.open(voices_tgz, "r:gz") as tar: | |
| tar.extractall(path=voices_tgz.parent) | |
| if not voices_dir.exists(): | |
| raise RuntimeError("voices.tgz did not contain a 'voices/' directory") | |
| _voice_prompt_dir = str(voices_dir) | |
| print("[Init] All assets downloaded successfully.") | |
| # Download on import (CPU only, no GPU needed) | |
| _download_assets() | |
| # --------------------------------------------------------------------------- | |
| # Audio helpers | |
| # --------------------------------------------------------------------------- | |
| def _resample_numpy(audio: np.ndarray, src_sr: int, dst_sr: int) -> np.ndarray: | |
| """Resample a 1-D numpy audio array from src_sr to dst_sr using linear interpolation.""" | |
| if src_sr == dst_sr: | |
| return audio | |
| duration = len(audio) / src_sr | |
| target_len = int(duration * dst_sr) | |
| indices = np.linspace(0, len(audio) - 1, target_len) | |
| return np.interp(indices, np.arange(len(audio)), audio).astype(np.float32) | |
| def _wrap_system_tags(text: str) -> str: | |
| """Add <system> tags as the model expects.""" | |
| cleaned = text.strip() | |
| if cleaned.startswith("<system>") and cleaned.endswith("<system>"): | |
| return cleaned | |
| return f"<system> {cleaned} <system>" | |
| # --------------------------------------------------------------------------- | |
| # Inference (runs on GPU via ZeroGPU) | |
| # --------------------------------------------------------------------------- | |
| def run_inference(audio_input, voice_name, persona_text, seed): | |
| """ | |
| Run PersonaPlex speech-to-speech inference. | |
| Args: | |
| audio_input: tuple (sample_rate, numpy_array) from Gradio audio component | |
| voice_name: key from VOICES dict | |
| persona_text: persona system prompt string | |
| seed: int seed (-1 for random) | |
| Returns: | |
| (sample_rate, numpy_array): output audio | |
| str: transcript text | |
| """ | |
| if audio_input is None: | |
| raise gr.Error("Please record or upload audio first.") | |
| input_sr, input_audio = audio_input | |
| # Convert to float32 if integer | |
| if input_audio.dtype in (np.int16, np.int32): | |
| input_audio = input_audio.astype(np.float32) / np.iinfo(input_audio.dtype).max | |
| # Convert stereo to mono | |
| if input_audio.ndim == 2: | |
| input_audio = input_audio.mean(axis=1) | |
| # Ensure 1-D float32 | |
| input_audio = input_audio.astype(np.float32) | |
| # Seed RNG | |
| actual_seed = seed if seed >= 0 else random.randint(0, 2**31 - 1) | |
| torch.manual_seed(int(actual_seed)) | |
| if torch.cuda.is_available(): | |
| torch.cuda.manual_seed(int(actual_seed)) | |
| torch.cuda.manual_seed_all(int(actual_seed)) | |
| random.seed(int(actual_seed)) | |
| np.random.seed(int(actual_seed)) | |
| torch.backends.cudnn.deterministic = False | |
| torch.backends.cudnn.benchmark = False | |
| device = "cuda" | |
| # Load models fresh on GPU each call (ZeroGPU gives us a clean GPU) | |
| print("[Inference] Loading Mimi on CUDA...") | |
| mimi = loaders.get_mimi(_mimi_weight_path, device) | |
| other_mimi = loaders.get_mimi(_mimi_weight_path, device) | |
| print("[Inference] Mimi loaded.") | |
| print("[Inference] Loading Moshi LM on CUDA...") | |
| lm = loaders.get_moshi_lm(_moshi_weight_path, device=device) | |
| lm.eval() | |
| print("[Inference] Moshi LM loaded.") | |
| # Build LMGen | |
| frame_size = int(mimi.sample_rate / mimi.frame_rate) | |
| lm_gen = LMGen( | |
| lm, | |
| audio_silence_frame_cnt=int(0.5 * mimi.frame_rate), | |
| sample_rate=mimi.sample_rate, | |
| device=device, | |
| frame_rate=mimi.frame_rate, | |
| save_voice_prompt_embeddings=False, | |
| use_sampling=True, | |
| temp=0.8, | |
| temp_text=0.7, | |
| top_k=250, | |
| top_k_text=25, | |
| ) | |
| # Streaming mode | |
| mimi.streaming_forever(1) | |
| other_mimi.streaming_forever(1) | |
| lm_gen.streaming_forever(1) | |
| # Warmup (CUDA graphs) | |
| print("[Inference] Warming up...") | |
| for _ in range(4): | |
| chunk = torch.zeros(1, 1, frame_size, dtype=torch.float32, device=device) | |
| codes = mimi.encode(chunk) | |
| _ = other_mimi.encode(chunk) | |
| for c in range(codes.shape[-1]): | |
| tokens = lm_gen.step(codes[:, :, c : c + 1]) | |
| if tokens is None: | |
| continue | |
| _ = mimi.decode(tokens[:, 1:9]) | |
| _ = other_mimi.decode(tokens[:, 1:9]) | |
| if torch.cuda.is_available(): | |
| torch.cuda.synchronize() | |
| print("[Inference] Warmup complete.") | |
| # Load voice prompt | |
| voice_file = VOICES.get(voice_name, "NATF2.pt") | |
| voice_path = os.path.join(_voice_prompt_dir, voice_file) | |
| if not os.path.exists(voice_path): | |
| raise gr.Error(f"Voice prompt file not found: {voice_path}") | |
| if voice_path.endswith(".pt"): | |
| lm_gen.load_voice_prompt_embeddings(voice_path) | |
| else: | |
| lm_gen.load_voice_prompt(voice_path) | |
| # Encode text prompt | |
| if persona_text and persona_text.strip(): | |
| lm_gen.text_prompt_tokens = _text_tokenizer.encode( | |
| _wrap_system_tags(persona_text) | |
| ) | |
| else: | |
| lm_gen.text_prompt_tokens = None | |
| # Reset streaming and run system prompts | |
| mimi.reset_streaming() | |
| other_mimi.reset_streaming() | |
| lm_gen.reset_streaming() | |
| print("[Inference] Running system prompts (voice + text)...") | |
| lm_gen.step_system_prompts(mimi) | |
| mimi.reset_streaming() | |
| print("[Inference] System prompts complete.") | |
| # Resample input audio to model sample rate (24 kHz) | |
| model_sr = int(mimi.sample_rate) | |
| user_pcm = _resample_numpy(input_audio, input_sr, model_sr) | |
| # Shape expected by lm helpers: (C, T) | |
| user_pcm_2d = user_pcm[np.newaxis, :] # (1, T) | |
| total_target_samples = user_pcm_2d.shape[-1] | |
| # Stream user audio through the model | |
| print(f"[Inference] Processing {total_target_samples} samples ({total_target_samples / model_sr:.1f}s)...") | |
| generated_frames = [] | |
| generated_text_tokens = [] | |
| for user_encoded in lm_encode_from_sphn( | |
| mimi, | |
| lm_iterate_audio(user_pcm_2d, sample_interval_size=lm_gen._frame_size, pad=True), | |
| max_batch=1, | |
| ): | |
| steps = user_encoded.shape[-1] | |
| for c in range(steps): | |
| step_in = user_encoded[:, :, c : c + 1] | |
| tokens = lm_gen.step(step_in) | |
| if tokens is None: | |
| continue | |
| # Decode agent audio | |
| pcm = mimi.decode(tokens[:, 1:9]) | |
| _ = other_mimi.decode(tokens[:, 1:9]) | |
| pcm_np = pcm.detach().cpu().numpy()[0, 0] | |
| generated_frames.append(pcm_np) | |
| # Decode text token | |
| text_token = tokens[0, 0, 0].item() | |
| if text_token not in (0, 3): | |
| piece = _text_tokenizer.id_to_piece(text_token) | |
| piece = piece.replace("\u2581", " ") | |
| generated_text_tokens.append(piece) | |
| else: | |
| token_map = ["EPAD", "BOS", "EOS", "PAD"] | |
| generated_text_tokens.append(token_map[text_token]) | |
| if not generated_frames: | |
| raise gr.Error("No audio frames were generated. Try a longer input.") | |
| # Concatenate and trim to match input duration | |
| output_pcm = np.concatenate(generated_frames, axis=-1) | |
| if output_pcm.shape[-1] > total_target_samples: | |
| output_pcm = output_pcm[:total_target_samples] | |
| elif output_pcm.shape[-1] < total_target_samples: | |
| pad_len = total_target_samples - output_pcm.shape[-1] | |
| output_pcm = np.concatenate( | |
| [output_pcm, np.zeros(pad_len, dtype=output_pcm.dtype)], axis=-1 | |
| ) | |
| # Build transcript (filter control tokens) | |
| transcript_parts = [] | |
| for tok in generated_text_tokens: | |
| if tok in ("EPAD", "BOS", "EOS", "PAD"): | |
| continue | |
| transcript_parts.append(tok) | |
| transcript = "".join(transcript_parts).strip() | |
| # Clean up GPU memory | |
| del lm_gen, lm, mimi, other_mimi | |
| torch.cuda.empty_cache() | |
| print(f"[Inference] Done. Output: {output_pcm.shape[-1]} samples, transcript: {len(transcript)} chars") | |
| return (model_sr, output_pcm), transcript | |
| # --------------------------------------------------------------------------- | |
| # Gradio UI | |
| # --------------------------------------------------------------------------- | |
| with gr.Blocks(theme=gr.themes.Base(), title="PersonaPlex") as demo: | |
| gr.Markdown( | |
| "# PersonaPlex\n" | |
| "Speech-to-speech with 16 voices and persona control. " | |
| "Powered by NVIDIA PersonaPlex on ZeroGPU." | |
| ) | |
| with gr.Row(): | |
| with gr.Column(scale=1): | |
| voice = gr.Dropdown( | |
| choices=list(VOICES.keys()), | |
| value="Natural Female 3 (NATF2)", | |
| label="Voice", | |
| ) | |
| persona_preset = gr.Dropdown( | |
| choices=list(PERSONAS.keys()), | |
| value="Assistant", | |
| label="Persona Preset", | |
| ) | |
| persona_text = gr.Textbox( | |
| value=PERSONAS["Assistant"], | |
| label="Persona Prompt", | |
| lines=3, | |
| ) | |
| seed = gr.Number( | |
| value=42424242, | |
| label="Seed (-1 for random)", | |
| precision=0, | |
| ) | |
| with gr.Column(scale=2): | |
| audio_input = gr.Audio( | |
| sources=["microphone", "upload"], | |
| type="numpy", | |
| label="Your Audio", | |
| ) | |
| run_btn = gr.Button( | |
| "Generate Response", | |
| variant="primary", | |
| size="lg", | |
| ) | |
| audio_output = gr.Audio(type="numpy", label="PersonaPlex Response") | |
| transcript = gr.Textbox( | |
| label="Transcript", lines=5, interactive=False | |
| ) | |
| # Wire preset dropdown to update persona text | |
| persona_preset.change( | |
| fn=lambda p: PERSONAS.get(p, ""), | |
| inputs=persona_preset, | |
| outputs=persona_text, | |
| ) | |
| run_btn.click( | |
| fn=run_inference, | |
| inputs=[audio_input, voice, persona_text, seed], | |
| outputs=[audio_output, transcript], | |
| ) | |
| demo.queue().launch() | |