Spaces:
Running on Zero
Running on Zero
| """ | |
| PersonaPlex 7B ZeroGPU Demo | |
| Speech-to-speech demo using nvidia/personaplex-7b-v1 on ZeroGPU. | |
| """ | |
| import os | |
| import tarfile | |
| import tempfile | |
| import spaces | |
| import gradio as gr | |
| import torch | |
| import numpy as np | |
| from huggingface_hub import hf_hub_download | |
| from sentencepiece import SentencePieceProcessor | |
| # ============================================================================ | |
| # Configuration | |
| # ============================================================================ | |
| HF_REPO = "nvidia/personaplex-7b-v1" | |
| SAMPLE_RATE = 24000 | |
| PERSONAS = { | |
| "Helpful Assistant": "You are a helpful, friendly AI assistant.", | |
| "Casual Friend": "You are a casual, laid-back friend having a conversation.", | |
| "Professional": "You are a professional business consultant.", | |
| "Teacher": "You are a patient, knowledgeable teacher explaining concepts.", | |
| } | |
| # Map voice names to .pt files from voices.tgz | |
| VOICES = { | |
| "Natural Female": "NATF2.pt", | |
| "Natural Male": "NATM0.pt", | |
| "Variety Female": "VARF0.pt", | |
| "Variety Male": "VARM0.pt", | |
| } | |
| # ============================================================================ | |
| # Download weights at startup | |
| # ============================================================================ | |
| print("PersonaPlex Demo starting...") | |
| HF_TOKEN = os.environ.get("HF_TOKEN") | |
| if not HF_TOKEN: | |
| print("Warning: HF_TOKEN not set.") | |
| from moshi.models import loaders | |
| print("Downloading model weights (this may take a while on first run)...") | |
| try: | |
| MIMI_WEIGHT_PATH = hf_hub_download(HF_REPO, loaders.MIMI_NAME, token=HF_TOKEN) | |
| MOSHI_WEIGHT_PATH = hf_hub_download(HF_REPO, loaders.MOSHI_NAME, token=HF_TOKEN) | |
| print(f"Mimi weights: {MIMI_WEIGHT_PATH}") | |
| print(f"Moshi weights: {MOSHI_WEIGHT_PATH}") | |
| # Download tokenizer for text prompts | |
| TOKENIZER_PATH = hf_hub_download(HF_REPO, "tokenizer_spm_32k_3.model", token=HF_TOKEN) | |
| print(f"Tokenizer: {TOKENIZER_PATH}") | |
| # Download and extract voice prompts | |
| voices_tgz = hf_hub_download(HF_REPO, "voices.tgz", token=HF_TOKEN) | |
| VOICES_DIR = tempfile.mkdtemp(prefix="personaplex_voices_") | |
| with tarfile.open(voices_tgz, "r:gz") as tar: | |
| tar.extractall(VOICES_DIR) | |
| # Check if voices are in a subdirectory (some tarballs have nested structure) | |
| pt_files = [f for f in os.listdir(VOICES_DIR) if f.endswith('.pt')] | |
| if not pt_files: | |
| subdirs = [d for d in os.listdir(VOICES_DIR) if os.path.isdir(os.path.join(VOICES_DIR, d))] | |
| if subdirs: | |
| VOICES_DIR = os.path.join(VOICES_DIR, subdirs[0]) | |
| pt_files = [f for f in os.listdir(VOICES_DIR) if f.endswith('.pt')] | |
| print(f"Voices directory: {VOICES_DIR} ({len(pt_files)} voice files)") | |
| # Load tokenizer | |
| TEXT_TOKENIZER = SentencePieceProcessor(TOKENIZER_PATH) | |
| print("No tokenizer found, using default" if TEXT_TOKENIZER is None else "Tokenizer loaded") | |
| except Exception as e: | |
| print(f"Error: {e}") | |
| raise | |
| print("Weight download complete! Models will load on GPU when needed.") | |
| # ============================================================================ | |
| # GPU Inference | |
| # ============================================================================ | |
| def generate_response( | |
| audio_input: tuple, | |
| persona: str, | |
| voice: str, | |
| temperature: float = 0.7, | |
| top_k: int = 250, | |
| max_duration: float = 10.0, | |
| ) -> tuple: | |
| """Generate speech response.""" | |
| from moshi.models import loaders | |
| from moshi.models.lm import LMGen | |
| if audio_input is None: | |
| raise gr.Error("Please provide audio input") | |
| input_sr, input_audio = audio_input | |
| if len(input_audio) == 0: | |
| raise gr.Error("Audio input is empty") | |
| print(f"Processing audio: {len(input_audio)} samples at {input_sr}Hz") | |
| print(f"Persona: {persona}, Voice: {voice}") | |
| print(f"Temperature: {temperature}, Top-k: {top_k}") | |
| device = torch.device("cuda") | |
| # Load Mimi codec to GPU | |
| print("Loading Mimi codec to GPU...") | |
| mimi = loaders.get_mimi(MIMI_WEIGHT_PATH, device=device) | |
| mimi.eval() | |
| mimi.set_num_codebooks(8) | |
| # Load Moshi LM to GPU | |
| print("Loading Moshi LM to GPU...") | |
| lm = loaders.get_moshi_lm(MOSHI_WEIGHT_PATH, device=device) | |
| lm.eval() | |
| # Load a separate Mimi for decoding (to avoid streaming state conflicts) | |
| print("Loading Mimi decoder...") | |
| mimi_dec = loaders.get_mimi(MIMI_WEIGHT_PATH, device=device) | |
| mimi_dec.eval() | |
| mimi_dec.set_num_codebooks(8) | |
| # Resample input audio to model sample rate if needed | |
| if input_sr != mimi.sample_rate: | |
| import torchaudio.functional as F | |
| audio_t = torch.from_numpy(input_audio.astype(np.float32)) | |
| if audio_t.dim() == 1: | |
| audio_t = audio_t.unsqueeze(0) | |
| audio_t = F.resample(audio_t, input_sr, int(mimi.sample_rate)) | |
| input_audio = audio_t.squeeze().numpy() | |
| # Normalize audio | |
| input_audio = input_audio.astype(np.float32) | |
| max_val = np.abs(input_audio).max() | |
| if max_val > 1.0: | |
| input_audio = input_audio / max_val | |
| elif 0 < max_val < 0.1: | |
| input_audio = input_audio / max_val * 0.5 | |
| # Convert to tensor [B, C, T] | |
| audio_tensor = torch.from_numpy(input_audio).to(device) | |
| if audio_tensor.dim() == 1: | |
| audio_tensor = audio_tensor.unsqueeze(0).unsqueeze(0) | |
| elif audio_tensor.dim() == 2: | |
| audio_tensor = audio_tensor.unsqueeze(0) | |
| # Pad to frame boundary | |
| frame_size = int(mimi.sample_rate / mimi.frame_rate) | |
| pad_len = (frame_size - (audio_tensor.shape[-1] % frame_size)) % frame_size | |
| if pad_len > 0: | |
| audio_tensor = torch.nn.functional.pad(audio_tensor, (0, pad_len)) | |
| # Create LMGen with proper configuration | |
| print("Initializing LMGen with prompts...") | |
| lm_gen = LMGen( | |
| lm, | |
| device, | |
| use_sampling=True, | |
| temp=temperature, | |
| temp_text=0.7, | |
| top_k=top_k, | |
| top_k_text=25, | |
| check=False, | |
| audio_silence_frame_cnt=int(0.5 * mimi.frame_rate), | |
| sample_rate=int(mimi.sample_rate), | |
| frame_rate=int(mimi.frame_rate), | |
| ) | |
| # Load voice prompt | |
| voice_file = VOICES.get(voice, "NATF2.pt") | |
| voice_path = os.path.join(VOICES_DIR, voice_file) | |
| if not os.path.exists(voice_path): | |
| available = [f for f in os.listdir(VOICES_DIR) if f.endswith('.pt')] | |
| raise gr.Error(f"Voice '{voice_file}' not found. Available: {available}") | |
| print(f"Voice: {voice_file}") | |
| lm_gen.load_voice_prompt_embeddings(voice_path) | |
| # Set text prompt (persona) | |
| text_prompt = PERSONAS.get(persona, PERSONAS["Helpful Assistant"]) | |
| text_tokens = TEXT_TOKENIZER.encode(f"<system> {text_prompt} </system>") | |
| lm_gen.text_prompt_tokens = text_tokens | |
| print(f"Text prompt set: {text_prompt[:50]}...") | |
| # Encode input audio | |
| print("Encoding input audio...") | |
| all_codes = [] | |
| with torch.no_grad(), mimi.streaming(batch_size=1): | |
| for offset in range(0, audio_tensor.shape[-1], frame_size): | |
| frame = audio_tensor[:, :, offset:offset + frame_size] | |
| codes = mimi.encode(frame) | |
| if codes.shape[-1] > 0: | |
| all_codes.append(codes) | |
| if not all_codes: | |
| raise gr.Error("Failed to encode audio") | |
| print(f"Encoded {len(all_codes)} frames from input") | |
| # Generate response with prompts | |
| print("Generating response...") | |
| out_wav_chunks = [] | |
| max_steps = int(max_duration * mimi.frame_rate) | |
| with torch.no_grad(), lm_gen.streaming(1), mimi_dec.streaming(1): | |
| # First, inject system prompts (voice + silence + text + silence) | |
| print("Stepping system prompts...") | |
| lm_gen.step_system_prompts(mimi_dec) | |
| # Feed user audio codes and collect response | |
| print("Processing user input...") | |
| for code in all_codes: | |
| tokens_out = lm_gen.step(code.to(device)) | |
| if tokens_out is not None: | |
| wav_chunk = mimi_dec.decode(tokens_out[:, 1:]) | |
| out_wav_chunks.append(wav_chunk) | |
| # Continue generating response | |
| print("Generating additional response...") | |
| for step in range(max_steps): | |
| tokens_out = lm_gen.step(None) | |
| if tokens_out is not None: | |
| wav_chunk = mimi_dec.decode(tokens_out[:, 1:]) | |
| out_wav_chunks.append(wav_chunk) | |
| # Early stop on extended silence | |
| if len(out_wav_chunks) > 30: | |
| recent = torch.cat(out_wav_chunks[-10:], dim=-1) | |
| if recent.abs().mean() < 0.001: | |
| print(f"Stopping at step {step} (silence detected)") | |
| break | |
| if not out_wav_chunks: | |
| raise gr.Error("No audio generated") | |
| # Concatenate and normalize output | |
| output_audio = torch.cat(out_wav_chunks, dim=-1) | |
| output_audio = output_audio.squeeze().cpu().numpy() | |
| max_val = np.abs(output_audio).max() | |
| if max_val > 0: | |
| output_audio = output_audio / max_val * 0.9 | |
| output_audio = (output_audio * 32767).astype(np.int16) | |
| output_sr = int(mimi.sample_rate) | |
| print(f"Output: {len(output_audio)} samples ({len(output_audio)/output_sr:.1f}s) at {output_sr}Hz") | |
| return (output_sr, output_audio) | |
| # ============================================================================ | |
| # UI | |
| # ============================================================================ | |
| def create_demo(): | |
| print("Creating Gradio demo...") | |
| with gr.Blocks(title="PersonaPlex 7B") as demo: | |
| gr.Markdown("# PersonaPlex 7B Demo\n\nSpeech-to-speech with [nvidia/personaplex-7b-v1](https://huggingface.co/nvidia/personaplex-7b-v1)") | |
| with gr.Row(): | |
| with gr.Column(): | |
| audio_input = gr.Audio(label="Input", sources=["microphone", "upload"], type="numpy") | |
| persona = gr.Dropdown(label="Persona", choices=list(PERSONAS.keys()), value="Helpful Assistant") | |
| voice = gr.Dropdown(label="Voice", choices=list(VOICES.keys()), value="Natural Female") | |
| with gr.Accordion("Advanced", open=False): | |
| temp = gr.Slider(label="Temperature", minimum=0.1, maximum=1.5, value=0.7, step=0.1) | |
| top_k = gr.Slider(label="Top-K", minimum=50, maximum=500, value=250, step=50) | |
| max_dur = gr.Slider(label="Max Duration (s)", minimum=1, maximum=30, value=10, step=1) | |
| btn = gr.Button("Generate", variant="primary") | |
| with gr.Column(): | |
| audio_output = gr.Audio(label="Output", type="numpy") | |
| gr.Markdown("**Note:** First inference loads model to GPU (~10s). Voice and persona affect the response style.") | |
| btn.click(generate_response, [audio_input, persona, voice, temp, top_k, max_dur], audio_output) | |
| return demo | |
| if __name__ == "__main__": | |
| demo = create_demo() | |
| print("Launching demo...") | |
| demo.queue(default_concurrency_limit=1, max_size=16) | |
| demo.launch() | |