test-1 / moshi /server.py
andy hickl
Uploaded from Github
ed99557
# Copyright (c) Kyutai, all rights reserved.
# This source code is licensed under the license found in the
# LICENSE file in the root directory of this source tree.
import argparse
import asyncio
from dataclasses import dataclass
import random
import os
from pathlib import Path
import tarfile
import time
import secrets
import sys
import aiohttp
from aiohttp import web
from huggingface_hub import hf_hub_download
import numpy as np
import sentencepiece
import sphn
import torch
from .client_utils import make_log
from .models import loaders, MimiModel, LMModel, LMGen
def log(level: str, msg: str):
print(make_log(level, msg))
def seed_all(seed):
torch.manual_seed(seed)
if torch.cuda.is_available():
torch.cuda.manual_seed(seed)
torch.cuda.manual_seed_all(seed) # for multi-GPU setups
random.seed(seed)
np.random.seed(seed)
torch.backends.cudnn.deterministic = False
torch.backends.cudnn.benchmark = False
@dataclass
class ServerState:
mimi: MimiModel
text_tokenizer: sentencepiece.SentencePieceProcessor
lm_gen: LMGen
lock: asyncio.Lock
def __init__(self, mimi: MimiModel, text_tokenizer: sentencepiece.SentencePieceProcessor,
lm: LMModel, device: str | torch.device):
self.mimi = mimi
self.text_tokenizer = text_tokenizer
self.lm_gen = LMGen(lm)
self.device = device
self.frame_size = int(self.mimi.sample_rate / self.mimi.frame_rate)
self.lock = asyncio.Lock()
self.mimi.streaming_forever(1)
self.lm_gen.streaming_forever(1)
def warmup(self):
for chunk in range(4):
chunk = torch.zeros(1, 1, self.frame_size, dtype=torch.float32, device=self.device)
codes = self.mimi.encode(chunk)
for c in range(codes.shape[-1]):
tokens = self.lm_gen.step(codes[:, :, c: c + 1])
if tokens is None:
continue
_ = self.mimi.decode(tokens[:, 1:])
torch.cuda.synchronize()
async def handle_chat(self, request):
ws = web.WebSocketResponse()
await ws.prepare(request)
async def recv_loop():
nonlocal close
try:
async for message in ws:
if message.type == aiohttp.WSMsgType.ERROR:
log("error", f"{ws.exception()}")
break
elif message.type == aiohttp.WSMsgType.CLOSED:
break
elif message.type != aiohttp.WSMsgType.BINARY:
log("error", f"unexpected message type {message.type}")
continue
message = message.data
if not isinstance(message, bytes):
log("error", f"unsupported message type {type(message)}")
continue
if len(message) == 0:
log("warning", "empty message")
continue
kind = message[0]
if kind == 1: # audio
payload = message[1:]
opus_reader.append_bytes(payload)
else:
log("warning", f"unknown message kind {kind}")
finally:
close = True
log("info", "connection closed")
async def opus_loop():
all_pcm_data = None
while True:
if close:
return
await asyncio.sleep(0.001)
pcm = opus_reader.read_pcm()
if pcm.shape[-1] == 0:
continue
if all_pcm_data is None:
all_pcm_data = pcm
else:
all_pcm_data = np.concatenate((all_pcm_data, pcm))
while all_pcm_data.shape[-1] >= self.frame_size:
be = time.time()
chunk = all_pcm_data[: self.frame_size]
all_pcm_data = all_pcm_data[self.frame_size:]
chunk = torch.from_numpy(chunk)
chunk = chunk.to(device=self.device)[None, None]
codes = self.mimi.encode(chunk)
for c in range(codes.shape[-1]):
tokens = self.lm_gen.step(codes[:, :, c: c + 1])
if tokens is None:
continue
assert tokens.shape[1] == self.lm_gen.lm_model.dep_q + 1
main_pcm = self.mimi.decode(tokens[:, 1:])
main_pcm = main_pcm.cpu()
opus_writer.append_pcm(main_pcm[0, 0].numpy())
text_token = tokens[0, 0, 0].item()
if text_token not in (0, 3):
_text = self.text_tokenizer.id_to_piece(text_token) # type: ignore
_text = _text.replace("▁", " ")
msg = b"\x02" + bytes(_text, encoding="utf8")
log("info", f"text token '{_text}'")
await ws.send_bytes(msg)
log("info", f"frame handled in {1000 * (time.time() - be):.1f}ms")
async def send_loop():
while True:
if close:
return
await asyncio.sleep(0.001)
msg = opus_writer.read_bytes()
if len(msg) > 0:
await ws.send_bytes(b"\x01" + msg)
log("info", "accepted connection")
close = False
async with self.lock:
opus_writer = sphn.OpusStreamWriter(self.mimi.sample_rate)
opus_reader = sphn.OpusStreamReader(self.mimi.sample_rate)
self.mimi.reset_streaming()
self.lm_gen.reset_streaming()
# Send the handshake.
await ws.send_bytes(b"\x00")
await asyncio.gather(opus_loop(), recv_loop(), send_loop())
log("info", "done with connection")
return ws
def main():
parser = argparse.ArgumentParser()
parser.add_argument("--host", default="localhost", type=str)
parser.add_argument("--port", default=8998, type=int)
parser.add_argument("--static", type=str)
parser.add_argument("--gradio-tunnel", action='store_true', help='Activate a gradio tunnel.')
parser.add_argument("--gradio-tunnel-token",
help='Provide a custom (secret) token here to keep getting the same URL.')
parser.add_argument("--tokenizer", type=str, help="Path to a local tokenizer file.")
parser.add_argument("--moshi-weight", type=str, help="Path to a local checkpoint file for Moshi.")
parser.add_argument("--mimi-weight", type=str, help="Path to a local checkpoint file for Mimi.")
parser.add_argument("--hf-repo", type=str, default=loaders.DEFAULT_REPO,
help="HF repo to look into, defaults Moshiko. "
"Use this to select a different pre-trained model.")
parser.add_argument("--device", type=str, default="cuda", help="Device on which to run, defaults to 'cuda'.")
args = parser.parse_args()
seed_all(42424242)
setup_tunnel = None
tunnel_token = ''
if args.gradio_tunnel:
try:
from gradio import networking # type: ignore
except ImportError:
log("error", "Cannot find gradio which is required to activate a tunnel. "
"Please install with `pip install gradio`.")
sys.exit(1)
setup_tunnel = networking.setup_tunnel
if args.gradio_tunnel_token is None:
tunnel_token = secrets.token_urlsafe(32)
else:
tunnel_token = args.gradio_tunnel_token
log("info", "loading mimi")
if args.mimi_weight is None:
args.mimi_weight = hf_hub_download(args.hf_repo, loaders.MIMI_NAME)
mimi = loaders.get_mimi(args.mimi_weight, args.device)
log("info", "mimi loaded")
if args.tokenizer is None:
args.tokenizer = hf_hub_download(args.hf_repo, loaders.TEXT_TOKENIZER_NAME)
text_tokenizer = sentencepiece.SentencePieceProcessor(args.tokenizer) # type: ignore
log("info", "loading moshi")
if args.moshi_weight is None:
args.moshi_weight = hf_hub_download(args.hf_repo, loaders.MOSHI_NAME)
lm = loaders.get_moshi_lm(args.moshi_weight, args.device)
log("info", "moshi loaded")
state = ServerState(mimi, text_tokenizer, lm, args.device)
log("info", "warming up the model")
state.warmup()
app = web.Application()
app.router.add_get("/api/chat", state.handle_chat)
static_path: None | str = None
if args.static is None:
log("info", "retrieving the static content")
dist_tgz = hf_hub_download("kyutai/moshi-artifacts", "dist.tgz")
dist_tgz = Path(dist_tgz)
dist = dist_tgz.parent / "dist"
if not dist.exists():
with tarfile.open(dist_tgz, "r:gz") as tar:
tar.extractall(path=dist_tgz.parent)
static_path = str(dist)
elif args.static != "none":
# When set to the "none" string, we don't serve any static content.
static_path = args.static
if static_path is not None:
async def handle_root(_):
return web.FileResponse(os.path.join(static_path, "index.html"))
log("info", f"serving static content from {static_path}")
app.router.add_get("/", handle_root)
app.router.add_static(
"/", path=static_path, follow_symlinks=True, name="static"
)
log("info", f"Access the Web UI directly at http://{args.host}:{args.port}")
if setup_tunnel is not None:
tunnel = setup_tunnel('localhost', args.port, tunnel_token, None)
log("info", f"Tunnel started, if executing on a remote GPU, you can use {tunnel}.")
log("info", "Note that this tunnel goes through the US and you might experience high latency in Europe.")
web.run_app(app, port=args.port)
with torch.no_grad():
main()