andy hickl
commited on
Commit
·
ed99557
1
Parent(s):
bfcda41
Uploaded from Github
Browse files- LICENSE +23 -0
- LICENSE.audiocraft +21 -0
- MANIFEST.in +5 -0
- README.md +151 -0
- moshi/__init__.py +18 -0
- moshi/client.py +196 -0
- moshi/client_utils.py +211 -0
- moshi/models/__init__.py +14 -0
- moshi/models/compression.py +474 -0
- moshi/models/lm.py +488 -0
- moshi/models/loaders.py +159 -0
- moshi/modules/__init__.py +23 -0
- moshi/modules/conv.py +329 -0
- moshi/modules/gating.py +82 -0
- moshi/modules/resample.py +119 -0
- moshi/modules/rope.py +90 -0
- moshi/modules/seanet.py +395 -0
- moshi/modules/streaming.py +363 -0
- moshi/modules/transformer.py +750 -0
- moshi/quantization/__init__.py +13 -0
- moshi/quantization/base.py +170 -0
- moshi/quantization/core_vq.py +384 -0
- moshi/quantization/vq.py +340 -0
- moshi/server.py +256 -0
- moshi/utils/__init__.py +10 -0
- moshi/utils/autocast.py +45 -0
- moshi/utils/compile.py +284 -0
- moshi/utils/sampling.py +126 -0
- pyproject.toml +33 -0
- requirements.txt +10 -0
- setup.cfg +10 -0
LICENSE
ADDED
|
@@ -0,0 +1,23 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
Permission is hereby granted, free of charge, to any
|
| 2 |
+
person obtaining a copy of this software and associated
|
| 3 |
+
documentation files (the "Software"), to deal in the
|
| 4 |
+
Software without restriction, including without
|
| 5 |
+
limitation the rights to use, copy, modify, merge,
|
| 6 |
+
publish, distribute, sublicense, and/or sell copies of
|
| 7 |
+
the Software, and to permit persons to whom the Software
|
| 8 |
+
is furnished to do so, subject to the following
|
| 9 |
+
conditions:
|
| 10 |
+
|
| 11 |
+
The above copyright notice and this permission notice
|
| 12 |
+
shall be included in all copies or substantial portions
|
| 13 |
+
of the Software.
|
| 14 |
+
|
| 15 |
+
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF
|
| 16 |
+
ANY KIND, EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED
|
| 17 |
+
TO THE WARRANTIES OF MERCHANTABILITY, FITNESS FOR A
|
| 18 |
+
PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT
|
| 19 |
+
SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY
|
| 20 |
+
CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION
|
| 21 |
+
OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR
|
| 22 |
+
IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER
|
| 23 |
+
DEALINGS IN THE SOFTWARE.
|
LICENSE.audiocraft
ADDED
|
@@ -0,0 +1,21 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
MIT License
|
| 2 |
+
|
| 3 |
+
Copyright (c) Meta Platforms, Inc. and affiliates.
|
| 4 |
+
|
| 5 |
+
Permission is hereby granted, free of charge, to any person obtaining a copy
|
| 6 |
+
of this software and associated documentation files (the "Software"), to deal
|
| 7 |
+
in the Software without restriction, including without limitation the rights
|
| 8 |
+
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
|
| 9 |
+
copies of the Software, and to permit persons to whom the Software is
|
| 10 |
+
furnished to do so, subject to the following conditions:
|
| 11 |
+
|
| 12 |
+
The above copyright notice and this permission notice shall be included in all
|
| 13 |
+
copies or substantial portions of the Software.
|
| 14 |
+
|
| 15 |
+
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
|
| 16 |
+
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
|
| 17 |
+
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
|
| 18 |
+
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
|
| 19 |
+
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
|
| 20 |
+
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
|
| 21 |
+
SOFTWARE.
|
MANIFEST.in
ADDED
|
@@ -0,0 +1,5 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
include LICENSE*
|
| 2 |
+
include *.md
|
| 3 |
+
include *.cfg
|
| 4 |
+
include requirements.txt
|
| 5 |
+
include moshi/py.typed
|
README.md
ADDED
|
@@ -0,0 +1,151 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Moshi - PyTorch
|
| 2 |
+
|
| 3 |
+
See the [top-level README.md][main_repo] for more information on Moshi.
|
| 4 |
+
|
| 5 |
+
[Moshi][moshi] is a speech-text foundation model and full-duplex spoken dialogue framework.
|
| 6 |
+
It uses [Mimi][moshi], a state-of-the-art streaming neural audio codec. Mimi operates at 12.5 Hz, and compresses
|
| 7 |
+
24 kHz audio down to 1.1 kbps, in a fully streaming manner (latency of 80ms, the frame size), yet performs better than existing, non-streaming, codec.
|
| 8 |
+
|
| 9 |
+
This is the PyTorch implementation for Moshi and Mimi.
|
| 10 |
+
|
| 11 |
+
|
| 12 |
+
## Requirements
|
| 13 |
+
|
| 14 |
+
You will need at least Python 3.10. We kept a minimal set of dependencies for the current project.
|
| 15 |
+
It was tested with PyTorch 2.2 or 2.4. If you need a specific CUDA version, please make sure
|
| 16 |
+
to have PyTorch properly installed before installing Moshi.
|
| 17 |
+
|
| 18 |
+
```bash
|
| 19 |
+
pip install moshi # moshi PyTorch, from PyPI
|
| 20 |
+
# Or the bleeding edge versions for Moshi
|
| 21 |
+
pip install -e "git+https://git@github.com/kyutai-labs/moshi#egg=moshi&subdirectory=moshi"
|
| 22 |
+
```
|
| 23 |
+
|
| 24 |
+
While we hope that the present codebase will work on Windows, we do not provide official support for it.
|
| 25 |
+
At the moment, we do not support quantization for the PyTorch version, so you will need a GPU with a significant amount of memory (24GB).
|
| 26 |
+
|
| 27 |
+
|
| 28 |
+
## Usage
|
| 29 |
+
|
| 30 |
+
This package provides a streaming version of the audio tokenizer (Mimi) and the lm model (Moshi).
|
| 31 |
+
|
| 32 |
+
In order to run in interactive mode, you need to start a server which will
|
| 33 |
+
run the model, you can then use either the web UI or a command line client.
|
| 34 |
+
|
| 35 |
+
Start the server with:
|
| 36 |
+
```bash
|
| 37 |
+
python -m moshi.server [--gradio-tunnel]
|
| 38 |
+
```
|
| 39 |
+
|
| 40 |
+
And then access the web UI on [localhost:8998](http://localhost:8998). If your GPU is on a distant machine
|
| 41 |
+
with no direct access, `--gradio-tunnel` will create a tunnel with a URL accessible from anywhere.
|
| 42 |
+
Keep in mind that this tunnel goes through the US and can add significant latency (up to 500ms from Europe).
|
| 43 |
+
You can use `--gradio-tunnel-token` to set a fixed secret token and reuse the same address over time.
|
| 44 |
+
Alternatively, you might want to use SSH to redirect your connection.
|
| 45 |
+
|
| 46 |
+
You can use `--hf-repo` to select a different pretrained model, by setting the proper Hugging Face repository.
|
| 47 |
+
See [the model list](https://github.com/kyutai-labs/moshi?tab=readme-ov-file#models) for a reference of the available models.
|
| 48 |
+
|
| 49 |
+
Accessing a server that is not localhost via http may cause issues with using
|
| 50 |
+
the microphone in the web UI (in some browsers this is only allowed using
|
| 51 |
+
https).
|
| 52 |
+
|
| 53 |
+
A local client is also available, as
|
| 54 |
+
```bash
|
| 55 |
+
python -m moshi.client [--url URL_TO_GRADIO]
|
| 56 |
+
```
|
| 57 |
+
However note, that unlike the web browser, this client is barebone. It does not perform any echo cancellation,
|
| 58 |
+
nor does it try to compensate for a growing lag by skipping frames.
|
| 59 |
+
|
| 60 |
+
|
| 61 |
+
## API
|
| 62 |
+
|
| 63 |
+
You can use programmatically the Mimi/Moshi as follows:
|
| 64 |
+
```python
|
| 65 |
+
from huggingface_hub import hf_hub_download
|
| 66 |
+
import torch
|
| 67 |
+
|
| 68 |
+
from moshi.models import loaders, LMGen
|
| 69 |
+
|
| 70 |
+
mimi_weight = hf_hub_download(loaders.DEFAULT_REPO, loaders.MIMI_NAME)
|
| 71 |
+
mimi = loaders.get_mimi(mimi_weight, device='cpu')
|
| 72 |
+
mimi.set_num_codebooks(8) # up to 32 for mimi, but limited to 8 for moshi.
|
| 73 |
+
|
| 74 |
+
wav = torch.randn(1, 1, 24000 * 10) # should be [B, C=1, T]
|
| 75 |
+
with torch.no_grad():
|
| 76 |
+
codes = mimi.encode(wav) # [B, K = 8, T]
|
| 77 |
+
decoded = mimi.decode(codes)
|
| 78 |
+
|
| 79 |
+
# Supports streaming too.
|
| 80 |
+
frame_size = int(mimi.sample_rate / mimi.frame_rate)
|
| 81 |
+
all_codes = []
|
| 82 |
+
with mimi.streaming(batch_size=1):
|
| 83 |
+
for offset in range(0, wav.shape[-1], frame_size):
|
| 84 |
+
frame = wav[:, :, offset: offset + frame_size]
|
| 85 |
+
codes = mimi.encode(frame)
|
| 86 |
+
assert codes.shape[-1] == 1, codes.shape
|
| 87 |
+
all_codes.append(codes)
|
| 88 |
+
# Now if you have a GPU around.
|
| 89 |
+
mimi.cuda()
|
| 90 |
+
moshi_weight = hf_hub_download(loaders.DEFAULT_REPO, loaders.MOSHI_NAME)
|
| 91 |
+
moshi = loaders.get_moshi_lm(moshi_weight, device='cuda')
|
| 92 |
+
lm_gen = LMGen(moshi, temp=0.8, temp_text=0.7) # this handles sampling params etc.
|
| 93 |
+
out_wav_chunks = []
|
| 94 |
+
# Now we will stream over both Moshi I/O, and decode on the fly with Mimi.
|
| 95 |
+
with torch.no_grad(), lm_gen.streaming(1), mimi.streaming(1):
|
| 96 |
+
for idx, code in enumerate(all_codes):
|
| 97 |
+
tokens_out = lm_gen.step(code.cuda())
|
| 98 |
+
# tokens_out is [B, 1 + 8, 1], with tokens_out[:, 1] representing the text token.
|
| 99 |
+
if tokens_out is not None:
|
| 100 |
+
wav_chunk = mimi.decode(tokens_out[:, 1:])
|
| 101 |
+
out_wav_chunks.append(wav_chunk)
|
| 102 |
+
print(idx, end='\r')
|
| 103 |
+
out_wav = torch.cat(out_wav_chunks, dim=-1)
|
| 104 |
+
```
|
| 105 |
+
|
| 106 |
+
## Development
|
| 107 |
+
|
| 108 |
+
If you wish to install from a clone of this repository, maybe to further develop Moshi, you can do the following:
|
| 109 |
+
```bash
|
| 110 |
+
# From the current folder (e.g. `moshi/`)
|
| 111 |
+
pip install -e '.[dev]'
|
| 112 |
+
pre-commit install
|
| 113 |
+
```
|
| 114 |
+
|
| 115 |
+
Once locally installed, Mimi can be tested with the following command, from **the root** of the repository,
|
| 116 |
+
```bash
|
| 117 |
+
wget https://github.com/metavoiceio/metavoice-src/raw/main/assets/bria.mp3
|
| 118 |
+
python scripts/mimi_streaming_test.py
|
| 119 |
+
|
| 120 |
+
```
|
| 121 |
+
|
| 122 |
+
Similary, Moshi can be tested (with a GPU) with
|
| 123 |
+
```bash
|
| 124 |
+
python scripts/moshi_benchmark.py
|
| 125 |
+
```
|
| 126 |
+
|
| 127 |
+
|
| 128 |
+
## License
|
| 129 |
+
|
| 130 |
+
The present code is provided under the MIT license.
|
| 131 |
+
Note that parts of this code is based on [AudioCraft](https://github.com/facebookresearch/audiocraft), released under
|
| 132 |
+
the MIT license.
|
| 133 |
+
|
| 134 |
+
## Citation
|
| 135 |
+
|
| 136 |
+
If you use either Mimi or Moshi, please cite the following paper,
|
| 137 |
+
|
| 138 |
+
```
|
| 139 |
+
@techreport{kyutai2024moshi,
|
| 140 |
+
author = {Alexandre D\'efossez and Laurent Mazar\'e and Manu Orsini and Am\'elie Royer and
|
| 141 |
+
Patrick P\'erez and Herv\'e J\'egou and Edouard Grave and Neil Zeghidour},
|
| 142 |
+
title = {Moshi: a speech-text foundation model for real-time dialogue},
|
| 143 |
+
institution = {Kyutai},
|
| 144 |
+
year={2024},
|
| 145 |
+
month={September},
|
| 146 |
+
url={http://kyutai.org/Moshi.pdf},
|
| 147 |
+
}
|
| 148 |
+
```
|
| 149 |
+
|
| 150 |
+
[moshi]: https://kyutai.org/Moshi.pdf
|
| 151 |
+
[main_repo]: https://github.com/kyutai-labs/moshi
|
moshi/__init__.py
ADDED
|
@@ -0,0 +1,18 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) Kyutai, all rights reserved.
|
| 2 |
+
# This source code is licensed under the license found in the
|
| 3 |
+
# LICENSE file in the root directory of this source tree.
|
| 4 |
+
|
| 5 |
+
"""
|
| 6 |
+
moshi is the inference codebase for Kyutai audio generation models.
|
| 7 |
+
|
| 8 |
+
The code has been adapted from Audiocraft, see LICENSE.audiocraft
|
| 9 |
+
Copyright (c) Meta Platforms, Inc. and affiliates.
|
| 10 |
+
"""
|
| 11 |
+
|
| 12 |
+
# flake8: noqa
|
| 13 |
+
from . import utils
|
| 14 |
+
from . import modules
|
| 15 |
+
from . import models
|
| 16 |
+
from . import quantization
|
| 17 |
+
|
| 18 |
+
__version__ = "0.1.0"
|
moshi/client.py
ADDED
|
@@ -0,0 +1,196 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) Kyutai, all rights reserved.
|
| 2 |
+
# This source code is licensed under the license found in the
|
| 3 |
+
# LICENSE file in the root directory of this source tree.
|
| 4 |
+
"""Client for the Moshi server."""
|
| 5 |
+
|
| 6 |
+
import argparse
|
| 7 |
+
import asyncio
|
| 8 |
+
import queue
|
| 9 |
+
import sys
|
| 10 |
+
|
| 11 |
+
import aiohttp
|
| 12 |
+
import numpy as np
|
| 13 |
+
import sphn
|
| 14 |
+
import sounddevice as sd
|
| 15 |
+
|
| 16 |
+
from .client_utils import AnyPrinter, Printer, RawPrinter
|
| 17 |
+
|
| 18 |
+
|
| 19 |
+
class Connection:
|
| 20 |
+
def __init__(
|
| 21 |
+
self,
|
| 22 |
+
printer: AnyPrinter,
|
| 23 |
+
websocket: aiohttp.ClientWebSocketResponse,
|
| 24 |
+
sample_rate: float = 24000,
|
| 25 |
+
channels: int = 1,
|
| 26 |
+
frame_size: int = 1920,
|
| 27 |
+
) -> None:
|
| 28 |
+
self.printer = printer
|
| 29 |
+
self.websocket = websocket
|
| 30 |
+
self.sample_rate = sample_rate
|
| 31 |
+
self.frame_size = frame_size
|
| 32 |
+
self.channels = channels
|
| 33 |
+
|
| 34 |
+
self._done = False
|
| 35 |
+
self._in_stream = sd.InputStream(
|
| 36 |
+
samplerate=sample_rate,
|
| 37 |
+
channels=channels,
|
| 38 |
+
blocksize=self.frame_size,
|
| 39 |
+
callback=self._on_audio_input,
|
| 40 |
+
)
|
| 41 |
+
|
| 42 |
+
self._out_stream = sd.OutputStream(
|
| 43 |
+
samplerate=sample_rate,
|
| 44 |
+
channels=channels,
|
| 45 |
+
blocksize=frame_size,
|
| 46 |
+
callback=self._on_audio_output,
|
| 47 |
+
)
|
| 48 |
+
self._opus_writer = sphn.OpusStreamWriter(sample_rate)
|
| 49 |
+
self._opus_reader = sphn.OpusStreamReader(sample_rate)
|
| 50 |
+
self._output_queue = queue.Queue()
|
| 51 |
+
|
| 52 |
+
async def _queue_loop(self) -> None:
|
| 53 |
+
while True:
|
| 54 |
+
if self._done:
|
| 55 |
+
return
|
| 56 |
+
await asyncio.sleep(0.001)
|
| 57 |
+
msg = self._opus_writer.read_bytes()
|
| 58 |
+
if len(msg) > 0:
|
| 59 |
+
try:
|
| 60 |
+
await self.websocket.send_bytes(b"\x01" + msg)
|
| 61 |
+
except Exception as e:
|
| 62 |
+
print(e)
|
| 63 |
+
self._lost_connection()
|
| 64 |
+
return
|
| 65 |
+
|
| 66 |
+
async def _decoder_loop(self) -> None:
|
| 67 |
+
all_pcm_data = None
|
| 68 |
+
while True:
|
| 69 |
+
if self._done:
|
| 70 |
+
return
|
| 71 |
+
await asyncio.sleep(0.001)
|
| 72 |
+
pcm = self._opus_reader.read_pcm()
|
| 73 |
+
if all_pcm_data is None:
|
| 74 |
+
all_pcm_data = pcm
|
| 75 |
+
else:
|
| 76 |
+
all_pcm_data = np.concatenate((all_pcm_data, pcm))
|
| 77 |
+
while all_pcm_data.shape[-1] >= self.frame_size:
|
| 78 |
+
self._output_queue.put(all_pcm_data[: self.frame_size])
|
| 79 |
+
all_pcm_data = np.array(all_pcm_data[self.frame_size :])
|
| 80 |
+
|
| 81 |
+
async def _recv_loop(self) -> None:
|
| 82 |
+
try:
|
| 83 |
+
async for message in self.websocket:
|
| 84 |
+
if message.type == aiohttp.WSMsgType.CLOSED:
|
| 85 |
+
self.printer.log("info", "Connection closed")
|
| 86 |
+
break
|
| 87 |
+
elif message.type == aiohttp.WSMsgType.ERROR:
|
| 88 |
+
self.printer.log("error", f"{self.websocket.exception()}")
|
| 89 |
+
break
|
| 90 |
+
elif message.type != aiohttp.WSMsgType.BINARY:
|
| 91 |
+
self.printer.log("error", f"received from server: {message.type}")
|
| 92 |
+
continue
|
| 93 |
+
message = message.data
|
| 94 |
+
if not isinstance(message, bytes):
|
| 95 |
+
self.printer.log(
|
| 96 |
+
"warning", f"unsupported message type {type(message)}"
|
| 97 |
+
)
|
| 98 |
+
continue
|
| 99 |
+
if len(message) == 0:
|
| 100 |
+
self.printer.log("warning", "empty message")
|
| 101 |
+
continue
|
| 102 |
+
kind = message[0]
|
| 103 |
+
if kind == 1: # audio
|
| 104 |
+
payload = message[1:]
|
| 105 |
+
self._opus_reader.append_bytes(payload)
|
| 106 |
+
self.printer.print_pending()
|
| 107 |
+
elif kind == 2: # text
|
| 108 |
+
payload = message[1:]
|
| 109 |
+
self.printer.print_token(payload.decode())
|
| 110 |
+
else:
|
| 111 |
+
self.printer.log("warning", f"unknown message kind {kind}")
|
| 112 |
+
except Exception as e:
|
| 113 |
+
print(e)
|
| 114 |
+
self._lost_connection()
|
| 115 |
+
return
|
| 116 |
+
|
| 117 |
+
def _lost_connection(self) -> None:
|
| 118 |
+
if not self._done:
|
| 119 |
+
self.printer.log("error", "Lost connection with the server!")
|
| 120 |
+
self._done = True
|
| 121 |
+
|
| 122 |
+
def _on_audio_input(self, in_data, frames, time_, status) -> None:
|
| 123 |
+
assert in_data.shape == (self.frame_size, self.channels), in_data.shape
|
| 124 |
+
self._opus_writer.append_pcm(in_data[:, 0])
|
| 125 |
+
|
| 126 |
+
def _on_audio_output(self, out_data, frames, time_, status) -> None:
|
| 127 |
+
assert out_data.shape == (self.frame_size, self.channels), out_data.shape
|
| 128 |
+
try:
|
| 129 |
+
pcm_data = self._output_queue.get(block=False)
|
| 130 |
+
# TODO: handle other shapes by using some form of fifo/ring buffer.
|
| 131 |
+
assert pcm_data.shape == (self.frame_size,), pcm_data.shape
|
| 132 |
+
out_data[:, 0] = pcm_data
|
| 133 |
+
except queue.Empty:
|
| 134 |
+
out_data.fill(0)
|
| 135 |
+
self.printer.print_lag()
|
| 136 |
+
|
| 137 |
+
async def run(self) -> None:
|
| 138 |
+
with self._in_stream, self._out_stream:
|
| 139 |
+
await asyncio.gather(
|
| 140 |
+
self._recv_loop(), self._decoder_loop(), self._queue_loop()
|
| 141 |
+
)
|
| 142 |
+
|
| 143 |
+
|
| 144 |
+
async def run(printer: AnyPrinter, args):
|
| 145 |
+
if args.url is None:
|
| 146 |
+
proto = "ws"
|
| 147 |
+
if args.https:
|
| 148 |
+
proto += "s"
|
| 149 |
+
uri = f"{proto}://{args.host}:{args.port}/api/chat"
|
| 150 |
+
else:
|
| 151 |
+
proto = "wss"
|
| 152 |
+
if '://' in args.url:
|
| 153 |
+
proto, without_proto = args.url.split('://', 1)
|
| 154 |
+
if proto in ['ws', 'http']:
|
| 155 |
+
proto = "ws"
|
| 156 |
+
elif proto in ['wss', 'https']:
|
| 157 |
+
proto = "wss"
|
| 158 |
+
else:
|
| 159 |
+
printer.log("error", "The provided URL {args.url} seems to contain a protocol but it is unknown.")
|
| 160 |
+
sys.exit(1)
|
| 161 |
+
else:
|
| 162 |
+
without_proto = args.url
|
| 163 |
+
uri = f"{proto}://{without_proto}/api/chat"
|
| 164 |
+
|
| 165 |
+
printer.log("info", "Connecting to {uri}.")
|
| 166 |
+
async with aiohttp.ClientSession() as session:
|
| 167 |
+
async with session.ws_connect(uri) as ws:
|
| 168 |
+
printer.log("info", "connected!")
|
| 169 |
+
printer.print_header()
|
| 170 |
+
connection = Connection(printer, ws)
|
| 171 |
+
await connection.run()
|
| 172 |
+
|
| 173 |
+
|
| 174 |
+
def main():
|
| 175 |
+
parser = argparse.ArgumentParser("client_opus")
|
| 176 |
+
parser.add_argument("--host", default="localhost", type=str, help="Hostname to connect to.")
|
| 177 |
+
parser.add_argument("--port", default=8998, type=int, help="Port to connect to.")
|
| 178 |
+
parser.add_argument("--https", action='store_true',
|
| 179 |
+
help="Set this flag for using a https connection.")
|
| 180 |
+
parser.add_argument("--url", type=str, help='Provides directly a URL, e.g. to a gradio tunnel.')
|
| 181 |
+
args = parser.parse_args()
|
| 182 |
+
printer: AnyPrinter
|
| 183 |
+
|
| 184 |
+
if sys.stdout.isatty():
|
| 185 |
+
printer = Printer()
|
| 186 |
+
else:
|
| 187 |
+
printer = RawPrinter()
|
| 188 |
+
try:
|
| 189 |
+
asyncio.run(run(printer, args))
|
| 190 |
+
except KeyboardInterrupt:
|
| 191 |
+
printer.log("warning", "Interrupting, exiting connection.")
|
| 192 |
+
printer.log("info", "All done!")
|
| 193 |
+
|
| 194 |
+
|
| 195 |
+
if __name__ == "__main__":
|
| 196 |
+
main()
|
moshi/client_utils.py
ADDED
|
@@ -0,0 +1,211 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) Kyutai, all rights reserved.
|
| 2 |
+
# This source code is licensed under the license found in the
|
| 3 |
+
# LICENSE file in the root directory of this source tree.
|
| 4 |
+
"""Utilities for the command line client, in particular for handling interactions with the terminal.
|
| 5 |
+
"""
|
| 6 |
+
|
| 7 |
+
from dataclasses import dataclass
|
| 8 |
+
import sys
|
| 9 |
+
|
| 10 |
+
|
| 11 |
+
def colorize(text, color):
|
| 12 |
+
code = f"\033[{color}m"
|
| 13 |
+
restore = "\033[0m"
|
| 14 |
+
return "".join([code, text, restore])
|
| 15 |
+
|
| 16 |
+
|
| 17 |
+
def make_log(level: str, msg: str) -> str:
|
| 18 |
+
if level == "warning":
|
| 19 |
+
prefix = colorize("[Warn]", "1;31")
|
| 20 |
+
elif level == "info":
|
| 21 |
+
prefix = colorize("[Info]", "1;34")
|
| 22 |
+
elif level == "error":
|
| 23 |
+
prefix = colorize("[Err ]", "1;31")
|
| 24 |
+
else:
|
| 25 |
+
raise ValueError(f"Unknown level {level}")
|
| 26 |
+
return prefix + " " + msg
|
| 27 |
+
|
| 28 |
+
|
| 29 |
+
class RawPrinter:
|
| 30 |
+
def __init__(self, stream=sys.stdout, err_stream=sys.stderr):
|
| 31 |
+
self.stream = stream
|
| 32 |
+
self.err_stream = err_stream
|
| 33 |
+
|
| 34 |
+
def print_header(self):
|
| 35 |
+
pass
|
| 36 |
+
|
| 37 |
+
def print_token(self, token: str):
|
| 38 |
+
self.stream.write(token)
|
| 39 |
+
self.stream.flush()
|
| 40 |
+
|
| 41 |
+
def log(self, level: str, msg: str):
|
| 42 |
+
print(f"{level.capitalize()}: {msg}", file=self.err_stream)
|
| 43 |
+
|
| 44 |
+
def print_lag(self):
|
| 45 |
+
self.err_stream.write(colorize(" [LAG]", "31"))
|
| 46 |
+
self.err_stream.flush()
|
| 47 |
+
|
| 48 |
+
def print_pending(self):
|
| 49 |
+
pass
|
| 50 |
+
|
| 51 |
+
|
| 52 |
+
@dataclass
|
| 53 |
+
class LineEntry:
|
| 54 |
+
msg: str
|
| 55 |
+
color: str | None = None
|
| 56 |
+
|
| 57 |
+
def render(self):
|
| 58 |
+
if self.color is None:
|
| 59 |
+
return self.msg
|
| 60 |
+
else:
|
| 61 |
+
return colorize(self.msg, self.color)
|
| 62 |
+
|
| 63 |
+
def __len__(self):
|
| 64 |
+
return len(self.msg)
|
| 65 |
+
|
| 66 |
+
|
| 67 |
+
class Line:
|
| 68 |
+
def __init__(self, stream):
|
| 69 |
+
self.stream = stream
|
| 70 |
+
self._line: list[LineEntry] = []
|
| 71 |
+
self._has_padding: bool = False
|
| 72 |
+
self._max_line_length = 0
|
| 73 |
+
|
| 74 |
+
def __bool__(self):
|
| 75 |
+
return bool(self._line)
|
| 76 |
+
|
| 77 |
+
def __len__(self):
|
| 78 |
+
return sum(len(entry) for entry in self._line)
|
| 79 |
+
|
| 80 |
+
def add(self, msg: str, color: str | None = None) -> int:
|
| 81 |
+
entry = LineEntry(msg, color)
|
| 82 |
+
return self._add(entry)
|
| 83 |
+
|
| 84 |
+
def _add(self, entry: LineEntry) -> int:
|
| 85 |
+
if self._has_padding:
|
| 86 |
+
self.erase(count=0)
|
| 87 |
+
self._line.append(entry)
|
| 88 |
+
self.stream.write(entry.render())
|
| 89 |
+
self._max_line_length = max(self._max_line_length, len(self))
|
| 90 |
+
return len(entry)
|
| 91 |
+
|
| 92 |
+
def erase(self, count: int = 1):
|
| 93 |
+
if count:
|
| 94 |
+
entries = list(self._line[:-count])
|
| 95 |
+
else:
|
| 96 |
+
entries = list(self._line)
|
| 97 |
+
self._line.clear()
|
| 98 |
+
self.stream.write("\r")
|
| 99 |
+
for entry in entries:
|
| 100 |
+
self._line.append(entry)
|
| 101 |
+
self.stream.write(entry.render())
|
| 102 |
+
|
| 103 |
+
self._has_padding = False
|
| 104 |
+
|
| 105 |
+
def newline(self):
|
| 106 |
+
missing = self._max_line_length - len(self)
|
| 107 |
+
if missing > 0:
|
| 108 |
+
self.stream.write(" " * missing)
|
| 109 |
+
self.stream.write("\n")
|
| 110 |
+
self._line.clear()
|
| 111 |
+
self._max_line_length = 0
|
| 112 |
+
self._has_padding = False
|
| 113 |
+
|
| 114 |
+
def flush(self):
|
| 115 |
+
missing = self._max_line_length - len(self)
|
| 116 |
+
if missing > 0:
|
| 117 |
+
self.stream.write(" " * missing)
|
| 118 |
+
self._has_padding = True
|
| 119 |
+
self.stream.flush()
|
| 120 |
+
|
| 121 |
+
|
| 122 |
+
class Printer:
|
| 123 |
+
def __init__(self, max_cols: int = 80, stream=sys.stdout, err_stream=sys.stderr):
|
| 124 |
+
self.max_cols = max_cols
|
| 125 |
+
self.line = Line(stream)
|
| 126 |
+
self.stream = stream
|
| 127 |
+
self.err_stream = err_stream
|
| 128 |
+
self._pending_count = 0
|
| 129 |
+
self._pending_printed = False
|
| 130 |
+
|
| 131 |
+
def print_header(self):
|
| 132 |
+
self.line.add(" " + "-" * (self.max_cols) + " ")
|
| 133 |
+
self.line.newline()
|
| 134 |
+
self.line.flush()
|
| 135 |
+
self.line.add("| ")
|
| 136 |
+
|
| 137 |
+
def _remove_pending(self) -> bool:
|
| 138 |
+
if self._pending_printed:
|
| 139 |
+
self._pending_printed = False
|
| 140 |
+
self.line.erase(1)
|
| 141 |
+
return True
|
| 142 |
+
return False
|
| 143 |
+
|
| 144 |
+
def print_token(self, token: str, color: str | None = None):
|
| 145 |
+
self._remove_pending()
|
| 146 |
+
remaining = self.max_cols - len(self.line)
|
| 147 |
+
if len(token) <= remaining:
|
| 148 |
+
self.line.add(token, color)
|
| 149 |
+
else:
|
| 150 |
+
end = " " * remaining + " |"
|
| 151 |
+
if token.startswith(" "):
|
| 152 |
+
token = token.lstrip()
|
| 153 |
+
self.line.add(end)
|
| 154 |
+
self.line.newline()
|
| 155 |
+
self.line.add("| ")
|
| 156 |
+
self.line.add(token, color)
|
| 157 |
+
else:
|
| 158 |
+
assert color is None
|
| 159 |
+
erase_count = None
|
| 160 |
+
cumulated = ""
|
| 161 |
+
for idx, entry in enumerate(self.line._line[::-1]):
|
| 162 |
+
if entry.color:
|
| 163 |
+
# probably a LAG message
|
| 164 |
+
erase_count = idx
|
| 165 |
+
break
|
| 166 |
+
if entry.msg.startswith(" "):
|
| 167 |
+
erase_count = idx + 1
|
| 168 |
+
cumulated = entry.msg + cumulated
|
| 169 |
+
break
|
| 170 |
+
if erase_count is not None:
|
| 171 |
+
if erase_count > 0:
|
| 172 |
+
self.line.erase(erase_count)
|
| 173 |
+
remaining = self.max_cols - len(self.line)
|
| 174 |
+
end = " " * remaining + " |"
|
| 175 |
+
self.line.add(end)
|
| 176 |
+
self.line.newline()
|
| 177 |
+
self.line.add("| ")
|
| 178 |
+
token = cumulated.lstrip() + token
|
| 179 |
+
self.line.add(token)
|
| 180 |
+
else:
|
| 181 |
+
self.line.add(token[:remaining])
|
| 182 |
+
self.line.add(" |")
|
| 183 |
+
self.line.newline()
|
| 184 |
+
self.line.add("| ")
|
| 185 |
+
self.line.add(token[remaining:])
|
| 186 |
+
self.line.flush()
|
| 187 |
+
|
| 188 |
+
def log(self, level: str, msg: str):
|
| 189 |
+
msg = make_log(level, msg)
|
| 190 |
+
self._remove_pending()
|
| 191 |
+
if self.line:
|
| 192 |
+
self.line.newline()
|
| 193 |
+
self.line.flush()
|
| 194 |
+
print(msg, file=self.err_stream)
|
| 195 |
+
self.err_stream.flush()
|
| 196 |
+
|
| 197 |
+
def print_lag(self):
|
| 198 |
+
self.print_token(" [LAG]", "31")
|
| 199 |
+
|
| 200 |
+
def print_pending(self):
|
| 201 |
+
chars = ["|", "/", "-", "\\"]
|
| 202 |
+
count = int(self._pending_count / 5)
|
| 203 |
+
char = chars[count % len(chars)]
|
| 204 |
+
colors = ["32", "33", "31"]
|
| 205 |
+
self._remove_pending()
|
| 206 |
+
self.line.add(char, colors[count % len(colors)])
|
| 207 |
+
self._pending_printed = True
|
| 208 |
+
self._pending_count += 1
|
| 209 |
+
|
| 210 |
+
|
| 211 |
+
AnyPrinter = Printer | RawPrinter
|
moshi/models/__init__.py
ADDED
|
@@ -0,0 +1,14 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) Kyutai, all rights reserved.
|
| 2 |
+
# This source code is licensed under the license found in the
|
| 3 |
+
# LICENSE file in the root directory of this source tree.
|
| 4 |
+
"""
|
| 5 |
+
Models for the compression model Moshi,
|
| 6 |
+
"""
|
| 7 |
+
|
| 8 |
+
# flake8: noqa
|
| 9 |
+
from .compression import (
|
| 10 |
+
CompressionModel,
|
| 11 |
+
MimiModel,
|
| 12 |
+
)
|
| 13 |
+
from .lm import LMModel, LMGen
|
| 14 |
+
from .loaders import get_mimi, get_moshi_lm
|
moshi/models/compression.py
ADDED
|
@@ -0,0 +1,474 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) Kyutai, all rights reserved.
|
| 2 |
+
# This source code is licensed under the license found in the
|
| 3 |
+
# LICENSE file in the root directory of this source tree.
|
| 4 |
+
|
| 5 |
+
# Part of this file is adapted from encodec.py in https://github.com/facebookresearch/audiocraft
|
| 6 |
+
# released under the following license.
|
| 7 |
+
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
| 8 |
+
# All rights reserved.
|
| 9 |
+
#
|
| 10 |
+
# This source code is licensed under the license found in the
|
| 11 |
+
# LICENSE file in the root directory of this source tree.
|
| 12 |
+
"""Compression models or wrapper around existing models. In particular, provides the implementation
|
| 13 |
+
for Mimi. Also defines the main interface that a model must follow to be usable as an audio tokenizer.
|
| 14 |
+
"""
|
| 15 |
+
|
| 16 |
+
from abc import abstractmethod
|
| 17 |
+
from contextlib import nullcontext
|
| 18 |
+
from dataclasses import dataclass
|
| 19 |
+
import logging
|
| 20 |
+
import typing as tp
|
| 21 |
+
|
| 22 |
+
import torch
|
| 23 |
+
from torch import nn
|
| 24 |
+
|
| 25 |
+
|
| 26 |
+
from ..quantization import (
|
| 27 |
+
QuantizedResult,
|
| 28 |
+
BaseQuantizer,
|
| 29 |
+
SplitResidualVectorQuantizer,
|
| 30 |
+
ResidualVectorQuantizer,
|
| 31 |
+
)
|
| 32 |
+
from ..modules.resample import ConvDownsample1d, ConvTrUpsample1d
|
| 33 |
+
from ..modules.streaming import StreamingModule, State
|
| 34 |
+
from ..utils.compile import no_compile, CUDAGraphed
|
| 35 |
+
|
| 36 |
+
|
| 37 |
+
logger = logging.getLogger()
|
| 38 |
+
|
| 39 |
+
|
| 40 |
+
class CompressionModel(StreamingModule[State]):
|
| 41 |
+
"""Base API for all compression model that aim at being used as audio tokenizers
|
| 42 |
+
with a language model.
|
| 43 |
+
"""
|
| 44 |
+
|
| 45 |
+
@abstractmethod
|
| 46 |
+
def forward(self, x: torch.Tensor) -> QuantizedResult: ...
|
| 47 |
+
|
| 48 |
+
@abstractmethod
|
| 49 |
+
def encode(self, x: torch.Tensor) -> torch.Tensor:
|
| 50 |
+
"""See `MimiModel.encode`."""
|
| 51 |
+
...
|
| 52 |
+
|
| 53 |
+
@abstractmethod
|
| 54 |
+
def decode(self, codes: torch.Tensor) -> torch.Tensor:
|
| 55 |
+
"""See `MimiModel.decode`."""
|
| 56 |
+
...
|
| 57 |
+
|
| 58 |
+
@abstractmethod
|
| 59 |
+
def decode_latent(self, codes: torch.Tensor) -> torch.Tensor:
|
| 60 |
+
"""Decode from the discrete codes to continuous latent space."""
|
| 61 |
+
...
|
| 62 |
+
|
| 63 |
+
@property
|
| 64 |
+
@abstractmethod
|
| 65 |
+
def channels(self) -> int: ...
|
| 66 |
+
|
| 67 |
+
@property
|
| 68 |
+
@abstractmethod
|
| 69 |
+
def frame_rate(self) -> float: ...
|
| 70 |
+
|
| 71 |
+
@property
|
| 72 |
+
@abstractmethod
|
| 73 |
+
def sample_rate(self) -> int: ...
|
| 74 |
+
|
| 75 |
+
@property
|
| 76 |
+
@abstractmethod
|
| 77 |
+
def cardinality(self) -> int: ...
|
| 78 |
+
|
| 79 |
+
@property
|
| 80 |
+
@abstractmethod
|
| 81 |
+
def num_codebooks(self) -> int: ...
|
| 82 |
+
|
| 83 |
+
@property
|
| 84 |
+
@abstractmethod
|
| 85 |
+
def total_codebooks(self) -> int: ...
|
| 86 |
+
|
| 87 |
+
@abstractmethod
|
| 88 |
+
def set_num_codebooks(self, n: int):
|
| 89 |
+
"""Set the active number of codebooks used by the quantizer."""
|
| 90 |
+
...
|
| 91 |
+
|
| 92 |
+
|
| 93 |
+
@dataclass
|
| 94 |
+
class _MimiState:
|
| 95 |
+
graphed_tr_enc: CUDAGraphed | None
|
| 96 |
+
graphed_tr_dec: CUDAGraphed | None
|
| 97 |
+
|
| 98 |
+
def reset(self):
|
| 99 |
+
pass
|
| 100 |
+
|
| 101 |
+
|
| 102 |
+
class MimiModel(CompressionModel[_MimiState]):
|
| 103 |
+
"""Mimi model operating on the raw waveform.
|
| 104 |
+
|
| 105 |
+
Args:
|
| 106 |
+
encoder (nn.Module): Encoder network.
|
| 107 |
+
decoder (nn.Module): Decoder network.
|
| 108 |
+
quantizer (qt.BaseQuantizer): Quantizer network.
|
| 109 |
+
frame_rate (float): Final frame rate of the quantized representatiopn.
|
| 110 |
+
encoder_frame_rate (float): frame rate of the encoder model. Note that if `frame_rate != encopder_frame_rate`,
|
| 111 |
+
the latent will be resampled linearly to match the desired `frame_rate` before and after quantization.
|
| 112 |
+
sample_rate (int): Audio sample rate.
|
| 113 |
+
channels (int): Number of audio channels.
|
| 114 |
+
causal (bool): Whether to use a causal version of the model.
|
| 115 |
+
encoder_transformer (nn.Module or None): optional transformer for the encoder.
|
| 116 |
+
decoder_transformer (nn.Module or None): optional transformer for the decoder.
|
| 117 |
+
resample_method (str): method to use for resampling the latent space before the quantizer.
|
| 118 |
+
upsample_channel_wise_bug (bool): controls whether the upsampling is channel wise.
|
| 119 |
+
Defaults to true to reproduce bug in original implementation.
|
| 120 |
+
freeze_encoder: whether to freeze the encoder weights.
|
| 121 |
+
freeze_quantizer: whether to freeze the quantizer weights.
|
| 122 |
+
freeze_quantizer_level: If positive, freeze the quantizer up to this level.
|
| 123 |
+
torch_compile_encoder_decoder (bool): if True, uses torch.compile on the encoder / decoder.
|
| 124 |
+
Deactivated by default for training as this is incompatible at the moment with weight norm.
|
| 125 |
+
See https://github.com/pytorch/pytorch/issues/121902
|
| 126 |
+
Also this seems to work well with 2.2.0, but completely fail with 2.4.0.
|
| 127 |
+
"""
|
| 128 |
+
|
| 129 |
+
def __init__(
|
| 130 |
+
self,
|
| 131 |
+
encoder: nn.Module,
|
| 132 |
+
decoder: nn.Module,
|
| 133 |
+
quantizer: BaseQuantizer,
|
| 134 |
+
frame_rate: float,
|
| 135 |
+
encoder_frame_rate: float,
|
| 136 |
+
sample_rate: int,
|
| 137 |
+
channels: int,
|
| 138 |
+
causal: bool = False,
|
| 139 |
+
encoder_transformer: tp.Optional[nn.Module] = None,
|
| 140 |
+
decoder_transformer: tp.Optional[nn.Module] = None,
|
| 141 |
+
resample_method: str = "interpolate",
|
| 142 |
+
upsample_channel_wise_bug: bool = True,
|
| 143 |
+
freeze_encoder: bool = False,
|
| 144 |
+
freeze_quantizer: bool = False,
|
| 145 |
+
freeze_quantizer_level: int = -1,
|
| 146 |
+
torch_compile_encoder_decoder: bool = False,
|
| 147 |
+
):
|
| 148 |
+
super().__init__()
|
| 149 |
+
self.encoder = encoder
|
| 150 |
+
self.decoder = decoder
|
| 151 |
+
self.encoder_transformer = encoder_transformer
|
| 152 |
+
self.decoder_transformer = decoder_transformer
|
| 153 |
+
self.quantizer = quantizer
|
| 154 |
+
self._frame_rate = frame_rate
|
| 155 |
+
self._sample_rate = sample_rate
|
| 156 |
+
self._channels = channels
|
| 157 |
+
self.encoder_frame_rate = encoder_frame_rate
|
| 158 |
+
self.torch_compile_encoder_decoder = torch_compile_encoder_decoder
|
| 159 |
+
|
| 160 |
+
if freeze_encoder:
|
| 161 |
+
for p in self.encoder.parameters():
|
| 162 |
+
p.requires_grad = False
|
| 163 |
+
if self.encoder_transformer is not None:
|
| 164 |
+
for p in self.encoder_transformer.parameters():
|
| 165 |
+
p.requires_grad = False
|
| 166 |
+
for name, p in self.quantizer.named_parameters():
|
| 167 |
+
if name.endswith("input_proj.weight"):
|
| 168 |
+
p.requires_grad = False
|
| 169 |
+
if freeze_quantizer:
|
| 170 |
+
self.quantizer.ema_frozen_(True)
|
| 171 |
+
self.freeze_quantizer = freeze_quantizer
|
| 172 |
+
self.freeze_quantizer_level = (
|
| 173 |
+
freeze_quantizer_level
|
| 174 |
+
if freeze_quantizer_level > 0
|
| 175 |
+
else self.quantizer.num_codebooks
|
| 176 |
+
)
|
| 177 |
+
|
| 178 |
+
# We will need the dimension for the resampling. In general the encoder will be a SeanetEncoder
|
| 179 |
+
# which exposes a `dimension` attribute.
|
| 180 |
+
dimension = encoder.dimension
|
| 181 |
+
assert isinstance(
|
| 182 |
+
dimension, int
|
| 183 |
+
), f"Dimension should be int, got {dimension} of type {type(dimension)}."
|
| 184 |
+
self.dimension = dimension
|
| 185 |
+
|
| 186 |
+
assert resample_method in [
|
| 187 |
+
"interpolate",
|
| 188 |
+
"conv",
|
| 189 |
+
"avg_pool",
|
| 190 |
+
], f"Invalid resample_method {resample_method}"
|
| 191 |
+
self.resample_method = resample_method
|
| 192 |
+
if encoder_frame_rate != frame_rate:
|
| 193 |
+
assert not (
|
| 194 |
+
causal and resample_method == "interpolate"
|
| 195 |
+
), "Cannot interpolate with causal model."
|
| 196 |
+
if resample_method in ["conv", "avg_pool"]:
|
| 197 |
+
assert (
|
| 198 |
+
self.encoder_frame_rate > self.frame_rate
|
| 199 |
+
), "Cannot upsample with conv."
|
| 200 |
+
downsample_stride = self.encoder_frame_rate / self.frame_rate
|
| 201 |
+
assert downsample_stride == int(
|
| 202 |
+
downsample_stride
|
| 203 |
+
), f"Only integer strides are supported, got {downsample_stride}"
|
| 204 |
+
learnt = resample_method == "conv"
|
| 205 |
+
self.downsample = ConvDownsample1d(
|
| 206 |
+
int(downsample_stride),
|
| 207 |
+
dimension=dimension,
|
| 208 |
+
learnt=learnt,
|
| 209 |
+
causal=causal,
|
| 210 |
+
)
|
| 211 |
+
if freeze_encoder:
|
| 212 |
+
for p in self.downsample.parameters():
|
| 213 |
+
p.requires_grad = False
|
| 214 |
+
self.upsample = ConvTrUpsample1d(
|
| 215 |
+
int(downsample_stride),
|
| 216 |
+
dimension=dimension,
|
| 217 |
+
learnt=learnt,
|
| 218 |
+
causal=causal,
|
| 219 |
+
channel_wise=upsample_channel_wise_bug,
|
| 220 |
+
)
|
| 221 |
+
|
| 222 |
+
def _init_streaming_state(self, batch_size: int) -> _MimiState:
|
| 223 |
+
device = next(self.parameters()).device
|
| 224 |
+
disable = device.type != 'cuda'
|
| 225 |
+
graphed_tr_dec = None
|
| 226 |
+
graphed_tr_enc = None
|
| 227 |
+
if self.encoder_transformer is not None:
|
| 228 |
+
graphed_tr_enc = CUDAGraphed(self.encoder_transformer, disable=disable)
|
| 229 |
+
if self.decoder_transformer is not None:
|
| 230 |
+
graphed_tr_dec = CUDAGraphed(self.decoder_transformer, disable=disable)
|
| 231 |
+
return _MimiState(graphed_tr_enc, graphed_tr_dec)
|
| 232 |
+
|
| 233 |
+
@property
|
| 234 |
+
def channels(self) -> int:
|
| 235 |
+
return self._channels
|
| 236 |
+
|
| 237 |
+
@property
|
| 238 |
+
def frame_rate(self) -> float:
|
| 239 |
+
return self._frame_rate
|
| 240 |
+
|
| 241 |
+
@property
|
| 242 |
+
def sample_rate(self) -> int:
|
| 243 |
+
return self._sample_rate
|
| 244 |
+
|
| 245 |
+
@property
|
| 246 |
+
def total_codebooks(self):
|
| 247 |
+
"""Total number of quantizer codebooks available."""
|
| 248 |
+
return self.quantizer.total_codebooks
|
| 249 |
+
|
| 250 |
+
@property
|
| 251 |
+
def num_codebooks(self):
|
| 252 |
+
"""Active number of codebooks used by the quantizer."""
|
| 253 |
+
return self.quantizer.num_codebooks
|
| 254 |
+
|
| 255 |
+
def set_num_codebooks(self, n: int):
|
| 256 |
+
"""Set the active number of codebooks used by the quantizer."""
|
| 257 |
+
self.quantizer.set_num_codebooks(n)
|
| 258 |
+
|
| 259 |
+
@property
|
| 260 |
+
def cardinality(self):
|
| 261 |
+
"""Cardinality of each codebook."""
|
| 262 |
+
return self.quantizer.cardinality
|
| 263 |
+
|
| 264 |
+
def _to_framerate(self, x: torch.Tensor):
|
| 265 |
+
# Convert from the encoder frame rate to the overall framerate.
|
| 266 |
+
_, _, length = x.shape
|
| 267 |
+
frame_rate = self.encoder_frame_rate
|
| 268 |
+
new_frame_rate = self.frame_rate
|
| 269 |
+
if frame_rate == new_frame_rate:
|
| 270 |
+
return x
|
| 271 |
+
if self.resample_method == "interpolate":
|
| 272 |
+
target_length = int(length * new_frame_rate / frame_rate)
|
| 273 |
+
return nn.functional.interpolate(x, size=target_length, mode="linear")
|
| 274 |
+
else:
|
| 275 |
+
return self.downsample(x)
|
| 276 |
+
|
| 277 |
+
def _to_encoder_framerate(self, x: torch.Tensor):
|
| 278 |
+
# Convert from overall framerate to the encoder frame rate.
|
| 279 |
+
_, _, length = x.shape
|
| 280 |
+
frame_rate = self.encoder_frame_rate
|
| 281 |
+
new_frame_rate = self.frame_rate
|
| 282 |
+
if frame_rate == new_frame_rate:
|
| 283 |
+
return x
|
| 284 |
+
if self.resample_method == "interpolate":
|
| 285 |
+
target_length = int(length * new_frame_rate / frame_rate)
|
| 286 |
+
return nn.functional.interpolate(x, size=target_length, mode="linear")
|
| 287 |
+
else:
|
| 288 |
+
return self.upsample(x)
|
| 289 |
+
|
| 290 |
+
@property
|
| 291 |
+
def _context_for_encoder_decoder(self):
|
| 292 |
+
if self.torch_compile_encoder_decoder:
|
| 293 |
+
return nullcontext()
|
| 294 |
+
else:
|
| 295 |
+
return no_compile()
|
| 296 |
+
|
| 297 |
+
def forward(self, x: torch.Tensor) -> QuantizedResult:
|
| 298 |
+
assert x.dim() == 3
|
| 299 |
+
length = x.shape[-1]
|
| 300 |
+
extra_metrics: tp.Dict[str, torch.Tensor] = {}
|
| 301 |
+
|
| 302 |
+
if self.freeze_quantizer:
|
| 303 |
+
if isinstance(self.quantizer, SplitResidualVectorQuantizer):
|
| 304 |
+
self.quantizer.rvq_first.eval()
|
| 305 |
+
for i in range(
|
| 306 |
+
self.freeze_quantizer_level - self.quantizer.n_q_semantic
|
| 307 |
+
):
|
| 308 |
+
self.quantizer.rvq_rest.vq.layers[i].eval()
|
| 309 |
+
elif isinstance(self.quantizer, ResidualVectorQuantizer):
|
| 310 |
+
for i in range(self.freeze_quantizer_level):
|
| 311 |
+
self.quantizer.vq.layers[i].eval()
|
| 312 |
+
else:
|
| 313 |
+
raise ValueError(f"Unsupported quantizer type {type(self.quantizer)}")
|
| 314 |
+
|
| 315 |
+
with self._context_for_encoder_decoder:
|
| 316 |
+
emb = self.encoder(x)
|
| 317 |
+
if self.encoder_transformer is not None:
|
| 318 |
+
(emb,) = self.encoder_transformer(emb)
|
| 319 |
+
emb = self._to_framerate(emb)
|
| 320 |
+
expected_length = self.frame_rate * length / self.sample_rate
|
| 321 |
+
# Checking that we have the proper length given the advertised frame rate.
|
| 322 |
+
assert abs(emb.shape[-1] - expected_length) < 1, (
|
| 323 |
+
emb.shape[-1],
|
| 324 |
+
expected_length,
|
| 325 |
+
)
|
| 326 |
+
|
| 327 |
+
q_res = self.quantizer(emb, self.frame_rate)
|
| 328 |
+
emb = q_res.x
|
| 329 |
+
emb = self._to_encoder_framerate(emb)
|
| 330 |
+
if self.decoder_transformer is not None:
|
| 331 |
+
(emb,) = self.decoder_transformer(emb)
|
| 332 |
+
|
| 333 |
+
with self._context_for_encoder_decoder:
|
| 334 |
+
out = self.decoder(emb)
|
| 335 |
+
|
| 336 |
+
# remove extra padding added by the encoder and decoder
|
| 337 |
+
assert out.shape[-1] >= length, (out.shape[-1], length)
|
| 338 |
+
out = out[..., :length]
|
| 339 |
+
|
| 340 |
+
q_res.x = out
|
| 341 |
+
q_res.metrics.update(extra_metrics)
|
| 342 |
+
return q_res
|
| 343 |
+
|
| 344 |
+
def _encode_to_unquantized_latent(self, x: torch.Tensor) -> torch.Tensor:
|
| 345 |
+
"""Projects a batch of waveforms to unquantized latent space.
|
| 346 |
+
|
| 347 |
+
Args:
|
| 348 |
+
x (torch.Tensor): Float tensor of shape [B, C, T].
|
| 349 |
+
|
| 350 |
+
Returns:
|
| 351 |
+
Unquantized embeddings.
|
| 352 |
+
"""
|
| 353 |
+
assert (
|
| 354 |
+
x.dim() == 3
|
| 355 |
+
), f"CompressionModel._encode_to_unquantized_latent expects audio of shape [B, C, T] but got {x.shape}"
|
| 356 |
+
state = self._streaming_state
|
| 357 |
+
with self._context_for_encoder_decoder:
|
| 358 |
+
emb = self.encoder(x)
|
| 359 |
+
if self.encoder_transformer is not None:
|
| 360 |
+
if state is None:
|
| 361 |
+
(emb,) = self.encoder_transformer(emb)
|
| 362 |
+
else:
|
| 363 |
+
assert state.graphed_tr_enc is not None
|
| 364 |
+
(emb,) = state.graphed_tr_enc(emb)
|
| 365 |
+
emb = self._to_framerate(emb)
|
| 366 |
+
return emb
|
| 367 |
+
|
| 368 |
+
def encode(self, x: torch.Tensor) -> torch.Tensor:
|
| 369 |
+
"""Encode the given input tensor to quantized representation.
|
| 370 |
+
|
| 371 |
+
Args:
|
| 372 |
+
x (torch.Tensor): Float tensor of shape [B, C, T]
|
| 373 |
+
|
| 374 |
+
Returns:
|
| 375 |
+
codes (torch.Tensor): an int tensor of shape [B, K, T]
|
| 376 |
+
with K the number of codebooks used and T the timestep.
|
| 377 |
+
"""
|
| 378 |
+
emb = self._encode_to_unquantized_latent(x)
|
| 379 |
+
codes = self.quantizer.encode(emb)
|
| 380 |
+
return codes
|
| 381 |
+
|
| 382 |
+
def encode_to_latent(self, x: torch.Tensor, quantize: bool = True) -> torch.Tensor:
|
| 383 |
+
"""Projects a batch of waveforms to latent space.
|
| 384 |
+
|
| 385 |
+
Args:
|
| 386 |
+
x (torch.Tensor): Float tensor of shape [B, C, T].
|
| 387 |
+
|
| 388 |
+
Returns:
|
| 389 |
+
Embeddings, either quantized or not.
|
| 390 |
+
"""
|
| 391 |
+
emb = self._encode_to_unquantized_latent(x)
|
| 392 |
+
if not quantize:
|
| 393 |
+
return emb
|
| 394 |
+
else:
|
| 395 |
+
codes = self.quantizer.encode(emb)
|
| 396 |
+
return self.decode_latent(codes)
|
| 397 |
+
|
| 398 |
+
def decode(self, codes: torch.Tensor):
|
| 399 |
+
"""Decode the given codes to a reconstructed representation.
|
| 400 |
+
|
| 401 |
+
Args:
|
| 402 |
+
codes (torch.Tensor): Int tensor of shape [B, K, T]
|
| 403 |
+
|
| 404 |
+
Returns:
|
| 405 |
+
out (torch.Tensor): Float tensor of shape [B, C, T], the reconstructed audio.
|
| 406 |
+
"""
|
| 407 |
+
state = self._streaming_state
|
| 408 |
+
emb = self.decode_latent(codes)
|
| 409 |
+
emb = self._to_encoder_framerate(emb)
|
| 410 |
+
if self.decoder_transformer is not None:
|
| 411 |
+
if state is None:
|
| 412 |
+
(emb,) = self.decoder_transformer(emb)
|
| 413 |
+
else:
|
| 414 |
+
assert state.graphed_tr_dec is not None
|
| 415 |
+
(emb,) = state.graphed_tr_dec(emb)
|
| 416 |
+
with self._context_for_encoder_decoder:
|
| 417 |
+
out = self.decoder(emb)
|
| 418 |
+
# out contains extra padding added by the encoder and decoder
|
| 419 |
+
return out
|
| 420 |
+
|
| 421 |
+
def decode_latent(self, codes: torch.Tensor) -> torch.Tensor:
|
| 422 |
+
"""Decode from the discrete codes to continuous latent space."""
|
| 423 |
+
return self.quantizer.decode(codes)
|
| 424 |
+
|
| 425 |
+
|
| 426 |
+
class WrapperCompressionModel(CompressionModel[State]):
|
| 427 |
+
"""Base API for CompressionModel wrappers that do not depend on external frameworks."""
|
| 428 |
+
|
| 429 |
+
def __init__(self, model: CompressionModel):
|
| 430 |
+
super().__init__()
|
| 431 |
+
self.model = model
|
| 432 |
+
|
| 433 |
+
def forward(self, x: torch.Tensor) -> QuantizedResult:
|
| 434 |
+
return self.model.forward(x)
|
| 435 |
+
|
| 436 |
+
def encode(self, x: torch.Tensor) -> torch.Tensor:
|
| 437 |
+
return self.model.encode(x)
|
| 438 |
+
|
| 439 |
+
def decode(self, codes: torch.Tensor) -> torch.Tensor:
|
| 440 |
+
return self.model.decode(codes)
|
| 441 |
+
|
| 442 |
+
def decode_latent(self, codes: torch.Tensor) -> torch.Tensor:
|
| 443 |
+
return self.model.decode_latent(codes)
|
| 444 |
+
|
| 445 |
+
def set_num_codebooks(self, n: int):
|
| 446 |
+
self.model.set_num_codebooks(n)
|
| 447 |
+
|
| 448 |
+
@property
|
| 449 |
+
def quantizer(self):
|
| 450 |
+
return self.model.quantizer
|
| 451 |
+
|
| 452 |
+
@property
|
| 453 |
+
def channels(self) -> int:
|
| 454 |
+
return self.model.channels
|
| 455 |
+
|
| 456 |
+
@property
|
| 457 |
+
def frame_rate(self) -> float:
|
| 458 |
+
return self.model.frame_rate
|
| 459 |
+
|
| 460 |
+
@property
|
| 461 |
+
def sample_rate(self) -> int:
|
| 462 |
+
return self.model.sample_rate
|
| 463 |
+
|
| 464 |
+
@property
|
| 465 |
+
def cardinality(self) -> int:
|
| 466 |
+
return self.model.cardinality
|
| 467 |
+
|
| 468 |
+
@property
|
| 469 |
+
def num_codebooks(self) -> int:
|
| 470 |
+
return self.model.num_codebooks
|
| 471 |
+
|
| 472 |
+
@property
|
| 473 |
+
def total_codebooks(self) -> int:
|
| 474 |
+
return self.model.total_codebooks
|
moshi/models/lm.py
ADDED
|
@@ -0,0 +1,488 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) Kyutai, all rights reserved.
|
| 2 |
+
# This source code is licensed under the license found in the
|
| 3 |
+
# LICENSE file in the root directory of this source tree.
|
| 4 |
+
|
| 5 |
+
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
| 6 |
+
# All rights reserved.
|
| 7 |
+
#
|
| 8 |
+
# This source code is licensed under the license found in the
|
| 9 |
+
# LICENSE file in the root directory of this source tree.
|
| 10 |
+
|
| 11 |
+
from dataclasses import dataclass
|
| 12 |
+
from functools import partial
|
| 13 |
+
import logging
|
| 14 |
+
import typing as tp
|
| 15 |
+
|
| 16 |
+
import torch
|
| 17 |
+
from torch import nn
|
| 18 |
+
|
| 19 |
+
from ..utils.sampling import sample_token
|
| 20 |
+
from ..utils.compile import CUDAGraphed
|
| 21 |
+
from ..modules.streaming import StreamingContainer, StreamingModule
|
| 22 |
+
from ..modules.transformer import (
|
| 23 |
+
StreamingTransformer,
|
| 24 |
+
create_norm_fn,
|
| 25 |
+
)
|
| 26 |
+
|
| 27 |
+
|
| 28 |
+
logger = logging.getLogger(__name__)
|
| 29 |
+
|
| 30 |
+
|
| 31 |
+
class ScaledEmbedding(nn.Embedding):
|
| 32 |
+
"""Boost learning rate for embeddings (with `scale`).
|
| 33 |
+
|
| 34 |
+
Args:
|
| 35 |
+
norm (bool): if True, uses a layer norm after the embedding.
|
| 36 |
+
zero_idx (int): special value indicating that the output should be exactly 0.
|
| 37 |
+
"""
|
| 38 |
+
|
| 39 |
+
def __init__(self, *args, norm: bool = False, zero_idx: int = -1, **kwargs):
|
| 40 |
+
super().__init__(*args, **kwargs)
|
| 41 |
+
self.norm = None
|
| 42 |
+
if norm:
|
| 43 |
+
self.norm = create_norm_fn("layer_norm", self.embedding_dim)
|
| 44 |
+
assert zero_idx < 0, "Please use negative values for the zero_idx."
|
| 45 |
+
self.zero_idx = zero_idx
|
| 46 |
+
|
| 47 |
+
def forward(self, input, *args, **kwargs):
|
| 48 |
+
is_zero = input == self.zero_idx
|
| 49 |
+
zero = torch.zeros(1, dtype=input.dtype, device=input.device)
|
| 50 |
+
input = input.clamp(min=0)
|
| 51 |
+
y = super().forward(input, *args, **kwargs)
|
| 52 |
+
if self.norm is not None:
|
| 53 |
+
y = self.norm(y)
|
| 54 |
+
y = torch.where(is_zero[..., None], zero, y)
|
| 55 |
+
return y
|
| 56 |
+
|
| 57 |
+
|
| 58 |
+
class LMModel(StreamingContainer):
|
| 59 |
+
"""Transformer-based language model on multiple streams of codes.
|
| 60 |
+
|
| 61 |
+
Args:
|
| 62 |
+
n_q (int): Number of parallel streams to model as input.
|
| 63 |
+
dep_q (int): Number of parallel streams to model in the depformer.
|
| 64 |
+
card (int): Cardinality, vocabulary size.
|
| 65 |
+
text_card (int): Cardinality of the text vocabulary.
|
| 66 |
+
dim (int): Dimension of the transformer encoder.
|
| 67 |
+
num_heads (int): Number of heads for the transformer encoder.
|
| 68 |
+
hidden_scale (int): Scale for hidden feed forward dimension of the transformer encoder.
|
| 69 |
+
norm (str): Normalization method.
|
| 70 |
+
norm_emb (bool): Whether to normalize embeddings.
|
| 71 |
+
bias_proj (bool): Use bias for output projections.
|
| 72 |
+
depformer_*: params used for the Depformer Transformer, all the other will be shared.
|
| 73 |
+
depformer_multi_linear (bool): if True, uses one linear layer per codebook to project the
|
| 74 |
+
output of the main transformer to the Depformer latent space.
|
| 75 |
+
depformer_dim_feedforward (int| list[int]| None): If None, defaults to hidden_scale * depformer_dim.
|
| 76 |
+
existing_text_padding_id (bool): if True, will use a different token for the initial text token, and
|
| 77 |
+
the text padding token.
|
| 78 |
+
same_initial (bool): if True, uses the same initial tokens for both text and audio mode.
|
| 79 |
+
**kwargs: Additional parameters for the transformer encoder.
|
| 80 |
+
"""
|
| 81 |
+
|
| 82 |
+
def __init__(
|
| 83 |
+
self,
|
| 84 |
+
delays: tp.List[int] = [0],
|
| 85 |
+
n_q: int = 8,
|
| 86 |
+
dep_q: int = 8,
|
| 87 |
+
card: int = 1024,
|
| 88 |
+
text_card: int = 32000,
|
| 89 |
+
dim: int = 128,
|
| 90 |
+
num_heads: int = 8,
|
| 91 |
+
hidden_scale: int = 4,
|
| 92 |
+
norm: str = "layer_norm",
|
| 93 |
+
norm_emb: bool = False,
|
| 94 |
+
bias_proj: bool = False,
|
| 95 |
+
depformer_dim: int = 256,
|
| 96 |
+
depformer_dim_feedforward: int | list[int] | None = None,
|
| 97 |
+
depformer_multi_linear: bool = False,
|
| 98 |
+
depformer_weights_per_step: bool = False,
|
| 99 |
+
depformer_pos_emb: str = "sin",
|
| 100 |
+
existing_text_padding_id: tp.Optional[int] = None,
|
| 101 |
+
context: tp.Optional[int] = None,
|
| 102 |
+
device=None,
|
| 103 |
+
dtype=None,
|
| 104 |
+
**kwargs,
|
| 105 |
+
):
|
| 106 |
+
super().__init__()
|
| 107 |
+
self.n_q = n_q
|
| 108 |
+
self.dep_q = dep_q
|
| 109 |
+
self.card = card
|
| 110 |
+
self.text_card = text_card
|
| 111 |
+
assert len(delays) == self.num_codebooks, "unexpected number of delays"
|
| 112 |
+
self.delays = delays
|
| 113 |
+
self.dim = dim
|
| 114 |
+
self.existing_text_padding_id = existing_text_padding_id
|
| 115 |
+
self.context = context
|
| 116 |
+
kwargs["context"] = context
|
| 117 |
+
EmbeddingFactory = partial(
|
| 118 |
+
ScaledEmbedding,
|
| 119 |
+
norm=norm_emb,
|
| 120 |
+
device=device,
|
| 121 |
+
dtype=dtype,
|
| 122 |
+
zero_idx=self.zero_token_id,
|
| 123 |
+
)
|
| 124 |
+
self.emb = nn.ModuleList(
|
| 125 |
+
[EmbeddingFactory(self.card + 1, dim) for _ in range(n_q)]
|
| 126 |
+
)
|
| 127 |
+
# Text card + padding token (if not in the original tokenizer)
|
| 128 |
+
extra_text = self.existing_text_padding_id is None
|
| 129 |
+
# Unlike for audio, here we authorize the model to output the special token.
|
| 130 |
+
self.text_emb = EmbeddingFactory(text_card + 1, dim)
|
| 131 |
+
self.text_linear = nn.Linear(dim, text_card + extra_text, bias=bias_proj)
|
| 132 |
+
depformer_prefix = "depformer_"
|
| 133 |
+
main_kwargs = {
|
| 134 |
+
k: v for k, v in kwargs.items() if not k.startswith(depformer_prefix)
|
| 135 |
+
}
|
| 136 |
+
self.transformer = StreamingTransformer(
|
| 137 |
+
d_model=dim,
|
| 138 |
+
num_heads=num_heads,
|
| 139 |
+
dim_feedforward=int(hidden_scale * dim),
|
| 140 |
+
norm=norm,
|
| 141 |
+
device=device,
|
| 142 |
+
dtype=dtype,
|
| 143 |
+
**main_kwargs,
|
| 144 |
+
)
|
| 145 |
+
self.out_norm = create_norm_fn(norm, dim)
|
| 146 |
+
self.depformer_multi_linear = depformer_multi_linear
|
| 147 |
+
kwargs_dep = main_kwargs.copy()
|
| 148 |
+
kwargs_dep.update(
|
| 149 |
+
{
|
| 150 |
+
k.removeprefix(depformer_prefix): v
|
| 151 |
+
for k, v in kwargs.items()
|
| 152 |
+
if k.startswith(depformer_prefix)
|
| 153 |
+
}
|
| 154 |
+
)
|
| 155 |
+
kwargs_dep["positional_embedding"] = depformer_pos_emb
|
| 156 |
+
kwargs_dep["context"] = None
|
| 157 |
+
if depformer_weights_per_step:
|
| 158 |
+
kwargs_dep["weights_per_step"] = dep_q
|
| 159 |
+
if depformer_multi_linear:
|
| 160 |
+
# One linear layer per codebook to project different informations from the main model.
|
| 161 |
+
self.depformer_in = nn.ModuleList(
|
| 162 |
+
[nn.Linear(dim, depformer_dim, bias=False) for _ in range(dep_q)]
|
| 163 |
+
)
|
| 164 |
+
else:
|
| 165 |
+
self.depformer_in = nn.ModuleList(
|
| 166 |
+
[nn.Linear(dim, depformer_dim, bias=False)]
|
| 167 |
+
)
|
| 168 |
+
# Only using up to dep_q - 1 because the last codebook is never an input to Depformer.
|
| 169 |
+
self.depformer_emb = nn.ModuleList(
|
| 170 |
+
[EmbeddingFactory(self.card + 1, depformer_dim) for _ in range(dep_q - 1)]
|
| 171 |
+
)
|
| 172 |
+
self.depformer_text_emb = EmbeddingFactory(text_card + 1, depformer_dim)
|
| 173 |
+
if depformer_dim_feedforward is None:
|
| 174 |
+
depformer_dim_feedforward = int(hidden_scale * depformer_dim)
|
| 175 |
+
self.depformer = StreamingTransformer(
|
| 176 |
+
d_model=depformer_dim,
|
| 177 |
+
dim_feedforward=depformer_dim_feedforward,
|
| 178 |
+
norm=norm,
|
| 179 |
+
device=device,
|
| 180 |
+
dtype=dtype,
|
| 181 |
+
**kwargs_dep,
|
| 182 |
+
)
|
| 183 |
+
self.depformer.set_streaming_propagate(False)
|
| 184 |
+
dim = depformer_dim # we will directly apply the next linears to the output of the Depformer.
|
| 185 |
+
|
| 186 |
+
self.linears = nn.ModuleList(
|
| 187 |
+
[nn.Linear(dim, self.card, bias=bias_proj) for _ in range(dep_q)]
|
| 188 |
+
)
|
| 189 |
+
|
| 190 |
+
@property
|
| 191 |
+
def initial_token_id(self) -> int:
|
| 192 |
+
"""Token id for the start of sequence (audio)."""
|
| 193 |
+
return self.card
|
| 194 |
+
|
| 195 |
+
@property
|
| 196 |
+
def text_initial_token_id(self) -> int:
|
| 197 |
+
"""Token id for the start of sequence (text)."""
|
| 198 |
+
return self.text_card
|
| 199 |
+
|
| 200 |
+
@property
|
| 201 |
+
def text_padding_token_id(self) -> int:
|
| 202 |
+
"""Token id for text padding."""
|
| 203 |
+
if self.existing_text_padding_id is None:
|
| 204 |
+
return self.text_card
|
| 205 |
+
else:
|
| 206 |
+
return self.existing_text_padding_id
|
| 207 |
+
|
| 208 |
+
@property
|
| 209 |
+
def end_of_text_padding_id(self) -> int:
|
| 210 |
+
"""Token id for optionally marking the last padding step for a word."""
|
| 211 |
+
return 0
|
| 212 |
+
|
| 213 |
+
@property
|
| 214 |
+
def zero_token_id(self) -> int:
|
| 215 |
+
"""Special value in the input tokens, indicating that no sampling should
|
| 216 |
+
happen for that value, and no input should be given to the model."""
|
| 217 |
+
return -1
|
| 218 |
+
|
| 219 |
+
@property
|
| 220 |
+
def ungenerated_token_id(self) -> int:
|
| 221 |
+
"""Special value that can be provided in the prompt to indicate that this specific
|
| 222 |
+
value should be predicted and sampled. This allows for partial teacher forcing, by generating
|
| 223 |
+
one modality, with the other one fixed.
|
| 224 |
+
"""
|
| 225 |
+
return -2
|
| 226 |
+
|
| 227 |
+
@property
|
| 228 |
+
def device(self):
|
| 229 |
+
first_param = next(iter(self.parameters()))
|
| 230 |
+
return first_param.device
|
| 231 |
+
|
| 232 |
+
@property
|
| 233 |
+
def num_codebooks(self) -> int:
|
| 234 |
+
return self.n_q + 1
|
| 235 |
+
|
| 236 |
+
@property
|
| 237 |
+
def num_audio_codebooks(self) -> int:
|
| 238 |
+
return self.n_q
|
| 239 |
+
|
| 240 |
+
@property
|
| 241 |
+
def audio_offset(self) -> int:
|
| 242 |
+
return 1
|
| 243 |
+
|
| 244 |
+
def _get_initial_token(self) -> torch.Tensor:
|
| 245 |
+
# Returns the initial token that will be fed to the model to predict the very first timestep.
|
| 246 |
+
# The output shape will be [B, K, 1].
|
| 247 |
+
device = next(iter(self.parameters())).device
|
| 248 |
+
zero = torch.full(
|
| 249 |
+
[1, 1, 1], self.zero_token_id, device=device, dtype=torch.long
|
| 250 |
+
)
|
| 251 |
+
special = torch.full_like(zero, self.initial_token_id)
|
| 252 |
+
|
| 253 |
+
text_special = torch.full_like(zero, self.text_initial_token_id)
|
| 254 |
+
audio_token = special
|
| 255 |
+
text_token = text_special
|
| 256 |
+
audio_token = audio_token.expand(-1, self.num_audio_codebooks, -1)
|
| 257 |
+
token = torch.cat([text_token, audio_token], dim=1)
|
| 258 |
+
return token
|
| 259 |
+
|
| 260 |
+
def forward_text(
|
| 261 |
+
self,
|
| 262 |
+
sequence: torch.Tensor,
|
| 263 |
+
) -> tuple[torch.Tensor, torch.Tensor]:
|
| 264 |
+
B, K, S = sequence.shape
|
| 265 |
+
assert (
|
| 266 |
+
K == self.num_codebooks
|
| 267 |
+
), f"Sequence shape {sequence.shape} must match the number of codebooks."
|
| 268 |
+
input_sequence = sequence
|
| 269 |
+
input_ = None
|
| 270 |
+
for cb_index in range(self.num_audio_codebooks):
|
| 271 |
+
audio_emb = self.emb[cb_index](
|
| 272 |
+
input_sequence[:, cb_index + self.audio_offset]
|
| 273 |
+
)
|
| 274 |
+
input_ = audio_emb if input_ is None else input_ + audio_emb
|
| 275 |
+
text_emb = self.text_emb(input_sequence[:, 0])
|
| 276 |
+
input_ = text_emb if input_ is None else input_ + text_emb
|
| 277 |
+
transformer_out = self.transformer(input_)
|
| 278 |
+
|
| 279 |
+
if self.out_norm:
|
| 280 |
+
transformer_out = self.out_norm(transformer_out)
|
| 281 |
+
assert isinstance(transformer_out, torch.Tensor)
|
| 282 |
+
text_logits = self.text_linear(transformer_out)
|
| 283 |
+
text_logits = text_logits[:, None]
|
| 284 |
+
return transformer_out, text_logits
|
| 285 |
+
|
| 286 |
+
def forward_depformer(
|
| 287 |
+
self,
|
| 288 |
+
depformer_cb_index: int,
|
| 289 |
+
sequence: torch.Tensor,
|
| 290 |
+
transformer_out: torch.Tensor,
|
| 291 |
+
) -> torch.Tensor:
|
| 292 |
+
B, K, S = sequence.shape
|
| 293 |
+
assert (
|
| 294 |
+
K == 1
|
| 295 |
+
), f"Codebooks for Depformer streaming should be passed 1 by 1, got {K}."
|
| 296 |
+
assert (
|
| 297 |
+
S == 1
|
| 298 |
+
), f"Steps for Depformer streaming should be passed 1 by 1, got {S}."
|
| 299 |
+
assert (
|
| 300 |
+
transformer_out.shape[1] == 1
|
| 301 |
+
), "Transformer out should be a for a single step."
|
| 302 |
+
last_token_input: tp.Optional[torch.Tensor] = None
|
| 303 |
+
depformer_input = transformer_out
|
| 304 |
+
if self.depformer_multi_linear:
|
| 305 |
+
depformer_input = self.depformer_in[depformer_cb_index](depformer_input)
|
| 306 |
+
else:
|
| 307 |
+
depformer_input = self.depformer_in[0](depformer_input)
|
| 308 |
+
if depformer_cb_index == 0:
|
| 309 |
+
last_token_input = self.depformer_text_emb(sequence[:, 0])
|
| 310 |
+
else:
|
| 311 |
+
last_token_input = self.depformer_emb[depformer_cb_index - 1](
|
| 312 |
+
sequence[:, 0]
|
| 313 |
+
)
|
| 314 |
+
depformer_input = depformer_input + last_token_input
|
| 315 |
+
assert depformer_input.shape[1] == 1
|
| 316 |
+
# depformer_input is [B, 1, depformer_dim].
|
| 317 |
+
# The streaming state of the depformer ensures that the proper layer is run.
|
| 318 |
+
dep_output = self.depformer(depformer_input)
|
| 319 |
+
logits = self.linears[depformer_cb_index](dep_output)
|
| 320 |
+
logits = logits[:, None]
|
| 321 |
+
assert logits.dim() == 4, logits.shape # [B, Ka, S, card]
|
| 322 |
+
return logits
|
| 323 |
+
|
| 324 |
+
|
| 325 |
+
@dataclass
|
| 326 |
+
class _LMGenState:
|
| 327 |
+
cache: torch.Tensor
|
| 328 |
+
initial: torch.Tensor
|
| 329 |
+
graphed_main: CUDAGraphed
|
| 330 |
+
graphed_depth: CUDAGraphed
|
| 331 |
+
offset: int = 0
|
| 332 |
+
|
| 333 |
+
def reset(self):
|
| 334 |
+
self.offset = 0
|
| 335 |
+
|
| 336 |
+
|
| 337 |
+
class LMGen(StreamingModule[_LMGenState]):
|
| 338 |
+
def __init__(
|
| 339 |
+
self,
|
| 340 |
+
lm_model: LMModel,
|
| 341 |
+
use_sampling: bool = True,
|
| 342 |
+
temp: float = 0.8,
|
| 343 |
+
temp_text: float = 0.7,
|
| 344 |
+
top_k: int = 250,
|
| 345 |
+
top_k_text: int = 25,
|
| 346 |
+
check: bool = False,
|
| 347 |
+
):
|
| 348 |
+
assert not lm_model.training, "generation shouldn't be used in training mode."
|
| 349 |
+
super().__init__()
|
| 350 |
+
|
| 351 |
+
self.lm_model = lm_model
|
| 352 |
+
self.use_sampling = use_sampling
|
| 353 |
+
self.temp = temp
|
| 354 |
+
self.temp_text = temp_text
|
| 355 |
+
self.top_k = top_k
|
| 356 |
+
self.top_k_text = top_k_text
|
| 357 |
+
self.check = check
|
| 358 |
+
self.max_delay = max(
|
| 359 |
+
lm_model.delays
|
| 360 |
+
) # with delays, we need to generate a few more time steps.
|
| 361 |
+
self.delays_cuda = torch.tensor(
|
| 362 |
+
lm_model.delays, device=lm_model.device, dtype=torch.long
|
| 363 |
+
)
|
| 364 |
+
|
| 365 |
+
def _init_streaming_state(self, batch_size: int) -> _LMGenState:
|
| 366 |
+
lm_model = self.lm_model
|
| 367 |
+
initial = lm_model._get_initial_token()
|
| 368 |
+
cache = torch.full(
|
| 369 |
+
(batch_size, self.lm_model.num_codebooks, self.max_delay + 2),
|
| 370 |
+
lm_model.ungenerated_token_id,
|
| 371 |
+
device=lm_model.device,
|
| 372 |
+
dtype=torch.long,
|
| 373 |
+
)
|
| 374 |
+
|
| 375 |
+
disable = lm_model.device.type != 'cuda'
|
| 376 |
+
graphed_main = CUDAGraphed(lm_model.forward_text, disable=disable)
|
| 377 |
+
graphed_depth = CUDAGraphed(self.depformer_step, disable=disable)
|
| 378 |
+
|
| 379 |
+
return _LMGenState(cache, initial, graphed_main, graphed_depth)
|
| 380 |
+
|
| 381 |
+
@torch.no_grad()
|
| 382 |
+
def step(self, input_tokens: torch.Tensor) -> torch.Tensor | None:
|
| 383 |
+
state = self._streaming_state
|
| 384 |
+
if state is None:
|
| 385 |
+
raise RuntimeError(
|
| 386 |
+
"You should wrap those calls with a `with lm_gen.streaming(): ...`."
|
| 387 |
+
)
|
| 388 |
+
lm_model = self.lm_model
|
| 389 |
+
|
| 390 |
+
assert input_tokens.dim() == 3, "Shape should be [B, K, T]."
|
| 391 |
+
B, Ki, S = input_tokens.shape
|
| 392 |
+
assert S == 1, "Only support being given steps one by one."
|
| 393 |
+
needed_tokens = lm_model.num_codebooks - lm_model.dep_q - 1
|
| 394 |
+
assert (
|
| 395 |
+
Ki == needed_tokens
|
| 396 |
+
), f"We expect {needed_tokens} tokens from the user stream, got {Ki}."
|
| 397 |
+
|
| 398 |
+
CT = state.cache.shape[2]
|
| 399 |
+
|
| 400 |
+
for q_other in range(input_tokens.shape[1]):
|
| 401 |
+
k = lm_model.dep_q + 1 + q_other
|
| 402 |
+
delay = lm_model.delays[k]
|
| 403 |
+
write_position = (state.offset + delay) % CT
|
| 404 |
+
state.cache[:, k, write_position : write_position + 1] = input_tokens[
|
| 405 |
+
:, q_other
|
| 406 |
+
]
|
| 407 |
+
|
| 408 |
+
position = state.offset % CT
|
| 409 |
+
for k, delay in enumerate(lm_model.delays):
|
| 410 |
+
# Only for the very beginning, we extend the initial token for the acoustic
|
| 411 |
+
# token that are delayed, and thus have no good value to take.
|
| 412 |
+
if state.offset <= delay:
|
| 413 |
+
state.cache[:, k, position] = state.initial[:, k, 0]
|
| 414 |
+
input_ = state.cache[:, :, position : position + 1]
|
| 415 |
+
|
| 416 |
+
if self.check:
|
| 417 |
+
# Check that we are not feeding in any value that is not generated yet.
|
| 418 |
+
assert not (input_ == lm_model.ungenerated_token_id).any(), (
|
| 419 |
+
state.offset,
|
| 420 |
+
input_,
|
| 421 |
+
)
|
| 422 |
+
assert (input_[:, lm_model.audio_offset :] <= lm_model.card).all(), input_
|
| 423 |
+
assert (input_[:, :1] <= lm_model.text_card).all()
|
| 424 |
+
|
| 425 |
+
transformer_out, text_logits = state.graphed_main(input_)
|
| 426 |
+
# Shape of text_logits should be [B, K_text=1, T=1, Card_text]
|
| 427 |
+
text_token = sample_token(
|
| 428 |
+
text_logits.float(),
|
| 429 |
+
self.use_sampling,
|
| 430 |
+
self.temp_text,
|
| 431 |
+
self.top_k_text,
|
| 432 |
+
)
|
| 433 |
+
assert text_token.dim() == 3, text_token.shape
|
| 434 |
+
assert text_token.shape[2] == 1
|
| 435 |
+
assert text_token.shape[1] == 1, "Only one text stream supported."
|
| 436 |
+
text_token = text_token[:, 0, 0] # shape is [B]
|
| 437 |
+
audio_tokens = state.graphed_depth(text_token, transformer_out)
|
| 438 |
+
|
| 439 |
+
# ensure we don't overwrite prompt tokens, we only write over ungenerated tokens
|
| 440 |
+
state.offset += 1
|
| 441 |
+
position = state.offset % CT
|
| 442 |
+
state.cache[:, 0, position] = text_token
|
| 443 |
+
state.cache[:, 1 : lm_model.dep_q + 1, position] = audio_tokens
|
| 444 |
+
|
| 445 |
+
if state.offset <= self.max_delay:
|
| 446 |
+
return None
|
| 447 |
+
B = state.cache.shape[0]
|
| 448 |
+
gen_delays_cuda = self.delays_cuda[: lm_model.dep_q + 1]
|
| 449 |
+
index = (
|
| 450 |
+
((state.offset - self.max_delay + gen_delays_cuda) % CT)
|
| 451 |
+
.view(1, -1, 1)
|
| 452 |
+
.expand(B, -1, 1)
|
| 453 |
+
)
|
| 454 |
+
out = state.cache.gather(dim=2, index=index)
|
| 455 |
+
return out
|
| 456 |
+
|
| 457 |
+
def depformer_step(
|
| 458 |
+
self,
|
| 459 |
+
text_token: torch.Tensor,
|
| 460 |
+
transformer_out: torch.Tensor,
|
| 461 |
+
) -> torch.Tensor:
|
| 462 |
+
(B,) = text_token.shape
|
| 463 |
+
prev_token = text_token
|
| 464 |
+
lm_model = self.lm_model
|
| 465 |
+
depformer_tokens: list[torch.Tensor] = []
|
| 466 |
+
assert not lm_model.depformer.is_streaming
|
| 467 |
+
with lm_model.depformer.streaming(B):
|
| 468 |
+
for cb_index in range(lm_model.dep_q):
|
| 469 |
+
input_ = prev_token[:, None, None]
|
| 470 |
+
logits = lm_model.forward_depformer(cb_index, input_, transformer_out)
|
| 471 |
+
next_token = sample_token(
|
| 472 |
+
logits.float(),
|
| 473 |
+
self.use_sampling,
|
| 474 |
+
self.temp,
|
| 475 |
+
self.top_k,
|
| 476 |
+
)
|
| 477 |
+
assert next_token.shape == (B, 1, 1)
|
| 478 |
+
next_token = next_token[:, 0, 0] # shape is B
|
| 479 |
+
depformer_tokens.append(next_token)
|
| 480 |
+
prev_token = next_token
|
| 481 |
+
|
| 482 |
+
assert len(depformer_tokens) == lm_model.dep_q, (
|
| 483 |
+
len(depformer_tokens),
|
| 484 |
+
lm_model.dep_q,
|
| 485 |
+
)
|
| 486 |
+
out = torch.stack(depformer_tokens, dim=1)
|
| 487 |
+
assert out.shape == (B, lm_model.dep_q), out.shape
|
| 488 |
+
return out
|
moshi/models/loaders.py
ADDED
|
@@ -0,0 +1,159 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) Kyutai, all rights reserved.
|
| 2 |
+
# This source code is licensed under the license found in the
|
| 3 |
+
# LICENSE file in the root directory of this source tree.
|
| 4 |
+
"""Retrieves the pretrained models for Moshi and Mimi."""
|
| 5 |
+
from pathlib import Path
|
| 6 |
+
|
| 7 |
+
from safetensors.torch import load_model
|
| 8 |
+
import torch
|
| 9 |
+
|
| 10 |
+
from .compression import MimiModel
|
| 11 |
+
from .lm import LMModel
|
| 12 |
+
from ..modules import SEANetEncoder, SEANetDecoder, transformer
|
| 13 |
+
from ..quantization import SplitResidualVectorQuantizer
|
| 14 |
+
|
| 15 |
+
SAMPLE_RATE = 24000
|
| 16 |
+
FRAME_RATE = 12.5
|
| 17 |
+
|
| 18 |
+
TEXT_TOKENIZER_NAME = 'tokenizer_spm_32k_3.model'
|
| 19 |
+
MOSHI_NAME = 'model.safetensors'
|
| 20 |
+
MIMI_NAME = 'tokenizer-e351c8d8-checkpoint125.safetensors'
|
| 21 |
+
DEFAULT_REPO = 'kyutai/moshiko-pytorch-bf16'
|
| 22 |
+
|
| 23 |
+
|
| 24 |
+
_seanet_kwargs = {
|
| 25 |
+
"channels": 1,
|
| 26 |
+
"dimension": 512,
|
| 27 |
+
"causal": True,
|
| 28 |
+
"n_filters": 64,
|
| 29 |
+
"n_residual_layers": 1,
|
| 30 |
+
"activation": "ELU",
|
| 31 |
+
"compress": 2,
|
| 32 |
+
"dilation_base": 2,
|
| 33 |
+
"disable_norm_outer_blocks": 0,
|
| 34 |
+
"kernel_size": 7,
|
| 35 |
+
"residual_kernel_size": 3,
|
| 36 |
+
"last_kernel_size": 3,
|
| 37 |
+
# We train using weight_norm but then the weights are pre-processed for inference so
|
| 38 |
+
# that we can use a normal convolution.
|
| 39 |
+
"norm": "none",
|
| 40 |
+
"pad_mode": "constant",
|
| 41 |
+
"ratios": [8, 6, 5, 4],
|
| 42 |
+
"true_skip": True,
|
| 43 |
+
}
|
| 44 |
+
_quantizer_kwargs = {
|
| 45 |
+
"dimension": 256,
|
| 46 |
+
"n_q": 32,
|
| 47 |
+
"bins": 2048,
|
| 48 |
+
"input_dimension": _seanet_kwargs["dimension"],
|
| 49 |
+
"output_dimension": _seanet_kwargs["dimension"],
|
| 50 |
+
}
|
| 51 |
+
_transformer_kwargs = {
|
| 52 |
+
"d_model": _seanet_kwargs["dimension"],
|
| 53 |
+
"num_heads": 8,
|
| 54 |
+
"num_layers": 8,
|
| 55 |
+
"causal": True,
|
| 56 |
+
"layer_scale": 0.01,
|
| 57 |
+
"context": 250,
|
| 58 |
+
"conv_layout": True,
|
| 59 |
+
"max_period": 10000,
|
| 60 |
+
"gating": "none",
|
| 61 |
+
"norm": "layer_norm",
|
| 62 |
+
"positional_embedding": "rope",
|
| 63 |
+
"dim_feedforward": 2048,
|
| 64 |
+
"input_dimension": _seanet_kwargs["dimension"],
|
| 65 |
+
"output_dimensions": [_seanet_kwargs["dimension"]],
|
| 66 |
+
}
|
| 67 |
+
|
| 68 |
+
_lm_kwargs = {
|
| 69 |
+
"dim": 4096,
|
| 70 |
+
"text_card": 32000,
|
| 71 |
+
"existing_text_padding_id": 3,
|
| 72 |
+
"n_q": 16,
|
| 73 |
+
"dep_q": 8,
|
| 74 |
+
"card": _quantizer_kwargs["bins"],
|
| 75 |
+
"num_heads": 32,
|
| 76 |
+
"num_layers": 32,
|
| 77 |
+
"hidden_scale": 4.125,
|
| 78 |
+
"causal": True,
|
| 79 |
+
"layer_scale": None,
|
| 80 |
+
"context": 3000,
|
| 81 |
+
"max_period": 10000,
|
| 82 |
+
"gating": "silu",
|
| 83 |
+
"norm": "rms_norm_f32",
|
| 84 |
+
"positional_embedding": "rope",
|
| 85 |
+
"depformer_dim": 1024,
|
| 86 |
+
"depformer_dim_feedforward": int(4.125 * 1024),
|
| 87 |
+
"depformer_num_heads": 16,
|
| 88 |
+
"depformer_num_layers": 6,
|
| 89 |
+
"depformer_causal": True,
|
| 90 |
+
"depformer_layer_scale": None,
|
| 91 |
+
"depformer_multi_linear": True,
|
| 92 |
+
"depformer_context": 8,
|
| 93 |
+
"depformer_max_period": 10000,
|
| 94 |
+
"depformer_gating": "silu",
|
| 95 |
+
"depformer_pos_emb": "none",
|
| 96 |
+
"depformer_weights_per_step": True,
|
| 97 |
+
"delays": [0, 0, 1, 1, 1, 1, 1, 1, 1, 0, 1, 1, 1, 1, 1, 1, 1],
|
| 98 |
+
}
|
| 99 |
+
|
| 100 |
+
|
| 101 |
+
def _is_safetensors(path: Path | str) -> bool:
|
| 102 |
+
return Path(path).suffix in (".safetensors", ".sft", ".sfts")
|
| 103 |
+
|
| 104 |
+
|
| 105 |
+
def get_mimi(filename: str | Path,
|
| 106 |
+
device: torch.device | str = 'cpu') -> MimiModel:
|
| 107 |
+
"""Return a pretrained Mimi model."""
|
| 108 |
+
encoder = SEANetEncoder(**_seanet_kwargs)
|
| 109 |
+
decoder = SEANetDecoder(**_seanet_kwargs)
|
| 110 |
+
encoder_transformer = transformer.ProjectedTransformer(
|
| 111 |
+
device=device, **_transformer_kwargs
|
| 112 |
+
)
|
| 113 |
+
decoder_transformer = transformer.ProjectedTransformer(
|
| 114 |
+
device=device, **_transformer_kwargs
|
| 115 |
+
)
|
| 116 |
+
quantizer = SplitResidualVectorQuantizer(
|
| 117 |
+
**_quantizer_kwargs,
|
| 118 |
+
)
|
| 119 |
+
model = MimiModel(
|
| 120 |
+
encoder,
|
| 121 |
+
decoder,
|
| 122 |
+
quantizer,
|
| 123 |
+
channels=1,
|
| 124 |
+
sample_rate=SAMPLE_RATE,
|
| 125 |
+
frame_rate=FRAME_RATE,
|
| 126 |
+
encoder_frame_rate=SAMPLE_RATE / encoder.hop_length,
|
| 127 |
+
causal=True,
|
| 128 |
+
resample_method="conv",
|
| 129 |
+
encoder_transformer=encoder_transformer,
|
| 130 |
+
decoder_transformer=decoder_transformer,
|
| 131 |
+
).to(device=device)
|
| 132 |
+
model.eval()
|
| 133 |
+
if _is_safetensors(filename):
|
| 134 |
+
load_model(model, filename)
|
| 135 |
+
else:
|
| 136 |
+
pkg = torch.load(filename, "cpu")
|
| 137 |
+
model.load_state_dict(pkg["model"])
|
| 138 |
+
model.set_num_codebooks(8)
|
| 139 |
+
return model
|
| 140 |
+
|
| 141 |
+
|
| 142 |
+
def get_moshi_lm(filename: str | Path,
|
| 143 |
+
device: torch.device | str = 'cpu') -> LMModel:
|
| 144 |
+
dtype = torch.bfloat16
|
| 145 |
+
model = LMModel(
|
| 146 |
+
device=device,
|
| 147 |
+
dtype=dtype,
|
| 148 |
+
**_lm_kwargs,
|
| 149 |
+
).to(device=device, dtype=dtype)
|
| 150 |
+
model.eval()
|
| 151 |
+
if _is_safetensors(filename):
|
| 152 |
+
load_model(model, filename)
|
| 153 |
+
else:
|
| 154 |
+
pkg = torch.load(
|
| 155 |
+
filename,
|
| 156 |
+
"cpu",
|
| 157 |
+
)
|
| 158 |
+
model.load_state_dict(pkg["fsdp_best_state"]["model"])
|
| 159 |
+
return model
|
moshi/modules/__init__.py
ADDED
|
@@ -0,0 +1,23 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) Kyutai, all rights reserved.
|
| 2 |
+
# This source code is licensed under the license found in the
|
| 3 |
+
# LICENSE file in the root directory of this source tree.
|
| 4 |
+
|
| 5 |
+
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
| 6 |
+
# All rights reserved.
|
| 7 |
+
#
|
| 8 |
+
# This source code is licensed under the license found in the
|
| 9 |
+
# LICENSE file in the root directory of this source tree.
|
| 10 |
+
"""Modules used for building the models."""
|
| 11 |
+
|
| 12 |
+
# flake8: noqa
|
| 13 |
+
from .conv import (
|
| 14 |
+
NormConv1d,
|
| 15 |
+
NormConvTranspose1d,
|
| 16 |
+
StreamingConv1d,
|
| 17 |
+
StreamingConvTranspose1d,
|
| 18 |
+
pad_for_conv1d,
|
| 19 |
+
pad1d,
|
| 20 |
+
unpad1d,
|
| 21 |
+
)
|
| 22 |
+
from .seanet import SEANetEncoder, SEANetDecoder
|
| 23 |
+
from .transformer import StreamingTransformer
|
moshi/modules/conv.py
ADDED
|
@@ -0,0 +1,329 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) Kyutai, all rights reserved.
|
| 2 |
+
# This source code is licensed under the license found in the
|
| 3 |
+
# LICENSE file in the root directory of this source tree.
|
| 4 |
+
|
| 5 |
+
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
| 6 |
+
# All rights reserved.
|
| 7 |
+
#
|
| 8 |
+
# This source code is licensed under the license found in the
|
| 9 |
+
# LICENSE file in the root directory of this source tree.
|
| 10 |
+
|
| 11 |
+
from dataclasses import dataclass
|
| 12 |
+
import math
|
| 13 |
+
import typing as tp
|
| 14 |
+
import warnings
|
| 15 |
+
|
| 16 |
+
import torch
|
| 17 |
+
from torch import nn
|
| 18 |
+
from torch.nn import functional as F
|
| 19 |
+
from torch.nn.utils import weight_norm
|
| 20 |
+
|
| 21 |
+
from .streaming import RawStreamingConv1d, RawStreamingConvTranspose1d, StreamingModule
|
| 22 |
+
|
| 23 |
+
|
| 24 |
+
CONV_NORMALIZATIONS = frozenset(["none", "weight_norm"])
|
| 25 |
+
|
| 26 |
+
|
| 27 |
+
class TransposedLayerNorm(nn.Module):
|
| 28 |
+
"""LayerNorm for [B, C, T] inputs."""
|
| 29 |
+
|
| 30 |
+
def __init__(self, **kwargs):
|
| 31 |
+
super().__init__()
|
| 32 |
+
self.layer_norm = nn.LayerNorm(**kwargs)
|
| 33 |
+
|
| 34 |
+
def forward(self, x):
|
| 35 |
+
x = x.transpose(1, 2)
|
| 36 |
+
x = self.layer_norm(x)
|
| 37 |
+
return x.transpose(1, 2)
|
| 38 |
+
|
| 39 |
+
|
| 40 |
+
def apply_parametrization_norm(module: nn.Module, norm: str = "none"):
|
| 41 |
+
assert norm in CONV_NORMALIZATIONS
|
| 42 |
+
if norm == "weight_norm":
|
| 43 |
+
return weight_norm(module)
|
| 44 |
+
else:
|
| 45 |
+
# We already check was in CONV_NORMALIZATION, so any other choice
|
| 46 |
+
# doesn't need reparametrization.
|
| 47 |
+
return module
|
| 48 |
+
|
| 49 |
+
|
| 50 |
+
def get_extra_padding_for_conv1d(
|
| 51 |
+
x: torch.Tensor, kernel_size: int, stride: int, padding_total: int = 0
|
| 52 |
+
) -> int:
|
| 53 |
+
"""See `pad_for_conv1d`."""
|
| 54 |
+
length = x.shape[-1]
|
| 55 |
+
n_frames = (length - kernel_size + padding_total) / stride + 1
|
| 56 |
+
ideal_length = (math.ceil(n_frames) - 1) * stride + (kernel_size - padding_total)
|
| 57 |
+
return ideal_length - length
|
| 58 |
+
|
| 59 |
+
|
| 60 |
+
def pad_for_conv1d(
|
| 61 |
+
x: torch.Tensor, kernel_size: int, stride: int, padding_total: int = 0
|
| 62 |
+
):
|
| 63 |
+
"""Pad for a convolution to make sure that the last window is full.
|
| 64 |
+
Extra padding is added at the end. This is required to ensure that we can rebuild
|
| 65 |
+
an output of the same length, as otherwise, even with padding, some time steps
|
| 66 |
+
might get removed.
|
| 67 |
+
For instance, with total padding = 4, kernel size = 4, stride = 2:
|
| 68 |
+
0 0 1 2 3 4 5 0 0 # (0s are padding)
|
| 69 |
+
1 2 3 # (output frames of a convolution, last 0 is never used)
|
| 70 |
+
0 0 1 2 3 4 5 0 # (output of tr. conv., but pos. 5 is going to get removed as padding)
|
| 71 |
+
1 2 3 4 # once you removed padding, we are missing one time step !
|
| 72 |
+
"""
|
| 73 |
+
extra_padding = get_extra_padding_for_conv1d(x, kernel_size, stride, padding_total)
|
| 74 |
+
return F.pad(x, (0, extra_padding))
|
| 75 |
+
|
| 76 |
+
|
| 77 |
+
def pad1d(
|
| 78 |
+
x: torch.Tensor,
|
| 79 |
+
paddings: tp.Tuple[int, int],
|
| 80 |
+
mode: str = "constant",
|
| 81 |
+
value: float = 0.0,
|
| 82 |
+
):
|
| 83 |
+
"""Tiny wrapper around F.pad, just to allow for reflect padding on small input.
|
| 84 |
+
If this is the case, we insert extra 0 padding to the right before the reflection happen.
|
| 85 |
+
"""
|
| 86 |
+
length = x.shape[-1]
|
| 87 |
+
padding_left, padding_right = paddings
|
| 88 |
+
assert padding_left >= 0 and padding_right >= 0, (padding_left, padding_right)
|
| 89 |
+
if mode == "reflect":
|
| 90 |
+
max_pad = max(padding_left, padding_right)
|
| 91 |
+
extra_pad = 0
|
| 92 |
+
if length <= max_pad:
|
| 93 |
+
extra_pad = max_pad - length + 1
|
| 94 |
+
x = F.pad(x, (0, extra_pad))
|
| 95 |
+
padded = F.pad(x, paddings, mode, value)
|
| 96 |
+
end = padded.shape[-1] - extra_pad
|
| 97 |
+
return padded[..., :end]
|
| 98 |
+
else:
|
| 99 |
+
return F.pad(x, paddings, mode, value)
|
| 100 |
+
|
| 101 |
+
|
| 102 |
+
def unpad1d(x: torch.Tensor, paddings: tp.Tuple[int, int]):
|
| 103 |
+
"""Remove padding from x, handling properly zero padding. Only for 1d!"""
|
| 104 |
+
padding_left, padding_right = paddings
|
| 105 |
+
assert padding_left >= 0 and padding_right >= 0, (padding_left, padding_right)
|
| 106 |
+
assert (padding_left + padding_right) <= x.shape[-1]
|
| 107 |
+
end = x.shape[-1] - padding_right
|
| 108 |
+
return x[..., padding_left:end]
|
| 109 |
+
|
| 110 |
+
|
| 111 |
+
class NormConv1d(nn.Module):
|
| 112 |
+
"""Wrapper around Conv1d and normalization applied to this conv
|
| 113 |
+
to provide a uniform interface across normalization approaches.
|
| 114 |
+
"""
|
| 115 |
+
|
| 116 |
+
def __init__(
|
| 117 |
+
self,
|
| 118 |
+
*args,
|
| 119 |
+
causal: bool = False,
|
| 120 |
+
norm: str = "none",
|
| 121 |
+
norm_kwargs: tp.Dict[str, tp.Any] = {},
|
| 122 |
+
**kwargs,
|
| 123 |
+
):
|
| 124 |
+
super().__init__()
|
| 125 |
+
self.conv = apply_parametrization_norm(
|
| 126 |
+
RawStreamingConv1d(*args, **kwargs), norm
|
| 127 |
+
)
|
| 128 |
+
self.norm_type = norm
|
| 129 |
+
|
| 130 |
+
def forward(self, x):
|
| 131 |
+
x = self.conv(x)
|
| 132 |
+
return x
|
| 133 |
+
|
| 134 |
+
|
| 135 |
+
class NormConvTranspose1d(nn.Module):
|
| 136 |
+
"""Wrapper around ConvTranspose1d and normalization applied to this conv
|
| 137 |
+
to provide a uniform interface across normalization approaches.
|
| 138 |
+
"""
|
| 139 |
+
|
| 140 |
+
def __init__(
|
| 141 |
+
self,
|
| 142 |
+
*args,
|
| 143 |
+
causal: bool = False,
|
| 144 |
+
norm: str = "none",
|
| 145 |
+
norm_kwargs: tp.Dict[str, tp.Any] = {},
|
| 146 |
+
**kwargs,
|
| 147 |
+
):
|
| 148 |
+
super().__init__()
|
| 149 |
+
self.convtr = apply_parametrization_norm(
|
| 150 |
+
RawStreamingConvTranspose1d(*args, **kwargs), norm
|
| 151 |
+
)
|
| 152 |
+
self.norm_type = norm
|
| 153 |
+
|
| 154 |
+
def forward(self, x):
|
| 155 |
+
x = self.convtr(x)
|
| 156 |
+
return x
|
| 157 |
+
|
| 158 |
+
|
| 159 |
+
@dataclass
|
| 160 |
+
class _StreamingConv1dState:
|
| 161 |
+
padding_to_add: int
|
| 162 |
+
original_padding_to_add: int
|
| 163 |
+
|
| 164 |
+
def reset(self):
|
| 165 |
+
self.padding_to_add = self.original_padding_to_add
|
| 166 |
+
|
| 167 |
+
|
| 168 |
+
class StreamingConv1d(StreamingModule[_StreamingConv1dState]):
|
| 169 |
+
"""Conv1d with some builtin handling of asymmetric or causal padding
|
| 170 |
+
and normalization.
|
| 171 |
+
"""
|
| 172 |
+
|
| 173 |
+
def __init__(
|
| 174 |
+
self,
|
| 175 |
+
in_channels: int,
|
| 176 |
+
out_channels: int,
|
| 177 |
+
kernel_size: int,
|
| 178 |
+
stride: int = 1,
|
| 179 |
+
dilation: int = 1,
|
| 180 |
+
groups: int = 1,
|
| 181 |
+
bias: bool = True,
|
| 182 |
+
causal: bool = False,
|
| 183 |
+
norm: str = "none",
|
| 184 |
+
norm_kwargs: tp.Dict[str, tp.Any] = {},
|
| 185 |
+
pad_mode: str = "reflect",
|
| 186 |
+
):
|
| 187 |
+
super().__init__()
|
| 188 |
+
# warn user on unusual setup between dilation and stride
|
| 189 |
+
if stride > 1 and dilation > 1:
|
| 190 |
+
warnings.warn(
|
| 191 |
+
"StreamingConv1d has been initialized with stride > 1 and dilation > 1"
|
| 192 |
+
f" (kernel_size={kernel_size} stride={stride}, dilation={dilation})."
|
| 193 |
+
)
|
| 194 |
+
self.conv = NormConv1d(
|
| 195 |
+
in_channels,
|
| 196 |
+
out_channels,
|
| 197 |
+
kernel_size,
|
| 198 |
+
stride,
|
| 199 |
+
dilation=dilation,
|
| 200 |
+
groups=groups,
|
| 201 |
+
bias=bias,
|
| 202 |
+
causal=causal,
|
| 203 |
+
norm=norm,
|
| 204 |
+
norm_kwargs=norm_kwargs,
|
| 205 |
+
)
|
| 206 |
+
self.causal = causal
|
| 207 |
+
self.pad_mode = pad_mode
|
| 208 |
+
|
| 209 |
+
@property
|
| 210 |
+
def _stride(self) -> int:
|
| 211 |
+
return self.conv.conv.stride[0]
|
| 212 |
+
|
| 213 |
+
@property
|
| 214 |
+
def _kernel_size(self) -> int:
|
| 215 |
+
return self.conv.conv.kernel_size[0]
|
| 216 |
+
|
| 217 |
+
@property
|
| 218 |
+
def _effective_kernel_size(self) -> int:
|
| 219 |
+
dilation = self.conv.conv.dilation[0]
|
| 220 |
+
return (
|
| 221 |
+
self._kernel_size - 1
|
| 222 |
+
) * dilation + 1 # effective kernel size with dilations
|
| 223 |
+
|
| 224 |
+
@property
|
| 225 |
+
def _padding_total(self) -> int:
|
| 226 |
+
return self._effective_kernel_size - self._stride
|
| 227 |
+
|
| 228 |
+
def _init_streaming_state(self, batch_size: int) -> _StreamingConv1dState:
|
| 229 |
+
assert self.causal, "streaming is only supported for causal convs"
|
| 230 |
+
return _StreamingConv1dState(self._padding_total, self._padding_total)
|
| 231 |
+
|
| 232 |
+
def forward(self, x):
|
| 233 |
+
B, C, T = x.shape
|
| 234 |
+
padding_total = self._padding_total
|
| 235 |
+
extra_padding = get_extra_padding_for_conv1d(
|
| 236 |
+
x, self._effective_kernel_size, self._stride, padding_total
|
| 237 |
+
)
|
| 238 |
+
state = self._streaming_state
|
| 239 |
+
if state is None:
|
| 240 |
+
if self.causal:
|
| 241 |
+
# Left padding for causal
|
| 242 |
+
x = pad1d(x, (padding_total, extra_padding), mode=self.pad_mode)
|
| 243 |
+
else:
|
| 244 |
+
# Asymmetric padding required for odd strides
|
| 245 |
+
padding_right = padding_total // 2
|
| 246 |
+
padding_left = padding_total - padding_right
|
| 247 |
+
x = pad1d(
|
| 248 |
+
x, (padding_left, padding_right + extra_padding), mode=self.pad_mode
|
| 249 |
+
)
|
| 250 |
+
else:
|
| 251 |
+
if state.padding_to_add > 0 and x.shape[-1] > 0:
|
| 252 |
+
x = pad1d(x, (state.padding_to_add, 0), mode=self.pad_mode)
|
| 253 |
+
state.padding_to_add = 0
|
| 254 |
+
return self.conv(x)
|
| 255 |
+
|
| 256 |
+
|
| 257 |
+
@dataclass
|
| 258 |
+
class _StreamingConvTr1dState:
|
| 259 |
+
pass
|
| 260 |
+
|
| 261 |
+
def reset(self):
|
| 262 |
+
pass
|
| 263 |
+
|
| 264 |
+
|
| 265 |
+
class StreamingConvTranspose1d(StreamingModule[_StreamingConvTr1dState]):
|
| 266 |
+
"""ConvTranspose1d with some builtin handling of asymmetric or causal padding
|
| 267 |
+
and normalization.
|
| 268 |
+
"""
|
| 269 |
+
|
| 270 |
+
def __init__(
|
| 271 |
+
self,
|
| 272 |
+
in_channels: int,
|
| 273 |
+
out_channels: int,
|
| 274 |
+
kernel_size: int,
|
| 275 |
+
stride: int = 1,
|
| 276 |
+
groups: int = 1,
|
| 277 |
+
bias: bool = True,
|
| 278 |
+
causal: bool = False,
|
| 279 |
+
norm: str = "none",
|
| 280 |
+
trim_right_ratio: float = 1.0,
|
| 281 |
+
norm_kwargs: tp.Dict[str, tp.Any] = {},
|
| 282 |
+
):
|
| 283 |
+
super().__init__()
|
| 284 |
+
self.convtr = NormConvTranspose1d(
|
| 285 |
+
in_channels,
|
| 286 |
+
out_channels,
|
| 287 |
+
kernel_size,
|
| 288 |
+
stride,
|
| 289 |
+
groups=groups,
|
| 290 |
+
bias=bias,
|
| 291 |
+
causal=causal,
|
| 292 |
+
norm=norm,
|
| 293 |
+
norm_kwargs=norm_kwargs,
|
| 294 |
+
)
|
| 295 |
+
self.causal = causal
|
| 296 |
+
self.trim_right_ratio = trim_right_ratio
|
| 297 |
+
assert (
|
| 298 |
+
self.causal or self.trim_right_ratio == 1.0
|
| 299 |
+
), "`trim_right_ratio` != 1.0 only makes sense for causal convolutions"
|
| 300 |
+
assert self.trim_right_ratio >= 0.0 and self.trim_right_ratio <= 1.0
|
| 301 |
+
|
| 302 |
+
def _init_streaming_state(self, batch_size: int) -> _StreamingConvTr1dState:
|
| 303 |
+
assert self.causal, "streaming is only supported for causal convtrs"
|
| 304 |
+
return _StreamingConvTr1dState()
|
| 305 |
+
|
| 306 |
+
def forward(self, x):
|
| 307 |
+
kernel_size = self.convtr.convtr.kernel_size[0]
|
| 308 |
+
stride = self.convtr.convtr.stride[0]
|
| 309 |
+
padding_total = kernel_size - stride
|
| 310 |
+
|
| 311 |
+
y = self.convtr(x)
|
| 312 |
+
|
| 313 |
+
if not self.is_streaming:
|
| 314 |
+
# We will only trim fixed padding. Extra padding from `pad_for_conv1d` would be
|
| 315 |
+
# removed at the very end, when keeping only the right length for the output,
|
| 316 |
+
# as removing it here would require also passing the length at the matching layer
|
| 317 |
+
# in the encoder.
|
| 318 |
+
if self.causal:
|
| 319 |
+
# Trim the padding on the right according to the specified ratio
|
| 320 |
+
# if trim_right_ratio = 1.0, trim everything from right
|
| 321 |
+
padding_right = math.ceil(padding_total * self.trim_right_ratio)
|
| 322 |
+
padding_left = padding_total - padding_right
|
| 323 |
+
y = unpad1d(y, (padding_left, padding_right))
|
| 324 |
+
else:
|
| 325 |
+
# Asymmetric padding required for odd strides
|
| 326 |
+
padding_right = padding_total // 2
|
| 327 |
+
padding_left = padding_total - padding_right
|
| 328 |
+
y = unpad1d(y, (padding_left, padding_right))
|
| 329 |
+
return y
|
moshi/modules/gating.py
ADDED
|
@@ -0,0 +1,82 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) Kyutai, all rights reserved.
|
| 2 |
+
# This source code is licensed under the license found in the
|
| 3 |
+
# LICENSE file in the root directory of this source tree.
|
| 4 |
+
|
| 5 |
+
import torch
|
| 6 |
+
from torch import nn
|
| 7 |
+
from torch.nn import functional as F
|
| 8 |
+
|
| 9 |
+
from ..utils.compile import torch_compile_lazy
|
| 10 |
+
|
| 11 |
+
|
| 12 |
+
@torch_compile_lazy
|
| 13 |
+
def gating_forward_kernel(
|
| 14 |
+
weight_in: torch.Tensor, weight_out: torch.Tensor, activation, x: torch.Tensor
|
| 15 |
+
):
|
| 16 |
+
x = F.linear(x, weight_in)
|
| 17 |
+
B, T, _ = x.shape
|
| 18 |
+
x = x.view(B, T, 2, -1)
|
| 19 |
+
x = activation(x[..., 0, :]) * x[..., 1, :]
|
| 20 |
+
x = F.linear(x, weight_out)
|
| 21 |
+
return x
|
| 22 |
+
|
| 23 |
+
|
| 24 |
+
class ActivationGating(nn.Module):
|
| 25 |
+
"""
|
| 26 |
+
Gating FFN layer, using the given activation.
|
| 27 |
+
Args:
|
| 28 |
+
dim (int): dimension of the input and output of the transformer.
|
| 29 |
+
activation (any callable Tensor to Tensor): activation function to use.
|
| 30 |
+
**factory_kwargs: other kwargs passed to the linear layer, in particular device and dtype.
|
| 31 |
+
"""
|
| 32 |
+
|
| 33 |
+
_fsdp_final = True
|
| 34 |
+
|
| 35 |
+
def __init__(self, dim: int, dim_feedforward: int, activation, **factory_kwargs):
|
| 36 |
+
super().__init__()
|
| 37 |
+
# We should have 8 d^2 param, instead we will have
|
| 38 |
+
# 2 * h * d + h * d = 3 h * d = 8 d^2
|
| 39 |
+
# so h = 8 d / 3 but following Hervé's advice we use 21 / 8 as an approx.
|
| 40 |
+
if dim_feedforward == 4 * dim:
|
| 41 |
+
hidden = (21 * dim) // 8
|
| 42 |
+
else:
|
| 43 |
+
hidden = (2 * dim_feedforward) // 3
|
| 44 |
+
self.linear_in = nn.Linear(dim, 2 * hidden, bias=False, **factory_kwargs)
|
| 45 |
+
self.linear_out = nn.Linear(hidden, dim, bias=False, **factory_kwargs)
|
| 46 |
+
self.activation = activation
|
| 47 |
+
|
| 48 |
+
def forward(self, x: torch.Tensor):
|
| 49 |
+
return gating_forward_kernel(
|
| 50 |
+
self.linear_in.weight, self.linear_out.weight, self.activation, x
|
| 51 |
+
)
|
| 52 |
+
|
| 53 |
+
|
| 54 |
+
def _get_activation(name: str):
|
| 55 |
+
if name in ["sigmoid", "tanh", "relu"]:
|
| 56 |
+
return getattr(torch, name)
|
| 57 |
+
elif name in ["leaky_relu", "elu", "gelu", "silu", "mish", "softsign"]:
|
| 58 |
+
return getattr(torch.nn.functional, name)
|
| 59 |
+
elif name == "identity":
|
| 60 |
+
return torch.nn.Identity()
|
| 61 |
+
else:
|
| 62 |
+
raise ValueError(f"Unknown activation {name}")
|
| 63 |
+
|
| 64 |
+
|
| 65 |
+
def _make_gating(
|
| 66 |
+
name: str, dim: int, dim_feedforward: int, **factory_kwargs
|
| 67 |
+
) -> nn.Module:
|
| 68 |
+
return ActivationGating(
|
| 69 |
+
dim, dim_feedforward, _get_activation(name), **factory_kwargs
|
| 70 |
+
)
|
| 71 |
+
|
| 72 |
+
|
| 73 |
+
def make_gating(
|
| 74 |
+
name: str, dim: int, dim_feedforward: int, **factory_kwargs
|
| 75 |
+
) -> nn.Module:
|
| 76 |
+
gating = _make_gating(name, dim, dim_feedforward, **factory_kwargs)
|
| 77 |
+
max_params = 2 * dim * dim_feedforward
|
| 78 |
+
params = sum(p.numel() for p in gating.parameters())
|
| 79 |
+
assert (
|
| 80 |
+
params <= max_params
|
| 81 |
+
), f"{name} gating has {params} params, max is {max_params}"
|
| 82 |
+
return gating
|
moshi/modules/resample.py
ADDED
|
@@ -0,0 +1,119 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) Kyutai, all rights reserved.
|
| 2 |
+
# This source code is licensed under the license found in the
|
| 3 |
+
# LICENSE file in the root directory of this source tree.
|
| 4 |
+
|
| 5 |
+
import typing as tp
|
| 6 |
+
|
| 7 |
+
from einops import rearrange
|
| 8 |
+
import torch
|
| 9 |
+
from torch import nn
|
| 10 |
+
|
| 11 |
+
from .conv import StreamingConv1d, StreamingConvTranspose1d
|
| 12 |
+
|
| 13 |
+
|
| 14 |
+
class ConvDownsample1d(nn.Module):
|
| 15 |
+
"""
|
| 16 |
+
Downsampling by some integer amount `stride` using convolutions
|
| 17 |
+
with a kernel size of twice the stride.
|
| 18 |
+
If `causal` is True, the output uses a causal convolution.
|
| 19 |
+
"""
|
| 20 |
+
|
| 21 |
+
def __init__(
|
| 22 |
+
self,
|
| 23 |
+
stride: int,
|
| 24 |
+
dimension: tp.Optional[int] = None,
|
| 25 |
+
causal: bool = False,
|
| 26 |
+
learnt: bool = False,
|
| 27 |
+
channel_wise: bool = False,
|
| 28 |
+
):
|
| 29 |
+
super().__init__()
|
| 30 |
+
self.learnt = learnt
|
| 31 |
+
self.channel_wise = channel_wise
|
| 32 |
+
groups = 1
|
| 33 |
+
if learnt:
|
| 34 |
+
assert dimension is not None, "Dimension required for learnt convolutions."
|
| 35 |
+
in_channels = dimension
|
| 36 |
+
out_channels = dimension
|
| 37 |
+
if channel_wise:
|
| 38 |
+
groups = dimension
|
| 39 |
+
else:
|
| 40 |
+
in_channels = 1
|
| 41 |
+
out_channels = 1
|
| 42 |
+
|
| 43 |
+
self.conv = StreamingConv1d(
|
| 44 |
+
in_channels,
|
| 45 |
+
out_channels,
|
| 46 |
+
kernel_size=2 * stride,
|
| 47 |
+
stride=stride,
|
| 48 |
+
causal=causal,
|
| 49 |
+
groups=groups,
|
| 50 |
+
bias=False,
|
| 51 |
+
pad_mode="replicate",
|
| 52 |
+
)
|
| 53 |
+
if not learnt:
|
| 54 |
+
actual_conv = self.conv.conv.conv
|
| 55 |
+
actual_conv.weight.requires_grad_(False)
|
| 56 |
+
actual_conv.weight.data.fill_(1.0 / (2 * stride))
|
| 57 |
+
|
| 58 |
+
def forward(self, x: torch.Tensor):
|
| 59 |
+
batch_size = len(x)
|
| 60 |
+
if not self.learnt:
|
| 61 |
+
x = rearrange(x, "b c t -> (b c) () t")
|
| 62 |
+
y = self.conv(x)
|
| 63 |
+
if not self.learnt:
|
| 64 |
+
y = rearrange(y, "(b c) () t -> b c t", b=batch_size)
|
| 65 |
+
return y
|
| 66 |
+
|
| 67 |
+
|
| 68 |
+
class ConvTrUpsample1d(nn.Module):
|
| 69 |
+
"""
|
| 70 |
+
Upsample by some integer amount `stride` using transposed convolutions.
|
| 71 |
+
"""
|
| 72 |
+
|
| 73 |
+
def __init__(
|
| 74 |
+
self,
|
| 75 |
+
stride: int,
|
| 76 |
+
dimension: tp.Optional[int] = None,
|
| 77 |
+
causal: bool = False,
|
| 78 |
+
learnt: bool = False,
|
| 79 |
+
channel_wise: bool = False,
|
| 80 |
+
):
|
| 81 |
+
super().__init__()
|
| 82 |
+
self.learnt = learnt
|
| 83 |
+
self.channel_wise = channel_wise
|
| 84 |
+
groups = 1
|
| 85 |
+
if learnt:
|
| 86 |
+
assert dimension is not None, "Dimension required for learnt convolutions."
|
| 87 |
+
in_channels = dimension
|
| 88 |
+
out_channels = dimension
|
| 89 |
+
if channel_wise:
|
| 90 |
+
groups = dimension
|
| 91 |
+
else:
|
| 92 |
+
in_channels = 1
|
| 93 |
+
out_channels = 1
|
| 94 |
+
|
| 95 |
+
self.convtr = StreamingConvTranspose1d(
|
| 96 |
+
in_channels,
|
| 97 |
+
out_channels,
|
| 98 |
+
kernel_size=2 * stride,
|
| 99 |
+
stride=stride,
|
| 100 |
+
causal=causal,
|
| 101 |
+
groups=groups,
|
| 102 |
+
bias=False,
|
| 103 |
+
)
|
| 104 |
+
if not learnt:
|
| 105 |
+
actual_convtr = self.convtr.convtr.convtr
|
| 106 |
+
actual_convtr.weight.requires_grad_(False)
|
| 107 |
+
actual_convtr.weight.data.fill_(1.0)
|
| 108 |
+
|
| 109 |
+
def forward(self, x: torch.Tensor):
|
| 110 |
+
batch_size = len(x)
|
| 111 |
+
if not self.learnt:
|
| 112 |
+
x = rearrange(x, "b c t -> (b c) () t")
|
| 113 |
+
y = self.convtr(x)
|
| 114 |
+
if not self.learnt:
|
| 115 |
+
x_for_normalization = torch.ones_like(x[:1])
|
| 116 |
+
normalization = self.convtr(x_for_normalization)
|
| 117 |
+
y = y / normalization
|
| 118 |
+
y = rearrange(y, "(b c) () t -> b c t", b=batch_size)
|
| 119 |
+
return y
|
moshi/modules/rope.py
ADDED
|
@@ -0,0 +1,90 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) Kyutai, all rights reserved.
|
| 2 |
+
# This source code is licensed under the license found in the
|
| 3 |
+
# LICENSE file in the root directory of this source tree.
|
| 4 |
+
|
| 5 |
+
from torch import nn
|
| 6 |
+
import math
|
| 7 |
+
import torch
|
| 8 |
+
from ..utils.compile import torch_compile_lazy
|
| 9 |
+
|
| 10 |
+
|
| 11 |
+
@torch_compile_lazy
|
| 12 |
+
def apply_rope(
|
| 13 |
+
q: torch.Tensor,
|
| 14 |
+
k: torch.Tensor,
|
| 15 |
+
offset: torch.Tensor,
|
| 16 |
+
max_period: float = 10_000,
|
| 17 |
+
time_before_heads: bool = False,
|
| 18 |
+
):
|
| 19 |
+
"""
|
| 20 |
+
Args:
|
| 21 |
+
q (torch.Tensor): queries, shape `[B, T, H, D]`.
|
| 22 |
+
k (torch.Tensor): keys, shape `[B, T, H, D]`.
|
| 23 |
+
offset (int): current offset, e.g. when streaming.
|
| 24 |
+
max_period (float): maximum period for the cos and sin.
|
| 25 |
+
time_before_heads (bool): if True, expected [B, T, H, D], else [B, H, T ,D]
|
| 26 |
+
"""
|
| 27 |
+
|
| 28 |
+
if time_before_heads:
|
| 29 |
+
B, T, H, D = q.shape
|
| 30 |
+
else:
|
| 31 |
+
B, H, T, D = q.shape
|
| 32 |
+
assert k.shape == q.shape
|
| 33 |
+
assert D > 0
|
| 34 |
+
assert D % 2 == 0
|
| 35 |
+
assert max_period > 0
|
| 36 |
+
|
| 37 |
+
ds = torch.arange(D // 2, device=q.device, dtype=torch.float32)
|
| 38 |
+
freqs = torch.exp(ds * (-math.log(max_period) * 2 / D))
|
| 39 |
+
ts = offset.float() + torch.arange(T, device=q.device, dtype=torch.float32)
|
| 40 |
+
if time_before_heads:
|
| 41 |
+
ts = ts.view(-1, 1, 1)
|
| 42 |
+
else:
|
| 43 |
+
ts = ts.view(1, -1, 1)
|
| 44 |
+
|
| 45 |
+
dims = q.shape[:-1]
|
| 46 |
+
q = q.view(*dims, D // 2, 2)
|
| 47 |
+
k = k.view(*dims, D // 2, 2)
|
| 48 |
+
|
| 49 |
+
# convention is `r` suffix is real part, `i` is imaginary.
|
| 50 |
+
qr = q[..., 0].float()
|
| 51 |
+
qi = q[..., 1].float()
|
| 52 |
+
|
| 53 |
+
kr = k[..., 0].float()
|
| 54 |
+
ki = k[..., 1].float()
|
| 55 |
+
|
| 56 |
+
rotr = torch.cos(freqs * ts)
|
| 57 |
+
roti = torch.sin(freqs * ts)
|
| 58 |
+
qor = qr * rotr - qi * roti
|
| 59 |
+
qoi = qr * roti + qi * rotr
|
| 60 |
+
|
| 61 |
+
kor = kr * rotr - ki * roti
|
| 62 |
+
koi = kr * roti + ki * rotr
|
| 63 |
+
|
| 64 |
+
dtype = q.dtype
|
| 65 |
+
qo = torch.stack([qor.to(dtype), qoi.to(dtype)], dim=-1)
|
| 66 |
+
ko = torch.stack([kor.to(dtype), koi.to(dtype)], dim=-1)
|
| 67 |
+
|
| 68 |
+
return qo.view(*dims, D), ko.view(*dims, D)
|
| 69 |
+
|
| 70 |
+
|
| 71 |
+
class RotaryEmbedding(nn.Module):
|
| 72 |
+
"""Rotary positional embedding (RoPE) from [Su et al 2022](https://arxiv.org/abs/2104.09864).
|
| 73 |
+
|
| 74 |
+
Args:
|
| 75 |
+
max_period (float): Maximum period of the rotation frequencies.
|
| 76 |
+
"""
|
| 77 |
+
|
| 78 |
+
def __init__(self, max_period: float = 10000.0):
|
| 79 |
+
super().__init__()
|
| 80 |
+
self.max_period = max_period
|
| 81 |
+
|
| 82 |
+
def forward(
|
| 83 |
+
self,
|
| 84 |
+
q: torch.Tensor,
|
| 85 |
+
k: torch.Tensor,
|
| 86 |
+
offset: torch.Tensor,
|
| 87 |
+
time_before_heads: bool = False,
|
| 88 |
+
):
|
| 89 |
+
"""Apply rope rotation to query or key tensor."""
|
| 90 |
+
return apply_rope(q, k, offset, self.max_period, time_before_heads)
|
moshi/modules/seanet.py
ADDED
|
@@ -0,0 +1,395 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) Kyutai, all rights reserved.
|
| 2 |
+
# This source code is licensed under the license found in the
|
| 3 |
+
# LICENSE file in the root directory of this source tree.
|
| 4 |
+
|
| 5 |
+
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
| 6 |
+
# All rights reserved.
|
| 7 |
+
#
|
| 8 |
+
# This source code is licensed under the license found in the
|
| 9 |
+
# LICENSE file in the root directory of this source tree.
|
| 10 |
+
|
| 11 |
+
import typing as tp
|
| 12 |
+
|
| 13 |
+
import numpy as np
|
| 14 |
+
import torch.nn as nn
|
| 15 |
+
|
| 16 |
+
from .conv import StreamingConv1d, StreamingConvTranspose1d
|
| 17 |
+
from .streaming import StreamingContainer, StreamingAdd
|
| 18 |
+
from ..utils.compile import torch_compile_lazy
|
| 19 |
+
|
| 20 |
+
|
| 21 |
+
class SEANetResnetBlock(StreamingContainer):
|
| 22 |
+
"""Residual block from SEANet model.
|
| 23 |
+
|
| 24 |
+
Args:
|
| 25 |
+
dim (int): Dimension of the input/output.
|
| 26 |
+
kernel_sizes (list): List of kernel sizes for the convolutions.
|
| 27 |
+
dilations (list): List of dilations for the convolutions.
|
| 28 |
+
activation (str): Activation function.
|
| 29 |
+
activation_params (dict): Parameters to provide to the activation function.
|
| 30 |
+
norm (str): Normalization method.
|
| 31 |
+
norm_params (dict): Parameters to provide to the underlying normalization used along with the convolution.
|
| 32 |
+
causal (bool): Whether to use fully causal convolution.
|
| 33 |
+
pad_mode (str): Padding mode for the convolutions.
|
| 34 |
+
compress (int): Reduced dimensionality in residual branches (from Demucs v3).
|
| 35 |
+
true_skip (bool): Whether to use true skip connection or a simple
|
| 36 |
+
(streamable) convolution as the skip connection.
|
| 37 |
+
"""
|
| 38 |
+
|
| 39 |
+
def __init__(
|
| 40 |
+
self,
|
| 41 |
+
dim: int,
|
| 42 |
+
kernel_sizes: tp.List[int] = [3, 1],
|
| 43 |
+
dilations: tp.List[int] = [1, 1],
|
| 44 |
+
activation: str = "ELU",
|
| 45 |
+
activation_params: dict = {"alpha": 1.0},
|
| 46 |
+
norm: str = "none",
|
| 47 |
+
norm_params: tp.Dict[str, tp.Any] = {},
|
| 48 |
+
causal: bool = False,
|
| 49 |
+
pad_mode: str = "reflect",
|
| 50 |
+
compress: int = 2,
|
| 51 |
+
true_skip: bool = True,
|
| 52 |
+
):
|
| 53 |
+
super().__init__()
|
| 54 |
+
assert len(kernel_sizes) == len(
|
| 55 |
+
dilations
|
| 56 |
+
), "Number of kernel sizes should match number of dilations"
|
| 57 |
+
act = getattr(nn, activation)
|
| 58 |
+
hidden = dim // compress
|
| 59 |
+
block = []
|
| 60 |
+
for i, (kernel_size, dilation) in enumerate(zip(kernel_sizes, dilations)):
|
| 61 |
+
in_chs = dim if i == 0 else hidden
|
| 62 |
+
out_chs = dim if i == len(kernel_sizes) - 1 else hidden
|
| 63 |
+
block += [
|
| 64 |
+
act(**activation_params),
|
| 65 |
+
StreamingConv1d(
|
| 66 |
+
in_chs,
|
| 67 |
+
out_chs,
|
| 68 |
+
kernel_size=kernel_size,
|
| 69 |
+
dilation=dilation,
|
| 70 |
+
norm=norm,
|
| 71 |
+
norm_kwargs=norm_params,
|
| 72 |
+
causal=causal,
|
| 73 |
+
pad_mode=pad_mode,
|
| 74 |
+
),
|
| 75 |
+
]
|
| 76 |
+
self.block = nn.Sequential(*block)
|
| 77 |
+
self.add = StreamingAdd()
|
| 78 |
+
self.shortcut: nn.Module
|
| 79 |
+
if true_skip:
|
| 80 |
+
self.shortcut = nn.Identity()
|
| 81 |
+
else:
|
| 82 |
+
self.shortcut = StreamingConv1d(
|
| 83 |
+
dim,
|
| 84 |
+
dim,
|
| 85 |
+
kernel_size=1,
|
| 86 |
+
norm=norm,
|
| 87 |
+
norm_kwargs=norm_params,
|
| 88 |
+
causal=causal,
|
| 89 |
+
pad_mode=pad_mode,
|
| 90 |
+
)
|
| 91 |
+
|
| 92 |
+
def forward(self, x):
|
| 93 |
+
u, v = self.shortcut(x), self.block(x)
|
| 94 |
+
return self.add(u, v)
|
| 95 |
+
|
| 96 |
+
|
| 97 |
+
class SEANetEncoder(StreamingContainer):
|
| 98 |
+
"""SEANet encoder.
|
| 99 |
+
|
| 100 |
+
Args:
|
| 101 |
+
channels (int): Audio channels.
|
| 102 |
+
dimension (int): Intermediate representation dimension.
|
| 103 |
+
n_filters (int): Base width for the model.
|
| 104 |
+
n_residual_layers (int): nb of residual layers.
|
| 105 |
+
ratios (Sequence[int]): kernel size and stride ratios. The encoder uses downsampling ratios instead of
|
| 106 |
+
upsampling ratios, hence it will use the ratios in the reverse order to the ones specified here
|
| 107 |
+
that must match the decoder order. We use the decoder order as some models may only employ the decoder.
|
| 108 |
+
activation (str): Activation function.
|
| 109 |
+
activation_params (dict): Parameters to provide to the activation function.
|
| 110 |
+
norm (str): Normalization method.
|
| 111 |
+
norm_params (dict): Parameters to provide to the underlying normalization used along with the convolution.
|
| 112 |
+
kernel_size (int): Kernel size for the initial convolution.
|
| 113 |
+
last_kernel_size (int): Kernel size for the initial convolution.
|
| 114 |
+
residual_kernel_size (int): Kernel size for the residual layers.
|
| 115 |
+
dilation_base (int): How much to increase the dilation with each layer.
|
| 116 |
+
causal (bool): Whether to use fully causal convolution.
|
| 117 |
+
pad_mode (str): Padding mode for the convolutions.
|
| 118 |
+
true_skip (bool): Whether to use true skip connection or a simple
|
| 119 |
+
(streamable) convolution as the skip connection in the residual network blocks.
|
| 120 |
+
compress (int): Reduced dimensionality in residual branches (from Demucs v3).
|
| 121 |
+
disable_norm_outer_blocks (int): Number of blocks for which we don't apply norm.
|
| 122 |
+
For the encoder, it corresponds to the N first blocks.
|
| 123 |
+
mask_fn (nn.Module): Optional mask function to apply after convolution layers.
|
| 124 |
+
mask_position (int): Position of the mask function, with mask_position == 0 for the first convolution layer,
|
| 125 |
+
mask_position == 1 for the first conv block, etc.
|
| 126 |
+
"""
|
| 127 |
+
|
| 128 |
+
def __init__(
|
| 129 |
+
self,
|
| 130 |
+
channels: int = 1,
|
| 131 |
+
dimension: int = 128,
|
| 132 |
+
n_filters: int = 32,
|
| 133 |
+
n_residual_layers: int = 3,
|
| 134 |
+
ratios: tp.List[int] = [8, 5, 4, 2],
|
| 135 |
+
activation: str = "ELU",
|
| 136 |
+
activation_params: dict = {"alpha": 1.0},
|
| 137 |
+
norm: str = "none",
|
| 138 |
+
norm_params: tp.Dict[str, tp.Any] = {},
|
| 139 |
+
kernel_size: int = 7,
|
| 140 |
+
last_kernel_size: int = 7,
|
| 141 |
+
residual_kernel_size: int = 3,
|
| 142 |
+
dilation_base: int = 2,
|
| 143 |
+
causal: bool = False,
|
| 144 |
+
pad_mode: str = "reflect",
|
| 145 |
+
true_skip: bool = True,
|
| 146 |
+
compress: int = 2,
|
| 147 |
+
disable_norm_outer_blocks: int = 0,
|
| 148 |
+
mask_fn: tp.Optional[nn.Module] = None,
|
| 149 |
+
mask_position: tp.Optional[int] = None,
|
| 150 |
+
):
|
| 151 |
+
super().__init__()
|
| 152 |
+
self.channels = channels
|
| 153 |
+
self.dimension = dimension
|
| 154 |
+
self.n_filters = n_filters
|
| 155 |
+
self.ratios = list(reversed(ratios))
|
| 156 |
+
del ratios
|
| 157 |
+
self.n_residual_layers = n_residual_layers
|
| 158 |
+
self.hop_length = int(np.prod(self.ratios))
|
| 159 |
+
self.n_blocks = len(self.ratios) + 2 # first and last conv + residual blocks
|
| 160 |
+
self.disable_norm_outer_blocks = disable_norm_outer_blocks
|
| 161 |
+
assert (
|
| 162 |
+
self.disable_norm_outer_blocks >= 0 and self.disable_norm_outer_blocks <= self.n_blocks
|
| 163 |
+
), (
|
| 164 |
+
"Number of blocks for which to disable norm is invalid."
|
| 165 |
+
"It should be lower or equal to the actual number of blocks in the network and greater or equal to 0."
|
| 166 |
+
)
|
| 167 |
+
|
| 168 |
+
act = getattr(nn, activation)
|
| 169 |
+
mult = 1
|
| 170 |
+
model: tp.List[nn.Module] = [
|
| 171 |
+
StreamingConv1d(
|
| 172 |
+
channels,
|
| 173 |
+
mult * n_filters,
|
| 174 |
+
kernel_size,
|
| 175 |
+
norm="none" if self.disable_norm_outer_blocks >= 1 else norm,
|
| 176 |
+
norm_kwargs=norm_params,
|
| 177 |
+
causal=causal,
|
| 178 |
+
pad_mode=pad_mode,
|
| 179 |
+
)
|
| 180 |
+
]
|
| 181 |
+
if mask_fn is not None and mask_position == 0:
|
| 182 |
+
model += [mask_fn]
|
| 183 |
+
# Downsample to raw audio scale
|
| 184 |
+
for i, ratio in enumerate(self.ratios):
|
| 185 |
+
block_norm = "none" if self.disable_norm_outer_blocks >= i + 2 else norm
|
| 186 |
+
# Add residual layers
|
| 187 |
+
for j in range(n_residual_layers):
|
| 188 |
+
model += [
|
| 189 |
+
SEANetResnetBlock(
|
| 190 |
+
mult * n_filters,
|
| 191 |
+
kernel_sizes=[residual_kernel_size, 1],
|
| 192 |
+
dilations=[dilation_base**j, 1],
|
| 193 |
+
norm=block_norm,
|
| 194 |
+
norm_params=norm_params,
|
| 195 |
+
activation=activation,
|
| 196 |
+
activation_params=activation_params,
|
| 197 |
+
causal=causal,
|
| 198 |
+
pad_mode=pad_mode,
|
| 199 |
+
compress=compress,
|
| 200 |
+
true_skip=true_skip,
|
| 201 |
+
)
|
| 202 |
+
]
|
| 203 |
+
|
| 204 |
+
# Add downsampling layers
|
| 205 |
+
model += [
|
| 206 |
+
act(**activation_params),
|
| 207 |
+
StreamingConv1d(
|
| 208 |
+
mult * n_filters,
|
| 209 |
+
mult * n_filters * 2,
|
| 210 |
+
kernel_size=ratio * 2,
|
| 211 |
+
stride=ratio,
|
| 212 |
+
norm=block_norm,
|
| 213 |
+
norm_kwargs=norm_params,
|
| 214 |
+
causal=causal,
|
| 215 |
+
pad_mode=pad_mode,
|
| 216 |
+
),
|
| 217 |
+
]
|
| 218 |
+
mult *= 2
|
| 219 |
+
if mask_fn is not None and mask_position == i + 1:
|
| 220 |
+
model += [mask_fn]
|
| 221 |
+
|
| 222 |
+
model += [
|
| 223 |
+
act(**activation_params),
|
| 224 |
+
StreamingConv1d(
|
| 225 |
+
mult * n_filters,
|
| 226 |
+
dimension,
|
| 227 |
+
last_kernel_size,
|
| 228 |
+
norm=(
|
| 229 |
+
"none" if self.disable_norm_outer_blocks == self.n_blocks else norm
|
| 230 |
+
),
|
| 231 |
+
norm_kwargs=norm_params,
|
| 232 |
+
causal=causal,
|
| 233 |
+
pad_mode=pad_mode,
|
| 234 |
+
),
|
| 235 |
+
]
|
| 236 |
+
|
| 237 |
+
self.model = nn.Sequential(*model)
|
| 238 |
+
|
| 239 |
+
@torch_compile_lazy
|
| 240 |
+
def forward(self, x):
|
| 241 |
+
return self.model(x)
|
| 242 |
+
|
| 243 |
+
|
| 244 |
+
class SEANetDecoder(StreamingContainer):
|
| 245 |
+
"""SEANet decoder.
|
| 246 |
+
|
| 247 |
+
Args:
|
| 248 |
+
channels (int): Audio channels.
|
| 249 |
+
dimension (int): Intermediate representation dimension.
|
| 250 |
+
n_filters (int): Base width for the model.
|
| 251 |
+
n_residual_layers (int): nb of residual layers.
|
| 252 |
+
ratios (Sequence[int]): kernel size and stride ratios.
|
| 253 |
+
activation (str): Activation function.
|
| 254 |
+
activation_params (dict): Parameters to provide to the activation function.
|
| 255 |
+
final_activation (str): Final activation function after all convolutions.
|
| 256 |
+
final_activation_params (dict): Parameters to provide to the activation function.
|
| 257 |
+
norm (str): Normalization method.
|
| 258 |
+
norm_params (dict): Parameters to provide to the underlying normalization used along with the convolution.
|
| 259 |
+
kernel_size (int): Kernel size for the initial convolution.
|
| 260 |
+
last_kernel_size (int): Kernel size for the initial convolution.
|
| 261 |
+
residual_kernel_size (int): Kernel size for the residual layers.
|
| 262 |
+
dilation_base (int): How much to increase the dilation with each layer.
|
| 263 |
+
causal (bool): Whether to use fully causal convolution.
|
| 264 |
+
pad_mode (str): Padding mode for the convolutions.
|
| 265 |
+
true_skip (bool): Whether to use true skip connection or a simple.
|
| 266 |
+
(streamable) convolution as the skip connection in the residual network blocks.
|
| 267 |
+
compress (int): Reduced dimensionality in residual branches (from Demucs v3).
|
| 268 |
+
disable_norm_outer_blocks (int): Number of blocks for which we don't apply norm.
|
| 269 |
+
For the decoder, it corresponds to the N last blocks.
|
| 270 |
+
trim_right_ratio (float): Ratio for trimming at the right of the transposed convolution under the causal setup.
|
| 271 |
+
If equal to 1.0, it means that all the trimming is done at the right.
|
| 272 |
+
"""
|
| 273 |
+
|
| 274 |
+
def __init__(
|
| 275 |
+
self,
|
| 276 |
+
channels: int = 1,
|
| 277 |
+
dimension: int = 128,
|
| 278 |
+
n_filters: int = 32,
|
| 279 |
+
n_residual_layers: int = 3,
|
| 280 |
+
ratios: tp.List[int] = [8, 5, 4, 2],
|
| 281 |
+
activation: str = "ELU",
|
| 282 |
+
activation_params: dict = {"alpha": 1.0},
|
| 283 |
+
final_activation: tp.Optional[str] = None,
|
| 284 |
+
final_activation_params: tp.Optional[dict] = None,
|
| 285 |
+
norm: str = "none",
|
| 286 |
+
norm_params: tp.Dict[str, tp.Any] = {},
|
| 287 |
+
kernel_size: int = 7,
|
| 288 |
+
last_kernel_size: int = 7,
|
| 289 |
+
residual_kernel_size: int = 3,
|
| 290 |
+
dilation_base: int = 2,
|
| 291 |
+
causal: bool = False,
|
| 292 |
+
pad_mode: str = "reflect",
|
| 293 |
+
true_skip: bool = True,
|
| 294 |
+
compress: int = 2,
|
| 295 |
+
disable_norm_outer_blocks: int = 0,
|
| 296 |
+
trim_right_ratio: float = 1.0,
|
| 297 |
+
):
|
| 298 |
+
super().__init__()
|
| 299 |
+
self.dimension = dimension
|
| 300 |
+
self.channels = channels
|
| 301 |
+
self.n_filters = n_filters
|
| 302 |
+
self.ratios = ratios
|
| 303 |
+
del ratios
|
| 304 |
+
self.n_residual_layers = n_residual_layers
|
| 305 |
+
self.hop_length = int(np.prod(self.ratios))
|
| 306 |
+
self.n_blocks = len(self.ratios) + 2 # first and last conv + residual blocks
|
| 307 |
+
self.disable_norm_outer_blocks = disable_norm_outer_blocks
|
| 308 |
+
assert (
|
| 309 |
+
self.disable_norm_outer_blocks >= 0 and self.disable_norm_outer_blocks <= self.n_blocks
|
| 310 |
+
), (
|
| 311 |
+
"Number of blocks for which to disable norm is invalid."
|
| 312 |
+
"It should be lower or equal to the actual number of blocks in the network and greater or equal to 0."
|
| 313 |
+
)
|
| 314 |
+
|
| 315 |
+
act = getattr(nn, activation)
|
| 316 |
+
mult = int(2 ** len(self.ratios))
|
| 317 |
+
model: tp.List[nn.Module] = [
|
| 318 |
+
StreamingConv1d(
|
| 319 |
+
dimension,
|
| 320 |
+
mult * n_filters,
|
| 321 |
+
kernel_size,
|
| 322 |
+
norm=(
|
| 323 |
+
"none" if self.disable_norm_outer_blocks == self.n_blocks else norm
|
| 324 |
+
),
|
| 325 |
+
norm_kwargs=norm_params,
|
| 326 |
+
causal=causal,
|
| 327 |
+
pad_mode=pad_mode,
|
| 328 |
+
)
|
| 329 |
+
]
|
| 330 |
+
|
| 331 |
+
# Upsample to raw audio scale
|
| 332 |
+
for i, ratio in enumerate(self.ratios):
|
| 333 |
+
block_norm = (
|
| 334 |
+
"none"
|
| 335 |
+
if self.disable_norm_outer_blocks >= self.n_blocks - (i + 1)
|
| 336 |
+
else norm
|
| 337 |
+
)
|
| 338 |
+
# Add upsampling layers
|
| 339 |
+
model += [
|
| 340 |
+
act(**activation_params),
|
| 341 |
+
StreamingConvTranspose1d(
|
| 342 |
+
mult * n_filters,
|
| 343 |
+
mult * n_filters // 2,
|
| 344 |
+
kernel_size=ratio * 2,
|
| 345 |
+
stride=ratio,
|
| 346 |
+
norm=block_norm,
|
| 347 |
+
norm_kwargs=norm_params,
|
| 348 |
+
causal=causal,
|
| 349 |
+
trim_right_ratio=trim_right_ratio,
|
| 350 |
+
),
|
| 351 |
+
]
|
| 352 |
+
# Add residual layers
|
| 353 |
+
for j in range(n_residual_layers):
|
| 354 |
+
model += [
|
| 355 |
+
SEANetResnetBlock(
|
| 356 |
+
mult * n_filters // 2,
|
| 357 |
+
kernel_sizes=[residual_kernel_size, 1],
|
| 358 |
+
dilations=[dilation_base**j, 1],
|
| 359 |
+
activation=activation,
|
| 360 |
+
activation_params=activation_params,
|
| 361 |
+
norm=block_norm,
|
| 362 |
+
norm_params=norm_params,
|
| 363 |
+
causal=causal,
|
| 364 |
+
pad_mode=pad_mode,
|
| 365 |
+
compress=compress,
|
| 366 |
+
true_skip=true_skip,
|
| 367 |
+
)
|
| 368 |
+
]
|
| 369 |
+
|
| 370 |
+
mult //= 2
|
| 371 |
+
|
| 372 |
+
# Add final layers
|
| 373 |
+
model += [
|
| 374 |
+
act(**activation_params),
|
| 375 |
+
StreamingConv1d(
|
| 376 |
+
n_filters,
|
| 377 |
+
channels,
|
| 378 |
+
last_kernel_size,
|
| 379 |
+
norm="none" if self.disable_norm_outer_blocks >= 1 else norm,
|
| 380 |
+
norm_kwargs=norm_params,
|
| 381 |
+
causal=causal,
|
| 382 |
+
pad_mode=pad_mode,
|
| 383 |
+
),
|
| 384 |
+
]
|
| 385 |
+
# Add optional final activation to decoder (eg. tanh)
|
| 386 |
+
if final_activation is not None:
|
| 387 |
+
final_act = getattr(nn, final_activation)
|
| 388 |
+
final_activation_params = final_activation_params or {}
|
| 389 |
+
model += [final_act(**final_activation_params)]
|
| 390 |
+
self.model = nn.Sequential(*model)
|
| 391 |
+
|
| 392 |
+
@torch_compile_lazy
|
| 393 |
+
def forward(self, z):
|
| 394 |
+
y = self.model(z)
|
| 395 |
+
return y
|
moshi/modules/streaming.py
ADDED
|
@@ -0,0 +1,363 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) Kyutai, all rights reserved.
|
| 2 |
+
# This source code is licensed under the license found in the
|
| 3 |
+
# LICENSE file in the root directory of this source tree.
|
| 4 |
+
|
| 5 |
+
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
| 6 |
+
# All rights reserved.
|
| 7 |
+
#
|
| 8 |
+
# This source code is licensed under the license found in the
|
| 9 |
+
# LICENSE file in the root directory of this source tree.
|
| 10 |
+
|
| 11 |
+
"""
|
| 12 |
+
Streaming module API that should be implemented by all Streaming components,
|
| 13 |
+
"""
|
| 14 |
+
|
| 15 |
+
import abc
|
| 16 |
+
from contextlib import contextmanager
|
| 17 |
+
from dataclasses import dataclass
|
| 18 |
+
import itertools
|
| 19 |
+
import math
|
| 20 |
+
import typing as tp
|
| 21 |
+
from torch import nn
|
| 22 |
+
import torch
|
| 23 |
+
|
| 24 |
+
|
| 25 |
+
class Resetable(tp.Protocol):
|
| 26 |
+
def reset(self) -> None:
|
| 27 |
+
pass
|
| 28 |
+
|
| 29 |
+
|
| 30 |
+
State = tp.TypeVar("State", bound=Resetable)
|
| 31 |
+
|
| 32 |
+
|
| 33 |
+
class StreamingModule(abc.ABC, nn.Module, tp.Generic[State]):
|
| 34 |
+
"""Common API for streaming components.
|
| 35 |
+
|
| 36 |
+
Each streaming component has a streaming state, which is just a dict[str, Tensor].
|
| 37 |
+
By convention, the first dim of each tensor must be the batch size.
|
| 38 |
+
Don't use dots in the key names, as this would clash with submodules
|
| 39 |
+
(like in state_dict).
|
| 40 |
+
|
| 41 |
+
If `self._is_streaming` is True, the component should use and remember
|
| 42 |
+
the proper state inside `self._streaming_state`.
|
| 43 |
+
|
| 44 |
+
To set a streaming component in streaming state, use
|
| 45 |
+
|
| 46 |
+
with module.streaming():
|
| 47 |
+
...
|
| 48 |
+
|
| 49 |
+
This will automatically reset the streaming state when exiting the context manager.
|
| 50 |
+
This also automatically propagates to all streaming children module.
|
| 51 |
+
|
| 52 |
+
Some module might also implement the `StreamingModule.flush` method, although
|
| 53 |
+
this one is trickier, as all parents module must be StreamingModule and implement
|
| 54 |
+
it as well for it to work properly. See `StreamingSequential` after.
|
| 55 |
+
"""
|
| 56 |
+
|
| 57 |
+
def __init__(self) -> None:
|
| 58 |
+
super().__init__()
|
| 59 |
+
self._streaming_state: State | None = None
|
| 60 |
+
self._streaming_propagate: bool = True
|
| 61 |
+
|
| 62 |
+
@property
|
| 63 |
+
def is_streaming(self):
|
| 64 |
+
return self._streaming_state is not None
|
| 65 |
+
|
| 66 |
+
def set_streaming_propagate(self, streaming_propagate: bool):
|
| 67 |
+
self._streaming_propagate = streaming_propagate
|
| 68 |
+
|
| 69 |
+
def _apply_named_streaming(self, fn: tp.Any):
|
| 70 |
+
def _handle_module(prefix: str, module: nn.Module, recurse: bool = True):
|
| 71 |
+
propagate = True
|
| 72 |
+
if isinstance(module, StreamingModule):
|
| 73 |
+
if module._streaming_propagate:
|
| 74 |
+
fn(prefix, module)
|
| 75 |
+
else:
|
| 76 |
+
propagate = False
|
| 77 |
+
if not recurse:
|
| 78 |
+
return
|
| 79 |
+
if propagate:
|
| 80 |
+
for name, child in module.named_children():
|
| 81 |
+
_handle_module(prefix + "." + name, child)
|
| 82 |
+
|
| 83 |
+
_handle_module("", self, recurse=False)
|
| 84 |
+
for name, child in self.named_children():
|
| 85 |
+
_handle_module(name, child)
|
| 86 |
+
|
| 87 |
+
def _start_streaming(self, batch_size: int):
|
| 88 |
+
def _start_streaming(name: str, module: StreamingModule):
|
| 89 |
+
module._streaming_state = module._init_streaming_state(batch_size)
|
| 90 |
+
|
| 91 |
+
self._apply_named_streaming(_start_streaming)
|
| 92 |
+
|
| 93 |
+
def _stop_streaming(self):
|
| 94 |
+
def _stop_streaming(name: str, module: StreamingModule):
|
| 95 |
+
module._streaming_state = None
|
| 96 |
+
|
| 97 |
+
self._apply_named_streaming(_stop_streaming)
|
| 98 |
+
|
| 99 |
+
@abc.abstractmethod
|
| 100 |
+
def _init_streaming_state(self, batch_size: int) -> State: ...
|
| 101 |
+
|
| 102 |
+
def streaming_forever(self, batch_size: int):
|
| 103 |
+
self._start_streaming(batch_size)
|
| 104 |
+
|
| 105 |
+
@contextmanager
|
| 106 |
+
def streaming(self, batch_size: int):
|
| 107 |
+
"""Context manager to enter streaming mode. Reset streaming state on exit."""
|
| 108 |
+
|
| 109 |
+
self._start_streaming(batch_size)
|
| 110 |
+
try:
|
| 111 |
+
yield
|
| 112 |
+
finally:
|
| 113 |
+
self._stop_streaming()
|
| 114 |
+
|
| 115 |
+
def reset_streaming(self):
|
| 116 |
+
"""Reset the streaming state."""
|
| 117 |
+
|
| 118 |
+
def _reset(name: str, module: StreamingModule):
|
| 119 |
+
state = module._streaming_state
|
| 120 |
+
if state is None:
|
| 121 |
+
raise ValueError(
|
| 122 |
+
f"Trying to reset streaming, but {name} wasn't streaming."
|
| 123 |
+
)
|
| 124 |
+
state.reset()
|
| 125 |
+
|
| 126 |
+
self._apply_named_streaming(_reset)
|
| 127 |
+
|
| 128 |
+
def get_streaming_state(self) -> dict[str, tp.Any]:
|
| 129 |
+
"""Return the complete streaming state, including that of sub-modules."""
|
| 130 |
+
state: dict[str, tp.Any] = {}
|
| 131 |
+
|
| 132 |
+
def _add(name: str, module: StreamingModule):
|
| 133 |
+
state[name] = module._streaming_state
|
| 134 |
+
|
| 135 |
+
self._apply_named_streaming(_add)
|
| 136 |
+
return state
|
| 137 |
+
|
| 138 |
+
def set_streaming_state(self, state: dict[str, tp.Any]):
|
| 139 |
+
"""Set the streaming state, including that of sub-modules."""
|
| 140 |
+
state = dict(state)
|
| 141 |
+
|
| 142 |
+
def _set(name: str, module: StreamingModule):
|
| 143 |
+
if name in state:
|
| 144 |
+
module._streaming_state = state[name]
|
| 145 |
+
state.pop(name)
|
| 146 |
+
else:
|
| 147 |
+
raise RuntimeError(f"Expected to find a streaming state for {name}.")
|
| 148 |
+
|
| 149 |
+
self._apply_named_streaming(_set)
|
| 150 |
+
if state:
|
| 151 |
+
raise RuntimeError(f"Some states were not consumed: {list(state.keys())}")
|
| 152 |
+
|
| 153 |
+
|
| 154 |
+
@dataclass
|
| 155 |
+
class _NullState:
|
| 156 |
+
pass
|
| 157 |
+
|
| 158 |
+
def reset(self) -> None:
|
| 159 |
+
pass
|
| 160 |
+
|
| 161 |
+
|
| 162 |
+
class StreamingContainer(StreamingModule[_NullState]):
|
| 163 |
+
def _init_streaming_state(self, batch_size: int) -> _NullState:
|
| 164 |
+
return _NullState()
|
| 165 |
+
|
| 166 |
+
|
| 167 |
+
@dataclass
|
| 168 |
+
class _StreamingAddState:
|
| 169 |
+
previous_x: torch.Tensor | None = None
|
| 170 |
+
previous_y: torch.Tensor | None = None
|
| 171 |
+
|
| 172 |
+
def reset(self):
|
| 173 |
+
self.previous_x = None
|
| 174 |
+
self.previous_y = None
|
| 175 |
+
|
| 176 |
+
|
| 177 |
+
class StreamingAdd(StreamingModule[_StreamingAddState]):
|
| 178 |
+
def _init_streaming_state(self, batch_size: int) -> _StreamingAddState:
|
| 179 |
+
return _StreamingAddState()
|
| 180 |
+
|
| 181 |
+
def forward(self, x: torch.Tensor, y: torch.Tensor):
|
| 182 |
+
if self._streaming_state is None:
|
| 183 |
+
return x + y
|
| 184 |
+
else:
|
| 185 |
+
prev_x = self._streaming_state.previous_x
|
| 186 |
+
prev_y = self._streaming_state.previous_y
|
| 187 |
+
if prev_x is not None:
|
| 188 |
+
x = torch.cat([prev_x, x], dim=-1)
|
| 189 |
+
if prev_y is not None:
|
| 190 |
+
y = torch.cat([prev_y, y], dim=-1)
|
| 191 |
+
m_l = min(x.shape[-1], y.shape[-1])
|
| 192 |
+
self._streaming_state.previous_x = x[..., m_l:]
|
| 193 |
+
self._streaming_state.previous_y = y[..., m_l:]
|
| 194 |
+
return x[..., :m_l] + y[..., :m_l]
|
| 195 |
+
|
| 196 |
+
|
| 197 |
+
@dataclass
|
| 198 |
+
class _StreamingConvState:
|
| 199 |
+
previous: torch.Tensor | None = None
|
| 200 |
+
|
| 201 |
+
def reset(self):
|
| 202 |
+
self.previous = None
|
| 203 |
+
|
| 204 |
+
|
| 205 |
+
class RawStreamingConv1d(nn.Conv1d, StreamingModule[_StreamingConvState]):
|
| 206 |
+
def __init__(self, *args, **kwargs):
|
| 207 |
+
super().__init__(*args, **kwargs)
|
| 208 |
+
assert self.padding[0] == 0, "Padding should be handled outside."
|
| 209 |
+
assert (
|
| 210 |
+
self.stride[0] <= self.kernel_size[0]
|
| 211 |
+
), "stride must be less than kernel_size."
|
| 212 |
+
|
| 213 |
+
def _init_streaming_state(self, batch_size: int) -> _StreamingConvState:
|
| 214 |
+
return _StreamingConvState()
|
| 215 |
+
|
| 216 |
+
def forward(self, input: torch.Tensor) -> torch.Tensor:
|
| 217 |
+
stride = self.stride[0]
|
| 218 |
+
# Effective kernel size accounting for dilation.
|
| 219 |
+
kernel = (self.kernel_size[0] - 1) * self.dilation[0] + 1
|
| 220 |
+
if self._streaming_state is None:
|
| 221 |
+
return super().forward(input)
|
| 222 |
+
else:
|
| 223 |
+
# Due to the potential overlap, we might have some cache of the previous time steps.
|
| 224 |
+
previous = self._streaming_state.previous
|
| 225 |
+
if previous is not None:
|
| 226 |
+
input = torch.cat([previous, input], dim=-1)
|
| 227 |
+
B, C, T = input.shape
|
| 228 |
+
# We now compute the number of full convolution frames, i.e. the frames
|
| 229 |
+
# that are ready to be computed.
|
| 230 |
+
num_frames = max(0, int(math.floor((T - kernel) / stride) + 1))
|
| 231 |
+
offset = num_frames * stride
|
| 232 |
+
# We will compute `num_frames` outputs, and we are advancing by `stride`
|
| 233 |
+
# for each of the frame, so we know the data before `stride * num_frames`
|
| 234 |
+
# will never be used again.
|
| 235 |
+
self._streaming_state.previous = input[..., offset:]
|
| 236 |
+
if num_frames > 0:
|
| 237 |
+
input_length = (num_frames - 1) * stride + kernel
|
| 238 |
+
out = super().forward(input[..., :input_length])
|
| 239 |
+
else:
|
| 240 |
+
# Not enough data as this point to output some new frames.
|
| 241 |
+
out = torch.empty(
|
| 242 |
+
B, self.out_channels, 0, device=input.device, dtype=input.dtype
|
| 243 |
+
)
|
| 244 |
+
return out
|
| 245 |
+
|
| 246 |
+
|
| 247 |
+
@dataclass
|
| 248 |
+
class _StreamingConvTrState:
|
| 249 |
+
partial: torch.Tensor | None = None
|
| 250 |
+
|
| 251 |
+
def reset(self):
|
| 252 |
+
self.partial = None
|
| 253 |
+
|
| 254 |
+
|
| 255 |
+
class RawStreamingConvTranspose1d(
|
| 256 |
+
nn.ConvTranspose1d, StreamingModule[_StreamingConvTrState]
|
| 257 |
+
):
|
| 258 |
+
def __init__(self, *args, **kwargs):
|
| 259 |
+
super().__init__(*args, **kwargs)
|
| 260 |
+
assert self.padding[0] == 0, "Padding should be handled outside."
|
| 261 |
+
assert self.dilation[0] == 1, "No dilation for now"
|
| 262 |
+
assert (
|
| 263 |
+
self.stride[0] <= self.kernel_size[0]
|
| 264 |
+
), "stride must be less than kernel_size."
|
| 265 |
+
assert self.output_padding[0] == 0, "Output padding not supported."
|
| 266 |
+
|
| 267 |
+
def _init_streaming_state(self, batch_size: int) -> _StreamingConvTrState:
|
| 268 |
+
return _StreamingConvTrState()
|
| 269 |
+
|
| 270 |
+
def forward(self, x: torch.Tensor) -> torch.Tensor: # type: ignore
|
| 271 |
+
B, C, T = x.shape
|
| 272 |
+
stride = self.stride[0]
|
| 273 |
+
kernel = self.kernel_size[0]
|
| 274 |
+
if self._streaming_state is None:
|
| 275 |
+
return super().forward(x)
|
| 276 |
+
else:
|
| 277 |
+
if T == 0:
|
| 278 |
+
return torch.empty(
|
| 279 |
+
B, self.out_channels, 0, device=x.device, dtype=x.dtype
|
| 280 |
+
)
|
| 281 |
+
out = super().forward(x)
|
| 282 |
+
OT = out.shape[-1]
|
| 283 |
+
partial = self._streaming_state.partial
|
| 284 |
+
if partial is not None:
|
| 285 |
+
# Due to the potential overlap, the rightmost output of the conv transpose is not
|
| 286 |
+
# ready to be output, as it will receive contributions from the next input frames.
|
| 287 |
+
# Here we recover those `partial` output frames. We know that the first time step
|
| 288 |
+
# of the `partial` tensor corresponds to the first time step of `out` as anything
|
| 289 |
+
# coming before the first time step of `out` would have been already flushed.
|
| 290 |
+
PT = partial.shape[-1]
|
| 291 |
+
if self.bias is not None:
|
| 292 |
+
out[..., :PT] += partial - self.bias[:, None]
|
| 293 |
+
else:
|
| 294 |
+
out[..., :PT] += partial
|
| 295 |
+
# The input is T, the output is S * (T - 1) + K.
|
| 296 |
+
# The offset of the left of the next frame will be S * T
|
| 297 |
+
# so everything between 0 and S * T is ready to be output, and we need
|
| 298 |
+
# to keep in the internal state everything beyond that, i.e. S (T - 1) + K - S T = K - S
|
| 299 |
+
invalid_steps = kernel - stride
|
| 300 |
+
partial = out[..., OT - invalid_steps :]
|
| 301 |
+
out = out[..., : OT - invalid_steps]
|
| 302 |
+
self._streaming_state.partial = partial
|
| 303 |
+
return out
|
| 304 |
+
|
| 305 |
+
|
| 306 |
+
def test():
|
| 307 |
+
torch.manual_seed(1234)
|
| 308 |
+
device = "cpu"
|
| 309 |
+
if torch.cuda.is_available():
|
| 310 |
+
# Avoid the cuda optimizations that would take place on single precision
|
| 311 |
+
# floats for convolutions.
|
| 312 |
+
torch.backends.cudnn.enabled = True
|
| 313 |
+
torch.backends.cudnn.benchmark = False
|
| 314 |
+
torch.backends.cudnn.deterministic = True
|
| 315 |
+
torch.backends.cuda.matmul.allow_tf32 = False
|
| 316 |
+
torch.backends.cudnn.allow_tf32 = False
|
| 317 |
+
device = "cuda:0"
|
| 318 |
+
|
| 319 |
+
kernel_sizes = [1, 3, 4, 8, 15, 16]
|
| 320 |
+
strides = [1, 2, 3, 4, 5, 6, 7, 8, 9]
|
| 321 |
+
chin = 6
|
| 322 |
+
chout = 12
|
| 323 |
+
|
| 324 |
+
for kernel, stride in itertools.product(kernel_sizes, strides):
|
| 325 |
+
if stride > kernel:
|
| 326 |
+
continue
|
| 327 |
+
conv = RawStreamingConv1d(chin, chout, kernel, stride).to(device)
|
| 328 |
+
convtr = RawStreamingConvTranspose1d(chout, chin, kernel, stride).to(device)
|
| 329 |
+
|
| 330 |
+
for length in [4, 8, 32, 54, 65, 128, 1043]:
|
| 331 |
+
print(f"ksize {kernel} strides {stride} len {length}")
|
| 332 |
+
if length < kernel:
|
| 333 |
+
continue
|
| 334 |
+
batch_size = 3
|
| 335 |
+
x = torch.randn(batch_size, chin, length).to(device)
|
| 336 |
+
y = conv(x)
|
| 337 |
+
z = convtr(y)
|
| 338 |
+
for chunk_size in [1, 3, 5, 8]:
|
| 339 |
+
ys = []
|
| 340 |
+
zs = []
|
| 341 |
+
with conv.streaming(batch_size), convtr.streaming(batch_size):
|
| 342 |
+
for offset in range(0, length, chunk_size):
|
| 343 |
+
chunk = x[..., offset : offset + chunk_size]
|
| 344 |
+
ys.append(conv(chunk))
|
| 345 |
+
zs.append(convtr(ys[-1]))
|
| 346 |
+
y_stream = torch.cat(ys, dim=-1)
|
| 347 |
+
z_stream = torch.cat(zs, dim=-1)
|
| 348 |
+
y = y[..., : y_stream.shape[-1]]
|
| 349 |
+
z = z[..., : z_stream.shape[-1]]
|
| 350 |
+
assert y.shape == y_stream.shape, (y.shape, y_stream.shape)
|
| 351 |
+
delta = (y_stream - y).norm() / y.norm()
|
| 352 |
+
assert delta <= 1e-6, delta
|
| 353 |
+
num_frames = int((length - kernel) / stride) + 1
|
| 354 |
+
assert num_frames == y_stream.shape[-1]
|
| 355 |
+
|
| 356 |
+
assert z.shape == z_stream.shape, (z.shape, z_stream.shape)
|
| 357 |
+
delta = (z_stream - z).norm() / z.norm()
|
| 358 |
+
assert delta <= 1e-6, (delta, (z_stream - z).abs().mean(dim=(0, 1)))
|
| 359 |
+
|
| 360 |
+
|
| 361 |
+
if __name__ == "__main__":
|
| 362 |
+
with torch.no_grad():
|
| 363 |
+
test()
|
moshi/modules/transformer.py
ADDED
|
@@ -0,0 +1,750 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) Kyutai, all rights reserved.
|
| 2 |
+
# This source code is licensed under the license found in the
|
| 3 |
+
# LICENSE file in the root directory of this source tree.
|
| 4 |
+
|
| 5 |
+
"""
|
| 6 |
+
Transformer model, with streaming support, + CUDA Graphable.
|
| 7 |
+
Optimized for inference.
|
| 8 |
+
|
| 9 |
+
See `StreamingTransformer` for more information.
|
| 10 |
+
"""
|
| 11 |
+
|
| 12 |
+
from contextlib import ExitStack
|
| 13 |
+
from dataclasses import dataclass
|
| 14 |
+
import typing as tp
|
| 15 |
+
|
| 16 |
+
from einops import rearrange
|
| 17 |
+
import torch
|
| 18 |
+
import torch.nn as nn
|
| 19 |
+
from torch.nn import functional as F
|
| 20 |
+
|
| 21 |
+
from ..utils.compile import no_compile
|
| 22 |
+
from .gating import make_gating
|
| 23 |
+
from .rope import RotaryEmbedding
|
| 24 |
+
from .streaming import StreamingModule, StreamingContainer
|
| 25 |
+
|
| 26 |
+
|
| 27 |
+
class LayerNormF32(nn.LayerNorm):
|
| 28 |
+
def forward(self, input: torch.Tensor) -> torch.Tensor:
|
| 29 |
+
x_f32 = input.float()
|
| 30 |
+
out_f32 = super().forward(x_f32)
|
| 31 |
+
return out_f32.to(input.dtype)
|
| 32 |
+
|
| 33 |
+
|
| 34 |
+
def _rms_norm(
|
| 35 |
+
x: torch.Tensor,
|
| 36 |
+
alpha: torch.Tensor,
|
| 37 |
+
dtype: tp.Optional[torch.dtype],
|
| 38 |
+
eps: float,
|
| 39 |
+
):
|
| 40 |
+
assert x.dim() == 3, f"RMSNorm expects 3D inputs but got {x.shape}"
|
| 41 |
+
x_dtype = x.dtype
|
| 42 |
+
if dtype is not None:
|
| 43 |
+
x = x.to(dtype)
|
| 44 |
+
var = eps + torch.mean(x**2, dim=2, keepdim=True)
|
| 45 |
+
y = (x * (alpha.to(var) * torch.rsqrt(var))).to(x_dtype)
|
| 46 |
+
return y
|
| 47 |
+
|
| 48 |
+
|
| 49 |
+
class RMSNorm(nn.Module):
|
| 50 |
+
def __init__(
|
| 51 |
+
self,
|
| 52 |
+
dim: int,
|
| 53 |
+
eps: float = 1e-5,
|
| 54 |
+
dtype: tp.Optional[torch.dtype] = None,
|
| 55 |
+
device=None,
|
| 56 |
+
):
|
| 57 |
+
super().__init__()
|
| 58 |
+
self.eps = eps
|
| 59 |
+
self.dtype = dtype
|
| 60 |
+
self.alpha = nn.Parameter(
|
| 61 |
+
torch.full((1, 1, dim), 1.0, requires_grad=True, device=device, dtype=dtype)
|
| 62 |
+
)
|
| 63 |
+
|
| 64 |
+
def forward(self, x: torch.Tensor):
|
| 65 |
+
return _rms_norm(x, self.alpha, self.dtype, self.eps)
|
| 66 |
+
|
| 67 |
+
|
| 68 |
+
class LayerScale(nn.Module):
|
| 69 |
+
"""Layer scale from [Touvron et al 2021] (https://arxiv.org/pdf/2103.17239.pdf).
|
| 70 |
+
This rescales diagonally the residual outputs close to 0, with a learnt scale.
|
| 71 |
+
|
| 72 |
+
Args:
|
| 73 |
+
channels (int): Number of channels.
|
| 74 |
+
init (float): Initial scale.
|
| 75 |
+
channel_last (bool): If True, expect `[*, C]` shaped tensors, otherwise, `[*, C, T]`.
|
| 76 |
+
device (torch.device or str, optional): Device on which to initialize the module.
|
| 77 |
+
dtype (torch.dtype, optional): dtype to use to initialize the module.
|
| 78 |
+
"""
|
| 79 |
+
|
| 80 |
+
def __init__(
|
| 81 |
+
self,
|
| 82 |
+
channels: int,
|
| 83 |
+
init: float = 1e-4,
|
| 84 |
+
channel_last: bool = True,
|
| 85 |
+
device=None,
|
| 86 |
+
dtype=None,
|
| 87 |
+
):
|
| 88 |
+
super().__init__()
|
| 89 |
+
self.channel_last = channel_last
|
| 90 |
+
self.scale = nn.Parameter(
|
| 91 |
+
torch.full(
|
| 92 |
+
(channels,), init, requires_grad=True, device=device, dtype=dtype
|
| 93 |
+
)
|
| 94 |
+
)
|
| 95 |
+
|
| 96 |
+
def forward(self, x: torch.Tensor):
|
| 97 |
+
if self.channel_last:
|
| 98 |
+
return self.scale * x
|
| 99 |
+
else:
|
| 100 |
+
return self.scale[:, None] * x
|
| 101 |
+
|
| 102 |
+
|
| 103 |
+
def create_norm_fn(norm_type: str, dim: int, **kwargs) -> nn.Module:
|
| 104 |
+
"""Create normalization module for transformer encoder layer.
|
| 105 |
+
|
| 106 |
+
Args:
|
| 107 |
+
norm_type (str): Normalization method.
|
| 108 |
+
dim (int): Dimension of the normalized layer.
|
| 109 |
+
**kwargs (dict): Additional parameters for normalization layer.
|
| 110 |
+
Returns:
|
| 111 |
+
nn.Module: Normalization module.
|
| 112 |
+
"""
|
| 113 |
+
if norm_type == "layer_norm":
|
| 114 |
+
return nn.LayerNorm(dim, eps=1e-5, **kwargs)
|
| 115 |
+
elif norm_type == "layer_norm_f32":
|
| 116 |
+
kwargs.pop("dtype", None)
|
| 117 |
+
return LayerNormF32(dim, eps=1e-8, **kwargs)
|
| 118 |
+
elif norm_type in {"rms_norm"}:
|
| 119 |
+
return RMSNorm(dim, eps=1e-5, **kwargs)
|
| 120 |
+
elif norm_type in {"rms_norm_f32"}:
|
| 121 |
+
kwargs.pop("dtype", None)
|
| 122 |
+
return RMSNorm(dim, eps=1e-8, dtype=torch.float, **kwargs)
|
| 123 |
+
else:
|
| 124 |
+
raise ValueError(f"Unknown norm type: {norm_type}")
|
| 125 |
+
|
| 126 |
+
|
| 127 |
+
def create_sin_embedding(
|
| 128 |
+
positions: torch.Tensor,
|
| 129 |
+
dim: int,
|
| 130 |
+
max_period: float = 10000,
|
| 131 |
+
dtype: torch.dtype = torch.float32,
|
| 132 |
+
) -> torch.Tensor:
|
| 133 |
+
"""Create sinusoidal positional embedding, with shape `[B, T, C]`.
|
| 134 |
+
|
| 135 |
+
Args:
|
| 136 |
+
positions (torch.Tensor): LongTensor of positions.
|
| 137 |
+
dim (int): Dimension of the embedding.
|
| 138 |
+
max_period (float): Maximum period of the cosine/sine functions.
|
| 139 |
+
dtype (torch.dtype or str): dtype to use to generate the embedding.
|
| 140 |
+
Returns:
|
| 141 |
+
torch.Tensor: Sinusoidal positional embedding.
|
| 142 |
+
"""
|
| 143 |
+
# We aim for BTC format
|
| 144 |
+
assert dim % 2 == 0
|
| 145 |
+
half_dim = dim // 2
|
| 146 |
+
positions = positions.to(dtype)
|
| 147 |
+
adim = torch.arange(half_dim, device=positions.device, dtype=dtype).view(1, 1, -1)
|
| 148 |
+
max_period_tensor = torch.full(
|
| 149 |
+
[], max_period, device=positions.device, dtype=dtype
|
| 150 |
+
) # avoid sync point
|
| 151 |
+
phase = positions / (max_period_tensor ** (adim / (half_dim - 1)))
|
| 152 |
+
return torch.cat([torch.cos(phase), torch.sin(phase)], dim=-1)
|
| 153 |
+
|
| 154 |
+
|
| 155 |
+
def multi_linear(
|
| 156 |
+
num_linear: int,
|
| 157 |
+
weight: torch.Tensor,
|
| 158 |
+
x: torch.Tensor,
|
| 159 |
+
offset: int,
|
| 160 |
+
):
|
| 161 |
+
"""Utility to apply a multi linear layer to the given input. A multi linear layer
|
| 162 |
+
applies a different set of weight for each time step.
|
| 163 |
+
|
| 164 |
+
Args:
|
| 165 |
+
num_linear (int): Number of possible time steps and so number of linears.
|
| 166 |
+
weight (torch.Tensor): Weight tensor, with shape `[num_linear * chout, chin]`.
|
| 167 |
+
x (torch.Tensor): Input tensor, with shape `[B, T, C]`.
|
| 168 |
+
offset (int): offset for the current time step, in particular for decoding, with
|
| 169 |
+
time steps provided one by one.
|
| 170 |
+
"""
|
| 171 |
+
B, T, C = x.shape
|
| 172 |
+
ys = []
|
| 173 |
+
chout, chin = weight.shape
|
| 174 |
+
weight = weight.view(num_linear, -1, chin)
|
| 175 |
+
for t in range(T):
|
| 176 |
+
y = F.linear(x[:, t], weight[t + offset])
|
| 177 |
+
ys.append(y)
|
| 178 |
+
out = torch.stack(ys, 1)
|
| 179 |
+
return out
|
| 180 |
+
|
| 181 |
+
|
| 182 |
+
def set_attention_context(model: nn.Module, context: tp.Optional[int] = None) -> None:
|
| 183 |
+
"""Deactivates or changes the context span (in time steps) in a model.
|
| 184 |
+
Args:
|
| 185 |
+
model (nn.Module): model over which to look for attentions.
|
| 186 |
+
context (int or None): new temporary context value.
|
| 187 |
+
|
| 188 |
+
..Note:: this is not a context manager but a plain function changing the context forever.
|
| 189 |
+
Initially, it was a context manager, but that led to interesting bugs when using
|
| 190 |
+
activation checkpointing, with the context being inconsistent between the forward
|
| 191 |
+
and backward.
|
| 192 |
+
"""
|
| 193 |
+
for module in model.modules():
|
| 194 |
+
if isinstance(module, StreamingMultiheadAttention):
|
| 195 |
+
module.context = context
|
| 196 |
+
|
| 197 |
+
|
| 198 |
+
class KVCacheResult(tp.NamedTuple):
|
| 199 |
+
keys: torch.Tensor
|
| 200 |
+
values: torch.Tensor
|
| 201 |
+
positions: torch.Tensor
|
| 202 |
+
|
| 203 |
+
@staticmethod
|
| 204 |
+
def from_kv(keys: torch.Tensor, values: torch.Tensor) -> "KVCacheResult":
|
| 205 |
+
B, H, T, D = keys.shape
|
| 206 |
+
assert tuple(values.shape[:-1]) == (B, H, T)
|
| 207 |
+
positions = torch.arange(T, device=keys.device, dtype=torch.long)
|
| 208 |
+
return KVCacheResult(keys, values, positions)
|
| 209 |
+
|
| 210 |
+
|
| 211 |
+
class RingKVCache:
|
| 212 |
+
"""Efficient streaming KVCache to be compatible with Cuda Graph.
|
| 213 |
+
|
| 214 |
+
Args:
|
| 215 |
+
batch_size (int): Batch size.
|
| 216 |
+
num_heads (int): Number of heads in the attention.
|
| 217 |
+
dim_per_head (int): Dimension per head.
|
| 218 |
+
device (torch.device): Device on which to initialize the cache.
|
| 219 |
+
dtype (torch.dtype): dtype to use for the cache.
|
| 220 |
+
"""
|
| 221 |
+
|
| 222 |
+
def __init__(
|
| 223 |
+
self,
|
| 224 |
+
batch_size: int,
|
| 225 |
+
num_heads: int,
|
| 226 |
+
dim_per_head: int,
|
| 227 |
+
capacity: int,
|
| 228 |
+
device: torch.device = torch.device("cuda"),
|
| 229 |
+
dtype: torch.dtype = torch.bfloat16,
|
| 230 |
+
):
|
| 231 |
+
self.capacity = capacity
|
| 232 |
+
self.cache = torch.zeros(
|
| 233 |
+
(2, batch_size, num_heads, capacity, dim_per_head),
|
| 234 |
+
device=device,
|
| 235 |
+
dtype=dtype,
|
| 236 |
+
)
|
| 237 |
+
self.end_offset = torch.zeros(1, device=device, dtype=torch.long)
|
| 238 |
+
|
| 239 |
+
def reset(self):
|
| 240 |
+
self.end_offset.zero_()
|
| 241 |
+
|
| 242 |
+
def complete(self, k: torch.Tensor, v: torch.Tensor) -> KVCacheResult:
|
| 243 |
+
assert k.shape[:-1] == v.shape[:-1], (k.shape, v.shape)
|
| 244 |
+
B, H, T, D = k.shape
|
| 245 |
+
indexes = torch.arange(T, device=self.end_offset.device, dtype=self.end_offset.dtype) + self.end_offset
|
| 246 |
+
indexes = indexes % self.capacity
|
| 247 |
+
self.cache[0].index_copy_(2, indexes, k)
|
| 248 |
+
self.cache[1].index_copy_(2, indexes, v)
|
| 249 |
+
self.end_offset.add_(T)
|
| 250 |
+
|
| 251 |
+
keys = self.cache[0]
|
| 252 |
+
values = self.cache[1]
|
| 253 |
+
|
| 254 |
+
indexes = torch.arange(
|
| 255 |
+
self.capacity, device=self.end_offset.device, dtype=torch.long
|
| 256 |
+
)
|
| 257 |
+
invalid = indexes >= self.end_offset
|
| 258 |
+
|
| 259 |
+
end_index = self.end_offset % self.capacity
|
| 260 |
+
delta = indexes - end_index
|
| 261 |
+
|
| 262 |
+
# If last key is for step S, and capacity is C, last key was written at index S % C.
|
| 263 |
+
# then end_offset = S + 1, and end_index = (S + 1) % C.
|
| 264 |
+
# Then for index = (S % C), delta = -1, and the next code gives us:
|
| 265 |
+
# position(index) = (S + 1) - 1 = S, all good.
|
| 266 |
+
# Now the time step at end_offset is actually the oldest in the KVCache, e.g., its
|
| 267 |
+
# position should be (S - self.capacity + 1).
|
| 268 |
+
# The following code gives us:
|
| 269 |
+
# position(index + 1) = S + 1 + 0 - self.capacity.
|
| 270 |
+
|
| 271 |
+
positions = torch.where(
|
| 272 |
+
delta <= 0,
|
| 273 |
+
self.end_offset + delta,
|
| 274 |
+
self.end_offset + delta - self.capacity,
|
| 275 |
+
)
|
| 276 |
+
positions = torch.where(invalid, torch.full_like(positions, -1), positions)
|
| 277 |
+
|
| 278 |
+
return KVCacheResult(keys, values, positions)
|
| 279 |
+
|
| 280 |
+
|
| 281 |
+
@dataclass
|
| 282 |
+
class _MHAState:
|
| 283 |
+
kv_cache: RingKVCache
|
| 284 |
+
offset: torch.Tensor
|
| 285 |
+
offset_cpu: int
|
| 286 |
+
|
| 287 |
+
def reset(self):
|
| 288 |
+
self.kv_cache.reset()
|
| 289 |
+
self.offset.zero_()
|
| 290 |
+
self.offset_cpu = 0
|
| 291 |
+
|
| 292 |
+
|
| 293 |
+
class StreamingMultiheadAttention(StreamingModule[_MHAState]):
|
| 294 |
+
"""Similar to `nn.MultiheadAttention` but with support for streaming, causal evaluation.
|
| 295 |
+
|
| 296 |
+
Args:
|
| 297 |
+
embed_dim (int): Dimension to project to.
|
| 298 |
+
num_heads (int): Number of heads.
|
| 299 |
+
causal (bool): Causal mask applied automatically.
|
| 300 |
+
context (int, optional): Number of time steps the attention can access to.
|
| 301 |
+
When causal, can access `context` time steps into the past, and when non causal,
|
| 302 |
+
can access `context // 2` steps in the past, and the same in the future.
|
| 303 |
+
rope (`RotaryEmbedding`, optional): Rope embedding to use.
|
| 304 |
+
weights_per_step (int): use different weights per time step. If non zero, should correspond to the
|
| 305 |
+
number of possible time steps.
|
| 306 |
+
device (torch.device, optional): Device on which to initialize.
|
| 307 |
+
dtype (torch.dtype, optional): dtype to use.
|
| 308 |
+
"""
|
| 309 |
+
|
| 310 |
+
_fsdp_final = True
|
| 311 |
+
|
| 312 |
+
def __init__(
|
| 313 |
+
self,
|
| 314 |
+
embed_dim: int,
|
| 315 |
+
num_heads: int,
|
| 316 |
+
causal: bool = False,
|
| 317 |
+
context: tp.Optional[int] = None,
|
| 318 |
+
rope: tp.Optional[RotaryEmbedding] = None,
|
| 319 |
+
weights_per_step: int = 0,
|
| 320 |
+
device=None,
|
| 321 |
+
dtype=None,
|
| 322 |
+
):
|
| 323 |
+
super().__init__()
|
| 324 |
+
factory_kwargs = {"device": device, "dtype": dtype}
|
| 325 |
+
|
| 326 |
+
self.embed_dim = embed_dim
|
| 327 |
+
self.causal = causal
|
| 328 |
+
self.context = context
|
| 329 |
+
self.rope = rope
|
| 330 |
+
self.num_heads = num_heads
|
| 331 |
+
|
| 332 |
+
out_dim = embed_dim
|
| 333 |
+
out_dim = 3 * embed_dim
|
| 334 |
+
mult = 1
|
| 335 |
+
self.weights_per_step = weights_per_step
|
| 336 |
+
if weights_per_step:
|
| 337 |
+
mult = weights_per_step
|
| 338 |
+
in_proj = nn.Linear(embed_dim, mult * out_dim, bias=False, **factory_kwargs)
|
| 339 |
+
# We try to follow the default PyTorch MHA convention, to easily compare results.
|
| 340 |
+
self.in_proj_weight = in_proj.weight
|
| 341 |
+
self.in_proj_bias = in_proj.bias
|
| 342 |
+
self.out_proj = nn.Linear(
|
| 343 |
+
embed_dim, mult * embed_dim, bias=False, **factory_kwargs
|
| 344 |
+
)
|
| 345 |
+
|
| 346 |
+
def _init_streaming_state(self, batch_size: int) -> _MHAState:
|
| 347 |
+
if self.context is None:
|
| 348 |
+
if self.weights_per_step:
|
| 349 |
+
capacity = self.weights_per_step
|
| 350 |
+
else:
|
| 351 |
+
raise RuntimeError(
|
| 352 |
+
"Cannot create a streaming KVCache without a context to estimate capacity."
|
| 353 |
+
)
|
| 354 |
+
else:
|
| 355 |
+
capacity = self.context
|
| 356 |
+
device = self.in_proj_weight.device
|
| 357 |
+
# TODO: the following estimation will not work great with FSDP.
|
| 358 |
+
dtype = self.in_proj_weight.dtype
|
| 359 |
+
dim_per_head = self.embed_dim // self.num_heads
|
| 360 |
+
kv_cache = RingKVCache(
|
| 361 |
+
batch_size, self.num_heads, dim_per_head, capacity, device, dtype
|
| 362 |
+
)
|
| 363 |
+
return _MHAState(
|
| 364 |
+
kv_cache,
|
| 365 |
+
offset=torch.zeros(1, device=device, dtype=torch.long),
|
| 366 |
+
offset_cpu=0,
|
| 367 |
+
)
|
| 368 |
+
|
| 369 |
+
def _complete_kv(self, k, v) -> KVCacheResult:
|
| 370 |
+
state = self._streaming_state
|
| 371 |
+
if state is None:
|
| 372 |
+
return KVCacheResult.from_kv(k, v)
|
| 373 |
+
else:
|
| 374 |
+
return state.kv_cache.complete(k, v)
|
| 375 |
+
|
| 376 |
+
def forward(self, query: torch.Tensor, key: torch.Tensor, value: torch.Tensor):
|
| 377 |
+
state = self._streaming_state
|
| 378 |
+
T = query.shape[1]
|
| 379 |
+
|
| 380 |
+
if state is None:
|
| 381 |
+
offset = torch.zeros(1, device=query.device, dtype=torch.long)
|
| 382 |
+
offset_cpu = 0
|
| 383 |
+
else:
|
| 384 |
+
assert self.causal, "Streaming only available for causal"
|
| 385 |
+
offset = state.offset
|
| 386 |
+
offset_cpu = state.offset_cpu
|
| 387 |
+
|
| 388 |
+
if self.weights_per_step:
|
| 389 |
+
projected = multi_linear(
|
| 390 |
+
self.weights_per_step, self.in_proj_weight, query, offset_cpu
|
| 391 |
+
)
|
| 392 |
+
else:
|
| 393 |
+
projected = nn.functional.linear(query, self.in_proj_weight)
|
| 394 |
+
q, k, v = rearrange(
|
| 395 |
+
projected, "b t (p h d) -> p b h t d", p=3, h=self.num_heads
|
| 396 |
+
)
|
| 397 |
+
|
| 398 |
+
if self.rope:
|
| 399 |
+
q, k = self.rope(q, k, offset, time_before_heads=False)
|
| 400 |
+
|
| 401 |
+
k, v, pos_k = self._complete_kv(k, v)
|
| 402 |
+
if self.causal:
|
| 403 |
+
pos_k = pos_k.view(1, -1)
|
| 404 |
+
pos_q = offset + torch.arange(T, device=q.device, dtype=torch.long).view(
|
| 405 |
+
-1, 1
|
| 406 |
+
)
|
| 407 |
+
delta = pos_q - pos_k
|
| 408 |
+
attn_bias = (pos_k >= 0) & (delta >= 0)
|
| 409 |
+
if self.context is not None:
|
| 410 |
+
attn_bias = attn_bias & (delta < self.context)
|
| 411 |
+
else:
|
| 412 |
+
attn_bias = None
|
| 413 |
+
x = F.scaled_dot_product_attention(q, k, v, attn_bias, dropout_p=0.0)
|
| 414 |
+
|
| 415 |
+
x = rearrange(x, "b h t d -> b t (h d)")
|
| 416 |
+
if self.weights_per_step:
|
| 417 |
+
x = multi_linear(self.weights_per_step, self.out_proj.weight, x, offset_cpu)
|
| 418 |
+
else:
|
| 419 |
+
x = self.out_proj(x)
|
| 420 |
+
if state is not None:
|
| 421 |
+
state.offset.add_(T)
|
| 422 |
+
state.offset_cpu += T
|
| 423 |
+
return x
|
| 424 |
+
|
| 425 |
+
|
| 426 |
+
@dataclass
|
| 427 |
+
class _LayerState:
|
| 428 |
+
offset_cpu: int
|
| 429 |
+
|
| 430 |
+
def reset(self):
|
| 431 |
+
self.offset_cpu = 0
|
| 432 |
+
|
| 433 |
+
|
| 434 |
+
class StreamingTransformerLayer(StreamingModule[_LayerState]):
|
| 435 |
+
"""TransformerLayer with Streaming / Causal support.
|
| 436 |
+
|
| 437 |
+
Args:
|
| 438 |
+
d_model (int): Dimension of the data.
|
| 439 |
+
num_heads (int): Number of heads.
|
| 440 |
+
dim_feedforward (int): Intermediate dimension of FF module.
|
| 441 |
+
causal (bool): Causal mask applied automatically.
|
| 442 |
+
context (int, optional): Receptive field for the causal mask, infinite if None.
|
| 443 |
+
custom (bool): Use custom MHA implementation, for testing / benchmarking.
|
| 444 |
+
rope (`RotaryEmbedding`, optional): Rope embedding to use.
|
| 445 |
+
norm (str): Normalization to use. Currently, only 'layer_norm' is supported.
|
| 446 |
+
layer_scale (float, optional): If not None, LayerScale will be used with the given value as initial scale.
|
| 447 |
+
gating (str): if provided, replaces FFN with special gating, like GLU, GSiGLU etc.
|
| 448 |
+
weights_per_step (int): use different weights per time step. If non zero, should correspond to the
|
| 449 |
+
number of possible time steps.
|
| 450 |
+
skip_self_attn: If true, skips the self attention module and the norm
|
| 451 |
+
device (torch.device, optional): Device on which to initialize.
|
| 452 |
+
dtype (torch.dtype, optional): dtype to use.
|
| 453 |
+
"""
|
| 454 |
+
|
| 455 |
+
_fsdp_final = True
|
| 456 |
+
|
| 457 |
+
def __init__(
|
| 458 |
+
self,
|
| 459 |
+
d_model: int,
|
| 460 |
+
num_heads: int,
|
| 461 |
+
dim_feedforward: int | list[int] = 2048,
|
| 462 |
+
causal: bool = False,
|
| 463 |
+
context: tp.Optional[int] = None,
|
| 464 |
+
rope: tp.Optional[RotaryEmbedding] = None,
|
| 465 |
+
norm: str = "layer_norm",
|
| 466 |
+
layer_scale: tp.Optional[float] = None,
|
| 467 |
+
gating: str = "none",
|
| 468 |
+
weights_per_step: int = 0,
|
| 469 |
+
activation=F.gelu,
|
| 470 |
+
skip_self_attn: bool = False,
|
| 471 |
+
device=None,
|
| 472 |
+
dtype=None,
|
| 473 |
+
):
|
| 474 |
+
super().__init__()
|
| 475 |
+
factory_kwargs = {"device": device, "dtype": dtype}
|
| 476 |
+
# Redefine self_attn to our streaming multi-head attention
|
| 477 |
+
attn_kwargs: tp.Dict[str, tp.Any] = {
|
| 478 |
+
"embed_dim": d_model,
|
| 479 |
+
"num_heads": num_heads,
|
| 480 |
+
}
|
| 481 |
+
if not skip_self_attn:
|
| 482 |
+
self.self_attn: StreamingMultiheadAttention = StreamingMultiheadAttention(
|
| 483 |
+
causal=causal,
|
| 484 |
+
context=context,
|
| 485 |
+
rope=rope,
|
| 486 |
+
weights_per_step=weights_per_step,
|
| 487 |
+
**attn_kwargs, # type: ignore
|
| 488 |
+
**factory_kwargs, # type: ignore
|
| 489 |
+
) # type: ignore
|
| 490 |
+
self.norm1 = create_norm_fn(norm, d_model, **factory_kwargs)
|
| 491 |
+
self.norm2 = create_norm_fn(norm, d_model, **factory_kwargs)
|
| 492 |
+
# Redefine feedforward layers to expose bias parameter
|
| 493 |
+
self.weights_per_step = weights_per_step
|
| 494 |
+
self.gating: tp.Optional[nn.Module] = None
|
| 495 |
+
self.linear1: tp.Optional[nn.Module] = None
|
| 496 |
+
self.linear2: tp.Optional[nn.Module] = None
|
| 497 |
+
self.activation = activation
|
| 498 |
+
self.skip_self_attn = skip_self_attn
|
| 499 |
+
|
| 500 |
+
if isinstance(dim_feedforward, list):
|
| 501 |
+
assert dim_feedforward
|
| 502 |
+
assert len(dim_feedforward) == weights_per_step, (
|
| 503 |
+
"Length of dim_feedforward must match weights_per_step,"
|
| 504 |
+
f" got {len(dim_feedforward)} != {weights_per_step}"
|
| 505 |
+
)
|
| 506 |
+
if gating == "none":
|
| 507 |
+
assert (
|
| 508 |
+
not weights_per_step
|
| 509 |
+
), "weights_per_step without gating not supported for now."
|
| 510 |
+
assert not isinstance(
|
| 511 |
+
dim_feedforward, list
|
| 512 |
+
), "List dim_feedforward without gating not supported for now."
|
| 513 |
+
self.linear1 = nn.Linear(
|
| 514 |
+
d_model, dim_feedforward, bias=False, **factory_kwargs
|
| 515 |
+
)
|
| 516 |
+
self.linear2 = nn.Linear(
|
| 517 |
+
dim_feedforward, d_model, bias=False, **factory_kwargs
|
| 518 |
+
)
|
| 519 |
+
else:
|
| 520 |
+
self.linear1 = None
|
| 521 |
+
self.linear2 = None
|
| 522 |
+
if weights_per_step:
|
| 523 |
+
if isinstance(dim_feedforward, int):
|
| 524 |
+
dim_feedforward = [dim_feedforward] * weights_per_step
|
| 525 |
+
assert isinstance(dim_feedforward, list), dim_feedforward
|
| 526 |
+
self.gating = nn.ModuleList(
|
| 527 |
+
[
|
| 528 |
+
make_gating(gating, d_model, dim, **factory_kwargs)
|
| 529 |
+
for dim in dim_feedforward
|
| 530 |
+
]
|
| 531 |
+
)
|
| 532 |
+
else:
|
| 533 |
+
assert isinstance(dim_feedforward, int)
|
| 534 |
+
self.gating = make_gating(
|
| 535 |
+
gating, d_model, dim_feedforward, **factory_kwargs
|
| 536 |
+
)
|
| 537 |
+
|
| 538 |
+
self.layer_scale_1: nn.Module
|
| 539 |
+
self.layer_scale_2: nn.Module
|
| 540 |
+
if layer_scale is None:
|
| 541 |
+
self.layer_scale_1 = nn.Identity()
|
| 542 |
+
self.layer_scale_2 = nn.Identity()
|
| 543 |
+
else:
|
| 544 |
+
self.layer_scale_1 = LayerScale(d_model, layer_scale, **factory_kwargs) # type: ignore
|
| 545 |
+
self.layer_scale_2 = LayerScale(d_model, layer_scale, **factory_kwargs) # type: ignore
|
| 546 |
+
|
| 547 |
+
def _init_streaming_state(self, batch_size: int) -> _LayerState:
|
| 548 |
+
return _LayerState(offset_cpu=0)
|
| 549 |
+
|
| 550 |
+
# feed forward block
|
| 551 |
+
def _ff_block(self, x: torch.Tensor) -> torch.Tensor:
|
| 552 |
+
state = self._streaming_state
|
| 553 |
+
offset = 0
|
| 554 |
+
if state is not None:
|
| 555 |
+
offset = state.offset_cpu
|
| 556 |
+
x_orig = x
|
| 557 |
+
x = self.norm2(x)
|
| 558 |
+
if self.gating is None:
|
| 559 |
+
assert self.linear1 is not None
|
| 560 |
+
assert self.linear2 is not None
|
| 561 |
+
update = self.linear2(self.activation(self.linear1(x)))
|
| 562 |
+
else:
|
| 563 |
+
if self.weights_per_step:
|
| 564 |
+
assert isinstance(self.gating, nn.ModuleList)
|
| 565 |
+
B, T, D = x.shape
|
| 566 |
+
ys = []
|
| 567 |
+
for t in range(T):
|
| 568 |
+
y = self.gating[offset + t](x[:, t : t + 1])
|
| 569 |
+
ys.append(y)
|
| 570 |
+
update = torch.cat(ys, dim=1)
|
| 571 |
+
else:
|
| 572 |
+
update = self.gating(x)
|
| 573 |
+
return x_orig + self.layer_scale_2(update)
|
| 574 |
+
|
| 575 |
+
def _sa_block(self, x: torch.Tensor):
|
| 576 |
+
if self.skip_self_attn:
|
| 577 |
+
return x
|
| 578 |
+
x_orig = x
|
| 579 |
+
x = self.norm1(x)
|
| 580 |
+
update = self.self_attn(x, x, x)
|
| 581 |
+
return x_orig + self.layer_scale_1(update)
|
| 582 |
+
|
| 583 |
+
def forward(self, x: torch.Tensor):
|
| 584 |
+
with ExitStack() as stack:
|
| 585 |
+
if x.device.type != 'cuda':
|
| 586 |
+
stack.enter_context(no_compile())
|
| 587 |
+
x = self._sa_block(x)
|
| 588 |
+
x = self._ff_block(x)
|
| 589 |
+
state = self._streaming_state
|
| 590 |
+
if state:
|
| 591 |
+
state.offset_cpu += x.shape[1]
|
| 592 |
+
return x
|
| 593 |
+
|
| 594 |
+
|
| 595 |
+
@dataclass
|
| 596 |
+
class _TransformerState:
|
| 597 |
+
offset: torch.Tensor
|
| 598 |
+
|
| 599 |
+
def reset(self):
|
| 600 |
+
self.offset.zero_()
|
| 601 |
+
|
| 602 |
+
|
| 603 |
+
class StreamingTransformer(StreamingModule[_TransformerState]):
|
| 604 |
+
"""Transformer with Streaming / Causal support.
|
| 605 |
+
|
| 606 |
+
Args:
|
| 607 |
+
d_model (int): Dimension of the data.
|
| 608 |
+
num_heads (int): Number of heads.
|
| 609 |
+
dim_feedforward (int): Intermediate dimension of FF module.
|
| 610 |
+
causal (bool): Causal mask applied automatically.
|
| 611 |
+
context (int, optional): Receptive field for the causal mask, infinite if None.
|
| 612 |
+
layer_scale (float, optional): If not None, LayerScale will be used
|
| 613 |
+
with the given value as initial scale.
|
| 614 |
+
positional_embedding (str): Positional embedding strategy (sin, rope, sin_rope, or none).
|
| 615 |
+
max_period (float): Maximum period of the time embedding.
|
| 616 |
+
positional_scale (float): Scale of positional embedding, set to 0 to deactivate.
|
| 617 |
+
layer_class: (subclass of `StreamingTransformerLayer): class to use
|
| 618 |
+
to initialize the layers, allowing further customization outside of AudioCraft.
|
| 619 |
+
device (torch.device, optional): Device on which to initialize.
|
| 620 |
+
dtype (torch.dtype, optional): dtype to use.
|
| 621 |
+
**kwargs: See `StreamingTransformerLayer`.
|
| 622 |
+
"""
|
| 623 |
+
|
| 624 |
+
def __init__(
|
| 625 |
+
self,
|
| 626 |
+
d_model: int,
|
| 627 |
+
num_heads: int,
|
| 628 |
+
num_layers: int,
|
| 629 |
+
dim_feedforward: int | list[int] = 2048,
|
| 630 |
+
causal: bool = False,
|
| 631 |
+
context: tp.Optional[int] = None,
|
| 632 |
+
positional_embedding: str = "sin",
|
| 633 |
+
max_period: float = 10_000,
|
| 634 |
+
positional_scale: float = 1.0,
|
| 635 |
+
betas: tp.Optional[tp.Tuple[float, float]] = None,
|
| 636 |
+
layer_class: tp.Type[StreamingTransformerLayer] = StreamingTransformerLayer,
|
| 637 |
+
device=None,
|
| 638 |
+
dtype=None,
|
| 639 |
+
**kwargs,
|
| 640 |
+
):
|
| 641 |
+
super().__init__()
|
| 642 |
+
assert d_model % num_heads == 0
|
| 643 |
+
|
| 644 |
+
self.positional_embedding = positional_embedding
|
| 645 |
+
self.max_period = max_period
|
| 646 |
+
self.positional_scale = positional_scale
|
| 647 |
+
self.betas = betas
|
| 648 |
+
|
| 649 |
+
assert positional_embedding in {"sin", "rope", "sin_rope", "none"}
|
| 650 |
+
self.rope: tp.Optional[RotaryEmbedding] = None
|
| 651 |
+
if self.positional_embedding in {"rope", "sin_rope"}:
|
| 652 |
+
self.rope = RotaryEmbedding(max_period=max_period)
|
| 653 |
+
|
| 654 |
+
self.layers = nn.ModuleList()
|
| 655 |
+
for _ in range(num_layers):
|
| 656 |
+
self.layers.append(
|
| 657 |
+
layer_class(
|
| 658 |
+
d_model=d_model,
|
| 659 |
+
num_heads=num_heads,
|
| 660 |
+
dim_feedforward=dim_feedforward,
|
| 661 |
+
causal=causal,
|
| 662 |
+
context=context,
|
| 663 |
+
rope=self.rope,
|
| 664 |
+
device=device,
|
| 665 |
+
dtype=dtype,
|
| 666 |
+
**kwargs,
|
| 667 |
+
)
|
| 668 |
+
)
|
| 669 |
+
|
| 670 |
+
def _init_streaming_state(self, batch_size: int) -> _TransformerState:
|
| 671 |
+
device = next(self.parameters()).device
|
| 672 |
+
return _TransformerState(offset=torch.zeros(1, device=device, dtype=torch.long))
|
| 673 |
+
|
| 674 |
+
def forward(self, x: torch.Tensor, *args, **kwargs):
|
| 675 |
+
B, T, C = x.shape
|
| 676 |
+
|
| 677 |
+
state = self._streaming_state
|
| 678 |
+
if state is None:
|
| 679 |
+
offset = torch.zeros(1, dtype=torch.long, device=x.device)
|
| 680 |
+
else:
|
| 681 |
+
offset = state.offset
|
| 682 |
+
|
| 683 |
+
if self.positional_embedding in {"sin", "sin_rope"}:
|
| 684 |
+
positions = torch.arange(T, device=x.device).view(1, -1, 1)
|
| 685 |
+
positions = positions + offset.view(-1, 1, 1)
|
| 686 |
+
pos_emb = create_sin_embedding(
|
| 687 |
+
positions, C, max_period=self.max_period, dtype=x.dtype
|
| 688 |
+
)
|
| 689 |
+
x = x + self.positional_scale * pos_emb
|
| 690 |
+
|
| 691 |
+
for layer in self.layers:
|
| 692 |
+
x = layer(x, *args, **kwargs)
|
| 693 |
+
|
| 694 |
+
if state is not None:
|
| 695 |
+
state.offset.add_(T)
|
| 696 |
+
return x
|
| 697 |
+
|
| 698 |
+
|
| 699 |
+
class ProjectedTransformer(StreamingContainer):
|
| 700 |
+
"""Transformer with optional projections of the input and output to different dimensions when needed.
|
| 701 |
+
Supports multiple outputs.
|
| 702 |
+
|
| 703 |
+
Args:
|
| 704 |
+
input_dimension (int): dimension of the input.
|
| 705 |
+
output_dimensions (tuple[int]): dimensions of the outputs.
|
| 706 |
+
d_model (int): inner dimension of the Transformer.
|
| 707 |
+
conv_layout (bool): If True, expects `[B, C, T]` shaped tensors, otherwise, `[B, T, C]`.
|
| 708 |
+
Similarly, the output will have the same layout.
|
| 709 |
+
"""
|
| 710 |
+
|
| 711 |
+
def __init__(
|
| 712 |
+
self,
|
| 713 |
+
input_dimension: int,
|
| 714 |
+
output_dimensions: tp.Tuple[int, ...],
|
| 715 |
+
d_model: int,
|
| 716 |
+
*,
|
| 717 |
+
conv_layout: bool = False,
|
| 718 |
+
**kwargs,
|
| 719 |
+
):
|
| 720 |
+
super().__init__()
|
| 721 |
+
self.transformer = StreamingTransformer(d_model=d_model, **kwargs)
|
| 722 |
+
self.input_dimension = input_dimension
|
| 723 |
+
self.output_dimensions = output_dimensions
|
| 724 |
+
self.conv_layout = conv_layout
|
| 725 |
+
self.input_proj = None
|
| 726 |
+
if d_model != input_dimension:
|
| 727 |
+
self.input_proj = nn.Linear(input_dimension, d_model, bias=False)
|
| 728 |
+
|
| 729 |
+
self.output_projs = nn.ModuleList()
|
| 730 |
+
for output_dimension in output_dimensions:
|
| 731 |
+
if d_model == output_dimension:
|
| 732 |
+
self.output_projs.append(nn.Identity())
|
| 733 |
+
else:
|
| 734 |
+
self.output_projs.append(
|
| 735 |
+
nn.Linear(d_model, output_dimension, bias=False)
|
| 736 |
+
)
|
| 737 |
+
|
| 738 |
+
def forward(self, x, *args, **kwargs):
|
| 739 |
+
if self.conv_layout:
|
| 740 |
+
x = x.transpose(1, 2)
|
| 741 |
+
if self.input_proj is not None:
|
| 742 |
+
x = self.input_proj(x)
|
| 743 |
+
z = self.transformer(x, *args, **kwargs)
|
| 744 |
+
ys = []
|
| 745 |
+
for output_proj in self.output_projs:
|
| 746 |
+
y = output_proj(z)
|
| 747 |
+
if self.conv_layout:
|
| 748 |
+
y = y.transpose(1, 2)
|
| 749 |
+
ys.append(y)
|
| 750 |
+
return ys
|
moshi/quantization/__init__.py
ADDED
|
@@ -0,0 +1,13 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) Kyutai, all rights reserved.
|
| 2 |
+
# This source code is licensed under the license found in the
|
| 3 |
+
# LICENSE file in the root directory of this source tree.
|
| 4 |
+
|
| 5 |
+
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
| 6 |
+
# All rights reserved.
|
| 7 |
+
#
|
| 8 |
+
# This source code is licensed under the license found in the
|
| 9 |
+
# LICENSE file in the root directory of this source tree.
|
| 10 |
+
"""RVQ."""
|
| 11 |
+
# flake8: noqa
|
| 12 |
+
from .vq import ResidualVectorQuantizer, SplitResidualVectorQuantizer
|
| 13 |
+
from .base import BaseQuantizer, DummyQuantizer, QuantizedResult
|
moshi/quantization/base.py
ADDED
|
@@ -0,0 +1,170 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) Kyutai, all rights reserved.
|
| 2 |
+
# This source code is licensed under the license found in the
|
| 3 |
+
# LICENSE file in the root directory of this source tree.
|
| 4 |
+
|
| 5 |
+
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
| 6 |
+
# All rights reserved.
|
| 7 |
+
#
|
| 8 |
+
# This source code is licensed under the license found in the
|
| 9 |
+
# LICENSE file in the root directory of this source tree.
|
| 10 |
+
|
| 11 |
+
"""
|
| 12 |
+
Base class for all quantizers.
|
| 13 |
+
"""
|
| 14 |
+
|
| 15 |
+
from dataclasses import dataclass, field
|
| 16 |
+
import typing as tp
|
| 17 |
+
|
| 18 |
+
import torch
|
| 19 |
+
from torch import nn
|
| 20 |
+
|
| 21 |
+
|
| 22 |
+
@dataclass
|
| 23 |
+
class QuantizedResult:
|
| 24 |
+
x: torch.Tensor
|
| 25 |
+
codes: torch.Tensor
|
| 26 |
+
bandwidth: torch.Tensor # bandwidth in kb/s used, per batch item.
|
| 27 |
+
penalty: tp.Optional[torch.Tensor] = None
|
| 28 |
+
metrics: dict = field(default_factory=dict)
|
| 29 |
+
|
| 30 |
+
|
| 31 |
+
class BaseQuantizer(nn.Module):
|
| 32 |
+
"""Base class for quantizers."""
|
| 33 |
+
|
| 34 |
+
def __init__(self):
|
| 35 |
+
super().__init__()
|
| 36 |
+
self._ema_frozen = False
|
| 37 |
+
|
| 38 |
+
def forward(self, x: torch.Tensor, frame_rate: int) -> QuantizedResult:
|
| 39 |
+
"""
|
| 40 |
+
Given input tensor x, returns first the quantized (or approximately quantized)
|
| 41 |
+
representation along with quantized codes, bandwidth, and any penalty term for the loss.
|
| 42 |
+
Finally, this returns a dict of metrics to update logging etc.
|
| 43 |
+
Frame rate must be passed so that the bandwidth is properly computed.
|
| 44 |
+
"""
|
| 45 |
+
raise NotImplementedError()
|
| 46 |
+
|
| 47 |
+
def encode(self, x: torch.Tensor) -> torch.Tensor:
|
| 48 |
+
"""Encode a given input tensor with the specified sample rate at the given bandwidth."""
|
| 49 |
+
raise NotImplementedError()
|
| 50 |
+
|
| 51 |
+
def decode(self, codes: torch.Tensor) -> torch.Tensor:
|
| 52 |
+
"""Decode the given codes to the quantized representation."""
|
| 53 |
+
raise NotImplementedError()
|
| 54 |
+
|
| 55 |
+
@property
|
| 56 |
+
def cardinality(self) -> int:
|
| 57 |
+
"""Cardinality of each codebook."""
|
| 58 |
+
raise NotImplementedError()
|
| 59 |
+
|
| 60 |
+
@property
|
| 61 |
+
def total_codebooks(self) -> int:
|
| 62 |
+
"""Total number of codebooks."""
|
| 63 |
+
raise NotImplementedError()
|
| 64 |
+
|
| 65 |
+
@property
|
| 66 |
+
def num_codebooks(self) -> int:
|
| 67 |
+
"""Number of active codebooks."""
|
| 68 |
+
raise NotImplementedError()
|
| 69 |
+
|
| 70 |
+
@property
|
| 71 |
+
def semantic_quantizer(self) -> 'BaseQuantizer':
|
| 72 |
+
"""This returns the quantizer that models the first level of the hierarchy (typically semantic).
|
| 73 |
+
|
| 74 |
+
In this case, it's the quantizer itself.
|
| 75 |
+
"""
|
| 76 |
+
return self
|
| 77 |
+
|
| 78 |
+
@property
|
| 79 |
+
def acoustic_quantizer(self) -> 'BaseQuantizer':
|
| 80 |
+
"""This returns the quantizer that models the higher levels of the hierarchy (typically acoustic).
|
| 81 |
+
|
| 82 |
+
In this case, it's the quantizer itself.
|
| 83 |
+
"""
|
| 84 |
+
return self
|
| 85 |
+
|
| 86 |
+
def set_num_codebooks(self, n: int) -> None:
|
| 87 |
+
"""Set the number of active codebooks."""
|
| 88 |
+
raise NotImplementedError()
|
| 89 |
+
|
| 90 |
+
@property
|
| 91 |
+
def ema_frozen(self) -> bool:
|
| 92 |
+
"""Whether to apply ema to the codebooks."""
|
| 93 |
+
return self._ema_frozen
|
| 94 |
+
|
| 95 |
+
def ema_frozen_(self, ema_frozen: bool) -> None:
|
| 96 |
+
"""Set whether ema should be applied to the codebooks."""
|
| 97 |
+
self._ema_frozen = ema_frozen
|
| 98 |
+
|
| 99 |
+
|
| 100 |
+
class DummyQuantizer(BaseQuantizer):
|
| 101 |
+
"""Fake quantizer that actually does not perform any quantization."""
|
| 102 |
+
|
| 103 |
+
def __init__(
|
| 104 |
+
self,
|
| 105 |
+
dimension: int,
|
| 106 |
+
input_dimension: tp.Optional[int] = None,
|
| 107 |
+
output_dimension: tp.Optional[int] = None,
|
| 108 |
+
):
|
| 109 |
+
super().__init__()
|
| 110 |
+
self.dimension = dimension
|
| 111 |
+
self.input_dimension = input_dimension or dimension
|
| 112 |
+
self.output_dimension = output_dimension or dimension
|
| 113 |
+
self.input_proj: torch.nn.Module
|
| 114 |
+
self.output_proj: torch.nn.Module
|
| 115 |
+
if self.input_dimension == self.dimension:
|
| 116 |
+
self.input_proj = torch.nn.Identity()
|
| 117 |
+
else:
|
| 118 |
+
self.input_proj = torch.nn.Conv1d(
|
| 119 |
+
self.input_dimension, self.dimension, 1, bias=False
|
| 120 |
+
)
|
| 121 |
+
if self.input_dimension == self.dimension:
|
| 122 |
+
self.output_proj = torch.nn.Identity()
|
| 123 |
+
else:
|
| 124 |
+
self.output_proj = torch.nn.Conv1d(
|
| 125 |
+
self.dimension, self.output_dimension, 1, bias=False
|
| 126 |
+
)
|
| 127 |
+
|
| 128 |
+
def forward(self, x: torch.Tensor, frame_rate: int):
|
| 129 |
+
q = x.unsqueeze(1)
|
| 130 |
+
x = self.output_proj(self.input_proj(x))
|
| 131 |
+
return QuantizedResult(
|
| 132 |
+
x, q, torch.tensor(q.numel() * 32 * frame_rate / 1000 / len(x)).to(x)
|
| 133 |
+
)
|
| 134 |
+
|
| 135 |
+
def encode(self, x: torch.Tensor) -> torch.Tensor:
|
| 136 |
+
"""Encode a given input tensor with the specified sample rate at the given bandwidth.
|
| 137 |
+
In the case of the DummyQuantizer, the codes are actually identical
|
| 138 |
+
to the input and resulting quantized representation as no quantization is done.
|
| 139 |
+
"""
|
| 140 |
+
x = self.input_proj(x)
|
| 141 |
+
return x.unsqueeze(1)
|
| 142 |
+
|
| 143 |
+
def decode(self, codes: torch.Tensor) -> torch.Tensor:
|
| 144 |
+
"""Decode the given codes to the quantized representation.
|
| 145 |
+
In the case of the DummyQuantizer, the codes are actually identical
|
| 146 |
+
to the input and resulting quantized representation as no quantization is done.
|
| 147 |
+
"""
|
| 148 |
+
y = codes.squeeze(1)
|
| 149 |
+
return self.output_proj(y)
|
| 150 |
+
|
| 151 |
+
@property
|
| 152 |
+
def total_codebooks(self):
|
| 153 |
+
"""Total number of codebooks."""
|
| 154 |
+
return 1
|
| 155 |
+
|
| 156 |
+
@property
|
| 157 |
+
def num_codebooks(self):
|
| 158 |
+
"""Total number of codebooks."""
|
| 159 |
+
return self.total_codebooks
|
| 160 |
+
|
| 161 |
+
def set_num_codebooks(self, n: int):
|
| 162 |
+
"""Set the number of active codebooks."""
|
| 163 |
+
raise AttributeError(
|
| 164 |
+
"Cannot override the number of codebooks for the dummy quantizer"
|
| 165 |
+
)
|
| 166 |
+
|
| 167 |
+
@property
|
| 168 |
+
def cardinality(self) -> int:
|
| 169 |
+
"""Cardinality of each codebook."""
|
| 170 |
+
return 1
|
moshi/quantization/core_vq.py
ADDED
|
@@ -0,0 +1,384 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) Kyutai, all rights reserved.
|
| 2 |
+
# This source code is licensed under the license found in the
|
| 3 |
+
# LICENSE file in the root directory of this source tree.
|
| 4 |
+
|
| 5 |
+
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
| 6 |
+
# All rights reserved.
|
| 7 |
+
#
|
| 8 |
+
# This source code is licensed under the license found in the
|
| 9 |
+
# LICENSE file in the root directory of this source tree.
|
| 10 |
+
|
| 11 |
+
import typing as tp
|
| 12 |
+
|
| 13 |
+
from einops import rearrange
|
| 14 |
+
import torch
|
| 15 |
+
from torch import nn
|
| 16 |
+
from torch import distributed
|
| 17 |
+
import torch.nn.functional as F
|
| 18 |
+
|
| 19 |
+
|
| 20 |
+
class _CodebookForwardResult(tp.NamedTuple):
|
| 21 |
+
quantized: torch.Tensor
|
| 22 |
+
codes: torch.Tensor
|
| 23 |
+
metrics: tp.Dict[str, torch.Tensor]
|
| 24 |
+
|
| 25 |
+
|
| 26 |
+
class _VQForwardResult(tp.NamedTuple):
|
| 27 |
+
quantized: torch.Tensor
|
| 28 |
+
codes: torch.Tensor
|
| 29 |
+
loss: torch.Tensor
|
| 30 |
+
metrics: tp.Dict[str, torch.Tensor]
|
| 31 |
+
|
| 32 |
+
|
| 33 |
+
def _ema_inplace(moving_avg: torch.Tensor, new: torch.Tensor, decay: float) -> None:
|
| 34 |
+
moving_avg.data.mul_(decay).add_(new, alpha=(1 - decay))
|
| 35 |
+
|
| 36 |
+
|
| 37 |
+
def _uniform_init(*shape: int) -> torch.Tensor:
|
| 38 |
+
t = torch.empty(shape)
|
| 39 |
+
nn.init.kaiming_uniform_(t)
|
| 40 |
+
return t
|
| 41 |
+
|
| 42 |
+
|
| 43 |
+
def _sample_vectors(samples: torch.Tensor, num: int) -> torch.Tensor:
|
| 44 |
+
num_samples, device = samples.shape[0], samples.device
|
| 45 |
+
|
| 46 |
+
if num_samples >= num:
|
| 47 |
+
indices = torch.randperm(num_samples, device=device)[:num]
|
| 48 |
+
else:
|
| 49 |
+
indices = torch.randint(0, num_samples, (num,), device=device)
|
| 50 |
+
|
| 51 |
+
return samples[indices]
|
| 52 |
+
|
| 53 |
+
|
| 54 |
+
def _compute_entropy(usage: torch.Tensor) -> torch.Tensor:
|
| 55 |
+
# Usage is some unnormalized distribution.
|
| 56 |
+
proba = usage / usage.sum()
|
| 57 |
+
p_log_p = torch.where(
|
| 58 |
+
proba == 0, zero_scalar(usage.device), proba * torch.log(proba)
|
| 59 |
+
)
|
| 60 |
+
return -p_log_p.sum()
|
| 61 |
+
|
| 62 |
+
|
| 63 |
+
def _is_distributed() -> bool:
|
| 64 |
+
# Checks if we need to use distributed routines.
|
| 65 |
+
return distributed.is_initialized() and distributed.get_world_size() > 1
|
| 66 |
+
|
| 67 |
+
|
| 68 |
+
def zero_scalar(device) -> torch.Tensor:
|
| 69 |
+
"""Returns a 0. value on the given device without introducing a synchronization point."""
|
| 70 |
+
return torch.zeros([1], device=device)[0]
|
| 71 |
+
|
| 72 |
+
|
| 73 |
+
class EuclideanCodebook(nn.Module):
|
| 74 |
+
"""Codebook with Euclidean distance.
|
| 75 |
+
|
| 76 |
+
Args:
|
| 77 |
+
dim (int): Dimension.
|
| 78 |
+
codebook_size (int): Codebook size.
|
| 79 |
+
decay (float): Decay for exponential moving average over the codebooks.
|
| 80 |
+
epsilon (float): Epsilon value for numerical stability.
|
| 81 |
+
threshold_usage_ratio (float): Defines the threshold for the cluster usage under which a centroid
|
| 82 |
+
is replaced. This is expressed as a fraction of the usage a centroid would get under
|
| 83 |
+
a uniform distribution, so that it doesn't depend on the batch size etc.
|
| 84 |
+
replaced_usage_ratio (float): When replacing a centroid, use this as an initial centroid usage,
|
| 85 |
+
to avoid the centroid getting replaced too quickly.
|
| 86 |
+
check_unused_every (int): Check for unused centroids every `check_unused_every` iterations.
|
| 87 |
+
This is to avoid too many synchronization points.
|
| 88 |
+
|
| 89 |
+
Buffers:
|
| 90 |
+
cluster_usage (torch.Tensor): EMA of the cluster usage per batch, e.g. this will
|
| 91 |
+
be dependent on the batch size etc.
|
| 92 |
+
embedding_sum (torch.Tensor): EMA of the sum of the assigned points to each cluster.
|
| 93 |
+
In particular, this can be normalized by `cluster_usage` to obtain the
|
| 94 |
+
actual cluster centroids.
|
| 95 |
+
"""
|
| 96 |
+
|
| 97 |
+
def __init__(
|
| 98 |
+
self,
|
| 99 |
+
dim: int,
|
| 100 |
+
codebook_size: int,
|
| 101 |
+
decay: float = 0.99,
|
| 102 |
+
epsilon: float = 1e-5,
|
| 103 |
+
threshold_usage_ratio: float = 0.1,
|
| 104 |
+
replaced_usage_ratio: float = 1.0,
|
| 105 |
+
check_unused_every: int = 5,
|
| 106 |
+
):
|
| 107 |
+
super().__init__()
|
| 108 |
+
self.decay = decay
|
| 109 |
+
embedding = torch.zeros(codebook_size, dim)
|
| 110 |
+
|
| 111 |
+
self.dim = dim
|
| 112 |
+
self.codebook_size = codebook_size
|
| 113 |
+
|
| 114 |
+
self.epsilon = epsilon
|
| 115 |
+
self.threshold_usage_ratio = threshold_usage_ratio
|
| 116 |
+
self.replaced_usage_ratio = replaced_usage_ratio
|
| 117 |
+
self.check_unused_every = check_unused_every
|
| 118 |
+
self._next_unused_check = check_unused_every
|
| 119 |
+
|
| 120 |
+
self.register_buffer("_initialized", torch.tensor([False], dtype=torch.float))
|
| 121 |
+
self.register_buffer("cluster_usage", torch.ones(codebook_size))
|
| 122 |
+
self.register_buffer("embedding_sum", embedding)
|
| 123 |
+
self.register_buffer("_embedding", None, persistent=False)
|
| 124 |
+
self._cached_initialized = False
|
| 125 |
+
|
| 126 |
+
def _load_from_state_dict(self, state_dict, prefix, *args, **kwargs) -> None:
|
| 127 |
+
# Mapping old names to new names
|
| 128 |
+
mappings = {
|
| 129 |
+
"inited": "_initialized",
|
| 130 |
+
"cluster_size": "cluster_usage",
|
| 131 |
+
"embed_avg": "embedding_sum",
|
| 132 |
+
"embed_sum": "embedding_sum",
|
| 133 |
+
}
|
| 134 |
+
for old_name, new_name in mappings.items():
|
| 135 |
+
old_name = prefix + old_name
|
| 136 |
+
if old_name in state_dict:
|
| 137 |
+
value = state_dict.pop(old_name)
|
| 138 |
+
if new_name is not None:
|
| 139 |
+
state_dict[prefix + new_name] = value
|
| 140 |
+
super()._load_from_state_dict(state_dict, prefix, *args, **kwargs)
|
| 141 |
+
|
| 142 |
+
@property
|
| 143 |
+
def embedding(self) -> torch.Tensor:
|
| 144 |
+
if self._embedding is None:
|
| 145 |
+
embedding = (
|
| 146 |
+
self.embedding_sum / self.cluster_usage.clamp(min=self.epsilon)[:, None]
|
| 147 |
+
)
|
| 148 |
+
self.register_buffer("_embedding", embedding, persistent=False)
|
| 149 |
+
return embedding
|
| 150 |
+
return self._embedding
|
| 151 |
+
|
| 152 |
+
def _broadcast_buffers(self) -> None:
|
| 153 |
+
if _is_distributed():
|
| 154 |
+
for buffer in self.buffers():
|
| 155 |
+
distributed.broadcast(buffer, 0)
|
| 156 |
+
|
| 157 |
+
def _replace_expired_codes(self, samples: torch.Tensor, mask: torch.Tensor) -> None:
|
| 158 |
+
# Replaces expired centroids, as indicated by `mask` (a true value indicate the code needs to be replaced).
|
| 159 |
+
# The new codes are sampled from the batch `samples`.
|
| 160 |
+
new_vectors = _sample_vectors(samples, self.codebook_size)
|
| 161 |
+
replace_cluster_usage = (
|
| 162 |
+
self.replaced_usage_ratio * self.cluster_usage.sum() / self.codebook_size
|
| 163 |
+
)
|
| 164 |
+
self.embedding_sum[:] = torch.where(
|
| 165 |
+
mask[:, None], replace_cluster_usage * new_vectors, self.embedding_sum
|
| 166 |
+
)
|
| 167 |
+
self.cluster_usage[:] = torch.where(
|
| 168 |
+
mask, replace_cluster_usage, self.cluster_usage
|
| 169 |
+
)
|
| 170 |
+
|
| 171 |
+
def _reshape_input(self, x: torch.Tensor) -> torch.Tensor:
|
| 172 |
+
# Flattens all the dimensions but the last one, e.g. return a vector of shape `[N, D]`.
|
| 173 |
+
x = rearrange(x, "... d -> (...) d")
|
| 174 |
+
return x
|
| 175 |
+
|
| 176 |
+
def _reshape_codes(self, codes: torch.Tensor, shape: torch.Size) -> torch.Tensor:
|
| 177 |
+
return codes.view(*shape[:-1])
|
| 178 |
+
|
| 179 |
+
def _quantize(self, x: torch.Tensor) -> torch.Tensor:
|
| 180 |
+
# Projects each vector in `x` over the nearest centroid and return its index.
|
| 181 |
+
# `x` should be `[N, D]` with `N` the number of input vectors and `D` the dimension.
|
| 182 |
+
assert x.dim() == 2
|
| 183 |
+
dists = torch.cdist(x[None], self.embedding[None], p=2)[0]
|
| 184 |
+
codes = dists.argmin(dim=-1)
|
| 185 |
+
return codes
|
| 186 |
+
|
| 187 |
+
def encode(self, x: torch.Tensor) -> torch.Tensor:
|
| 188 |
+
"""Given a tensor `x` of shape `[*, D]`, returns a tensor of integer codes of shape `[*]`.
|
| 189 |
+
The codes are defined as the indexes of the centroids nearest to each vector in `x`.
|
| 190 |
+
"""
|
| 191 |
+
assert x.dtype.is_floating_point, f"Input should be floats, got {x.dtype}"
|
| 192 |
+
shape = x.shape
|
| 193 |
+
x = self._reshape_input(x)
|
| 194 |
+
codes = self._quantize(x)
|
| 195 |
+
codes = self._reshape_codes(codes, shape)
|
| 196 |
+
return codes
|
| 197 |
+
|
| 198 |
+
def decode(self, codes: torch.Tensor) -> torch.Tensor:
|
| 199 |
+
"""Given a tensor of codes of shape `[*]`, returns a tensor of shape `[*, D]`,
|
| 200 |
+
corresponding to the centroids associated to each code index.
|
| 201 |
+
"""
|
| 202 |
+
assert (
|
| 203 |
+
not codes.dtype.is_floating_point
|
| 204 |
+
), f"Codes should be integers, got {codes.dtype}"
|
| 205 |
+
quantized = F.embedding(codes, self.embedding)
|
| 206 |
+
return quantized
|
| 207 |
+
|
| 208 |
+
def forward(
|
| 209 |
+
self, x: torch.Tensor, initialize: bool = True
|
| 210 |
+
) -> _CodebookForwardResult:
|
| 211 |
+
shape = x.shape
|
| 212 |
+
x = self._reshape_input(x)
|
| 213 |
+
|
| 214 |
+
flat_codes = self._quantize(x)
|
| 215 |
+
codes = self._reshape_codes(flat_codes, shape)
|
| 216 |
+
quantized = self.decode(codes)
|
| 217 |
+
metrics: tp.Dict[str, torch.Tensor] = {}
|
| 218 |
+
|
| 219 |
+
return _CodebookForwardResult(quantized, codes, metrics)
|
| 220 |
+
|
| 221 |
+
|
| 222 |
+
class VectorQuantization(nn.Module):
|
| 223 |
+
"""Vector quantization implementation.
|
| 224 |
+
Currently supports only euclidean distance.
|
| 225 |
+
|
| 226 |
+
Args:
|
| 227 |
+
dim (int): Dimension
|
| 228 |
+
codebook_size (int): Codebook size
|
| 229 |
+
codebook_dim (int): Codebook dimension. If not defined, uses the specified dimension in dim.
|
| 230 |
+
decay (float): Decay for exponential moving average over the codebooks.
|
| 231 |
+
epsilon (float): Epsilon value for numerical stability.
|
| 232 |
+
threshold_usage_ratio (float): Defines the threshold for the cluster usage under which a centroid
|
| 233 |
+
is replaced. This is expressed as a fraction of the usage a centroid would get under
|
| 234 |
+
a uniform distribution, so that it doesn't depend on the batch size etc.
|
| 235 |
+
replaced_usage_ratio (float): When replacing a centroid, use this as an initial centroid usage,
|
| 236 |
+
to avoid the centroid getting replaced too quickly.
|
| 237 |
+
check_unused_every (int): Check for unused centroids every `check_unused_every` iterations.
|
| 238 |
+
This is to avoid too many synchronization points.
|
| 239 |
+
"""
|
| 240 |
+
|
| 241 |
+
def __init__(
|
| 242 |
+
self,
|
| 243 |
+
dim: int,
|
| 244 |
+
codebook_size: int,
|
| 245 |
+
codebook_dim: tp.Optional[int] = None,
|
| 246 |
+
decay: float = 0.99,
|
| 247 |
+
epsilon: float = 1e-5,
|
| 248 |
+
threshold_usage_ratio: float = 0.1,
|
| 249 |
+
**kwargs,
|
| 250 |
+
):
|
| 251 |
+
super().__init__()
|
| 252 |
+
if codebook_dim is None:
|
| 253 |
+
codebook_dim = dim
|
| 254 |
+
|
| 255 |
+
requires_projection = codebook_dim != dim
|
| 256 |
+
self.project_in = (
|
| 257 |
+
nn.Linear(dim, codebook_dim) if requires_projection else nn.Identity()
|
| 258 |
+
)
|
| 259 |
+
self.project_out = (
|
| 260 |
+
nn.Linear(codebook_dim, dim) if requires_projection else nn.Identity()
|
| 261 |
+
)
|
| 262 |
+
self.epsilon = epsilon
|
| 263 |
+
self._codebook = EuclideanCodebook(
|
| 264 |
+
dim=codebook_dim,
|
| 265 |
+
codebook_size=codebook_size,
|
| 266 |
+
decay=decay,
|
| 267 |
+
epsilon=epsilon,
|
| 268 |
+
threshold_usage_ratio=threshold_usage_ratio,
|
| 269 |
+
**kwargs,
|
| 270 |
+
)
|
| 271 |
+
self.codebook_size = codebook_size
|
| 272 |
+
|
| 273 |
+
@property
|
| 274 |
+
def embedding(self):
|
| 275 |
+
return self._codebook.embedding
|
| 276 |
+
|
| 277 |
+
def _rearrange_input(self, x):
|
| 278 |
+
x = rearrange(x, "b d n -> b n d")
|
| 279 |
+
return x
|
| 280 |
+
|
| 281 |
+
def _rearrange_output(self, quantized):
|
| 282 |
+
quantized = rearrange(quantized, "b n d -> b d n")
|
| 283 |
+
return quantized
|
| 284 |
+
|
| 285 |
+
def encode(self, x: torch.Tensor) -> torch.Tensor:
|
| 286 |
+
"""Encodes `x` into discrete integer codes."""
|
| 287 |
+
x = self._rearrange_input(x)
|
| 288 |
+
x = self.project_in(x)
|
| 289 |
+
codes = self._codebook.encode(x)
|
| 290 |
+
return codes
|
| 291 |
+
|
| 292 |
+
def decode(self, codes: torch.Tensor) -> torch.Tensor:
|
| 293 |
+
"""Converts integer codes into quantized vectors."""
|
| 294 |
+
quantized = self._codebook.decode(codes)
|
| 295 |
+
quantized = self.project_out(quantized)
|
| 296 |
+
quantized = self._rearrange_output(quantized)
|
| 297 |
+
return quantized
|
| 298 |
+
|
| 299 |
+
def forward(self, x: torch.Tensor, initialize: bool = True) -> _VQForwardResult:
|
| 300 |
+
x = self._rearrange_input(x)
|
| 301 |
+
quantized, codes, metrics = self._codebook(x, initialize=initialize)
|
| 302 |
+
|
| 303 |
+
loss = zero_scalar(x.device)
|
| 304 |
+
|
| 305 |
+
quantized = self.project_out(quantized)
|
| 306 |
+
quantized = self._rearrange_output(quantized)
|
| 307 |
+
|
| 308 |
+
return _VQForwardResult(quantized, codes, loss, metrics)
|
| 309 |
+
|
| 310 |
+
|
| 311 |
+
class ResidualVectorQuantization(nn.Module):
|
| 312 |
+
"""Residual vector quantization implementation.
|
| 313 |
+
|
| 314 |
+
Follows Algorithm 1. in https://arxiv.org/pdf/2107.03312.pdf
|
| 315 |
+
"""
|
| 316 |
+
|
| 317 |
+
def __init__(self, *, num_quantizers: int, codebook_offset: int, **kwargs):
|
| 318 |
+
super().__init__()
|
| 319 |
+
self.layers = nn.ModuleList(
|
| 320 |
+
[VectorQuantization(**kwargs) for _ in range(num_quantizers)]
|
| 321 |
+
)
|
| 322 |
+
self.codebook_offset = codebook_offset
|
| 323 |
+
|
| 324 |
+
def forward(
|
| 325 |
+
self, x: torch.Tensor, n_q: tp.Optional[int] = None
|
| 326 |
+
) -> _VQForwardResult:
|
| 327 |
+
"""
|
| 328 |
+
Args:
|
| 329 |
+
x (torch.Tensor): input tensor to quantize, of shape `[B, C, T]`.
|
| 330 |
+
n_q (int or None): if provided, number of codebook levels to use in RVQ.
|
| 331 |
+
"""
|
| 332 |
+
|
| 333 |
+
quantized_out = zero_scalar(x.device)
|
| 334 |
+
residual = x
|
| 335 |
+
|
| 336 |
+
all_losses = []
|
| 337 |
+
all_codes = []
|
| 338 |
+
all_metrics: tp.Dict[str, torch.Tensor] = {}
|
| 339 |
+
|
| 340 |
+
n_q = n_q or len(self.layers)
|
| 341 |
+
previous_layer_is_initialized = True
|
| 342 |
+
|
| 343 |
+
for i, layer in enumerate(self.layers[:n_q]): # type: ignore
|
| 344 |
+
quantized, codes, loss, metrics = layer(
|
| 345 |
+
residual, initialize=previous_layer_is_initialized
|
| 346 |
+
)
|
| 347 |
+
|
| 348 |
+
quantized = quantized.detach()
|
| 349 |
+
residual = residual - quantized
|
| 350 |
+
quantized_out = quantized_out + quantized
|
| 351 |
+
|
| 352 |
+
all_codes.append(codes)
|
| 353 |
+
all_losses.append(loss)
|
| 354 |
+
|
| 355 |
+
for key, value in metrics.items():
|
| 356 |
+
if key in all_metrics:
|
| 357 |
+
all_metrics[key] += value / n_q
|
| 358 |
+
else:
|
| 359 |
+
all_metrics[key] = value / n_q
|
| 360 |
+
all_metrics[key + f"_{i + self.codebook_offset}"] = value
|
| 361 |
+
|
| 362 |
+
out_losses, out_codes = map(torch.stack, (all_losses, all_codes))
|
| 363 |
+
return _VQForwardResult(quantized_out, out_codes, out_losses, all_metrics)
|
| 364 |
+
|
| 365 |
+
def encode(self, x: torch.Tensor, n_q: tp.Optional[int] = None) -> torch.Tensor:
|
| 366 |
+
"""Encodes `x` into discrete integer codes. If `n_q` is provided, only uses the first `n_q` codebook levels."""
|
| 367 |
+
residual = x
|
| 368 |
+
all_indices = []
|
| 369 |
+
n_q = n_q or len(self.layers)
|
| 370 |
+
for layer in self.layers[:n_q]: # type: ignore
|
| 371 |
+
indices = layer.encode(residual)
|
| 372 |
+
quantized = layer.decode(indices)
|
| 373 |
+
residual = residual - quantized
|
| 374 |
+
all_indices.append(indices)
|
| 375 |
+
out_indices = torch.stack(all_indices)
|
| 376 |
+
return out_indices
|
| 377 |
+
|
| 378 |
+
def decode(self, codes: torch.Tensor) -> torch.Tensor:
|
| 379 |
+
"""Converts the integer codes into quantized vectors."""
|
| 380 |
+
quantized = zero_scalar(codes.device)
|
| 381 |
+
for idx, layer_codes in enumerate(codes):
|
| 382 |
+
layer = self.layers[idx]
|
| 383 |
+
quantized = quantized + layer.decode(layer_codes)
|
| 384 |
+
return quantized
|
moshi/quantization/vq.py
ADDED
|
@@ -0,0 +1,340 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) Kyutai, all rights reserved.
|
| 2 |
+
# This source code is licensed under the license found in the
|
| 3 |
+
# LICENSE file in the root directory of this source tree.
|
| 4 |
+
|
| 5 |
+
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
| 6 |
+
# All rights reserved.
|
| 7 |
+
#
|
| 8 |
+
# This source code is licensed under the license found in the
|
| 9 |
+
# LICENSE file in the root directory of this source tree.
|
| 10 |
+
|
| 11 |
+
import math
|
| 12 |
+
import typing as tp
|
| 13 |
+
|
| 14 |
+
import torch
|
| 15 |
+
|
| 16 |
+
from .base import BaseQuantizer, QuantizedResult
|
| 17 |
+
from .core_vq import ResidualVectorQuantization
|
| 18 |
+
|
| 19 |
+
|
| 20 |
+
class ResidualVectorQuantizer(BaseQuantizer):
|
| 21 |
+
"""Residual Vector Quantizer.
|
| 22 |
+
|
| 23 |
+
Args:
|
| 24 |
+
dimension (int): Dimension of the codebooks.
|
| 25 |
+
input_dimension (None or int): dimension of the input, defaults to `dimension` if not provided.
|
| 26 |
+
output_dimension (None or int): dimension of the output, defaults to `dimension` if not provided.
|
| 27 |
+
n_q (int): Number of vector quantizers used.
|
| 28 |
+
q_dropout (bool): Random quantizer drop out at train time.
|
| 29 |
+
no_quantization_rate (float): Gives the probability of applying no quantization at all
|
| 30 |
+
at train time. The RVQ codebooks will still get the input value to learn the proper codebook.
|
| 31 |
+
bins (int): Codebook size.
|
| 32 |
+
decay (float): Decay for exponential moving average over the codebooks.
|
| 33 |
+
threshold_usage_ratio (float): Defines the threshold for the cluster usage under which a centroid
|
| 34 |
+
is replaced. This is expressed as a fraction of the usage a centroid would get under
|
| 35 |
+
a uniform distribution, so that it doesn't depend on the batch size etc.
|
| 36 |
+
replaced_usage_ratio (float): When replacing a centroid, use this as an initial centroid usage,
|
| 37 |
+
to avoid the centroid getting replaced too quickly.
|
| 38 |
+
codebook_offset (int): Offset to use for the codebook indices. This is useful when using multiple quantizers
|
| 39 |
+
such as in SplitResidualVectorQuantizer.
|
| 40 |
+
force_projection (bool): Whether to force input and output projections even when dimension is constant.
|
| 41 |
+
generator_seed (int or None): seed used to initialize the RNG used for no quantization.
|
| 42 |
+
"""
|
| 43 |
+
|
| 44 |
+
def __init__(
|
| 45 |
+
self,
|
| 46 |
+
dimension: int = 128,
|
| 47 |
+
input_dimension: tp.Optional[int] = None,
|
| 48 |
+
output_dimension: tp.Optional[int] = None,
|
| 49 |
+
n_q: int = 8,
|
| 50 |
+
q_dropout: bool = False,
|
| 51 |
+
q_first_only_proba: float = 0.0,
|
| 52 |
+
no_quantization_rate: float = 0.0,
|
| 53 |
+
bins: int = 1024,
|
| 54 |
+
decay: float = 0.99,
|
| 55 |
+
threshold_usage_ratio: float = 0.1,
|
| 56 |
+
replaced_usage_ratio: float = 1.0,
|
| 57 |
+
codebook_offset: int = 0,
|
| 58 |
+
force_projection: bool = False,
|
| 59 |
+
generator_seed: tp.Optional[int] = None,
|
| 60 |
+
):
|
| 61 |
+
super().__init__()
|
| 62 |
+
self.max_n_q = n_q
|
| 63 |
+
self.n_q = n_q
|
| 64 |
+
self.q_dropout = q_dropout
|
| 65 |
+
self.no_quantization_rate = no_quantization_rate
|
| 66 |
+
self.q_first_only_proba = q_first_only_proba
|
| 67 |
+
self.dimension = dimension
|
| 68 |
+
self.input_dimension = input_dimension or dimension
|
| 69 |
+
self.output_dimension = output_dimension or dimension
|
| 70 |
+
self.bins = bins
|
| 71 |
+
self.decay = decay
|
| 72 |
+
self.input_proj: torch.nn.Module
|
| 73 |
+
self.output_proj: torch.nn.Module
|
| 74 |
+
self.generator = None
|
| 75 |
+
if generator_seed is not None:
|
| 76 |
+
self.generator = torch.Generator(
|
| 77 |
+
device="cuda" if torch.cuda.is_available() else "cpu"
|
| 78 |
+
)
|
| 79 |
+
self.generator.manual_seed(generator_seed)
|
| 80 |
+
if self.input_dimension == self.dimension and not force_projection:
|
| 81 |
+
self.input_proj = torch.nn.Identity()
|
| 82 |
+
else:
|
| 83 |
+
self.input_proj = torch.nn.Conv1d(
|
| 84 |
+
self.input_dimension, self.dimension, 1, bias=False
|
| 85 |
+
)
|
| 86 |
+
if self.output_dimension == self.dimension and not force_projection:
|
| 87 |
+
self.output_proj = torch.nn.Identity()
|
| 88 |
+
else:
|
| 89 |
+
self.output_proj = torch.nn.Conv1d(
|
| 90 |
+
self.dimension, self.output_dimension, 1, bias=False
|
| 91 |
+
)
|
| 92 |
+
self.vq = ResidualVectorQuantization(
|
| 93 |
+
dim=self.dimension,
|
| 94 |
+
codebook_size=self.bins,
|
| 95 |
+
num_quantizers=self.n_q,
|
| 96 |
+
decay=self.decay,
|
| 97 |
+
threshold_usage_ratio=threshold_usage_ratio,
|
| 98 |
+
replaced_usage_ratio=replaced_usage_ratio,
|
| 99 |
+
codebook_offset=codebook_offset,
|
| 100 |
+
)
|
| 101 |
+
|
| 102 |
+
def forward(self, x: torch.Tensor, frame_rate: int):
|
| 103 |
+
"""
|
| 104 |
+
Args:
|
| 105 |
+
x (torch.Tensor): Input tensor of shape [B, C, T] with `C` number of channels.
|
| 106 |
+
frame_rate (int): frame rate of the input (e.g `T = frame_rate * duration`), used to compute
|
| 107 |
+
the bandwidth.
|
| 108 |
+
|
| 109 |
+
Returns:
|
| 110 |
+
QuantizedResult: Quantized result with the following attributes:
|
| 111 |
+
- `x` (torch.Tensor): Quantized tensor of shape [B, C, T].
|
| 112 |
+
- `codes` (torch.Tensor): Quantized codes of shape [B, K, T] with `K` number of codebooks.
|
| 113 |
+
- `bw` (torch.Tensor): Bandwidth of the quantized tensor in kbits per second.
|
| 114 |
+
- `penalty` (torch.Tensor): Commitment loss.
|
| 115 |
+
- `metrics` (dict): RVQ metrics, in particular rate of dead code replacement, and entropy.
|
| 116 |
+
"""
|
| 117 |
+
n_q = self.n_q
|
| 118 |
+
x = self.input_proj(x)
|
| 119 |
+
|
| 120 |
+
bw_per_q = math.log2(self.bins) * frame_rate / 1000
|
| 121 |
+
quantized, codes, commit_loss, metrics = self.vq(x, n_q=n_q)
|
| 122 |
+
B, _, _ = quantized.shape
|
| 123 |
+
quantized = self.output_proj(quantized)
|
| 124 |
+
codes = codes.transpose(0, 1)
|
| 125 |
+
# codes is [B, K, T], with T frames, K nb of codebooks.
|
| 126 |
+
bw = torch.tensor(n_q * bw_per_q).to(x)
|
| 127 |
+
return QuantizedResult(
|
| 128 |
+
quantized, codes, bw, penalty=torch.mean(commit_loss), metrics=metrics
|
| 129 |
+
)
|
| 130 |
+
|
| 131 |
+
def encode(self, x: torch.Tensor) -> torch.Tensor:
|
| 132 |
+
"""Encode a given input tensor with the specified frame rate at the given bandwidth.
|
| 133 |
+
The RVQ encode method sets the appropriate number of quantizer to use
|
| 134 |
+
and returns indices for each quantizer.
|
| 135 |
+
"""
|
| 136 |
+
n_q = self.n_q
|
| 137 |
+
if x.shape[-1] == 0:
|
| 138 |
+
return torch.empty((x.shape[0], n_q, 0), device=x.device, dtype=torch.int64)
|
| 139 |
+
|
| 140 |
+
x = self.input_proj(x)
|
| 141 |
+
codes = self.vq.encode(x, n_q=n_q)
|
| 142 |
+
codes = codes.transpose(0, 1)
|
| 143 |
+
# codes is [B, K, T], with T frames, K nb of codebooks.
|
| 144 |
+
return codes
|
| 145 |
+
|
| 146 |
+
def decode(self, codes: torch.Tensor) -> torch.Tensor:
|
| 147 |
+
"""Decode the given codes to the quantized representation."""
|
| 148 |
+
# codes is [B, K, T], with T frames, K nb of codebooks, vq.decode expects [K, B, T].
|
| 149 |
+
codes = codes.transpose(0, 1)
|
| 150 |
+
quantized = self.vq.decode(codes)
|
| 151 |
+
quantized = self.output_proj(quantized)
|
| 152 |
+
return quantized
|
| 153 |
+
|
| 154 |
+
@property
|
| 155 |
+
def total_codebooks(self):
|
| 156 |
+
return self.max_n_q
|
| 157 |
+
|
| 158 |
+
@property
|
| 159 |
+
def num_codebooks(self):
|
| 160 |
+
return self.n_q
|
| 161 |
+
|
| 162 |
+
def set_num_codebooks(self, n: int):
|
| 163 |
+
assert n >= 0 and n <= self.max_n_q
|
| 164 |
+
self.n_q = n
|
| 165 |
+
|
| 166 |
+
@property
|
| 167 |
+
def cardinality(self) -> int:
|
| 168 |
+
return self.bins
|
| 169 |
+
|
| 170 |
+
|
| 171 |
+
class SplitResidualVectorQuantizer(BaseQuantizer):
|
| 172 |
+
"""Residual Vector Quantizer with separate projections for the first quantizer and the rest.
|
| 173 |
+
|
| 174 |
+
Args:
|
| 175 |
+
n_q (int): Number of residual vector quantizers used.
|
| 176 |
+
n_semantic_q (int): Number of residual vector quantizers used for the semantic quantizer.
|
| 177 |
+
no_quantization_mode (str): if 'true_skip', when doing no quantization, the input will not go
|
| 178 |
+
through the sub quantizers. If `independent`, independent decisions are taken by
|
| 179 |
+
the semantic and acoustic quantizers. If `same` (the default), the same decision is taken by both.
|
| 180 |
+
**kwargs: Arguments to the constructor of `ResidualVectorQuantizer` that are shared between both.
|
| 181 |
+
"""
|
| 182 |
+
|
| 183 |
+
def __init__(
|
| 184 |
+
self,
|
| 185 |
+
*,
|
| 186 |
+
n_q: int = 8,
|
| 187 |
+
no_quantization_rate: float = 0.0,
|
| 188 |
+
no_quantization_mode: str = "same",
|
| 189 |
+
n_q_semantic: int = 1,
|
| 190 |
+
**kwargs,
|
| 191 |
+
):
|
| 192 |
+
super().__init__()
|
| 193 |
+
assert n_q > n_q_semantic, (
|
| 194 |
+
f"Number of quantizers {n_q} must be larger "
|
| 195 |
+
f"than the number of semantic quantizers {n_q_semantic}."
|
| 196 |
+
)
|
| 197 |
+
self.max_n_q = n_q
|
| 198 |
+
self.n_q_semantic = n_q_semantic
|
| 199 |
+
self.n_q_acoustic = n_q - n_q_semantic
|
| 200 |
+
if no_quantization_mode == "true_skip":
|
| 201 |
+
self.no_quantization_rate = no_quantization_rate
|
| 202 |
+
# Setting to zero for the underlying RVQ.
|
| 203 |
+
no_quantization_rate = 0.0
|
| 204 |
+
else:
|
| 205 |
+
self.no_quantization_rate = 0.0
|
| 206 |
+
if no_quantization_mode == "same":
|
| 207 |
+
kwargs["generator_seed"] = 1234
|
| 208 |
+
kwargs["no_quantization_rate"] = no_quantization_rate
|
| 209 |
+
q_dropout = kwargs.pop("q_dropout", False)
|
| 210 |
+
self.rvq_first = ResidualVectorQuantizer(
|
| 211 |
+
n_q=n_q_semantic, force_projection=True, q_dropout=False, **kwargs
|
| 212 |
+
)
|
| 213 |
+
self.rvq_rest = ResidualVectorQuantizer(
|
| 214 |
+
n_q=n_q - n_q_semantic,
|
| 215 |
+
codebook_offset=1,
|
| 216 |
+
force_projection=True,
|
| 217 |
+
q_dropout=q_dropout,
|
| 218 |
+
**kwargs,
|
| 219 |
+
)
|
| 220 |
+
if no_quantization_mode == "true_skip":
|
| 221 |
+
assert self.rvq_first.input_dimension == self.rvq_first.output_dimension
|
| 222 |
+
assert self.rvq_rest.input_dimension == self.rvq_rest.output_dimension
|
| 223 |
+
|
| 224 |
+
def _renorm_and_add(
|
| 225 |
+
self,
|
| 226 |
+
first_val: torch.Tensor,
|
| 227 |
+
rest_val: torch.Tensor,
|
| 228 |
+
n_q_semantic: int,
|
| 229 |
+
n_q_acoustic: int,
|
| 230 |
+
):
|
| 231 |
+
"""Renormalizes values from `rvq_first` and `rvq_rest` and adds them.
|
| 232 |
+
|
| 233 |
+
This allows correcting statistics that are normalized by the number of quantizers. To renormalize, we use the
|
| 234 |
+
number of quantizers that are actually used, e.g. taking into account quantizer dropout.
|
| 235 |
+
"""
|
| 236 |
+
n_q = n_q_semantic + n_q_acoustic
|
| 237 |
+
renorm_first_val = first_val * n_q_semantic / n_q
|
| 238 |
+
renorm_rest_val = rest_val * n_q_acoustic / n_q
|
| 239 |
+
return renorm_first_val + renorm_rest_val
|
| 240 |
+
|
| 241 |
+
def forward(self, x: torch.Tensor, frame_rate: int):
|
| 242 |
+
"""
|
| 243 |
+
Args:
|
| 244 |
+
x (torch.Tensor): Input tensor of shape [B, C, T] with `C` number of channels.
|
| 245 |
+
frame_rate (int): frame rate of the input (e.g `T = frame_rate * duration`), used to compute
|
| 246 |
+
the bandwidth.
|
| 247 |
+
|
| 248 |
+
Returns:
|
| 249 |
+
QuantizedResult: Quantized result with the following attributes:
|
| 250 |
+
- `x` (torch.Tensor): Quantized tensor of shape [B, C, T].
|
| 251 |
+
- `codes` (torch.Tensor): Quantized codes of shape [B, K, T] with `K` number of codebooks.
|
| 252 |
+
- `bw` (torch.Tensor): Bandwidth of the quantized tensor in kbits per second.
|
| 253 |
+
- `penalty` (torch.Tensor): Commitment loss.
|
| 254 |
+
- `metrics` (dict): RVQ metrics, in particular rate of dead code replacement, and entropy.
|
| 255 |
+
"""
|
| 256 |
+
semantic_result = self.rvq_first(x, frame_rate)
|
| 257 |
+
if self.n_q == self.n_q_semantic:
|
| 258 |
+
return semantic_result
|
| 259 |
+
acoustic_result = self.rvq_rest(x, frame_rate)
|
| 260 |
+
full_quantized_emb = semantic_result.x + acoustic_result.x
|
| 261 |
+
full_quantized_codes = torch.cat(
|
| 262 |
+
[semantic_result.codes, acoustic_result.codes], dim=1
|
| 263 |
+
)
|
| 264 |
+
# This is the actual number of quantizers used, e.g. taking into account quantizer dropout.
|
| 265 |
+
n_q_semantic = semantic_result.codes.shape[1]
|
| 266 |
+
n_q_acoustic = acoustic_result.codes.shape[1]
|
| 267 |
+
full_quantized_bandwidth = semantic_result.bandwidth + acoustic_result.bandwidth
|
| 268 |
+
full_quantized_penalty = self._renorm_and_add(
|
| 269 |
+
semantic_result.penalty, acoustic_result.penalty, n_q_semantic, n_q_acoustic
|
| 270 |
+
)
|
| 271 |
+
full_quantized_metrics = semantic_result.metrics
|
| 272 |
+
for key, value in acoustic_result.metrics.items():
|
| 273 |
+
if key in full_quantized_metrics:
|
| 274 |
+
full_quantized_metrics[key] = self._renorm_and_add(
|
| 275 |
+
full_quantized_metrics[key], value, n_q_semantic, n_q_acoustic
|
| 276 |
+
)
|
| 277 |
+
else:
|
| 278 |
+
full_quantized_metrics[key] = value
|
| 279 |
+
return QuantizedResult(
|
| 280 |
+
full_quantized_emb,
|
| 281 |
+
full_quantized_codes,
|
| 282 |
+
full_quantized_bandwidth,
|
| 283 |
+
penalty=full_quantized_penalty,
|
| 284 |
+
metrics=full_quantized_metrics,
|
| 285 |
+
)
|
| 286 |
+
|
| 287 |
+
def encode(self, x: torch.Tensor) -> torch.Tensor:
|
| 288 |
+
"""Encode a given input tensor with the specified frame rate at the given bandwidth.
|
| 289 |
+
The RVQ encode method sets the appropriate number of quantizer to use
|
| 290 |
+
and returns indices for each quantizer.
|
| 291 |
+
"""
|
| 292 |
+
codes = self.rvq_first.encode(x)
|
| 293 |
+
if self.n_q > self.n_q_semantic:
|
| 294 |
+
acoustic_codes = self.rvq_rest.encode(x)
|
| 295 |
+
codes = torch.cat([codes, acoustic_codes], dim=1)
|
| 296 |
+
# codes is [B, K, T], with T frames, K nb of codebooks.
|
| 297 |
+
return codes
|
| 298 |
+
|
| 299 |
+
def decode(self, codes: torch.Tensor) -> torch.Tensor:
|
| 300 |
+
"""Decode the given codes to the quantized representation."""
|
| 301 |
+
# codes is [B, K, T], with T frames, K nb of codebooks.
|
| 302 |
+
quantized = self.rvq_first.decode(codes[:, : self.n_q_semantic])
|
| 303 |
+
if codes.shape[1] > self.n_q_semantic:
|
| 304 |
+
quantized += self.rvq_rest.decode(codes[:, self.n_q_semantic :])
|
| 305 |
+
return quantized
|
| 306 |
+
|
| 307 |
+
@property
|
| 308 |
+
def total_codebooks(self):
|
| 309 |
+
return self.rvq_first.max_n_q + self.rvq_rest.max_n_q
|
| 310 |
+
|
| 311 |
+
@property
|
| 312 |
+
def num_codebooks(self):
|
| 313 |
+
return self.rvq_first.num_codebooks + self.rvq_rest.num_codebooks
|
| 314 |
+
|
| 315 |
+
@property
|
| 316 |
+
def n_q(self):
|
| 317 |
+
return self.rvq_first.n_q + self.rvq_rest.n_q
|
| 318 |
+
|
| 319 |
+
@property
|
| 320 |
+
def dimension(self):
|
| 321 |
+
return self.rvq_first.dimension
|
| 322 |
+
|
| 323 |
+
@property
|
| 324 |
+
def semantic_quantizer(self) -> ResidualVectorQuantizer:
|
| 325 |
+
"""This returns the quantizer that models the first level of the hierarchy (typically semantic)."""
|
| 326 |
+
return self.rvq_first
|
| 327 |
+
|
| 328 |
+
@property
|
| 329 |
+
def acoustic_quantizer(self) -> ResidualVectorQuantizer:
|
| 330 |
+
"""This returns the quantizer that models the higher levels of the hierarchy (typically acoustic)."""
|
| 331 |
+
return self.rvq_rest
|
| 332 |
+
|
| 333 |
+
def set_num_codebooks(self, n: int):
|
| 334 |
+
assert n >= self.n_q_semantic and n <= self.total_codebooks
|
| 335 |
+
self.rvq_rest.set_num_codebooks(n - self.n_q_semantic)
|
| 336 |
+
|
| 337 |
+
@property
|
| 338 |
+
def cardinality(self) -> int:
|
| 339 |
+
assert self.rvq_rest.cardinality == self.rvq_first.cardinality
|
| 340 |
+
return self.rvq_first.cardinality
|
moshi/server.py
ADDED
|
@@ -0,0 +1,256 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) Kyutai, all rights reserved.
|
| 2 |
+
# This source code is licensed under the license found in the
|
| 3 |
+
# LICENSE file in the root directory of this source tree.
|
| 4 |
+
|
| 5 |
+
import argparse
|
| 6 |
+
import asyncio
|
| 7 |
+
from dataclasses import dataclass
|
| 8 |
+
import random
|
| 9 |
+
import os
|
| 10 |
+
from pathlib import Path
|
| 11 |
+
import tarfile
|
| 12 |
+
import time
|
| 13 |
+
import secrets
|
| 14 |
+
import sys
|
| 15 |
+
|
| 16 |
+
import aiohttp
|
| 17 |
+
from aiohttp import web
|
| 18 |
+
from huggingface_hub import hf_hub_download
|
| 19 |
+
import numpy as np
|
| 20 |
+
import sentencepiece
|
| 21 |
+
import sphn
|
| 22 |
+
import torch
|
| 23 |
+
|
| 24 |
+
|
| 25 |
+
from .client_utils import make_log
|
| 26 |
+
from .models import loaders, MimiModel, LMModel, LMGen
|
| 27 |
+
|
| 28 |
+
|
| 29 |
+
def log(level: str, msg: str):
|
| 30 |
+
print(make_log(level, msg))
|
| 31 |
+
|
| 32 |
+
|
| 33 |
+
def seed_all(seed):
|
| 34 |
+
torch.manual_seed(seed)
|
| 35 |
+
if torch.cuda.is_available():
|
| 36 |
+
torch.cuda.manual_seed(seed)
|
| 37 |
+
torch.cuda.manual_seed_all(seed) # for multi-GPU setups
|
| 38 |
+
random.seed(seed)
|
| 39 |
+
np.random.seed(seed)
|
| 40 |
+
torch.backends.cudnn.deterministic = False
|
| 41 |
+
torch.backends.cudnn.benchmark = False
|
| 42 |
+
|
| 43 |
+
|
| 44 |
+
@dataclass
|
| 45 |
+
class ServerState:
|
| 46 |
+
mimi: MimiModel
|
| 47 |
+
text_tokenizer: sentencepiece.SentencePieceProcessor
|
| 48 |
+
lm_gen: LMGen
|
| 49 |
+
lock: asyncio.Lock
|
| 50 |
+
|
| 51 |
+
def __init__(self, mimi: MimiModel, text_tokenizer: sentencepiece.SentencePieceProcessor,
|
| 52 |
+
lm: LMModel, device: str | torch.device):
|
| 53 |
+
self.mimi = mimi
|
| 54 |
+
self.text_tokenizer = text_tokenizer
|
| 55 |
+
self.lm_gen = LMGen(lm)
|
| 56 |
+
|
| 57 |
+
self.device = device
|
| 58 |
+
self.frame_size = int(self.mimi.sample_rate / self.mimi.frame_rate)
|
| 59 |
+
self.lock = asyncio.Lock()
|
| 60 |
+
|
| 61 |
+
self.mimi.streaming_forever(1)
|
| 62 |
+
self.lm_gen.streaming_forever(1)
|
| 63 |
+
|
| 64 |
+
def warmup(self):
|
| 65 |
+
for chunk in range(4):
|
| 66 |
+
chunk = torch.zeros(1, 1, self.frame_size, dtype=torch.float32, device=self.device)
|
| 67 |
+
codes = self.mimi.encode(chunk)
|
| 68 |
+
for c in range(codes.shape[-1]):
|
| 69 |
+
tokens = self.lm_gen.step(codes[:, :, c: c + 1])
|
| 70 |
+
if tokens is None:
|
| 71 |
+
continue
|
| 72 |
+
_ = self.mimi.decode(tokens[:, 1:])
|
| 73 |
+
torch.cuda.synchronize()
|
| 74 |
+
|
| 75 |
+
async def handle_chat(self, request):
|
| 76 |
+
ws = web.WebSocketResponse()
|
| 77 |
+
await ws.prepare(request)
|
| 78 |
+
|
| 79 |
+
async def recv_loop():
|
| 80 |
+
nonlocal close
|
| 81 |
+
try:
|
| 82 |
+
async for message in ws:
|
| 83 |
+
if message.type == aiohttp.WSMsgType.ERROR:
|
| 84 |
+
log("error", f"{ws.exception()}")
|
| 85 |
+
break
|
| 86 |
+
elif message.type == aiohttp.WSMsgType.CLOSED:
|
| 87 |
+
break
|
| 88 |
+
elif message.type != aiohttp.WSMsgType.BINARY:
|
| 89 |
+
log("error", f"unexpected message type {message.type}")
|
| 90 |
+
continue
|
| 91 |
+
message = message.data
|
| 92 |
+
if not isinstance(message, bytes):
|
| 93 |
+
log("error", f"unsupported message type {type(message)}")
|
| 94 |
+
continue
|
| 95 |
+
if len(message) == 0:
|
| 96 |
+
log("warning", "empty message")
|
| 97 |
+
continue
|
| 98 |
+
kind = message[0]
|
| 99 |
+
if kind == 1: # audio
|
| 100 |
+
payload = message[1:]
|
| 101 |
+
opus_reader.append_bytes(payload)
|
| 102 |
+
else:
|
| 103 |
+
log("warning", f"unknown message kind {kind}")
|
| 104 |
+
finally:
|
| 105 |
+
close = True
|
| 106 |
+
log("info", "connection closed")
|
| 107 |
+
|
| 108 |
+
async def opus_loop():
|
| 109 |
+
all_pcm_data = None
|
| 110 |
+
|
| 111 |
+
while True:
|
| 112 |
+
if close:
|
| 113 |
+
return
|
| 114 |
+
await asyncio.sleep(0.001)
|
| 115 |
+
pcm = opus_reader.read_pcm()
|
| 116 |
+
if pcm.shape[-1] == 0:
|
| 117 |
+
continue
|
| 118 |
+
if all_pcm_data is None:
|
| 119 |
+
all_pcm_data = pcm
|
| 120 |
+
else:
|
| 121 |
+
all_pcm_data = np.concatenate((all_pcm_data, pcm))
|
| 122 |
+
while all_pcm_data.shape[-1] >= self.frame_size:
|
| 123 |
+
be = time.time()
|
| 124 |
+
chunk = all_pcm_data[: self.frame_size]
|
| 125 |
+
all_pcm_data = all_pcm_data[self.frame_size:]
|
| 126 |
+
chunk = torch.from_numpy(chunk)
|
| 127 |
+
chunk = chunk.to(device=self.device)[None, None]
|
| 128 |
+
codes = self.mimi.encode(chunk)
|
| 129 |
+
for c in range(codes.shape[-1]):
|
| 130 |
+
tokens = self.lm_gen.step(codes[:, :, c: c + 1])
|
| 131 |
+
if tokens is None:
|
| 132 |
+
continue
|
| 133 |
+
assert tokens.shape[1] == self.lm_gen.lm_model.dep_q + 1
|
| 134 |
+
main_pcm = self.mimi.decode(tokens[:, 1:])
|
| 135 |
+
main_pcm = main_pcm.cpu()
|
| 136 |
+
opus_writer.append_pcm(main_pcm[0, 0].numpy())
|
| 137 |
+
text_token = tokens[0, 0, 0].item()
|
| 138 |
+
if text_token not in (0, 3):
|
| 139 |
+
_text = self.text_tokenizer.id_to_piece(text_token) # type: ignore
|
| 140 |
+
_text = _text.replace("▁", " ")
|
| 141 |
+
msg = b"\x02" + bytes(_text, encoding="utf8")
|
| 142 |
+
log("info", f"text token '{_text}'")
|
| 143 |
+
await ws.send_bytes(msg)
|
| 144 |
+
log("info", f"frame handled in {1000 * (time.time() - be):.1f}ms")
|
| 145 |
+
|
| 146 |
+
async def send_loop():
|
| 147 |
+
while True:
|
| 148 |
+
if close:
|
| 149 |
+
return
|
| 150 |
+
await asyncio.sleep(0.001)
|
| 151 |
+
msg = opus_writer.read_bytes()
|
| 152 |
+
if len(msg) > 0:
|
| 153 |
+
await ws.send_bytes(b"\x01" + msg)
|
| 154 |
+
|
| 155 |
+
log("info", "accepted connection")
|
| 156 |
+
close = False
|
| 157 |
+
async with self.lock:
|
| 158 |
+
opus_writer = sphn.OpusStreamWriter(self.mimi.sample_rate)
|
| 159 |
+
opus_reader = sphn.OpusStreamReader(self.mimi.sample_rate)
|
| 160 |
+
self.mimi.reset_streaming()
|
| 161 |
+
self.lm_gen.reset_streaming()
|
| 162 |
+
# Send the handshake.
|
| 163 |
+
await ws.send_bytes(b"\x00")
|
| 164 |
+
await asyncio.gather(opus_loop(), recv_loop(), send_loop())
|
| 165 |
+
log("info", "done with connection")
|
| 166 |
+
return ws
|
| 167 |
+
|
| 168 |
+
|
| 169 |
+
def main():
|
| 170 |
+
parser = argparse.ArgumentParser()
|
| 171 |
+
parser.add_argument("--host", default="localhost", type=str)
|
| 172 |
+
parser.add_argument("--port", default=8998, type=int)
|
| 173 |
+
parser.add_argument("--static", type=str)
|
| 174 |
+
parser.add_argument("--gradio-tunnel", action='store_true', help='Activate a gradio tunnel.')
|
| 175 |
+
parser.add_argument("--gradio-tunnel-token",
|
| 176 |
+
help='Provide a custom (secret) token here to keep getting the same URL.')
|
| 177 |
+
|
| 178 |
+
parser.add_argument("--tokenizer", type=str, help="Path to a local tokenizer file.")
|
| 179 |
+
parser.add_argument("--moshi-weight", type=str, help="Path to a local checkpoint file for Moshi.")
|
| 180 |
+
parser.add_argument("--mimi-weight", type=str, help="Path to a local checkpoint file for Mimi.")
|
| 181 |
+
parser.add_argument("--hf-repo", type=str, default=loaders.DEFAULT_REPO,
|
| 182 |
+
help="HF repo to look into, defaults Moshiko. "
|
| 183 |
+
"Use this to select a different pre-trained model.")
|
| 184 |
+
parser.add_argument("--device", type=str, default="cuda", help="Device on which to run, defaults to 'cuda'.")
|
| 185 |
+
|
| 186 |
+
args = parser.parse_args()
|
| 187 |
+
seed_all(42424242)
|
| 188 |
+
|
| 189 |
+
setup_tunnel = None
|
| 190 |
+
tunnel_token = ''
|
| 191 |
+
if args.gradio_tunnel:
|
| 192 |
+
try:
|
| 193 |
+
from gradio import networking # type: ignore
|
| 194 |
+
except ImportError:
|
| 195 |
+
log("error", "Cannot find gradio which is required to activate a tunnel. "
|
| 196 |
+
"Please install with `pip install gradio`.")
|
| 197 |
+
sys.exit(1)
|
| 198 |
+
setup_tunnel = networking.setup_tunnel
|
| 199 |
+
if args.gradio_tunnel_token is None:
|
| 200 |
+
tunnel_token = secrets.token_urlsafe(32)
|
| 201 |
+
else:
|
| 202 |
+
tunnel_token = args.gradio_tunnel_token
|
| 203 |
+
|
| 204 |
+
log("info", "loading mimi")
|
| 205 |
+
if args.mimi_weight is None:
|
| 206 |
+
args.mimi_weight = hf_hub_download(args.hf_repo, loaders.MIMI_NAME)
|
| 207 |
+
mimi = loaders.get_mimi(args.mimi_weight, args.device)
|
| 208 |
+
log("info", "mimi loaded")
|
| 209 |
+
|
| 210 |
+
if args.tokenizer is None:
|
| 211 |
+
args.tokenizer = hf_hub_download(args.hf_repo, loaders.TEXT_TOKENIZER_NAME)
|
| 212 |
+
text_tokenizer = sentencepiece.SentencePieceProcessor(args.tokenizer) # type: ignore
|
| 213 |
+
|
| 214 |
+
log("info", "loading moshi")
|
| 215 |
+
if args.moshi_weight is None:
|
| 216 |
+
args.moshi_weight = hf_hub_download(args.hf_repo, loaders.MOSHI_NAME)
|
| 217 |
+
lm = loaders.get_moshi_lm(args.moshi_weight, args.device)
|
| 218 |
+
log("info", "moshi loaded")
|
| 219 |
+
|
| 220 |
+
state = ServerState(mimi, text_tokenizer, lm, args.device)
|
| 221 |
+
log("info", "warming up the model")
|
| 222 |
+
state.warmup()
|
| 223 |
+
app = web.Application()
|
| 224 |
+
app.router.add_get("/api/chat", state.handle_chat)
|
| 225 |
+
static_path: None | str = None
|
| 226 |
+
if args.static is None:
|
| 227 |
+
log("info", "retrieving the static content")
|
| 228 |
+
dist_tgz = hf_hub_download("kyutai/moshi-artifacts", "dist.tgz")
|
| 229 |
+
dist_tgz = Path(dist_tgz)
|
| 230 |
+
dist = dist_tgz.parent / "dist"
|
| 231 |
+
if not dist.exists():
|
| 232 |
+
with tarfile.open(dist_tgz, "r:gz") as tar:
|
| 233 |
+
tar.extractall(path=dist_tgz.parent)
|
| 234 |
+
static_path = str(dist)
|
| 235 |
+
elif args.static != "none":
|
| 236 |
+
# When set to the "none" string, we don't serve any static content.
|
| 237 |
+
static_path = args.static
|
| 238 |
+
if static_path is not None:
|
| 239 |
+
async def handle_root(_):
|
| 240 |
+
return web.FileResponse(os.path.join(static_path, "index.html"))
|
| 241 |
+
|
| 242 |
+
log("info", f"serving static content from {static_path}")
|
| 243 |
+
app.router.add_get("/", handle_root)
|
| 244 |
+
app.router.add_static(
|
| 245 |
+
"/", path=static_path, follow_symlinks=True, name="static"
|
| 246 |
+
)
|
| 247 |
+
log("info", f"Access the Web UI directly at http://{args.host}:{args.port}")
|
| 248 |
+
if setup_tunnel is not None:
|
| 249 |
+
tunnel = setup_tunnel('localhost', args.port, tunnel_token, None)
|
| 250 |
+
log("info", f"Tunnel started, if executing on a remote GPU, you can use {tunnel}.")
|
| 251 |
+
log("info", "Note that this tunnel goes through the US and you might experience high latency in Europe.")
|
| 252 |
+
web.run_app(app, port=args.port)
|
| 253 |
+
|
| 254 |
+
|
| 255 |
+
with torch.no_grad():
|
| 256 |
+
main()
|
moshi/utils/__init__.py
ADDED
|
@@ -0,0 +1,10 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) Kyutai, all rights reserved.
|
| 2 |
+
# This source code is licensed under the license found in the
|
| 3 |
+
# LICENSE file in the root directory of this source tree.
|
| 4 |
+
|
| 5 |
+
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
| 6 |
+
# All rights reserved.
|
| 7 |
+
#
|
| 8 |
+
# This source code is licensed under the license found in the
|
| 9 |
+
# LICENSE file in the root directory of this source tree.
|
| 10 |
+
"""Utilities."""
|
moshi/utils/autocast.py
ADDED
|
@@ -0,0 +1,45 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) Kyutai, all rights reserved.
|
| 2 |
+
# This source code is licensed under the license found in the
|
| 3 |
+
# LICENSE file in the root directory of this source tree.
|
| 4 |
+
|
| 5 |
+
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
| 6 |
+
# All rights reserved.
|
| 7 |
+
#
|
| 8 |
+
# This source code is licensed under the license found in the
|
| 9 |
+
# LICENSE file in the root directory of this source tree.
|
| 10 |
+
|
| 11 |
+
import torch
|
| 12 |
+
|
| 13 |
+
|
| 14 |
+
class TorchAutocast:
|
| 15 |
+
"""TorchAutocast utility class.
|
| 16 |
+
Allows you to enable and disable autocast. This is specially useful
|
| 17 |
+
when dealing with different architectures and clusters with different
|
| 18 |
+
levels of support.
|
| 19 |
+
|
| 20 |
+
Args:
|
| 21 |
+
enabled (bool): Whether to enable torch.autocast or not.
|
| 22 |
+
args: Additional args for torch.autocast.
|
| 23 |
+
kwargs: Additional kwargs for torch.autocast
|
| 24 |
+
"""
|
| 25 |
+
|
| 26 |
+
def __init__(self, enabled: bool, *args, **kwargs):
|
| 27 |
+
self.autocast = torch.autocast(*args, **kwargs) if enabled else None
|
| 28 |
+
|
| 29 |
+
def __enter__(self):
|
| 30 |
+
if self.autocast is None:
|
| 31 |
+
return
|
| 32 |
+
try:
|
| 33 |
+
self.autocast.__enter__()
|
| 34 |
+
except RuntimeError:
|
| 35 |
+
device = self.autocast.device
|
| 36 |
+
dtype = self.autocast.fast_dtype
|
| 37 |
+
raise RuntimeError(
|
| 38 |
+
f"There was an error autocasting with dtype={dtype} device={device}\n"
|
| 39 |
+
"If you are on the FAIR Cluster, you might need to use autocast_dtype=float16"
|
| 40 |
+
)
|
| 41 |
+
|
| 42 |
+
def __exit__(self, *args, **kwargs):
|
| 43 |
+
if self.autocast is None:
|
| 44 |
+
return
|
| 45 |
+
self.autocast.__exit__(*args, **kwargs)
|
moshi/utils/compile.py
ADDED
|
@@ -0,0 +1,284 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) Kyutai, all rights reserved.
|
| 2 |
+
# This source code is licensed under the license found in the
|
| 3 |
+
# LICENSE file in the root directory of this source tree.
|
| 4 |
+
|
| 5 |
+
"""
|
| 6 |
+
Provides some extra utilities around torch compile, in particular with a way
|
| 7 |
+
to fully deactivate it easily with a context manager.
|
| 8 |
+
Provides a simple activation checkpointing that is compatible with FSDP and torch compile.
|
| 9 |
+
Finally, provides some utilities for CUDA graphing functions.
|
| 10 |
+
"""
|
| 11 |
+
from contextlib import contextmanager
|
| 12 |
+
from functools import wraps
|
| 13 |
+
import inspect
|
| 14 |
+
import os
|
| 15 |
+
import typing as tp
|
| 16 |
+
|
| 17 |
+
import torch
|
| 18 |
+
from torch import cuda
|
| 19 |
+
|
| 20 |
+
|
| 21 |
+
_compile_disabled: bool = False
|
| 22 |
+
|
| 23 |
+
|
| 24 |
+
@contextmanager
|
| 25 |
+
def no_compile():
|
| 26 |
+
"""Disable torch.compile locally. Now Pytorch 2.4 provides a function to do that."""
|
| 27 |
+
global _compile_disabled
|
| 28 |
+
|
| 29 |
+
prev_disabled = _compile_disabled
|
| 30 |
+
_compile_disabled = True
|
| 31 |
+
try:
|
| 32 |
+
yield
|
| 33 |
+
finally:
|
| 34 |
+
_compile_disabled = prev_disabled
|
| 35 |
+
|
| 36 |
+
|
| 37 |
+
def torch_compile_lazy(fun):
|
| 38 |
+
"""torch.compile creates a huge pool of processes, even when not using the function at all,
|
| 39 |
+
e.g. with Dora. This can polute stderr when doing CTRL+C. So we do it in a lazy way.
|
| 40 |
+
"""
|
| 41 |
+
if os.environ.get("NO_TORCH_COMPILE"):
|
| 42 |
+
return fun
|
| 43 |
+
fun_compiled = None
|
| 44 |
+
|
| 45 |
+
@wraps(fun)
|
| 46 |
+
def _wrapped(*args, **kwargs):
|
| 47 |
+
nonlocal fun_compiled
|
| 48 |
+
if _compile_disabled:
|
| 49 |
+
return fun(*args, **kwargs)
|
| 50 |
+
if fun_compiled is None:
|
| 51 |
+
fun_compiled = torch.compile(fun)
|
| 52 |
+
return fun_compiled(*args, **kwargs)
|
| 53 |
+
|
| 54 |
+
return _wrapped
|
| 55 |
+
|
| 56 |
+
|
| 57 |
+
class Checkpoint(torch.autograd.Function):
|
| 58 |
+
@staticmethod
|
| 59 |
+
def forward(ctx, function, *args) -> tp.Any:
|
| 60 |
+
to_save = []
|
| 61 |
+
ctx.others = []
|
| 62 |
+
ctx.function = function
|
| 63 |
+
# Sources will indicate whether the arg in position N is
|
| 64 |
+
# a tensor stored in ctx.save_for_backward, or inside ctx.others.
|
| 65 |
+
ctx.sources = []
|
| 66 |
+
new_args = []
|
| 67 |
+
for arg in args:
|
| 68 |
+
if isinstance(arg, torch.Tensor):
|
| 69 |
+
to_save.append(arg)
|
| 70 |
+
ctx.sources.append("tensor")
|
| 71 |
+
new_args.append(arg.detach())
|
| 72 |
+
else:
|
| 73 |
+
ctx.sources.append("other")
|
| 74 |
+
ctx.others.append(arg)
|
| 75 |
+
new_args.append(arg)
|
| 76 |
+
ctx.save_for_backward(*to_save)
|
| 77 |
+
# During the forward, we just make a pass with no gradient computed.
|
| 78 |
+
with torch.no_grad():
|
| 79 |
+
res = function(*new_args)
|
| 80 |
+
return res
|
| 81 |
+
|
| 82 |
+
@staticmethod
|
| 83 |
+
def backward(ctx, *grads) -> tp.Tuple[tp.Optional[torch.Tensor], ...]:
|
| 84 |
+
pseudo_tensors = []
|
| 85 |
+
with torch.set_grad_enabled(True):
|
| 86 |
+
# We create leaf tensors to collect the output gradients.
|
| 87 |
+
# We call them pseudo_tensors because they are pretending to be the input
|
| 88 |
+
# to `function` but are not directly
|
| 89 |
+
for tensor in ctx.saved_tensors:
|
| 90 |
+
pseudo_tensor = tensor.detach()
|
| 91 |
+
pseudo_tensor.requires_grad_(True)
|
| 92 |
+
pseudo_tensors.append(pseudo_tensor)
|
| 93 |
+
pseudo_tensors_copy = list(pseudo_tensors)
|
| 94 |
+
args = []
|
| 95 |
+
for source in ctx.sources:
|
| 96 |
+
if source == "other":
|
| 97 |
+
args.append(ctx.others.pop(0))
|
| 98 |
+
else:
|
| 99 |
+
assert source == "tensor"
|
| 100 |
+
args.append(pseudo_tensors_copy.pop(0))
|
| 101 |
+
res = ctx.function(*args)
|
| 102 |
+
# The second forward with grad computation allows us to connect the input leaf tensors
|
| 103 |
+
# inside pseudo_tensors, to the outputs of the function called.
|
| 104 |
+
if not isinstance(res, tuple):
|
| 105 |
+
res = (res,)
|
| 106 |
+
# Now we just ask Torch to compute the derivative of `res` given the gradient coming from above
|
| 107 |
+
# `grads`. The computed gradient will end up into the `pseudo_tensors` grad attributes.
|
| 108 |
+
torch.autograd.backward(res, grads)
|
| 109 |
+
out: tp.List[tp.Optional[torch.Tensor]] = [None]
|
| 110 |
+
for source in ctx.sources:
|
| 111 |
+
# We still need to output `None` values for non tensor parameters.
|
| 112 |
+
if source == "other":
|
| 113 |
+
out.append(None)
|
| 114 |
+
else:
|
| 115 |
+
assert source == "tensor"
|
| 116 |
+
out.append(pseudo_tensors.pop(0).grad)
|
| 117 |
+
return tuple(out)
|
| 118 |
+
|
| 119 |
+
|
| 120 |
+
def simple_checkpoint(module: torch.nn.Module, *args, **kwargs):
|
| 121 |
+
"""Custom implementation of checkpointing in PyTorch as the builtin implementation is broken
|
| 122 |
+
when using torch compile. Only supports wrapping a `nn.Module` with a forward with no `*args` or `**kwargs`.
|
| 123 |
+
|
| 124 |
+
https://github.com/pytorch/pytorch/issues/97436.
|
| 125 |
+
Should be resolved in nightlies, but it is quite fun and simple to code it ourselves.
|
| 126 |
+
"""
|
| 127 |
+
if hasattr(module, "_fsdp_wrapped_module"):
|
| 128 |
+
module_for_sig = module._fsdp_wrapped_module
|
| 129 |
+
else:
|
| 130 |
+
module_for_sig = module
|
| 131 |
+
sig = inspect.signature(module_for_sig.forward)
|
| 132 |
+
# We first flatten all arguments to use only *args, to make things easier and because
|
| 133 |
+
# torch.autograd.Function has weird support for kwargs.
|
| 134 |
+
bounded = sig.bind(*args, **kwargs)
|
| 135 |
+
new_args = []
|
| 136 |
+
for name, param in sig.parameters.items():
|
| 137 |
+
if param.kind in {
|
| 138 |
+
inspect.Parameter.VAR_POSITIONAL,
|
| 139 |
+
inspect.Parameter.VAR_KEYWORD,
|
| 140 |
+
}:
|
| 141 |
+
raise RuntimeError("simple_checkpoint doesn't support var args.")
|
| 142 |
+
if name not in bounded.arguments:
|
| 143 |
+
break
|
| 144 |
+
new_args.append(bounded.arguments[name])
|
| 145 |
+
return Checkpoint.apply(module, *new_args)
|
| 146 |
+
|
| 147 |
+
|
| 148 |
+
_in_cuda_graph = False
|
| 149 |
+
_disable_cuda_graph = False
|
| 150 |
+
|
| 151 |
+
|
| 152 |
+
def in_cuda_graph() -> bool:
|
| 153 |
+
"""Indicate whether we are in a function that is CUDA Graphed (or will be soon)."""
|
| 154 |
+
return _in_cuda_graph
|
| 155 |
+
|
| 156 |
+
|
| 157 |
+
@contextmanager
|
| 158 |
+
def _set_in_cuda_graph():
|
| 159 |
+
global _in_cuda_graph
|
| 160 |
+
assert not _in_cuda_graph
|
| 161 |
+
_in_cuda_graph = True
|
| 162 |
+
try:
|
| 163 |
+
yield
|
| 164 |
+
finally:
|
| 165 |
+
_in_cuda_graph = False
|
| 166 |
+
|
| 167 |
+
|
| 168 |
+
def _is_cuda_graph_enabled() -> bool:
|
| 169 |
+
if _disable_cuda_graph:
|
| 170 |
+
return False
|
| 171 |
+
no_cuda_graph = os.environ.get("NO_CUDA_GRAPH", "")
|
| 172 |
+
if no_cuda_graph.lower() not in {"0", "no", "n", ""}:
|
| 173 |
+
return False
|
| 174 |
+
return True
|
| 175 |
+
|
| 176 |
+
|
| 177 |
+
@contextmanager
|
| 178 |
+
def no_cuda_graph():
|
| 179 |
+
"""Deactivate CUDA Graphing for all the calls in this context manager."""
|
| 180 |
+
global _disable_cuda_graph
|
| 181 |
+
old_value = _disable_cuda_graph
|
| 182 |
+
_disable_cuda_graph = True
|
| 183 |
+
try:
|
| 184 |
+
yield
|
| 185 |
+
finally:
|
| 186 |
+
_disable_cuda_graph = old_value
|
| 187 |
+
|
| 188 |
+
|
| 189 |
+
class CUDAGraphed:
|
| 190 |
+
"""Allow simple CUDA Graphing of a function.
|
| 191 |
+
|
| 192 |
+
Args:
|
| 193 |
+
func: callable, taking any number of arguments. Its tensors arguments should
|
| 194 |
+
be top level args, not nested in structures (tuples, dicts, etc). Keyword
|
| 195 |
+
arguments are NOT supported for simplicity.
|
| 196 |
+
warmup_steps: how many call to make normally before CUDA Graphing. In particular, this
|
| 197 |
+
allows torch.compiled functions to get properly compiled.
|
| 198 |
+
disabled: if True, just call the func directly, useful to quickly deactivate on CPU.
|
| 199 |
+
"""
|
| 200 |
+
|
| 201 |
+
def __init__(self, func: tp.Callable, warmup_steps: int = 1, disable: bool = False):
|
| 202 |
+
self.func = func
|
| 203 |
+
self.warmup_steps = warmup_steps
|
| 204 |
+
self.disable = disable
|
| 205 |
+
self._graph: cuda.CUDAGraph | None = None
|
| 206 |
+
self._output: tuple | None = None
|
| 207 |
+
self._args: tuple | None = None
|
| 208 |
+
|
| 209 |
+
def reset(self, warmup_steps: int = 0) -> None:
|
| 210 |
+
"""Reset the state, meaning the next call we get CUDA Graphed again. Useful if some
|
| 211 |
+
shapes have changed, or external state (e.g. KVCache) has changed."""
|
| 212 |
+
self.warmup_steps = warmup_steps
|
| 213 |
+
self._graph = None
|
| 214 |
+
self._output = None
|
| 215 |
+
self._args = None
|
| 216 |
+
|
| 217 |
+
def __call__(self, *args, **kwargs) -> tp.Any:
|
| 218 |
+
if kwargs:
|
| 219 |
+
raise RuntimeError("Named arguments not supported for now.")
|
| 220 |
+
if self.disable or not _is_cuda_graph_enabled() or in_cuda_graph():
|
| 221 |
+
return self.func(*args, **kwargs)
|
| 222 |
+
|
| 223 |
+
def _clone_tensors(args: tuple) -> tuple:
|
| 224 |
+
out: list = []
|
| 225 |
+
for arg in args:
|
| 226 |
+
if isinstance(arg, torch.Tensor):
|
| 227 |
+
arg = arg.clone()
|
| 228 |
+
out.append(arg)
|
| 229 |
+
return tuple(out)
|
| 230 |
+
|
| 231 |
+
def _match_values_copy_tensors(args: tuple, target_args: tuple) -> None:
|
| 232 |
+
if len(args) != len(target_args):
|
| 233 |
+
raise ValueError(
|
| 234 |
+
f"Expected {len(target_args)}, but got {args} for CUDA Graphed function."
|
| 235 |
+
)
|
| 236 |
+
for idx, (source, target) in enumerate(zip(args, target_args)):
|
| 237 |
+
if isinstance(target, torch.Tensor):
|
| 238 |
+
if not isinstance(source, torch.Tensor):
|
| 239 |
+
raise ValueError(
|
| 240 |
+
f"Argument #{idx} was a tensor, and is no longer (now {source})."
|
| 241 |
+
)
|
| 242 |
+
if source.shape != target.shape:
|
| 243 |
+
raise ValueError(
|
| 244 |
+
f"Argument #{idx} had shape {target.shape}, but got shae {source.shape}"
|
| 245 |
+
)
|
| 246 |
+
target.copy_(source)
|
| 247 |
+
else:
|
| 248 |
+
if isinstance(source, torch.Tensor):
|
| 249 |
+
raise ValueError(
|
| 250 |
+
f"Argument #{idx} was not a tensor {target}, but is now one."
|
| 251 |
+
)
|
| 252 |
+
if source is not target and source != target:
|
| 253 |
+
raise ValueError(
|
| 254 |
+
f"Argument #{idx} changed value from {target} to {source}."
|
| 255 |
+
)
|
| 256 |
+
|
| 257 |
+
with _set_in_cuda_graph():
|
| 258 |
+
# Prevent any one under us to try and CUDA Graph things.
|
| 259 |
+
if self._graph is None:
|
| 260 |
+
if self.warmup_steps <= 0:
|
| 261 |
+
self._graph = cuda.CUDAGraph()
|
| 262 |
+
# Making a copy just to ensure those are not used else where.
|
| 263 |
+
self._args = _clone_tensors(args)
|
| 264 |
+
with cuda.graph(self._graph):
|
| 265 |
+
self._output = self.func(*self._args)
|
| 266 |
+
# At this point nothing really happened, so we have to make it run for real.
|
| 267 |
+
self._graph.replay()
|
| 268 |
+
return self._output
|
| 269 |
+
else:
|
| 270 |
+
self.warmup_steps -= 1
|
| 271 |
+
return self.func(*args)
|
| 272 |
+
else:
|
| 273 |
+
assert self._args is not None
|
| 274 |
+
assert self._output is not None
|
| 275 |
+
_match_values_copy_tensors(args, self._args)
|
| 276 |
+
self._graph.replay()
|
| 277 |
+
return self._output
|
| 278 |
+
|
| 279 |
+
|
| 280 |
+
def cuda_graph(func: tp.Callable, warmup_steps: int = 1):
|
| 281 |
+
"""Just calls `CUDAGraphed` on the given function."""
|
| 282 |
+
if not _is_cuda_graph_enabled():
|
| 283 |
+
return func
|
| 284 |
+
return CUDAGraphed(func, warmup_steps)
|
moshi/utils/sampling.py
ADDED
|
@@ -0,0 +1,126 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) Kyutai, all rights reserved.
|
| 2 |
+
# This source code is licensed under the license found in the
|
| 3 |
+
# LICENSE file in the root directory of this source tree.
|
| 4 |
+
|
| 5 |
+
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
| 6 |
+
# All rights reserved.
|
| 7 |
+
#
|
| 8 |
+
# This source code is licensed under the license found in the
|
| 9 |
+
# LICENSE file in the root directory of this source tree.
|
| 10 |
+
|
| 11 |
+
|
| 12 |
+
import torch
|
| 13 |
+
|
| 14 |
+
|
| 15 |
+
def multinomial(
|
| 16 |
+
input: torch.Tensor, num_samples: int, replacement=False, *, generator=None
|
| 17 |
+
):
|
| 18 |
+
"""torch.multinomial with arbitrary number of dimensions, and number of candidates on the last dimension.
|
| 19 |
+
|
| 20 |
+
Args:
|
| 21 |
+
input (torch.Tensor): The input tensor containing probabilities.
|
| 22 |
+
num_samples (int): Number of samples to draw.
|
| 23 |
+
replacement (bool): Whether to draw with replacement or not.
|
| 24 |
+
Keywords args:
|
| 25 |
+
generator (torch.Generator): A pseudorandom number generator for sampling.
|
| 26 |
+
Returns:
|
| 27 |
+
torch.Tensor: Last dimension contains num_samples indices
|
| 28 |
+
sampled from the multinomial probability distribution
|
| 29 |
+
located in the last dimension of tensor input.
|
| 30 |
+
"""
|
| 31 |
+
input_ = input.reshape(-1, input.shape[-1])
|
| 32 |
+
# We should probably be able to remove this once the following PR has landed:
|
| 33 |
+
# https://github.com/pytorch/pytorch/pull/134818/files
|
| 34 |
+
# In the meantime, we specialize the case no-replacement, nsamples=1 so as not
|
| 35 |
+
# to have a synchronization point.
|
| 36 |
+
if replacement or num_samples != 1:
|
| 37 |
+
output_ = torch.multinomial(
|
| 38 |
+
input_,
|
| 39 |
+
num_samples=num_samples,
|
| 40 |
+
replacement=replacement,
|
| 41 |
+
generator=generator,
|
| 42 |
+
)
|
| 43 |
+
else:
|
| 44 |
+
q = torch.empty_like(input_).exponential_(1, generator=generator)
|
| 45 |
+
q = input_ / q
|
| 46 |
+
output_ = q.argmax(dim=-1, keepdim=True)
|
| 47 |
+
output = output_.reshape(*list(input.shape[:-1]), -1)
|
| 48 |
+
return output
|
| 49 |
+
|
| 50 |
+
|
| 51 |
+
def sample_top_k(probs: torch.Tensor, k: int) -> torch.Tensor:
|
| 52 |
+
"""Sample next token from top K values along the last dimension of the input probs tensor.
|
| 53 |
+
|
| 54 |
+
Args:
|
| 55 |
+
probs (torch.Tensor): Input probabilities with token candidates on the last dimension.
|
| 56 |
+
k (int): The k in “top-k”.
|
| 57 |
+
Returns:
|
| 58 |
+
torch.Tensor: Sampled tokens.
|
| 59 |
+
"""
|
| 60 |
+
probs, indices = torch.topk(probs, k, dim=-1)
|
| 61 |
+
next_token = multinomial(probs, num_samples=1)
|
| 62 |
+
next_token = indices.gather(-1, next_token)
|
| 63 |
+
return next_token
|
| 64 |
+
|
| 65 |
+
|
| 66 |
+
def sample_top_p(probs: torch.Tensor, p: float) -> torch.Tensor:
|
| 67 |
+
"""Sample next token from top P probabilities along the last dimension of the input probs tensor.
|
| 68 |
+
|
| 69 |
+
Args:
|
| 70 |
+
probs (torch.Tensor): Input probabilities with token candidates on the last dimension.
|
| 71 |
+
p (int): The p in “top-p”.
|
| 72 |
+
Returns:
|
| 73 |
+
torch.Tensor: Sampled tokens.
|
| 74 |
+
"""
|
| 75 |
+
probs_sort, probs_idx = torch.sort(probs, dim=-1, descending=True)
|
| 76 |
+
probs_sum = torch.cumsum(probs_sort, dim=-1)
|
| 77 |
+
mask = probs_sum - probs_sort > p
|
| 78 |
+
probs_sort *= (~mask).float()
|
| 79 |
+
probs_sort.div_(probs_sort.sum(dim=-1, keepdim=True))
|
| 80 |
+
next_token = multinomial(probs_sort, num_samples=1)
|
| 81 |
+
next_token = torch.gather(probs_idx, -1, next_token)
|
| 82 |
+
return next_token
|
| 83 |
+
|
| 84 |
+
|
| 85 |
+
def sample_token(
|
| 86 |
+
logits: torch.Tensor,
|
| 87 |
+
use_sampling: bool = False,
|
| 88 |
+
temp: float = 1.0,
|
| 89 |
+
top_k: int = 0,
|
| 90 |
+
top_p: float = 0.0,
|
| 91 |
+
) -> torch.Tensor:
|
| 92 |
+
"""Given logits of shape [*, Card], returns a LongTensor of shape [*]."""
|
| 93 |
+
# Apply softmax for sampling if temp > 0. Else, do greedy sampling to avoid zero division error.
|
| 94 |
+
if use_sampling and temp > 0.0:
|
| 95 |
+
probs = torch.softmax(logits / temp, dim=-1)
|
| 96 |
+
if top_p > 0.0:
|
| 97 |
+
next_token = sample_top_p(probs, p=top_p)
|
| 98 |
+
elif top_k > 0:
|
| 99 |
+
next_token = sample_top_k(probs, k=top_k)
|
| 100 |
+
else:
|
| 101 |
+
next_token = multinomial(probs, num_samples=1)
|
| 102 |
+
else:
|
| 103 |
+
next_token = torch.argmax(logits, dim=-1, keepdim=True)
|
| 104 |
+
assert next_token.shape[-1] == 1
|
| 105 |
+
return next_token[..., 0]
|
| 106 |
+
|
| 107 |
+
|
| 108 |
+
if __name__ == "__main__":
|
| 109 |
+
torch.manual_seed(1234)
|
| 110 |
+
device = "cpu"
|
| 111 |
+
if torch.cuda.is_available():
|
| 112 |
+
torch.backends.cuda.matmul.allow_tf32 = False
|
| 113 |
+
torch.backends.cudnn.allow_tf32 = False
|
| 114 |
+
device = "cuda:0"
|
| 115 |
+
|
| 116 |
+
ps = torch.tensor([5.0, 2.0, 12.0, 6.0, 8.0, 1.0, 0.0, 4.0], device=device)
|
| 117 |
+
cnts = torch.zeros(ps.shape, dtype=torch.long, device=device)
|
| 118 |
+
total_samples = 1000
|
| 119 |
+
for _ in range(total_samples):
|
| 120 |
+
vs = multinomial(ps, num_samples=1, replacement=False)
|
| 121 |
+
cnts[vs] += 1
|
| 122 |
+
diff = cnts / cnts.sum() - ps / ps.sum()
|
| 123 |
+
max_diff = diff.abs().max().cpu().item()
|
| 124 |
+
print(ps / ps.sum())
|
| 125 |
+
print(cnts / cnts.sum())
|
| 126 |
+
assert max_diff < 1.5e-2
|
pyproject.toml
ADDED
|
@@ -0,0 +1,33 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
[project]
|
| 2 |
+
name = "moshi"
|
| 3 |
+
requires-python = ">= 3.10"
|
| 4 |
+
description = "Moshi is moshi"
|
| 5 |
+
dependencies = [
|
| 6 |
+
"numpy >= 1.26, < 2.2",
|
| 7 |
+
"safetensors >= 0.4.0, < 0.5",
|
| 8 |
+
"huggingface-hub >= 0.24, < 0.25",
|
| 9 |
+
"einops == 0.7",
|
| 10 |
+
"sentencepiece == 0.2",
|
| 11 |
+
"sounddevice == 0.5",
|
| 12 |
+
"sphn >= 0.1.4",
|
| 13 |
+
"torch >= 2.2.0, < 2.5",
|
| 14 |
+
"aiohttp>=3.10.5, <3.11",
|
| 15 |
+
]
|
| 16 |
+
authors = [{name="Laurent Mazaré", email="laurent@kyutai.org"}]
|
| 17 |
+
maintainers = [{name="Laurent Mazaré", email="laurent@kyutai.org"}]
|
| 18 |
+
license = {text = "MIT"}
|
| 19 |
+
dynamic = ["version"]
|
| 20 |
+
|
| 21 |
+
[tool.setuptools.dynamic]
|
| 22 |
+
version = {attr = "moshi.__version__"}
|
| 23 |
+
|
| 24 |
+
[build-system]
|
| 25 |
+
requires = ["setuptools"]
|
| 26 |
+
build-backend = "setuptools.build_meta"
|
| 27 |
+
|
| 28 |
+
[project.optional-dependencies]
|
| 29 |
+
dev = [
|
| 30 |
+
"pyright",
|
| 31 |
+
"flake8",
|
| 32 |
+
"pre-commit",
|
| 33 |
+
]
|
requirements.txt
ADDED
|
@@ -0,0 +1,10 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
einops==0.7.0
|
| 2 |
+
safetensors=0.4.4
|
| 3 |
+
sentencepiece==0.2.0
|
| 4 |
+
sounddevice==0.5.0
|
| 5 |
+
soundfile==0.12.1
|
| 6 |
+
sphn==0.1.4
|
| 7 |
+
torch==2.2.0
|
| 8 |
+
numpy==1.26.4
|
| 9 |
+
aiohttp>=3.10.5, <3.11
|
| 10 |
+
huggingface-hub==0.24.6
|
setup.cfg
ADDED
|
@@ -0,0 +1,10 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
[pep8]
|
| 2 |
+
max-line-length = 120
|
| 3 |
+
|
| 4 |
+
[flake8]
|
| 5 |
+
max-line-length = 120
|
| 6 |
+
ignore = E203,E704
|
| 7 |
+
exclude =
|
| 8 |
+
dist
|
| 9 |
+
build
|
| 10 |
+
|