MohamedRashad commited on
Commit
77e6a0d
·
1 Parent(s): 77728be

Add main application logic and requirements for PersonaPlex

Browse files
Files changed (2) hide show
  1. main.py +273 -0
  2. requirements.txt +9 -0
main.py ADDED
@@ -0,0 +1,273 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gradio as gr
2
+ import spaces
3
+ import torch
4
+ import numpy as np
5
+ import os
6
+ import tarfile
7
+ from pathlib import Path
8
+ from typing import Optional
9
+ from huggingface_hub import hf_hub_download
10
+ import sentencepiece
11
+
12
+ # PersonaPlex model imports - installed via: pip install git+https://github.com/NVIDIA/personaplex.git#subdirectory=moshi
13
+ from moshi.models import loaders, LMGen
14
+ from moshi.models.lm import load_audio, _iterate_audio, encode_from_sphn
15
+
16
+ # Configuration
17
+ HF_REPO = "nvidia/personaplex-7b-v1"
18
+ DEVICE = "cuda"
19
+ SAMPLE_RATE = 24000
20
+
21
+ # Available voices in PersonaPlex
22
+ ALL_VOICES = [
23
+ "NATF0", "NATF1", "NATF2", "NATF3", # Natural Female
24
+ "NATM0", "NATM1", "NATM2", "NATM3", # Natural Male
25
+ "VARF0", "VARF1", "VARF2", "VARF3", "VARF4", # Variety Female
26
+ "VARM0", "VARM1", "VARM2", "VARM3", "VARM4", # Variety Male
27
+ ]
28
+
29
+ # Example persona prompts from PersonaPlex paper
30
+ EXAMPLE_PERSONAS = [
31
+ "You are a wise and friendly teacher. Answer questions or provide advice in a clear and engaging way.",
32
+ "You enjoy having a good conversation.",
33
+ "You work for CitySan Services which is a waste management company and your name is Ayelen Lucero.",
34
+ "You enjoy having a good conversation. Have a technical discussion about fixing a reactor core on a spaceship to Mars. You are an astronaut on a Mars mission. Your name is Alex.",
35
+ ]
36
+
37
+ # Pre-download model weights at startup (cached by huggingface_hub)
38
+ print("Downloading model weights...")
39
+ MIMI_WEIGHT = hf_hub_download(HF_REPO, loaders.MIMI_NAME)
40
+ MOSHI_WEIGHT = hf_hub_download(HF_REPO, loaders.MOSHI_NAME)
41
+ TOKENIZER_PATH = hf_hub_download(HF_REPO, loaders.TEXT_TOKENIZER_NAME)
42
+ VOICES_TGZ = hf_hub_download(HF_REPO, "voices.tgz")
43
+
44
+ # Extract voices archive
45
+ VOICES_DIR = Path(VOICES_TGZ).parent / "voices"
46
+ if not VOICES_DIR.exists():
47
+ print("Extracting voice embeddings...")
48
+ with tarfile.open(VOICES_TGZ, "r:gz") as tar:
49
+ tar.extractall(path=Path(VOICES_TGZ).parent)
50
+ print("Model weights ready.")
51
+
52
+ # Global model cache
53
+ _model_cache = {}
54
+
55
+
56
+ def get_models():
57
+ """Lazy load models on first GPU call."""
58
+ if "initialized" not in _model_cache:
59
+ print("Loading models to GPU...")
60
+
61
+ # Load Mimi encoder/decoder
62
+ mimi = loaders.get_mimi(MIMI_WEIGHT, DEVICE)
63
+ other_mimi = loaders.get_mimi(MIMI_WEIGHT, DEVICE)
64
+
65
+ # Load Moshi LM
66
+ lm = loaders.get_moshi_lm(MOSHI_WEIGHT, device=DEVICE)
67
+ lm.eval()
68
+
69
+ # Load text tokenizer
70
+ text_tokenizer = sentencepiece.SentencePieceProcessor(TOKENIZER_PATH)
71
+
72
+ # Create LMGen wrapper
73
+ frame_size = int(mimi.sample_rate / mimi.frame_rate)
74
+ lm_gen = LMGen(
75
+ lm,
76
+ audio_silence_frame_cnt=int(0.5 * mimi.frame_rate),
77
+ sample_rate=mimi.sample_rate,
78
+ device=DEVICE,
79
+ frame_rate=mimi.frame_rate,
80
+ temp=0.8,
81
+ temp_text=0.7,
82
+ top_k=250,
83
+ top_k_text=25,
84
+ )
85
+
86
+ # Enable streaming mode
87
+ mimi.streaming_forever(1)
88
+ other_mimi.streaming_forever(1)
89
+ lm_gen.streaming_forever(1)
90
+
91
+ _model_cache.update({
92
+ "mimi": mimi,
93
+ "other_mimi": other_mimi,
94
+ "lm_gen": lm_gen,
95
+ "tokenizer": text_tokenizer,
96
+ "frame_size": frame_size,
97
+ "initialized": True,
98
+ })
99
+ print("Models loaded successfully.")
100
+
101
+ return _model_cache
102
+
103
+
104
+ def wrap_with_system_tags(text: str) -> str:
105
+ """Add system tags as PersonaPlex expects."""
106
+ text = text.strip()
107
+ if text.startswith("<system>") and text.endswith("<system>"):
108
+ return text
109
+ return f"<system> {text} <system>"
110
+
111
+
112
+ def decode_tokens_to_pcm(mimi, other_mimi, tokens: torch.Tensor) -> np.ndarray:
113
+ """Decode audio tokens to PCM waveform."""
114
+ # tokens shape: [B, num_codebooks, 1]
115
+ # Agent audio is in codebooks 1:9
116
+ agent_audio_tokens = tokens[:, 1:9, :]
117
+ pcm = other_mimi.decode(agent_audio_tokens)
118
+ return pcm[0, 0].cpu().numpy()
119
+
120
+
121
+ @spaces.GPU(duration=120)
122
+ def generate_response(audio_input, persona: str, voice: str):
123
+ """Process audio input and generate PersonaPlex response."""
124
+ if audio_input is None:
125
+ return None, "Please record audio first."
126
+
127
+ models = get_models()
128
+ mimi = models["mimi"]
129
+ other_mimi = models["other_mimi"]
130
+ lm_gen = models["lm_gen"]
131
+ tokenizer = models["tokenizer"]
132
+ frame_size = models["frame_size"]
133
+
134
+ # Process input audio
135
+ sr, audio = audio_input
136
+ audio = audio.astype(np.float32)
137
+
138
+ # Convert to mono if stereo
139
+ if audio.ndim > 1:
140
+ audio = audio.mean(axis=1)
141
+
142
+ # Normalize to [-1, 1]
143
+ if audio.max() > 1.0 or audio.min() < -1.0:
144
+ audio = audio / 32768.0 if audio.dtype == np.int16 else audio / np.abs(audio).max()
145
+
146
+ # Resample to model's sample rate if needed
147
+ if sr != mimi.sample_rate:
148
+ import sphn
149
+ audio = sphn.resample(audio, sr, mimi.sample_rate)
150
+
151
+ # Add channel dimension: (T,) -> (1, T)
152
+ if audio.ndim == 1:
153
+ audio = audio[None, :]
154
+
155
+ # Load voice prompt
156
+ voice_path = str(VOICES_DIR / f"{voice}.pt")
157
+ if not os.path.exists(voice_path):
158
+ return None, f"Voice '{voice}' not found."
159
+ lm_gen.load_voice_prompt_embeddings(voice_path)
160
+
161
+ # Set text prompt
162
+ if persona.strip():
163
+ lm_gen.text_prompt_tokens = tokenizer.encode(wrap_with_system_tags(persona))
164
+ else:
165
+ lm_gen.text_prompt_tokens = None
166
+
167
+ # Reset streaming state
168
+ mimi.reset_streaming()
169
+ other_mimi.reset_streaming()
170
+ lm_gen.reset_streaming()
171
+
172
+ # Run system prompts (voice + text conditioning)
173
+ with lm_gen.streaming(1):
174
+ lm_gen.step_system_prompts(mimi)
175
+ mimi.reset_streaming()
176
+
177
+ # Process user audio frames
178
+ generated_frames = []
179
+ generated_text = []
180
+
181
+ for user_encoded in encode_from_sphn(
182
+ mimi,
183
+ _iterate_audio(audio, sample_interval_size=frame_size, pad=True),
184
+ max_batch=1,
185
+ ):
186
+ for c in range(user_encoded.shape[-1]):
187
+ step_in = user_encoded[:, :, c:c+1]
188
+ tokens = lm_gen.step(step_in)
189
+
190
+ if tokens is None:
191
+ continue
192
+
193
+ # Decode agent audio
194
+ pcm = decode_tokens_to_pcm(mimi, other_mimi, tokens)
195
+ generated_frames.append(pcm)
196
+
197
+ # Decode text token
198
+ text_token = tokens[0, 0, 0].item()
199
+ if text_token not in (0, 3): # Skip special tokens
200
+ text_piece = tokenizer.id_to_piece(text_token).replace("▁", " ")
201
+ generated_text.append(text_piece)
202
+
203
+ if not generated_frames:
204
+ return None, "No audio generated. Try speaking more clearly."
205
+
206
+ # Concatenate output audio
207
+ output_audio = np.concatenate(generated_frames, axis=-1)
208
+ output_text = "".join(generated_text).strip()
209
+
210
+ return (mimi.sample_rate, output_audio), output_text
211
+
212
+
213
+ # Build Gradio interface
214
+ with gr.Blocks(title="PersonaPlex Demo", theme=gr.themes.Soft()) as demo:
215
+ gr.Markdown(
216
+ """
217
+ # 🎭 PersonaPlex
218
+ **Voice and Role Control for Full Duplex Conversational Speech Models**
219
+
220
+ [Paper](https://arxiv.org/abs/2503.04721) | [GitHub](https://github.com/NVIDIA/personaplex) | [Model](https://huggingface.co/nvidia/personaplex-7b-v1)
221
+
222
+ ---
223
+
224
+ Record your message, and PersonaPlex will respond with the configured persona and voice.
225
+ """
226
+ )
227
+
228
+ with gr.Row():
229
+ with gr.Column(scale=1):
230
+ persona = gr.Textbox(
231
+ label="Persona Description",
232
+ placeholder="Describe the assistant's persona...",
233
+ value=EXAMPLE_PERSONAS[0],
234
+ lines=4,
235
+ )
236
+ voice = gr.Dropdown(
237
+ choices=ALL_VOICES,
238
+ value="NATF2",
239
+ label="Voice"
240
+ )
241
+ gr.Examples(
242
+ examples=[[p] for p in EXAMPLE_PERSONAS],
243
+ inputs=[persona],
244
+ label="Example Personas"
245
+ )
246
+
247
+ with gr.Column(scale=2):
248
+ audio_input = gr.Audio(
249
+ label="🎤 Record your message",
250
+ sources=["microphone"],
251
+ type="numpy",
252
+ )
253
+ generate_btn = gr.Button("Generate Response", variant="primary", size="lg")
254
+
255
+ audio_output = gr.Audio(
256
+ label="🔊 PersonaPlex Response",
257
+ type="numpy",
258
+ autoplay=True,
259
+ )
260
+ text_output = gr.Textbox(
261
+ label="📝 Response Text",
262
+ interactive=False,
263
+ )
264
+
265
+ generate_btn.click(
266
+ fn=generate_response,
267
+ inputs=[audio_input, persona, voice],
268
+ outputs=[audio_output, text_output],
269
+ )
270
+
271
+
272
+ if __name__ == "__main__":
273
+ demo.launch()
requirements.txt ADDED
@@ -0,0 +1,9 @@
 
 
 
 
 
 
 
 
 
 
1
+ gradio
2
+ spaces
3
+ torch
4
+ numpy
5
+ huggingface_hub
6
+ sentencepiece
7
+ sphn
8
+ safetensors
9
+ git+https://github.com/NVIDIA/personaplex.git#subdirectory=moshi