# SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. # SPDX-License-Identifier: MIT # # Permission is hereby granted, free of charge, to any person obtaining a # copy of this software and associated documentation files (the "Software"), # to deal in the Software without restriction, including without limitation # the rights to use, copy, modify, merge, publish, distribute, sublicense, # and/or sell copies of the Software, and to permit persons to whom the # Software is furnished to do so, subject to the following conditions: # # The above copyright notice and this permission notice shall be included in # all copies or substantial portions of the Software. # # THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR # IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, # FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL # THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER # LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING # FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER # DEALINGS IN THE SOFTWARE. # Copyright (c) Kyutai, all rights reserved. # This source code is licensed under the license found in the # LICENSE file in the root directory of this source tree. import argparse import asyncio from dataclasses import dataclass import random import os from pathlib import Path import tarfile import time import secrets import sys from typing import Literal, Optional import aiohttp from aiohttp import web from huggingface_hub import hf_hub_download import numpy as np import sentencepiece import sphn import torch import random from .client_utils import make_log, colorize from .models import loaders, MimiModel, LMModel, LMGen from .utils.connection import create_ssl_context, get_lan_ip from .utils.logging import setup_logger, ColorizedLog logger = setup_logger(__name__) DeviceString = Literal["cuda"] | Literal["cpu"] #| Literal["mps"] def torch_auto_device(requested: Optional[DeviceString] = None) -> torch.device: """Return a torch.device based on the requested string or availability.""" if requested is not None: return torch.device(requested) if torch.cuda.is_available(): return torch.device("cuda") #elif hasattr(torch.backends, "mps") and torch.backends.mps.is_available(): # return torch.device("mps") return torch.device("cpu") def seed_all(seed): torch.manual_seed(seed) if torch.cuda.is_available(): torch.cuda.manual_seed(seed) torch.cuda.manual_seed_all(seed) # for multi-GPU setups random.seed(seed) np.random.seed(seed) torch.backends.cudnn.deterministic = False torch.backends.cudnn.benchmark = False def wrap_with_system_tags(text: str) -> str: """Add system tags as the model expects if they are missing. Example: " You enjoy having a good conversation. Have a deep conversation about technology. Your name is Jane. " """ cleaned = text.strip() if cleaned.startswith("") and cleaned.endswith(""): return cleaned return f" {cleaned} " @dataclass class ServerState: mimi: MimiModel other_mimi: MimiModel text_tokenizer: sentencepiece.SentencePieceProcessor lm_gen: LMGen lock: asyncio.Lock def __init__(self, mimi: MimiModel, other_mimi: MimiModel, text_tokenizer: sentencepiece.SentencePieceProcessor, lm: LMModel, device: str | torch.device, voice_prompt_dir: str | None = None, save_voice_prompt_embeddings: bool = False): self.mimi = mimi self.other_mimi = other_mimi self.text_tokenizer = text_tokenizer self.device = device self.voice_prompt_dir = voice_prompt_dir self.frame_size = int(self.mimi.sample_rate / self.mimi.frame_rate) self.lm_gen = LMGen(lm, audio_silence_frame_cnt=int(0.5 * self.mimi.frame_rate), sample_rate=self.mimi.sample_rate, device=device, frame_rate=self.mimi.frame_rate, save_voice_prompt_embeddings=save_voice_prompt_embeddings, ) self.lock = asyncio.Lock() self.mimi.streaming_forever(1) self.other_mimi.streaming_forever(1) self.lm_gen.streaming_forever(1) def warmup(self): for _ in range(4): chunk = torch.zeros(1, 1, self.frame_size, dtype=torch.float32, device=self.device) codes = self.mimi.encode(chunk) _ = self.other_mimi.encode(chunk) for c in range(codes.shape[-1]): tokens = self.lm_gen.step(codes[:, :, c: c + 1]) if tokens is None: continue _ = self.mimi.decode(tokens[:, 1:9]) _ = self.other_mimi.decode(tokens[:, 1:9]) if self.device.type == 'cuda': torch.cuda.synchronize() async def handle_chat(self, request): ws = web.WebSocketResponse() await ws.prepare(request) clog = ColorizedLog.randomize() peer = request.remote # IP peer_port = request.transport.get_extra_info("peername")[1] # Port clog.log("info", f"Incoming connection from {peer}:{peer_port}") # self.lm_gen.temp = float(request.query["audio_temperature"]) # self.lm_gen.temp_text = float(request.query["text_temperature"]) # self.lm_gen.top_k_text = max(1, int(request.query["text_topk"])) # self.lm_gen.top_k = max(1, int(request.query["audio_topk"])) # Construct full voice prompt path requested_voice_prompt_path = None voice_prompt_path = None if self.voice_prompt_dir is not None: voice_prompt_filename = request.query["voice_prompt"] requested_voice_prompt_path = None if voice_prompt_filename is not None: requested_voice_prompt_path = os.path.join(self.voice_prompt_dir, voice_prompt_filename) # If the voice prompt file does not exist, find a valid (s0) voiceprompt file in the directory if requested_voice_prompt_path is None or not os.path.exists(requested_voice_prompt_path): raise FileNotFoundError( f"Requested voice prompt '{voice_prompt_filename}' not found in '{self.voice_prompt_dir}'" ) else: voice_prompt_path = requested_voice_prompt_path if self.lm_gen.voice_prompt != voice_prompt_path: if voice_prompt_path.endswith('.pt'): # Load pre-saved voice prompt embeddings self.lm_gen.load_voice_prompt_embeddings(voice_prompt_path) else: self.lm_gen.load_voice_prompt(voice_prompt_path) self.lm_gen.text_prompt_tokens = self.text_tokenizer.encode(wrap_with_system_tags(request.query["text_prompt"])) if len(request.query["text_prompt"]) > 0 else None seed = int(request["seed"]) if "seed" in request.query else None async def recv_loop(): nonlocal close try: async for message in ws: if message.type == aiohttp.WSMsgType.ERROR: clog.log("error", f"{ws.exception()}") break elif message.type == aiohttp.WSMsgType.CLOSED: break elif message.type == aiohttp.WSMsgType.CLOSE: break elif message.type != aiohttp.WSMsgType.BINARY: clog.log("error", f"unexpected message type {message.type}") continue message = message.data if not isinstance(message, bytes): clog.log("error", f"unsupported message type {type(message)}") continue if len(message) == 0: clog.log("warning", "empty message") continue kind = message[0] if kind == 1: # audio payload = message[1:] opus_reader.append_bytes(payload) else: clog.log("warning", f"unknown message kind {kind}") finally: close = True clog.log("info", "connection closed") async def opus_loop(): all_pcm_data = None while True: if close: return await asyncio.sleep(0.001) pcm = opus_reader.read_pcm() if pcm.shape[-1] == 0: continue if all_pcm_data is None: all_pcm_data = pcm else: all_pcm_data = np.concatenate((all_pcm_data, pcm)) while all_pcm_data.shape[-1] >= self.frame_size: be = time.time() chunk = all_pcm_data[: self.frame_size] all_pcm_data = all_pcm_data[self.frame_size:] chunk = torch.from_numpy(chunk) chunk = chunk.to(device=self.device)[None, None] codes = self.mimi.encode(chunk) _ = self.other_mimi.encode(chunk) for c in range(codes.shape[-1]): tokens = self.lm_gen.step(codes[:, :, c: c + 1]) if tokens is None: continue assert tokens.shape[1] == self.lm_gen.lm_model.dep_q + 1 main_pcm = self.mimi.decode(tokens[:, 1:9]) _ = self.other_mimi.decode(tokens[:, 1:9]) main_pcm = main_pcm.cpu() opus_writer.append_pcm(main_pcm[0, 0].numpy()) text_token = tokens[0, 0, 0].item() if text_token not in (0, 3): _text = self.text_tokenizer.id_to_piece(text_token) # type: ignore _text = _text.replace("▁", " ") msg = b"\x02" + bytes(_text, encoding="utf8") await ws.send_bytes(msg) else: text_token_map = ['EPAD', 'BOS', 'EOS', 'PAD'] async def send_loop(): while True: if close: return await asyncio.sleep(0.001) msg = opus_writer.read_bytes() if len(msg) > 0: await ws.send_bytes(b"\x01" + msg) clog.log("info", "accepted connection") if len(request.query["text_prompt"]) > 0: clog.log("info", f"text prompt: {request.query['text_prompt']}") if len(request.query["voice_prompt"]) > 0: clog.log("info", f"voice prompt: {voice_prompt_path} (requested: {requested_voice_prompt_path})") close = False async with self.lock: if seed is not None and seed != -1: seed_all(seed) opus_writer = sphn.OpusStreamWriter(self.mimi.sample_rate) opus_reader = sphn.OpusStreamReader(self.mimi.sample_rate) self.mimi.reset_streaming() self.other_mimi.reset_streaming() self.lm_gen.reset_streaming() async def is_alive(): if close or ws.closed: return False try: # Check for disconnect without waiting too long msg = await asyncio.wait_for(ws.receive(), timeout=0.01) if msg.type in (aiohttp.WSMsgType.CLOSE, aiohttp.WSMsgType.CLOSED, aiohttp.WSMsgType.ERROR): return False except asyncio.TimeoutError: # No messages → client probably still alive return True except aiohttp.ClientConnectionError: return False return True # Reuse mimi for encoding voice prompt and then reset it before conversation starts await self.lm_gen.step_system_prompts_async(self.mimi, is_alive=is_alive) self.mimi.reset_streaming() clog.log("info", "done with system prompts") # Send the handshake. if await is_alive(): await ws.send_bytes(b"\x00") clog.log("info", "sent handshake bytes") # Clean cancellation manager tasks = [ asyncio.create_task(recv_loop()), asyncio.create_task(opus_loop()), asyncio.create_task(send_loop()), ] done, pending = await asyncio.wait(tasks, return_when=asyncio.FIRST_COMPLETED) # Force-kill remaining tasks for task in pending: task.cancel() try: await task except asyncio.CancelledError: pass await ws.close() clog.log("info", "session closed") # await asyncio.gather(opus_loop(), recv_loop(), send_loop()) clog.log("info", "done with connection") return ws def _get_voice_prompt_dir(voice_prompt_dir: Optional[str], hf_repo: str) -> Optional[str]: """ If voice_prompt_dir is None: - download voices.tgz from HF - extract it once - return extracted directory If voice_prompt_dir is provided: - just return it """ if voice_prompt_dir is not None: return voice_prompt_dir logger.info("retrieving voice prompts") voices_tgz = hf_hub_download(hf_repo, "voices.tgz") voices_tgz = Path(voices_tgz) voices_dir = voices_tgz.parent / "voices" if not voices_dir.exists(): logger.info(f"extracting {voices_tgz} to {voices_dir}") with tarfile.open(voices_tgz, "r:gz") as tar: tar.extractall(path=voices_tgz.parent) if not voices_dir.exists(): raise RuntimeError("voices.tgz did not contain a 'voices/' directory") return str(voices_dir) def _get_static_path(static: Optional[str]) -> Optional[str]: if static is None: logger.info("retrieving the static content") dist_tgz = hf_hub_download("nvidia/personaplex-7b-v1", "dist.tgz") dist_tgz = Path(dist_tgz) dist = dist_tgz.parent / "dist" if not dist.exists(): with tarfile.open(dist_tgz, "r:gz") as tar: tar.extractall(path=dist_tgz.parent) return str(dist) elif static != "none": # When set to the "none" string, we don't serve any static content. return static return None def main(): parser = argparse.ArgumentParser() parser.add_argument("--host", default="localhost", type=str) parser.add_argument("--port", default=8998, type=int) parser.add_argument("--static", type=str) parser.add_argument("--gradio-tunnel", action='store_true', help='Activate a gradio tunnel.') parser.add_argument("--gradio-tunnel-token", help='Provide a custom (secret) token here to keep getting the same URL.') parser.add_argument("--tokenizer", type=str, help="Path to a local tokenizer file.") parser.add_argument("--moshi-weight", type=str, help="Path to a local checkpoint file for Moshi.") parser.add_argument("--mimi-weight", type=str, help="Path to a local checkpoint file for Mimi.") parser.add_argument("--hf-repo", type=str, default=loaders.DEFAULT_REPO, help="HF repo to look into, defaults PersonaPlex. " "Use this to select a different pre-trained model.") parser.add_argument("--device", type=str, default="cuda", help="Device on which to run, defaults to 'cuda'.") parser.add_argument("--cpu-offload", action="store_true", help="Offload LM model layers to CPU when GPU memory is insufficient. " "Requires 'accelerate' package.") parser.add_argument( "--voice-prompt-dir", type=str, help=( "Directory containing voice prompt files. " "If omitted, voices.tgz is downloaded from HF and extracted." "Voice prompt filenames from client requests will be joined with this directory path." ) ) parser.add_argument( "--ssl", type=str, help=( "use https instead of http, this flag should point to a directory " "that contains valid key.pem and cert.pem files" ) ) args = parser.parse_args() args.voice_prompt_dir = _get_voice_prompt_dir( args.voice_prompt_dir, args.hf_repo, ) if args.voice_prompt_dir is not None: assert os.path.exists(args.voice_prompt_dir), \ f"Directory missing: {args.voice_prompt_dir}" logger.info(f"voice_prompt_dir = {args.voice_prompt_dir}") static_path: None | str = _get_static_path(args.static) assert static_path is None or os.path.exists(static_path), \ f"Static path does not exist: {static_path}." logger.info(f"static_path = {static_path}") args.device = torch_auto_device(args.device) seed_all(42424242) setup_tunnel = None tunnel_token = '' if args.gradio_tunnel: try: from gradio import networking # type: ignore except ImportError: logger.error("Cannot find gradio which is required to activate a tunnel. " "Please install with `pip install gradio`.") sys.exit(1) setup_tunnel = networking.setup_tunnel if args.gradio_tunnel_token is None: tunnel_token = secrets.token_urlsafe(32) else: tunnel_token = args.gradio_tunnel_token # Download config.json to increment download counter # No worries about double-counting since config.json will be cached the second time hf_hub_download(args.hf_repo, "config.json") logger.info("loading mimi") if args.mimi_weight is None: args.mimi_weight = hf_hub_download(args.hf_repo, loaders.MIMI_NAME) mimi = loaders.get_mimi(args.mimi_weight, args.device) other_mimi = loaders.get_mimi(args.mimi_weight, args.device) logger.info("mimi loaded") if args.tokenizer is None: args.tokenizer = hf_hub_download(args.hf_repo, loaders.TEXT_TOKENIZER_NAME) text_tokenizer = sentencepiece.SentencePieceProcessor(args.tokenizer) # type: ignore logger.info("loading moshi") if args.moshi_weight is None: args.moshi_weight = hf_hub_download(args.hf_repo, loaders.MOSHI_NAME) lm = loaders.get_moshi_lm(args.moshi_weight, device=args.device, cpu_offload=args.cpu_offload) lm.eval() logger.info("moshi loaded") state = ServerState( mimi=mimi, other_mimi=other_mimi, text_tokenizer=text_tokenizer, lm=lm, device=args.device, voice_prompt_dir=args.voice_prompt_dir, save_voice_prompt_embeddings=False, ) logger.info("warming up the model") state.warmup() app = web.Application() app.router.add_get("/api/chat", state.handle_chat) if static_path is not None: async def handle_root(_): return web.FileResponse(os.path.join(static_path, "index.html")) logger.info(f"serving static content from {static_path}") app.router.add_get("/", handle_root) app.router.add_static( "/", path=static_path, follow_symlinks=True, name="static" ) protocol = "http" ssl_context = None if args.ssl is not None: ssl_context, protocol = create_ssl_context(args.ssl) host_ip = args.host if args.host not in ("0.0.0.0", "::", "localhost") else get_lan_ip() logger.info(f"Access the Web UI directly at {protocol}://{host_ip}:{args.port}") if setup_tunnel is not None: tunnel = setup_tunnel('localhost', args.port, tunnel_token, None) logger.info(f"Tunnel started, if executing on a remote GPU, you can use {tunnel}.") web.run_app(app, port=args.port, ssl_context=ssl_context) with torch.no_grad(): main()