PersonaPlex / app.py
victor's picture
victor HF Staff
Clean up debugging output
e915802
"""
PersonaPlex 7B ZeroGPU Demo
Speech-to-speech demo using nvidia/personaplex-7b-v1 on ZeroGPU.
"""
import os
import tarfile
import tempfile
import spaces
import gradio as gr
import torch
import numpy as np
from huggingface_hub import hf_hub_download
from sentencepiece import SentencePieceProcessor
# ============================================================================
# Configuration
# ============================================================================
HF_REPO = "nvidia/personaplex-7b-v1"
SAMPLE_RATE = 24000
PERSONAS = {
"Helpful Assistant": "You are a helpful, friendly AI assistant.",
"Casual Friend": "You are a casual, laid-back friend having a conversation.",
"Professional": "You are a professional business consultant.",
"Teacher": "You are a patient, knowledgeable teacher explaining concepts.",
}
# Map voice names to .pt files from voices.tgz
VOICES = {
"Natural Female": "NATF2.pt",
"Natural Male": "NATM0.pt",
"Variety Female": "VARF0.pt",
"Variety Male": "VARM0.pt",
}
# ============================================================================
# Download weights at startup
# ============================================================================
print("PersonaPlex Demo starting...")
HF_TOKEN = os.environ.get("HF_TOKEN")
if not HF_TOKEN:
print("Warning: HF_TOKEN not set.")
from moshi.models import loaders
print("Downloading model weights (this may take a while on first run)...")
try:
MIMI_WEIGHT_PATH = hf_hub_download(HF_REPO, loaders.MIMI_NAME, token=HF_TOKEN)
MOSHI_WEIGHT_PATH = hf_hub_download(HF_REPO, loaders.MOSHI_NAME, token=HF_TOKEN)
print(f"Mimi weights: {MIMI_WEIGHT_PATH}")
print(f"Moshi weights: {MOSHI_WEIGHT_PATH}")
# Download tokenizer for text prompts
TOKENIZER_PATH = hf_hub_download(HF_REPO, "tokenizer_spm_32k_3.model", token=HF_TOKEN)
print(f"Tokenizer: {TOKENIZER_PATH}")
# Download and extract voice prompts
voices_tgz = hf_hub_download(HF_REPO, "voices.tgz", token=HF_TOKEN)
VOICES_DIR = tempfile.mkdtemp(prefix="personaplex_voices_")
with tarfile.open(voices_tgz, "r:gz") as tar:
tar.extractall(VOICES_DIR)
# Check if voices are in a subdirectory (some tarballs have nested structure)
pt_files = [f for f in os.listdir(VOICES_DIR) if f.endswith('.pt')]
if not pt_files:
subdirs = [d for d in os.listdir(VOICES_DIR) if os.path.isdir(os.path.join(VOICES_DIR, d))]
if subdirs:
VOICES_DIR = os.path.join(VOICES_DIR, subdirs[0])
pt_files = [f for f in os.listdir(VOICES_DIR) if f.endswith('.pt')]
print(f"Voices directory: {VOICES_DIR} ({len(pt_files)} voice files)")
# Load tokenizer
TEXT_TOKENIZER = SentencePieceProcessor(TOKENIZER_PATH)
print("No tokenizer found, using default" if TEXT_TOKENIZER is None else "Tokenizer loaded")
except Exception as e:
print(f"Error: {e}")
raise
print("Weight download complete! Models will load on GPU when needed.")
# ============================================================================
# GPU Inference
# ============================================================================
@spaces.GPU(duration=120)
def generate_response(
audio_input: tuple,
persona: str,
voice: str,
temperature: float = 0.7,
top_k: int = 250,
max_duration: float = 10.0,
) -> tuple:
"""Generate speech response."""
from moshi.models import loaders
from moshi.models.lm import LMGen
if audio_input is None:
raise gr.Error("Please provide audio input")
input_sr, input_audio = audio_input
if len(input_audio) == 0:
raise gr.Error("Audio input is empty")
print(f"Processing audio: {len(input_audio)} samples at {input_sr}Hz")
print(f"Persona: {persona}, Voice: {voice}")
print(f"Temperature: {temperature}, Top-k: {top_k}")
device = torch.device("cuda")
# Load Mimi codec to GPU
print("Loading Mimi codec to GPU...")
mimi = loaders.get_mimi(MIMI_WEIGHT_PATH, device=device)
mimi.eval()
mimi.set_num_codebooks(8)
# Load Moshi LM to GPU
print("Loading Moshi LM to GPU...")
lm = loaders.get_moshi_lm(MOSHI_WEIGHT_PATH, device=device)
lm.eval()
# Load a separate Mimi for decoding (to avoid streaming state conflicts)
print("Loading Mimi decoder...")
mimi_dec = loaders.get_mimi(MIMI_WEIGHT_PATH, device=device)
mimi_dec.eval()
mimi_dec.set_num_codebooks(8)
# Resample input audio to model sample rate if needed
if input_sr != mimi.sample_rate:
import torchaudio.functional as F
audio_t = torch.from_numpy(input_audio.astype(np.float32))
if audio_t.dim() == 1:
audio_t = audio_t.unsqueeze(0)
audio_t = F.resample(audio_t, input_sr, int(mimi.sample_rate))
input_audio = audio_t.squeeze().numpy()
# Normalize audio
input_audio = input_audio.astype(np.float32)
max_val = np.abs(input_audio).max()
if max_val > 1.0:
input_audio = input_audio / max_val
elif 0 < max_val < 0.1:
input_audio = input_audio / max_val * 0.5
# Convert to tensor [B, C, T]
audio_tensor = torch.from_numpy(input_audio).to(device)
if audio_tensor.dim() == 1:
audio_tensor = audio_tensor.unsqueeze(0).unsqueeze(0)
elif audio_tensor.dim() == 2:
audio_tensor = audio_tensor.unsqueeze(0)
# Pad to frame boundary
frame_size = int(mimi.sample_rate / mimi.frame_rate)
pad_len = (frame_size - (audio_tensor.shape[-1] % frame_size)) % frame_size
if pad_len > 0:
audio_tensor = torch.nn.functional.pad(audio_tensor, (0, pad_len))
# Create LMGen with proper configuration
print("Initializing LMGen with prompts...")
lm_gen = LMGen(
lm,
device,
use_sampling=True,
temp=temperature,
temp_text=0.7,
top_k=top_k,
top_k_text=25,
check=False,
audio_silence_frame_cnt=int(0.5 * mimi.frame_rate),
sample_rate=int(mimi.sample_rate),
frame_rate=int(mimi.frame_rate),
)
# Load voice prompt
voice_file = VOICES.get(voice, "NATF2.pt")
voice_path = os.path.join(VOICES_DIR, voice_file)
if not os.path.exists(voice_path):
available = [f for f in os.listdir(VOICES_DIR) if f.endswith('.pt')]
raise gr.Error(f"Voice '{voice_file}' not found. Available: {available}")
print(f"Voice: {voice_file}")
lm_gen.load_voice_prompt_embeddings(voice_path)
# Set text prompt (persona)
text_prompt = PERSONAS.get(persona, PERSONAS["Helpful Assistant"])
text_tokens = TEXT_TOKENIZER.encode(f"<system> {text_prompt} </system>")
lm_gen.text_prompt_tokens = text_tokens
print(f"Text prompt set: {text_prompt[:50]}...")
# Encode input audio
print("Encoding input audio...")
all_codes = []
with torch.no_grad(), mimi.streaming(batch_size=1):
for offset in range(0, audio_tensor.shape[-1], frame_size):
frame = audio_tensor[:, :, offset:offset + frame_size]
codes = mimi.encode(frame)
if codes.shape[-1] > 0:
all_codes.append(codes)
if not all_codes:
raise gr.Error("Failed to encode audio")
print(f"Encoded {len(all_codes)} frames from input")
# Generate response with prompts
print("Generating response...")
out_wav_chunks = []
max_steps = int(max_duration * mimi.frame_rate)
with torch.no_grad(), lm_gen.streaming(1), mimi_dec.streaming(1):
# First, inject system prompts (voice + silence + text + silence)
print("Stepping system prompts...")
lm_gen.step_system_prompts(mimi_dec)
# Feed user audio codes and collect response
print("Processing user input...")
for code in all_codes:
tokens_out = lm_gen.step(code.to(device))
if tokens_out is not None:
wav_chunk = mimi_dec.decode(tokens_out[:, 1:])
out_wav_chunks.append(wav_chunk)
# Continue generating response
print("Generating additional response...")
for step in range(max_steps):
tokens_out = lm_gen.step(None)
if tokens_out is not None:
wav_chunk = mimi_dec.decode(tokens_out[:, 1:])
out_wav_chunks.append(wav_chunk)
# Early stop on extended silence
if len(out_wav_chunks) > 30:
recent = torch.cat(out_wav_chunks[-10:], dim=-1)
if recent.abs().mean() < 0.001:
print(f"Stopping at step {step} (silence detected)")
break
if not out_wav_chunks:
raise gr.Error("No audio generated")
# Concatenate and normalize output
output_audio = torch.cat(out_wav_chunks, dim=-1)
output_audio = output_audio.squeeze().cpu().numpy()
max_val = np.abs(output_audio).max()
if max_val > 0:
output_audio = output_audio / max_val * 0.9
output_audio = (output_audio * 32767).astype(np.int16)
output_sr = int(mimi.sample_rate)
print(f"Output: {len(output_audio)} samples ({len(output_audio)/output_sr:.1f}s) at {output_sr}Hz")
return (output_sr, output_audio)
# ============================================================================
# UI
# ============================================================================
def create_demo():
print("Creating Gradio demo...")
with gr.Blocks(title="PersonaPlex 7B") as demo:
gr.Markdown("# PersonaPlex 7B Demo\n\nSpeech-to-speech with [nvidia/personaplex-7b-v1](https://huggingface.co/nvidia/personaplex-7b-v1)")
with gr.Row():
with gr.Column():
audio_input = gr.Audio(label="Input", sources=["microphone", "upload"], type="numpy")
persona = gr.Dropdown(label="Persona", choices=list(PERSONAS.keys()), value="Helpful Assistant")
voice = gr.Dropdown(label="Voice", choices=list(VOICES.keys()), value="Natural Female")
with gr.Accordion("Advanced", open=False):
temp = gr.Slider(label="Temperature", minimum=0.1, maximum=1.5, value=0.7, step=0.1)
top_k = gr.Slider(label="Top-K", minimum=50, maximum=500, value=250, step=50)
max_dur = gr.Slider(label="Max Duration (s)", minimum=1, maximum=30, value=10, step=1)
btn = gr.Button("Generate", variant="primary")
with gr.Column():
audio_output = gr.Audio(label="Output", type="numpy")
gr.Markdown("**Note:** First inference loads model to GPU (~10s). Voice and persona affect the response style.")
btn.click(generate_response, [audio_input, persona, voice, temp, top_k, max_dur], audio_output)
return demo
if __name__ == "__main__":
demo = create_demo()
print("Launching demo...")
demo.queue(default_concurrency_limit=1, max_size=16)
demo.launch()