File size: 13,912 Bytes
493de4f
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
"""
PersonaPlex HuggingFace Space — Speech-to-speech with 16 voices and persona control.

Uses ZeroGPU (@spaces.GPU) for dynamic H200 allocation.
Models are loaded on CPU at startup, moved to CUDA inside the GPU-decorated function.
"""

import sys
import os
import random
import tarfile
import json
from pathlib import Path
from typing import Optional

sys.path.insert(0, ".")

import spaces
import gradio as gr
import torch
import numpy as np
import sentencepiece
import sphn
from huggingface_hub import hf_hub_download

from moshi.models import loaders, LMGen, MimiModel
from moshi.models.lm import (
    load_audio as lm_load_audio,
    _iterate_audio as lm_iterate_audio,
    encode_from_sphn as lm_encode_from_sphn,
)

# ---------------------------------------------------------------------------
# Constants
# ---------------------------------------------------------------------------
HF_REPO = "nvidia/personaplex-7b-v1"

VOICES = {
    "Natural Female 1 (NATF0)": "NATF0.pt",
    "Natural Female 2 (NATF1)": "NATF1.pt",
    "Natural Female 3 (NATF2)": "NATF2.pt",
    "Natural Female 4 (NATF3)": "NATF3.pt",
    "Natural Male 1 (NATM0)": "NATM0.pt",
    "Natural Male 2 (NATM1)": "NATM1.pt",
    "Natural Male 3 (NATM2)": "NATM2.pt",
    "Natural Male 4 (NATM3)": "NATM3.pt",
    "Variety Female 1 (VARF0)": "VARF0.pt",
    "Variety Female 2 (VARF1)": "VARF1.pt",
    "Variety Female 3 (VARF2)": "VARF2.pt",
    "Variety Female 4 (VARF3)": "VARF3.pt",
    "Variety Female 5 (VARF4)": "VARF4.pt",
    "Variety Male 1 (VARM0)": "VARM0.pt",
    "Variety Male 2 (VARM1)": "VARM1.pt",
    "Variety Male 3 (VARM2)": "VARM2.pt",
    "Variety Male 4 (VARM3)": "VARM3.pt",
    "Variety Male 5 (VARM4)": "VARM4.pt",
}

PERSONAS = {
    "Assistant": "You are a wise and friendly teacher. Answer questions or provide advice in a clear and engaging way.",
    "Mars Astronaut": "You enjoy having a good conversation. Have a technical discussion about fixing a reactor core on a spaceship to Mars. You are an astronaut on a Mars mission. Your name is Alex.",
    "Restaurant": "You work for Jerusalem Shakshuka which is a restaurant and your name is Owen Foster. Information: There are two shakshuka options: Classic (poached eggs, $9.50) and Spicy (scrambled eggs with jalapenos, $10.25).",
    "Casual Chat": "You enjoy having a good conversation.",
    "Custom": "",
}

# ---------------------------------------------------------------------------
# Model globals (loaded on CPU at startup)
# ---------------------------------------------------------------------------
_mimi_weight_path: Optional[str] = None
_moshi_weight_path: Optional[str] = None
_tokenizer_path: Optional[str] = None
_voice_prompt_dir: Optional[str] = None
_text_tokenizer: Optional[sentencepiece.SentencePieceProcessor] = None


def _download_assets():
    """Download all model weights and voice prompts from HuggingFace Hub."""
    global _mimi_weight_path, _moshi_weight_path, _tokenizer_path
    global _voice_prompt_dir, _text_tokenizer

    print("[Init] Downloading config.json (download counter)...")
    hf_hub_download(HF_REPO, "config.json")

    print("[Init] Downloading Mimi weights...")
    _mimi_weight_path = hf_hub_download(HF_REPO, loaders.MIMI_NAME)

    print("[Init] Downloading Moshi LM weights...")
    _moshi_weight_path = hf_hub_download(HF_REPO, loaders.MOSHI_NAME)

    print("[Init] Downloading tokenizer...")
    _tokenizer_path = hf_hub_download(HF_REPO, loaders.TEXT_TOKENIZER_NAME)
    _text_tokenizer = sentencepiece.SentencePieceProcessor(_tokenizer_path)

    print("[Init] Downloading voice prompts...")
    voices_tgz = hf_hub_download(HF_REPO, "voices.tgz")
    voices_tgz = Path(voices_tgz)
    voices_dir = voices_tgz.parent / "voices"
    if not voices_dir.exists():
        print(f"[Init] Extracting {voices_tgz} -> {voices_dir}")
        with tarfile.open(voices_tgz, "r:gz") as tar:
            tar.extractall(path=voices_tgz.parent)
    if not voices_dir.exists():
        raise RuntimeError("voices.tgz did not contain a 'voices/' directory")
    _voice_prompt_dir = str(voices_dir)

    print("[Init] All assets downloaded successfully.")


# Download on import (CPU only, no GPU needed)
_download_assets()


# ---------------------------------------------------------------------------
# Audio helpers
# ---------------------------------------------------------------------------
def _resample_numpy(audio: np.ndarray, src_sr: int, dst_sr: int) -> np.ndarray:
    """Resample a 1-D numpy audio array from src_sr to dst_sr using linear interpolation."""
    if src_sr == dst_sr:
        return audio
    duration = len(audio) / src_sr
    target_len = int(duration * dst_sr)
    indices = np.linspace(0, len(audio) - 1, target_len)
    return np.interp(indices, np.arange(len(audio)), audio).astype(np.float32)


def _wrap_system_tags(text: str) -> str:
    """Add <system> tags as the model expects."""
    cleaned = text.strip()
    if cleaned.startswith("<system>") and cleaned.endswith("<system>"):
        return cleaned
    return f"<system> {cleaned} <system>"


# ---------------------------------------------------------------------------
# Inference (runs on GPU via ZeroGPU)
# ---------------------------------------------------------------------------
@spaces.GPU(duration=120)
def run_inference(audio_input, voice_name, persona_text, seed):
    """
    Run PersonaPlex speech-to-speech inference.

    Args:
        audio_input: tuple (sample_rate, numpy_array) from Gradio audio component
        voice_name: key from VOICES dict
        persona_text: persona system prompt string
        seed: int seed (-1 for random)

    Returns:
        (sample_rate, numpy_array): output audio
        str: transcript text
    """
    if audio_input is None:
        raise gr.Error("Please record or upload audio first.")

    input_sr, input_audio = audio_input

    # Convert to float32 if integer
    if input_audio.dtype in (np.int16, np.int32):
        input_audio = input_audio.astype(np.float32) / np.iinfo(input_audio.dtype).max

    # Convert stereo to mono
    if input_audio.ndim == 2:
        input_audio = input_audio.mean(axis=1)

    # Ensure 1-D float32
    input_audio = input_audio.astype(np.float32)

    # Seed RNG
    actual_seed = seed if seed >= 0 else random.randint(0, 2**31 - 1)
    torch.manual_seed(int(actual_seed))
    if torch.cuda.is_available():
        torch.cuda.manual_seed(int(actual_seed))
        torch.cuda.manual_seed_all(int(actual_seed))
    random.seed(int(actual_seed))
    np.random.seed(int(actual_seed))
    torch.backends.cudnn.deterministic = False
    torch.backends.cudnn.benchmark = False

    device = "cuda"

    # Load models fresh on GPU each call (ZeroGPU gives us a clean GPU)
    print("[Inference] Loading Mimi on CUDA...")
    mimi = loaders.get_mimi(_mimi_weight_path, device)
    other_mimi = loaders.get_mimi(_mimi_weight_path, device)
    print("[Inference] Mimi loaded.")

    print("[Inference] Loading Moshi LM on CUDA...")
    lm = loaders.get_moshi_lm(_moshi_weight_path, device=device)
    lm.eval()
    print("[Inference] Moshi LM loaded.")

    # Build LMGen
    frame_size = int(mimi.sample_rate / mimi.frame_rate)
    lm_gen = LMGen(
        lm,
        audio_silence_frame_cnt=int(0.5 * mimi.frame_rate),
        sample_rate=mimi.sample_rate,
        device=device,
        frame_rate=mimi.frame_rate,
        save_voice_prompt_embeddings=False,
        use_sampling=True,
        temp=0.8,
        temp_text=0.7,
        top_k=250,
        top_k_text=25,
    )

    # Streaming mode
    mimi.streaming_forever(1)
    other_mimi.streaming_forever(1)
    lm_gen.streaming_forever(1)

    # Warmup (CUDA graphs)
    print("[Inference] Warming up...")
    for _ in range(4):
        chunk = torch.zeros(1, 1, frame_size, dtype=torch.float32, device=device)
        codes = mimi.encode(chunk)
        _ = other_mimi.encode(chunk)
        for c in range(codes.shape[-1]):
            tokens = lm_gen.step(codes[:, :, c : c + 1])
            if tokens is None:
                continue
            _ = mimi.decode(tokens[:, 1:9])
            _ = other_mimi.decode(tokens[:, 1:9])
    if torch.cuda.is_available():
        torch.cuda.synchronize()
    print("[Inference] Warmup complete.")

    # Load voice prompt
    voice_file = VOICES.get(voice_name, "NATF2.pt")
    voice_path = os.path.join(_voice_prompt_dir, voice_file)
    if not os.path.exists(voice_path):
        raise gr.Error(f"Voice prompt file not found: {voice_path}")

    if voice_path.endswith(".pt"):
        lm_gen.load_voice_prompt_embeddings(voice_path)
    else:
        lm_gen.load_voice_prompt(voice_path)

    # Encode text prompt
    if persona_text and persona_text.strip():
        lm_gen.text_prompt_tokens = _text_tokenizer.encode(
            _wrap_system_tags(persona_text)
        )
    else:
        lm_gen.text_prompt_tokens = None

    # Reset streaming and run system prompts
    mimi.reset_streaming()
    other_mimi.reset_streaming()
    lm_gen.reset_streaming()
    print("[Inference] Running system prompts (voice + text)...")
    lm_gen.step_system_prompts(mimi)
    mimi.reset_streaming()
    print("[Inference] System prompts complete.")

    # Resample input audio to model sample rate (24 kHz)
    model_sr = int(mimi.sample_rate)
    user_pcm = _resample_numpy(input_audio, input_sr, model_sr)
    # Shape expected by lm helpers: (C, T)
    user_pcm_2d = user_pcm[np.newaxis, :]  # (1, T)

    total_target_samples = user_pcm_2d.shape[-1]

    # Stream user audio through the model
    print(f"[Inference] Processing {total_target_samples} samples ({total_target_samples / model_sr:.1f}s)...")
    generated_frames = []
    generated_text_tokens = []

    for user_encoded in lm_encode_from_sphn(
        mimi,
        lm_iterate_audio(user_pcm_2d, sample_interval_size=lm_gen._frame_size, pad=True),
        max_batch=1,
    ):
        steps = user_encoded.shape[-1]
        for c in range(steps):
            step_in = user_encoded[:, :, c : c + 1]
            tokens = lm_gen.step(step_in)
            if tokens is None:
                continue
            # Decode agent audio
            pcm = mimi.decode(tokens[:, 1:9])
            _ = other_mimi.decode(tokens[:, 1:9])
            pcm_np = pcm.detach().cpu().numpy()[0, 0]
            generated_frames.append(pcm_np)
            # Decode text token
            text_token = tokens[0, 0, 0].item()
            if text_token not in (0, 3):
                piece = _text_tokenizer.id_to_piece(text_token)
                piece = piece.replace("\u2581", " ")
                generated_text_tokens.append(piece)
            else:
                token_map = ["EPAD", "BOS", "EOS", "PAD"]
                generated_text_tokens.append(token_map[text_token])

    if not generated_frames:
        raise gr.Error("No audio frames were generated. Try a longer input.")

    # Concatenate and trim to match input duration
    output_pcm = np.concatenate(generated_frames, axis=-1)
    if output_pcm.shape[-1] > total_target_samples:
        output_pcm = output_pcm[:total_target_samples]
    elif output_pcm.shape[-1] < total_target_samples:
        pad_len = total_target_samples - output_pcm.shape[-1]
        output_pcm = np.concatenate(
            [output_pcm, np.zeros(pad_len, dtype=output_pcm.dtype)], axis=-1
        )

    # Build transcript (filter control tokens)
    transcript_parts = []
    for tok in generated_text_tokens:
        if tok in ("EPAD", "BOS", "EOS", "PAD"):
            continue
        transcript_parts.append(tok)
    transcript = "".join(transcript_parts).strip()

    # Clean up GPU memory
    del lm_gen, lm, mimi, other_mimi
    torch.cuda.empty_cache()

    print(f"[Inference] Done. Output: {output_pcm.shape[-1]} samples, transcript: {len(transcript)} chars")
    return (model_sr, output_pcm), transcript


# ---------------------------------------------------------------------------
# Gradio UI
# ---------------------------------------------------------------------------
with gr.Blocks(theme=gr.themes.Base(), title="PersonaPlex") as demo:
    gr.Markdown(
        "# PersonaPlex\n"
        "Speech-to-speech with 16 voices and persona control. "
        "Powered by NVIDIA PersonaPlex on ZeroGPU."
    )

    with gr.Row():
        with gr.Column(scale=1):
            voice = gr.Dropdown(
                choices=list(VOICES.keys()),
                value="Natural Female 3 (NATF2)",
                label="Voice",
            )
            persona_preset = gr.Dropdown(
                choices=list(PERSONAS.keys()),
                value="Assistant",
                label="Persona Preset",
            )
            persona_text = gr.Textbox(
                value=PERSONAS["Assistant"],
                label="Persona Prompt",
                lines=3,
            )
            seed = gr.Number(
                value=42424242,
                label="Seed (-1 for random)",
                precision=0,
            )

        with gr.Column(scale=2):
            audio_input = gr.Audio(
                sources=["microphone", "upload"],
                type="numpy",
                label="Your Audio",
            )
            run_btn = gr.Button(
                "Generate Response",
                variant="primary",
                size="lg",
            )
            audio_output = gr.Audio(type="numpy", label="PersonaPlex Response")
            transcript = gr.Textbox(
                label="Transcript", lines=5, interactive=False
            )

    # Wire preset dropdown to update persona text
    persona_preset.change(
        fn=lambda p: PERSONAS.get(p, ""),
        inputs=persona_preset,
        outputs=persona_text,
    )

    run_btn.click(
        fn=run_inference,
        inputs=[audio_input, voice, persona_text, seed],
        outputs=[audio_output, transcript],
    )

demo.queue().launch()