|
|
|
|
|
|
|
|
|
|
|
|
|
|
import argparse |
|
|
import asyncio |
|
|
from dataclasses import dataclass |
|
|
import random |
|
|
import os |
|
|
from pathlib import Path |
|
|
import tarfile |
|
|
import time |
|
|
import secrets |
|
|
import sys |
|
|
|
|
|
import aiohttp |
|
|
from aiohttp import web |
|
|
from huggingface_hub import hf_hub_download |
|
|
import numpy as np |
|
|
import sentencepiece |
|
|
import sphn |
|
|
import torch |
|
|
|
|
|
|
|
|
from .client_utils import make_log |
|
|
from .models import loaders, MimiModel, LMModel, LMGen |
|
|
|
|
|
|
|
|
def log(level: str, msg: str): |
|
|
print(make_log(level, msg)) |
|
|
|
|
|
|
|
|
def seed_all(seed): |
|
|
torch.manual_seed(seed) |
|
|
if torch.cuda.is_available(): |
|
|
torch.cuda.manual_seed(seed) |
|
|
torch.cuda.manual_seed_all(seed) |
|
|
random.seed(seed) |
|
|
np.random.seed(seed) |
|
|
torch.backends.cudnn.deterministic = False |
|
|
torch.backends.cudnn.benchmark = False |
|
|
|
|
|
|
|
|
@dataclass |
|
|
class ServerState: |
|
|
mimi: MimiModel |
|
|
text_tokenizer: sentencepiece.SentencePieceProcessor |
|
|
lm_gen: LMGen |
|
|
lock: asyncio.Lock |
|
|
|
|
|
def __init__(self, mimi: MimiModel, text_tokenizer: sentencepiece.SentencePieceProcessor, |
|
|
lm: LMModel, device: str | torch.device): |
|
|
self.mimi = mimi |
|
|
self.text_tokenizer = text_tokenizer |
|
|
self.lm_gen = LMGen(lm) |
|
|
|
|
|
self.device = device |
|
|
self.frame_size = int(self.mimi.sample_rate / self.mimi.frame_rate) |
|
|
self.lock = asyncio.Lock() |
|
|
|
|
|
self.mimi.streaming_forever(1) |
|
|
self.lm_gen.streaming_forever(1) |
|
|
|
|
|
def warmup(self): |
|
|
for chunk in range(4): |
|
|
chunk = torch.zeros(1, 1, self.frame_size, dtype=torch.float32, device=self.device) |
|
|
codes = self.mimi.encode(chunk) |
|
|
for c in range(codes.shape[-1]): |
|
|
tokens = self.lm_gen.step(codes[:, :, c: c + 1]) |
|
|
if tokens is None: |
|
|
continue |
|
|
_ = self.mimi.decode(tokens[:, 1:]) |
|
|
torch.cuda.synchronize() |
|
|
|
|
|
async def handle_chat(self, request): |
|
|
ws = web.WebSocketResponse() |
|
|
await ws.prepare(request) |
|
|
|
|
|
async def recv_loop(): |
|
|
nonlocal close |
|
|
try: |
|
|
async for message in ws: |
|
|
if message.type == aiohttp.WSMsgType.ERROR: |
|
|
log("error", f"{ws.exception()}") |
|
|
break |
|
|
elif message.type == aiohttp.WSMsgType.CLOSED: |
|
|
break |
|
|
elif message.type != aiohttp.WSMsgType.BINARY: |
|
|
log("error", f"unexpected message type {message.type}") |
|
|
continue |
|
|
message = message.data |
|
|
if not isinstance(message, bytes): |
|
|
log("error", f"unsupported message type {type(message)}") |
|
|
continue |
|
|
if len(message) == 0: |
|
|
log("warning", "empty message") |
|
|
continue |
|
|
kind = message[0] |
|
|
if kind == 1: |
|
|
payload = message[1:] |
|
|
opus_reader.append_bytes(payload) |
|
|
else: |
|
|
log("warning", f"unknown message kind {kind}") |
|
|
finally: |
|
|
close = True |
|
|
log("info", "connection closed") |
|
|
|
|
|
async def opus_loop(): |
|
|
all_pcm_data = None |
|
|
|
|
|
while True: |
|
|
if close: |
|
|
return |
|
|
await asyncio.sleep(0.001) |
|
|
pcm = opus_reader.read_pcm() |
|
|
if pcm.shape[-1] == 0: |
|
|
continue |
|
|
if all_pcm_data is None: |
|
|
all_pcm_data = pcm |
|
|
else: |
|
|
all_pcm_data = np.concatenate((all_pcm_data, pcm)) |
|
|
while all_pcm_data.shape[-1] >= self.frame_size: |
|
|
be = time.time() |
|
|
chunk = all_pcm_data[: self.frame_size] |
|
|
all_pcm_data = all_pcm_data[self.frame_size:] |
|
|
chunk = torch.from_numpy(chunk) |
|
|
chunk = chunk.to(device=self.device)[None, None] |
|
|
codes = self.mimi.encode(chunk) |
|
|
for c in range(codes.shape[-1]): |
|
|
tokens = self.lm_gen.step(codes[:, :, c: c + 1]) |
|
|
if tokens is None: |
|
|
continue |
|
|
assert tokens.shape[1] == self.lm_gen.lm_model.dep_q + 1 |
|
|
main_pcm = self.mimi.decode(tokens[:, 1:]) |
|
|
main_pcm = main_pcm.cpu() |
|
|
opus_writer.append_pcm(main_pcm[0, 0].numpy()) |
|
|
text_token = tokens[0, 0, 0].item() |
|
|
if text_token not in (0, 3): |
|
|
_text = self.text_tokenizer.id_to_piece(text_token) |
|
|
_text = _text.replace("▁", " ") |
|
|
msg = b"\x02" + bytes(_text, encoding="utf8") |
|
|
log("info", f"text token '{_text}'") |
|
|
await ws.send_bytes(msg) |
|
|
log("info", f"frame handled in {1000 * (time.time() - be):.1f}ms") |
|
|
|
|
|
async def send_loop(): |
|
|
while True: |
|
|
if close: |
|
|
return |
|
|
await asyncio.sleep(0.001) |
|
|
msg = opus_writer.read_bytes() |
|
|
if len(msg) > 0: |
|
|
await ws.send_bytes(b"\x01" + msg) |
|
|
|
|
|
log("info", "accepted connection") |
|
|
close = False |
|
|
async with self.lock: |
|
|
opus_writer = sphn.OpusStreamWriter(self.mimi.sample_rate) |
|
|
opus_reader = sphn.OpusStreamReader(self.mimi.sample_rate) |
|
|
self.mimi.reset_streaming() |
|
|
self.lm_gen.reset_streaming() |
|
|
|
|
|
await ws.send_bytes(b"\x00") |
|
|
await asyncio.gather(opus_loop(), recv_loop(), send_loop()) |
|
|
log("info", "done with connection") |
|
|
return ws |
|
|
|
|
|
|
|
|
def main(): |
|
|
parser = argparse.ArgumentParser() |
|
|
parser.add_argument("--host", default="localhost", type=str) |
|
|
parser.add_argument("--port", default=8998, type=int) |
|
|
parser.add_argument("--static", type=str) |
|
|
parser.add_argument("--gradio-tunnel", action='store_true', help='Activate a gradio tunnel.') |
|
|
parser.add_argument("--gradio-tunnel-token", |
|
|
help='Provide a custom (secret) token here to keep getting the same URL.') |
|
|
|
|
|
parser.add_argument("--tokenizer", type=str, help="Path to a local tokenizer file.") |
|
|
parser.add_argument("--moshi-weight", type=str, help="Path to a local checkpoint file for Moshi.") |
|
|
parser.add_argument("--mimi-weight", type=str, help="Path to a local checkpoint file for Mimi.") |
|
|
parser.add_argument("--hf-repo", type=str, default=loaders.DEFAULT_REPO, |
|
|
help="HF repo to look into, defaults Moshiko. " |
|
|
"Use this to select a different pre-trained model.") |
|
|
parser.add_argument("--device", type=str, default="cuda", help="Device on which to run, defaults to 'cuda'.") |
|
|
|
|
|
args = parser.parse_args() |
|
|
seed_all(42424242) |
|
|
|
|
|
setup_tunnel = None |
|
|
tunnel_token = '' |
|
|
if args.gradio_tunnel: |
|
|
try: |
|
|
from gradio import networking |
|
|
except ImportError: |
|
|
log("error", "Cannot find gradio which is required to activate a tunnel. " |
|
|
"Please install with `pip install gradio`.") |
|
|
sys.exit(1) |
|
|
setup_tunnel = networking.setup_tunnel |
|
|
if args.gradio_tunnel_token is None: |
|
|
tunnel_token = secrets.token_urlsafe(32) |
|
|
else: |
|
|
tunnel_token = args.gradio_tunnel_token |
|
|
|
|
|
log("info", "loading mimi") |
|
|
if args.mimi_weight is None: |
|
|
args.mimi_weight = hf_hub_download(args.hf_repo, loaders.MIMI_NAME) |
|
|
mimi = loaders.get_mimi(args.mimi_weight, args.device) |
|
|
log("info", "mimi loaded") |
|
|
|
|
|
if args.tokenizer is None: |
|
|
args.tokenizer = hf_hub_download(args.hf_repo, loaders.TEXT_TOKENIZER_NAME) |
|
|
text_tokenizer = sentencepiece.SentencePieceProcessor(args.tokenizer) |
|
|
|
|
|
log("info", "loading moshi") |
|
|
if args.moshi_weight is None: |
|
|
args.moshi_weight = hf_hub_download(args.hf_repo, loaders.MOSHI_NAME) |
|
|
lm = loaders.get_moshi_lm(args.moshi_weight, args.device) |
|
|
log("info", "moshi loaded") |
|
|
|
|
|
state = ServerState(mimi, text_tokenizer, lm, args.device) |
|
|
log("info", "warming up the model") |
|
|
state.warmup() |
|
|
app = web.Application() |
|
|
app.router.add_get("/api/chat", state.handle_chat) |
|
|
static_path: None | str = None |
|
|
if args.static is None: |
|
|
log("info", "retrieving the static content") |
|
|
dist_tgz = hf_hub_download("kyutai/moshi-artifacts", "dist.tgz") |
|
|
dist_tgz = Path(dist_tgz) |
|
|
dist = dist_tgz.parent / "dist" |
|
|
if not dist.exists(): |
|
|
with tarfile.open(dist_tgz, "r:gz") as tar: |
|
|
tar.extractall(path=dist_tgz.parent) |
|
|
static_path = str(dist) |
|
|
elif args.static != "none": |
|
|
|
|
|
static_path = args.static |
|
|
if static_path is not None: |
|
|
async def handle_root(_): |
|
|
return web.FileResponse(os.path.join(static_path, "index.html")) |
|
|
|
|
|
log("info", f"serving static content from {static_path}") |
|
|
app.router.add_get("/", handle_root) |
|
|
app.router.add_static( |
|
|
"/", path=static_path, follow_symlinks=True, name="static" |
|
|
) |
|
|
log("info", f"Access the Web UI directly at http://{args.host}:{args.port}") |
|
|
if setup_tunnel is not None: |
|
|
tunnel = setup_tunnel('localhost', args.port, tunnel_token, None) |
|
|
log("info", f"Tunnel started, if executing on a remote GPU, you can use {tunnel}.") |
|
|
log("info", "Note that this tunnel goes through the US and you might experience high latency in Europe.") |
|
|
web.run_app(app, port=args.port) |
|
|
|
|
|
|
|
|
with torch.no_grad(): |
|
|
main() |
|
|
|