Spaces:
Runtime error
Runtime error
| # SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. | |
| # SPDX-License-Identifier: MIT | |
| # | |
| # Permission is hereby granted, free of charge, to any person obtaining a | |
| # copy of this software and associated documentation files (the "Software"), | |
| # to deal in the Software without restriction, including without limitation | |
| # the rights to use, copy, modify, merge, publish, distribute, sublicense, | |
| # and/or sell copies of the Software, and to permit persons to whom the | |
| # Software is furnished to do so, subject to the following conditions: | |
| # | |
| # The above copyright notice and this permission notice shall be included in | |
| # all copies or substantial portions of the Software. | |
| # | |
| # THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR | |
| # IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, | |
| # FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL | |
| # THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER | |
| # LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING | |
| # FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER | |
| # DEALINGS IN THE SOFTWARE. | |
| # Copyright (c) Kyutai, all rights reserved. | |
| # This source code is licensed under the license found in the | |
| # LICENSE file in the root directory of this source tree. | |
| import argparse | |
| import asyncio | |
| from dataclasses import dataclass | |
| import random | |
| import os | |
| from pathlib import Path | |
| import tarfile | |
| import time | |
| import secrets | |
| import sys | |
| from typing import Literal, Optional | |
| import aiohttp | |
| from aiohttp import web | |
| from huggingface_hub import hf_hub_download | |
| import numpy as np | |
| import sentencepiece | |
| import sphn | |
| import torch | |
| import random | |
| from .client_utils import make_log, colorize | |
| from .models import loaders, MimiModel, LMModel, LMGen | |
| from .utils.connection import create_ssl_context, get_lan_ip | |
| from .utils.logging import setup_logger, ColorizedLog | |
| logger = setup_logger(__name__) | |
| DeviceString = Literal["cuda"] | Literal["cpu"] #| Literal["mps"] | |
| def torch_auto_device(requested: Optional[DeviceString] = None) -> torch.device: | |
| """Return a torch.device based on the requested string or availability.""" | |
| if requested is not None: | |
| return torch.device(requested) | |
| if torch.cuda.is_available(): | |
| return torch.device("cuda") | |
| #elif hasattr(torch.backends, "mps") and torch.backends.mps.is_available(): | |
| # return torch.device("mps") | |
| return torch.device("cpu") | |
| def seed_all(seed): | |
| torch.manual_seed(seed) | |
| if torch.cuda.is_available(): | |
| torch.cuda.manual_seed(seed) | |
| torch.cuda.manual_seed_all(seed) # for multi-GPU setups | |
| random.seed(seed) | |
| np.random.seed(seed) | |
| torch.backends.cudnn.deterministic = False | |
| torch.backends.cudnn.benchmark = False | |
| def wrap_with_system_tags(text: str) -> str: | |
| """Add system tags as the model expects if they are missing. | |
| Example: "<system> You enjoy having a good conversation. Have a deep conversation about technology. Your name is Jane. <system>" | |
| """ | |
| cleaned = text.strip() | |
| if cleaned.startswith("<system>") and cleaned.endswith("<system>"): | |
| return cleaned | |
| return f"<system> {cleaned} <system>" | |
| class ServerState: | |
| mimi: MimiModel | |
| other_mimi: MimiModel | |
| text_tokenizer: sentencepiece.SentencePieceProcessor | |
| lm_gen: LMGen | |
| lock: asyncio.Lock | |
| def __init__(self, mimi: MimiModel, other_mimi: MimiModel, text_tokenizer: sentencepiece.SentencePieceProcessor, | |
| lm: LMModel, device: str | torch.device, voice_prompt_dir: str | None = None, | |
| save_voice_prompt_embeddings: bool = False): | |
| self.mimi = mimi | |
| self.other_mimi = other_mimi | |
| self.text_tokenizer = text_tokenizer | |
| self.device = device | |
| self.voice_prompt_dir = voice_prompt_dir | |
| self.frame_size = int(self.mimi.sample_rate / self.mimi.frame_rate) | |
| self.lm_gen = LMGen(lm, | |
| audio_silence_frame_cnt=int(0.5 * self.mimi.frame_rate), | |
| sample_rate=self.mimi.sample_rate, | |
| device=device, | |
| frame_rate=self.mimi.frame_rate, | |
| save_voice_prompt_embeddings=save_voice_prompt_embeddings, | |
| ) | |
| self.lock = asyncio.Lock() | |
| self.mimi.streaming_forever(1) | |
| self.other_mimi.streaming_forever(1) | |
| self.lm_gen.streaming_forever(1) | |
| def warmup(self): | |
| for _ in range(4): | |
| chunk = torch.zeros(1, 1, self.frame_size, dtype=torch.float32, device=self.device) | |
| codes = self.mimi.encode(chunk) | |
| _ = self.other_mimi.encode(chunk) | |
| for c in range(codes.shape[-1]): | |
| tokens = self.lm_gen.step(codes[:, :, c: c + 1]) | |
| if tokens is None: | |
| continue | |
| _ = self.mimi.decode(tokens[:, 1:9]) | |
| _ = self.other_mimi.decode(tokens[:, 1:9]) | |
| if self.device.type == 'cuda': | |
| torch.cuda.synchronize() | |
| async def handle_chat(self, request): | |
| ws = web.WebSocketResponse() | |
| await ws.prepare(request) | |
| clog = ColorizedLog.randomize() | |
| peer = request.remote # IP | |
| peer_port = request.transport.get_extra_info("peername")[1] # Port | |
| clog.log("info", f"Incoming connection from {peer}:{peer_port}") | |
| # self.lm_gen.temp = float(request.query["audio_temperature"]) | |
| # self.lm_gen.temp_text = float(request.query["text_temperature"]) | |
| # self.lm_gen.top_k_text = max(1, int(request.query["text_topk"])) | |
| # self.lm_gen.top_k = max(1, int(request.query["audio_topk"])) | |
| # Construct full voice prompt path | |
| requested_voice_prompt_path = None | |
| voice_prompt_path = None | |
| if self.voice_prompt_dir is not None: | |
| voice_prompt_filename = request.query["voice_prompt"] | |
| requested_voice_prompt_path = None | |
| if voice_prompt_filename is not None: | |
| requested_voice_prompt_path = os.path.join(self.voice_prompt_dir, voice_prompt_filename) | |
| # If the voice prompt file does not exist, find a valid (s0) voiceprompt file in the directory | |
| if requested_voice_prompt_path is None or not os.path.exists(requested_voice_prompt_path): | |
| raise FileNotFoundError( | |
| f"Requested voice prompt '{voice_prompt_filename}' not found in '{self.voice_prompt_dir}'" | |
| ) | |
| else: | |
| voice_prompt_path = requested_voice_prompt_path | |
| if self.lm_gen.voice_prompt != voice_prompt_path: | |
| if voice_prompt_path.endswith('.pt'): | |
| # Load pre-saved voice prompt embeddings | |
| self.lm_gen.load_voice_prompt_embeddings(voice_prompt_path) | |
| else: | |
| self.lm_gen.load_voice_prompt(voice_prompt_path) | |
| self.lm_gen.text_prompt_tokens = self.text_tokenizer.encode(wrap_with_system_tags(request.query["text_prompt"])) if len(request.query["text_prompt"]) > 0 else None | |
| seed = int(request["seed"]) if "seed" in request.query else None | |
| async def recv_loop(): | |
| nonlocal close | |
| try: | |
| async for message in ws: | |
| if message.type == aiohttp.WSMsgType.ERROR: | |
| clog.log("error", f"{ws.exception()}") | |
| break | |
| elif message.type == aiohttp.WSMsgType.CLOSED: | |
| break | |
| elif message.type == aiohttp.WSMsgType.CLOSE: | |
| break | |
| elif message.type != aiohttp.WSMsgType.BINARY: | |
| clog.log("error", f"unexpected message type {message.type}") | |
| continue | |
| message = message.data | |
| if not isinstance(message, bytes): | |
| clog.log("error", f"unsupported message type {type(message)}") | |
| continue | |
| if len(message) == 0: | |
| clog.log("warning", "empty message") | |
| continue | |
| kind = message[0] | |
| if kind == 1: # audio | |
| payload = message[1:] | |
| opus_reader.append_bytes(payload) | |
| else: | |
| clog.log("warning", f"unknown message kind {kind}") | |
| finally: | |
| close = True | |
| clog.log("info", "connection closed") | |
| async def opus_loop(): | |
| all_pcm_data = None | |
| while True: | |
| if close: | |
| return | |
| await asyncio.sleep(0.001) | |
| pcm = opus_reader.read_pcm() | |
| if pcm.shape[-1] == 0: | |
| continue | |
| if all_pcm_data is None: | |
| all_pcm_data = pcm | |
| else: | |
| all_pcm_data = np.concatenate((all_pcm_data, pcm)) | |
| while all_pcm_data.shape[-1] >= self.frame_size: | |
| be = time.time() | |
| chunk = all_pcm_data[: self.frame_size] | |
| all_pcm_data = all_pcm_data[self.frame_size:] | |
| chunk = torch.from_numpy(chunk) | |
| chunk = chunk.to(device=self.device)[None, None] | |
| codes = self.mimi.encode(chunk) | |
| _ = self.other_mimi.encode(chunk) | |
| for c in range(codes.shape[-1]): | |
| tokens = self.lm_gen.step(codes[:, :, c: c + 1]) | |
| if tokens is None: | |
| continue | |
| assert tokens.shape[1] == self.lm_gen.lm_model.dep_q + 1 | |
| main_pcm = self.mimi.decode(tokens[:, 1:9]) | |
| _ = self.other_mimi.decode(tokens[:, 1:9]) | |
| main_pcm = main_pcm.cpu() | |
| opus_writer.append_pcm(main_pcm[0, 0].numpy()) | |
| text_token = tokens[0, 0, 0].item() | |
| if text_token not in (0, 3): | |
| _text = self.text_tokenizer.id_to_piece(text_token) # type: ignore | |
| _text = _text.replace("▁", " ") | |
| msg = b"\x02" + bytes(_text, encoding="utf8") | |
| await ws.send_bytes(msg) | |
| else: | |
| text_token_map = ['EPAD', 'BOS', 'EOS', 'PAD'] | |
| async def send_loop(): | |
| while True: | |
| if close: | |
| return | |
| await asyncio.sleep(0.001) | |
| msg = opus_writer.read_bytes() | |
| if len(msg) > 0: | |
| await ws.send_bytes(b"\x01" + msg) | |
| clog.log("info", "accepted connection") | |
| if len(request.query["text_prompt"]) > 0: | |
| clog.log("info", f"text prompt: {request.query['text_prompt']}") | |
| if len(request.query["voice_prompt"]) > 0: | |
| clog.log("info", f"voice prompt: {voice_prompt_path} (requested: {requested_voice_prompt_path})") | |
| close = False | |
| async with self.lock: | |
| if seed is not None and seed != -1: | |
| seed_all(seed) | |
| opus_writer = sphn.OpusStreamWriter(self.mimi.sample_rate) | |
| opus_reader = sphn.OpusStreamReader(self.mimi.sample_rate) | |
| self.mimi.reset_streaming() | |
| self.other_mimi.reset_streaming() | |
| self.lm_gen.reset_streaming() | |
| async def is_alive(): | |
| if close or ws.closed: | |
| return False | |
| try: | |
| # Check for disconnect without waiting too long | |
| msg = await asyncio.wait_for(ws.receive(), timeout=0.01) | |
| if msg.type in (aiohttp.WSMsgType.CLOSE, aiohttp.WSMsgType.CLOSED, aiohttp.WSMsgType.ERROR): | |
| return False | |
| except asyncio.TimeoutError: | |
| # No messages → client probably still alive | |
| return True | |
| except aiohttp.ClientConnectionError: | |
| return False | |
| return True | |
| # Reuse mimi for encoding voice prompt and then reset it before conversation starts | |
| await self.lm_gen.step_system_prompts_async(self.mimi, is_alive=is_alive) | |
| self.mimi.reset_streaming() | |
| clog.log("info", "done with system prompts") | |
| # Send the handshake. | |
| if await is_alive(): | |
| await ws.send_bytes(b"\x00") | |
| clog.log("info", "sent handshake bytes") | |
| # Clean cancellation manager | |
| tasks = [ | |
| asyncio.create_task(recv_loop()), | |
| asyncio.create_task(opus_loop()), | |
| asyncio.create_task(send_loop()), | |
| ] | |
| done, pending = await asyncio.wait(tasks, return_when=asyncio.FIRST_COMPLETED) | |
| # Force-kill remaining tasks | |
| for task in pending: | |
| task.cancel() | |
| try: | |
| await task | |
| except asyncio.CancelledError: | |
| pass | |
| await ws.close() | |
| clog.log("info", "session closed") | |
| # await asyncio.gather(opus_loop(), recv_loop(), send_loop()) | |
| clog.log("info", "done with connection") | |
| return ws | |
| def _get_voice_prompt_dir(voice_prompt_dir: Optional[str], hf_repo: str) -> Optional[str]: | |
| """ | |
| If voice_prompt_dir is None: | |
| - download voices.tgz from HF | |
| - extract it once | |
| - return extracted directory | |
| If voice_prompt_dir is provided: | |
| - just return it | |
| """ | |
| if voice_prompt_dir is not None: | |
| return voice_prompt_dir | |
| logger.info("retrieving voice prompts") | |
| voices_tgz = hf_hub_download(hf_repo, "voices.tgz") | |
| voices_tgz = Path(voices_tgz) | |
| voices_dir = voices_tgz.parent / "voices" | |
| if not voices_dir.exists(): | |
| logger.info(f"extracting {voices_tgz} to {voices_dir}") | |
| with tarfile.open(voices_tgz, "r:gz") as tar: | |
| tar.extractall(path=voices_tgz.parent) | |
| if not voices_dir.exists(): | |
| raise RuntimeError("voices.tgz did not contain a 'voices/' directory") | |
| return str(voices_dir) | |
| def _get_static_path(static: Optional[str]) -> Optional[str]: | |
| if static is None: | |
| logger.info("retrieving the static content") | |
| dist_tgz = hf_hub_download("nvidia/personaplex-7b-v1", "dist.tgz") | |
| dist_tgz = Path(dist_tgz) | |
| dist = dist_tgz.parent / "dist" | |
| if not dist.exists(): | |
| with tarfile.open(dist_tgz, "r:gz") as tar: | |
| tar.extractall(path=dist_tgz.parent) | |
| return str(dist) | |
| elif static != "none": | |
| # When set to the "none" string, we don't serve any static content. | |
| return static | |
| return None | |
| def main(): | |
| parser = argparse.ArgumentParser() | |
| parser.add_argument("--host", default="localhost", type=str) | |
| parser.add_argument("--port", default=8998, type=int) | |
| parser.add_argument("--static", type=str) | |
| parser.add_argument("--gradio-tunnel", action='store_true', help='Activate a gradio tunnel.') | |
| parser.add_argument("--gradio-tunnel-token", | |
| help='Provide a custom (secret) token here to keep getting the same URL.') | |
| parser.add_argument("--tokenizer", type=str, help="Path to a local tokenizer file.") | |
| parser.add_argument("--moshi-weight", type=str, help="Path to a local checkpoint file for Moshi.") | |
| parser.add_argument("--mimi-weight", type=str, help="Path to a local checkpoint file for Mimi.") | |
| parser.add_argument("--hf-repo", type=str, default=loaders.DEFAULT_REPO, | |
| help="HF repo to look into, defaults PersonaPlex. " | |
| "Use this to select a different pre-trained model.") | |
| parser.add_argument("--device", type=str, default="cuda", help="Device on which to run, defaults to 'cuda'.") | |
| parser.add_argument("--cpu-offload", action="store_true", | |
| help="Offload LM model layers to CPU when GPU memory is insufficient. " | |
| "Requires 'accelerate' package.") | |
| parser.add_argument( | |
| "--voice-prompt-dir", | |
| type=str, | |
| help=( | |
| "Directory containing voice prompt files. " | |
| "If omitted, voices.tgz is downloaded from HF and extracted." | |
| "Voice prompt filenames from client requests will be joined with this directory path." | |
| ) | |
| ) | |
| parser.add_argument( | |
| "--ssl", | |
| type=str, | |
| help=( | |
| "use https instead of http, this flag should point to a directory " | |
| "that contains valid key.pem and cert.pem files" | |
| ) | |
| ) | |
| args = parser.parse_args() | |
| args.voice_prompt_dir = _get_voice_prompt_dir( | |
| args.voice_prompt_dir, | |
| args.hf_repo, | |
| ) | |
| if args.voice_prompt_dir is not None: | |
| assert os.path.exists(args.voice_prompt_dir), \ | |
| f"Directory missing: {args.voice_prompt_dir}" | |
| logger.info(f"voice_prompt_dir = {args.voice_prompt_dir}") | |
| static_path: None | str = _get_static_path(args.static) | |
| assert static_path is None or os.path.exists(static_path), \ | |
| f"Static path does not exist: {static_path}." | |
| logger.info(f"static_path = {static_path}") | |
| args.device = torch_auto_device(args.device) | |
| seed_all(42424242) | |
| setup_tunnel = None | |
| tunnel_token = '' | |
| if args.gradio_tunnel: | |
| try: | |
| from gradio import networking # type: ignore | |
| except ImportError: | |
| logger.error("Cannot find gradio which is required to activate a tunnel. " | |
| "Please install with `pip install gradio`.") | |
| sys.exit(1) | |
| setup_tunnel = networking.setup_tunnel | |
| if args.gradio_tunnel_token is None: | |
| tunnel_token = secrets.token_urlsafe(32) | |
| else: | |
| tunnel_token = args.gradio_tunnel_token | |
| # Download config.json to increment download counter | |
| # No worries about double-counting since config.json will be cached the second time | |
| hf_hub_download(args.hf_repo, "config.json") | |
| logger.info("loading mimi") | |
| if args.mimi_weight is None: | |
| args.mimi_weight = hf_hub_download(args.hf_repo, loaders.MIMI_NAME) | |
| mimi = loaders.get_mimi(args.mimi_weight, args.device) | |
| other_mimi = loaders.get_mimi(args.mimi_weight, args.device) | |
| logger.info("mimi loaded") | |
| if args.tokenizer is None: | |
| args.tokenizer = hf_hub_download(args.hf_repo, loaders.TEXT_TOKENIZER_NAME) | |
| text_tokenizer = sentencepiece.SentencePieceProcessor(args.tokenizer) # type: ignore | |
| logger.info("loading moshi") | |
| if args.moshi_weight is None: | |
| args.moshi_weight = hf_hub_download(args.hf_repo, loaders.MOSHI_NAME) | |
| lm = loaders.get_moshi_lm(args.moshi_weight, device=args.device, cpu_offload=args.cpu_offload) | |
| lm.eval() | |
| logger.info("moshi loaded") | |
| state = ServerState( | |
| mimi=mimi, | |
| other_mimi=other_mimi, | |
| text_tokenizer=text_tokenizer, | |
| lm=lm, | |
| device=args.device, | |
| voice_prompt_dir=args.voice_prompt_dir, | |
| save_voice_prompt_embeddings=False, | |
| ) | |
| logger.info("warming up the model") | |
| state.warmup() | |
| app = web.Application() | |
| app.router.add_get("/api/chat", state.handle_chat) | |
| if static_path is not None: | |
| async def handle_root(_): | |
| return web.FileResponse(os.path.join(static_path, "index.html")) | |
| logger.info(f"serving static content from {static_path}") | |
| app.router.add_get("/", handle_root) | |
| app.router.add_static( | |
| "/", path=static_path, follow_symlinks=True, name="static" | |
| ) | |
| protocol = "http" | |
| ssl_context = None | |
| if args.ssl is not None: | |
| ssl_context, protocol = create_ssl_context(args.ssl) | |
| host_ip = args.host if args.host not in ("0.0.0.0", "::", "localhost") else get_lan_ip() | |
| logger.info(f"Access the Web UI directly at {protocol}://{host_ip}:{args.port}") | |
| if setup_tunnel is not None: | |
| tunnel = setup_tunnel('localhost', args.port, tunnel_token, None) | |
| logger.info(f"Tunnel started, if executing on a remote GPU, you can use {tunnel}.") | |
| web.run_app(app, port=args.port, ssl_context=ssl_context) | |
| with torch.no_grad(): | |
| main() | |