File size: 11,073 Bytes
77e6a0d
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
a3b50ab
 
 
 
77e6a0d
 
 
 
 
 
 
 
 
 
 
 
 
 
 
a3b50ab
 
 
8c0b009
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
d96db07
 
 
 
 
8c0b009
 
 
 
 
 
 
 
 
 
77e6a0d
 
d96db07
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
77e6a0d
 
 
 
 
 
 
 
 
 
 
 
 
 
 
a4b11f0
77e6a0d
 
 
 
 
 
 
8c0b009
 
 
 
 
 
 
a4b11f0
77e6a0d
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
b840b20
 
 
 
 
 
 
 
 
 
 
 
 
 
d96db07
77e6a0d
 
 
 
 
 
 
 
 
 
 
 
a4b11f0
77e6a0d
 
 
 
 
4b1918b
 
 
 
 
77e6a0d
 
 
 
 
 
b840b20
77e6a0d
 
 
 
 
 
 
 
 
b840b20
77e6a0d
 
 
 
b840b20
 
 
 
77e6a0d
 
 
 
 
 
 
a4b11f0
77e6a0d
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1b0acf6
77e6a0d
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
import gradio as gr
import spaces
import torch
import numpy as np
import os
import tarfile
from pathlib import Path
from typing import Optional
from huggingface_hub import hf_hub_download
import sentencepiece

# Configuration
HF_REPO = "nvidia/personaplex-7b-v1"
DEVICE = "cuda"
SAMPLE_RATE = 24000

# Available voices in PersonaPlex
ALL_VOICES = [
    "NATF0", "NATF1", "NATF2", "NATF3",  # Natural Female
    "NATM0", "NATM1", "NATM2", "NATM3",  # Natural Male
    "VARF0", "VARF1", "VARF2", "VARF3", "VARF4",  # Variety Female
    "VARM0", "VARM1", "VARM2", "VARM3", "VARM4",  # Variety Male
]

# Example persona prompts from PersonaPlex paper
EXAMPLE_PERSONAS = [
    "You are a wise and friendly teacher. Answer questions or provide advice in a clear and engaging way.",
    "You enjoy having a good conversation.",
    "You work for CitySan Services which is a waste management company and your name is Ayelen Lucero.",
    "You enjoy having a good conversation. Have a technical discussion about fixing a reactor core on a spaceship to Mars. You are an astronaut on a Mars mission. Your name is Alex.",
]

# Import moshi after spaces to allow interception
from moshi.models import loaders, LMGen
from moshi.models.lm import load_audio, _iterate_audio, encode_from_sphn

# Pre-download model weights at startup (cached by huggingface_hub)
print("Downloading model weights...")
MIMI_WEIGHT = hf_hub_download(HF_REPO, loaders.MIMI_NAME)
MOSHI_WEIGHT = hf_hub_download(HF_REPO, loaders.MOSHI_NAME)
TOKENIZER_PATH = hf_hub_download(HF_REPO, loaders.TEXT_TOKENIZER_NAME)
VOICES_TGZ = hf_hub_download(HF_REPO, "voices.tgz")

# Extract voices archive
VOICES_DIR = Path(VOICES_TGZ).parent / "voices"
if not VOICES_DIR.exists():
    print("Extracting voice embeddings...")
    with tarfile.open(VOICES_TGZ, "r:gz") as tar:
        tar.extractall(path=Path(VOICES_TGZ).parent)
print("Model weights ready.")

# Load text tokenizer (CPU only, no CUDA needed)
text_tokenizer = sentencepiece.SentencePieceProcessor(TOKENIZER_PATH)

# Global model cache - models loaded lazily inside @spaces.GPU
_model_cache = {}


def get_models():
    """Lazy load models on first GPU call."""
    global _model_cache
    if "initialized" not in _model_cache:
        print("Loading models to GPU...")

        # Load Mimi encoder/decoder
        mimi = loaders.get_mimi(MIMI_WEIGHT, DEVICE)
        other_mimi = loaders.get_mimi(MIMI_WEIGHT, DEVICE)

        # Load Moshi LM
        lm = loaders.get_moshi_lm(MOSHI_WEIGHT, device=DEVICE)
        lm.eval()

        # Create LMGen wrapper
        frame_size = int(mimi.sample_rate / mimi.frame_rate)
        lm_gen = LMGen(
            lm,
            audio_silence_frame_cnt=int(0.5 * mimi.frame_rate),
            sample_rate=mimi.sample_rate,
            device=DEVICE,
            frame_rate=mimi.frame_rate,
            temp=0.8,
            temp_text=0.7,
            top_k=250,
            top_k_text=25,
        )

        # Enable streaming mode
        mimi.streaming_forever(1)
        other_mimi.streaming_forever(1)
        lm_gen.streaming_forever(1)

        # Run warmup to initialize CUDA graphs (improves performance)
        print("Running warmup...")
        _warmup_models(mimi, other_mimi, lm_gen, frame_size)
        print("Warmup complete.")

        _model_cache.update({
            "mimi": mimi,
            "other_mimi": other_mimi,
            "lm_gen": lm_gen,
            "frame_size": frame_size,
            "initialized": True,
        })
        print("Models loaded successfully.")
    
    return _model_cache


def _warmup_models(mimi, other_mimi, lm_gen, frame_size):
    """Run warmup passes to initialize CUDA graphs."""
    for _ in range(4):
        chunk = torch.zeros(1, 1, frame_size, dtype=torch.float32, device=DEVICE)
        codes = mimi.encode(chunk)
        _ = other_mimi.encode(chunk)
        for c in range(codes.shape[-1]):
            tokens = lm_gen.step(codes[:, :, c:c+1])
            if tokens is not None:
                _ = mimi.decode(tokens[:, 1:9])
                _ = other_mimi.decode(tokens[:, 1:9])
    torch.cuda.synchronize()
    # Reset after warmup
    mimi.reset_streaming()
    other_mimi.reset_streaming()
    lm_gen.reset_streaming()


def wrap_with_system_tags(text: str) -> str:
    """Add system tags as PersonaPlex expects."""
    text = text.strip()
    if text.startswith("<system>") and text.endswith("<system>"):
        return text
    return f"<system> {text} <system>"


def decode_tokens_to_pcm(mimi, other_mimi, tokens: torch.Tensor) -> np.ndarray:
    """Decode audio tokens to PCM waveform."""
    # tokens shape: [B, num_codebooks, 1]
    # Agent audio is in codebooks 1:9
    agent_audio_tokens = tokens[:, 1:9, :]
    pcm = other_mimi.decode(agent_audio_tokens)
    return pcm[0, 0].detach().cpu().numpy()


@spaces.GPU(duration=120)
def generate_response(audio_input, persona: str, voice: str):
    """Process audio input and generate PersonaPlex response."""
    if audio_input is None:
        return None, "Please record audio first."
    
    # Get lazily loaded models
    models = get_models()
    mimi = models["mimi"]
    other_mimi = models["other_mimi"]
    lm_gen = models["lm_gen"]
    frame_size = models["frame_size"]
        
    # Process input audio
    sr, audio = audio_input
    audio = audio.astype(np.float32)
    
    # Convert to mono if stereo
    if audio.ndim > 1:
        audio = audio.mean(axis=1)
    
    # Normalize to [-1, 1]
    if audio.max() > 1.0 or audio.min() < -1.0:
        audio = audio / 32768.0 if audio.dtype == np.int16 else audio / np.abs(audio).max()
    
    # Resample to model's sample rate if needed
    if sr != mimi.sample_rate:
        import sphn
        audio = sphn.resample(audio, sr, mimi.sample_rate)
    
    # PREPEND SILENCE: Let model say its default greeting during this time (we'll discard this output)
    prepend_silence_duration = 2  # seconds
    prepend_silence = np.zeros(int(prepend_silence_duration * mimi.sample_rate), dtype=np.float32)
    
    # APPEND SILENCE: Give model time to complete its response after user finishes speaking
    append_silence_duration = 8  # seconds
    append_silence = np.zeros(int(append_silence_duration * mimi.sample_rate), dtype=np.float32)
    
    # Final audio: [prepend_silence] + [user_audio] + [append_silence]
    audio = np.concatenate([prepend_silence, audio, append_silence])
    
    # Calculate how many output frames to skip (corresponds to prepend silence)
    # frame_rate is 12.5 Hz, so frames_to_skip = prepend_silence_duration * frame_rate
    frames_to_skip = int(prepend_silence_duration * 12.5)
    
    # Add channel dimension: (T,) -> (1, T)
    if audio.ndim == 1:
        audio = audio[None, :]
    
    # Load voice prompt
    voice_path = str(VOICES_DIR / f"{voice}.pt")
    if not os.path.exists(voice_path):
        return None, f"Voice '{voice}' not found."
    lm_gen.load_voice_prompt_embeddings(voice_path)
    
    # Set text prompt
    if persona.strip():
        lm_gen.text_prompt_tokens = text_tokenizer.encode(wrap_with_system_tags(persona))
    else:
        lm_gen.text_prompt_tokens = None
    
    # Run system prompts (voice + text conditioning)
    with lm_gen.streaming(1):
        # Reset streaming state inside the context
        mimi.reset_streaming()
        other_mimi.reset_streaming()
        lm_gen.reset_streaming()
        
        lm_gen.step_system_prompts(mimi)
        mimi.reset_streaming()
        
        # Process user audio frames
        generated_frames = []
        generated_text = []
        frame_count = 0  # Track frame index to skip prepend silence output
        
        for user_encoded in encode_from_sphn(
            mimi,
            _iterate_audio(audio, sample_interval_size=frame_size, pad=True),
            max_batch=1,
        ):
            for c in range(user_encoded.shape[-1]):
                step_in = user_encoded[:, :, c:c+1]
                tokens = lm_gen.step(step_in)
                frame_count += 1
                
                if tokens is None:
                    continue
                
                # Skip frames generated during prepend silence (model's default greeting)
                if frame_count <= frames_to_skip:
                    continue
                
                # Decode agent audio
                pcm = decode_tokens_to_pcm(mimi, other_mimi, tokens)
                generated_frames.append(pcm)
                
                # Decode text token
                text_token = tokens[0, 0, 0].item()
                if text_token not in (0, 3):  # Skip special tokens
                    text_piece = text_tokenizer.id_to_piece(text_token).replace("▁", " ")
                    generated_text.append(text_piece)
    
    if not generated_frames:
        return None, "No audio generated. Try speaking more clearly."
    
    # Concatenate output audio
    output_audio = np.concatenate(generated_frames, axis=-1)
    output_text = "".join(generated_text).strip()
    
    return (mimi.sample_rate, output_audio), output_text


# Build Gradio interface
with gr.Blocks(title="PersonaPlex Demo", theme=gr.themes.Soft()) as demo:
    gr.Markdown(
        """
        # 🎭 PersonaPlex
        **Voice and Role Control for Full Duplex Conversational Speech Models**
        
        [Paper](https://arxiv.org/abs/2503.04721) | [GitHub](https://github.com/NVIDIA/personaplex) | [Model](https://huggingface.co/nvidia/personaplex-7b-v1)
        
        ---
        
        Record your message, and PersonaPlex will respond with the configured persona and voice.
        """
    )
    
    with gr.Row():
        with gr.Column(scale=1):
            persona = gr.Textbox(
                label="Persona Description",
                placeholder="Describe the assistant's persona...",
                value=EXAMPLE_PERSONAS[0],
                lines=4,
            )
            voice = gr.Dropdown(
                choices=ALL_VOICES,
                value="NATF2",
                label="Voice"
            )
            gr.Examples(
                examples=[[p] for p in EXAMPLE_PERSONAS],
                inputs=[persona],
                label="Example Personas"
            )
        
        with gr.Column(scale=2):
            audio_input = gr.Audio(
                label="🎤 Record your message",
                sources=["microphone", "upload"],
                type="numpy",
            )
            generate_btn = gr.Button("Generate Response", variant="primary", size="lg")
            
            audio_output = gr.Audio(
                label="🔊 PersonaPlex Response",
                type="numpy",
                autoplay=True,
            )
            text_output = gr.Textbox(
                label="📝 Response Text",
                interactive=False,
            )
    
    generate_btn.click(
        fn=generate_response,
        inputs=[audio_input, persona, voice],
        outputs=[audio_output, text_output],
    )


if __name__ == "__main__":
    demo.launch()