File size: 10,161 Bytes
ed99557 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 |
# Copyright (c) Kyutai, all rights reserved.
# This source code is licensed under the license found in the
# LICENSE file in the root directory of this source tree.
import argparse
import asyncio
from dataclasses import dataclass
import random
import os
from pathlib import Path
import tarfile
import time
import secrets
import sys
import aiohttp
from aiohttp import web
from huggingface_hub import hf_hub_download
import numpy as np
import sentencepiece
import sphn
import torch
from .client_utils import make_log
from .models import loaders, MimiModel, LMModel, LMGen
def log(level: str, msg: str):
print(make_log(level, msg))
def seed_all(seed):
torch.manual_seed(seed)
if torch.cuda.is_available():
torch.cuda.manual_seed(seed)
torch.cuda.manual_seed_all(seed) # for multi-GPU setups
random.seed(seed)
np.random.seed(seed)
torch.backends.cudnn.deterministic = False
torch.backends.cudnn.benchmark = False
@dataclass
class ServerState:
mimi: MimiModel
text_tokenizer: sentencepiece.SentencePieceProcessor
lm_gen: LMGen
lock: asyncio.Lock
def __init__(self, mimi: MimiModel, text_tokenizer: sentencepiece.SentencePieceProcessor,
lm: LMModel, device: str | torch.device):
self.mimi = mimi
self.text_tokenizer = text_tokenizer
self.lm_gen = LMGen(lm)
self.device = device
self.frame_size = int(self.mimi.sample_rate / self.mimi.frame_rate)
self.lock = asyncio.Lock()
self.mimi.streaming_forever(1)
self.lm_gen.streaming_forever(1)
def warmup(self):
for chunk in range(4):
chunk = torch.zeros(1, 1, self.frame_size, dtype=torch.float32, device=self.device)
codes = self.mimi.encode(chunk)
for c in range(codes.shape[-1]):
tokens = self.lm_gen.step(codes[:, :, c: c + 1])
if tokens is None:
continue
_ = self.mimi.decode(tokens[:, 1:])
torch.cuda.synchronize()
async def handle_chat(self, request):
ws = web.WebSocketResponse()
await ws.prepare(request)
async def recv_loop():
nonlocal close
try:
async for message in ws:
if message.type == aiohttp.WSMsgType.ERROR:
log("error", f"{ws.exception()}")
break
elif message.type == aiohttp.WSMsgType.CLOSED:
break
elif message.type != aiohttp.WSMsgType.BINARY:
log("error", f"unexpected message type {message.type}")
continue
message = message.data
if not isinstance(message, bytes):
log("error", f"unsupported message type {type(message)}")
continue
if len(message) == 0:
log("warning", "empty message")
continue
kind = message[0]
if kind == 1: # audio
payload = message[1:]
opus_reader.append_bytes(payload)
else:
log("warning", f"unknown message kind {kind}")
finally:
close = True
log("info", "connection closed")
async def opus_loop():
all_pcm_data = None
while True:
if close:
return
await asyncio.sleep(0.001)
pcm = opus_reader.read_pcm()
if pcm.shape[-1] == 0:
continue
if all_pcm_data is None:
all_pcm_data = pcm
else:
all_pcm_data = np.concatenate((all_pcm_data, pcm))
while all_pcm_data.shape[-1] >= self.frame_size:
be = time.time()
chunk = all_pcm_data[: self.frame_size]
all_pcm_data = all_pcm_data[self.frame_size:]
chunk = torch.from_numpy(chunk)
chunk = chunk.to(device=self.device)[None, None]
codes = self.mimi.encode(chunk)
for c in range(codes.shape[-1]):
tokens = self.lm_gen.step(codes[:, :, c: c + 1])
if tokens is None:
continue
assert tokens.shape[1] == self.lm_gen.lm_model.dep_q + 1
main_pcm = self.mimi.decode(tokens[:, 1:])
main_pcm = main_pcm.cpu()
opus_writer.append_pcm(main_pcm[0, 0].numpy())
text_token = tokens[0, 0, 0].item()
if text_token not in (0, 3):
_text = self.text_tokenizer.id_to_piece(text_token) # type: ignore
_text = _text.replace("▁", " ")
msg = b"\x02" + bytes(_text, encoding="utf8")
log("info", f"text token '{_text}'")
await ws.send_bytes(msg)
log("info", f"frame handled in {1000 * (time.time() - be):.1f}ms")
async def send_loop():
while True:
if close:
return
await asyncio.sleep(0.001)
msg = opus_writer.read_bytes()
if len(msg) > 0:
await ws.send_bytes(b"\x01" + msg)
log("info", "accepted connection")
close = False
async with self.lock:
opus_writer = sphn.OpusStreamWriter(self.mimi.sample_rate)
opus_reader = sphn.OpusStreamReader(self.mimi.sample_rate)
self.mimi.reset_streaming()
self.lm_gen.reset_streaming()
# Send the handshake.
await ws.send_bytes(b"\x00")
await asyncio.gather(opus_loop(), recv_loop(), send_loop())
log("info", "done with connection")
return ws
def main():
parser = argparse.ArgumentParser()
parser.add_argument("--host", default="localhost", type=str)
parser.add_argument("--port", default=8998, type=int)
parser.add_argument("--static", type=str)
parser.add_argument("--gradio-tunnel", action='store_true', help='Activate a gradio tunnel.')
parser.add_argument("--gradio-tunnel-token",
help='Provide a custom (secret) token here to keep getting the same URL.')
parser.add_argument("--tokenizer", type=str, help="Path to a local tokenizer file.")
parser.add_argument("--moshi-weight", type=str, help="Path to a local checkpoint file for Moshi.")
parser.add_argument("--mimi-weight", type=str, help="Path to a local checkpoint file for Mimi.")
parser.add_argument("--hf-repo", type=str, default=loaders.DEFAULT_REPO,
help="HF repo to look into, defaults Moshiko. "
"Use this to select a different pre-trained model.")
parser.add_argument("--device", type=str, default="cuda", help="Device on which to run, defaults to 'cuda'.")
args = parser.parse_args()
seed_all(42424242)
setup_tunnel = None
tunnel_token = ''
if args.gradio_tunnel:
try:
from gradio import networking # type: ignore
except ImportError:
log("error", "Cannot find gradio which is required to activate a tunnel. "
"Please install with `pip install gradio`.")
sys.exit(1)
setup_tunnel = networking.setup_tunnel
if args.gradio_tunnel_token is None:
tunnel_token = secrets.token_urlsafe(32)
else:
tunnel_token = args.gradio_tunnel_token
log("info", "loading mimi")
if args.mimi_weight is None:
args.mimi_weight = hf_hub_download(args.hf_repo, loaders.MIMI_NAME)
mimi = loaders.get_mimi(args.mimi_weight, args.device)
log("info", "mimi loaded")
if args.tokenizer is None:
args.tokenizer = hf_hub_download(args.hf_repo, loaders.TEXT_TOKENIZER_NAME)
text_tokenizer = sentencepiece.SentencePieceProcessor(args.tokenizer) # type: ignore
log("info", "loading moshi")
if args.moshi_weight is None:
args.moshi_weight = hf_hub_download(args.hf_repo, loaders.MOSHI_NAME)
lm = loaders.get_moshi_lm(args.moshi_weight, args.device)
log("info", "moshi loaded")
state = ServerState(mimi, text_tokenizer, lm, args.device)
log("info", "warming up the model")
state.warmup()
app = web.Application()
app.router.add_get("/api/chat", state.handle_chat)
static_path: None | str = None
if args.static is None:
log("info", "retrieving the static content")
dist_tgz = hf_hub_download("kyutai/moshi-artifacts", "dist.tgz")
dist_tgz = Path(dist_tgz)
dist = dist_tgz.parent / "dist"
if not dist.exists():
with tarfile.open(dist_tgz, "r:gz") as tar:
tar.extractall(path=dist_tgz.parent)
static_path = str(dist)
elif args.static != "none":
# When set to the "none" string, we don't serve any static content.
static_path = args.static
if static_path is not None:
async def handle_root(_):
return web.FileResponse(os.path.join(static_path, "index.html"))
log("info", f"serving static content from {static_path}")
app.router.add_get("/", handle_root)
app.router.add_static(
"/", path=static_path, follow_symlinks=True, name="static"
)
log("info", f"Access the Web UI directly at http://{args.host}:{args.port}")
if setup_tunnel is not None:
tunnel = setup_tunnel('localhost', args.port, tunnel_token, None)
log("info", f"Tunnel started, if executing on a remote GPU, you can use {tunnel}.")
log("info", "Note that this tunnel goes through the US and you might experience high latency in Europe.")
web.run_app(app, port=args.port)
with torch.no_grad():
main()
|