andy hickl commited on
Commit
ed99557
·
1 Parent(s): bfcda41

Uploaded from Github

Browse files
LICENSE ADDED
@@ -0,0 +1,23 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ Permission is hereby granted, free of charge, to any
2
+ person obtaining a copy of this software and associated
3
+ documentation files (the "Software"), to deal in the
4
+ Software without restriction, including without
5
+ limitation the rights to use, copy, modify, merge,
6
+ publish, distribute, sublicense, and/or sell copies of
7
+ the Software, and to permit persons to whom the Software
8
+ is furnished to do so, subject to the following
9
+ conditions:
10
+
11
+ The above copyright notice and this permission notice
12
+ shall be included in all copies or substantial portions
13
+ of the Software.
14
+
15
+ THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF
16
+ ANY KIND, EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED
17
+ TO THE WARRANTIES OF MERCHANTABILITY, FITNESS FOR A
18
+ PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT
19
+ SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY
20
+ CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION
21
+ OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR
22
+ IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER
23
+ DEALINGS IN THE SOFTWARE.
LICENSE.audiocraft ADDED
@@ -0,0 +1,21 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ MIT License
2
+
3
+ Copyright (c) Meta Platforms, Inc. and affiliates.
4
+
5
+ Permission is hereby granted, free of charge, to any person obtaining a copy
6
+ of this software and associated documentation files (the "Software"), to deal
7
+ in the Software without restriction, including without limitation the rights
8
+ to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9
+ copies of the Software, and to permit persons to whom the Software is
10
+ furnished to do so, subject to the following conditions:
11
+
12
+ The above copyright notice and this permission notice shall be included in all
13
+ copies or substantial portions of the Software.
14
+
15
+ THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16
+ IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17
+ FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
18
+ AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19
+ LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20
+ OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
21
+ SOFTWARE.
MANIFEST.in ADDED
@@ -0,0 +1,5 @@
 
 
 
 
 
 
1
+ include LICENSE*
2
+ include *.md
3
+ include *.cfg
4
+ include requirements.txt
5
+ include moshi/py.typed
README.md ADDED
@@ -0,0 +1,151 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Moshi - PyTorch
2
+
3
+ See the [top-level README.md][main_repo] for more information on Moshi.
4
+
5
+ [Moshi][moshi] is a speech-text foundation model and full-duplex spoken dialogue framework.
6
+ It uses [Mimi][moshi], a state-of-the-art streaming neural audio codec. Mimi operates at 12.5 Hz, and compresses
7
+ 24 kHz audio down to 1.1 kbps, in a fully streaming manner (latency of 80ms, the frame size), yet performs better than existing, non-streaming, codec.
8
+
9
+ This is the PyTorch implementation for Moshi and Mimi.
10
+
11
+
12
+ ## Requirements
13
+
14
+ You will need at least Python 3.10. We kept a minimal set of dependencies for the current project.
15
+ It was tested with PyTorch 2.2 or 2.4. If you need a specific CUDA version, please make sure
16
+ to have PyTorch properly installed before installing Moshi.
17
+
18
+ ```bash
19
+ pip install moshi # moshi PyTorch, from PyPI
20
+ # Or the bleeding edge versions for Moshi
21
+ pip install -e "git+https://git@github.com/kyutai-labs/moshi#egg=moshi&subdirectory=moshi"
22
+ ```
23
+
24
+ While we hope that the present codebase will work on Windows, we do not provide official support for it.
25
+ At the moment, we do not support quantization for the PyTorch version, so you will need a GPU with a significant amount of memory (24GB).
26
+
27
+
28
+ ## Usage
29
+
30
+ This package provides a streaming version of the audio tokenizer (Mimi) and the lm model (Moshi).
31
+
32
+ In order to run in interactive mode, you need to start a server which will
33
+ run the model, you can then use either the web UI or a command line client.
34
+
35
+ Start the server with:
36
+ ```bash
37
+ python -m moshi.server [--gradio-tunnel]
38
+ ```
39
+
40
+ And then access the web UI on [localhost:8998](http://localhost:8998). If your GPU is on a distant machine
41
+ with no direct access, `--gradio-tunnel` will create a tunnel with a URL accessible from anywhere.
42
+ Keep in mind that this tunnel goes through the US and can add significant latency (up to 500ms from Europe).
43
+ You can use `--gradio-tunnel-token` to set a fixed secret token and reuse the same address over time.
44
+ Alternatively, you might want to use SSH to redirect your connection.
45
+
46
+ You can use `--hf-repo` to select a different pretrained model, by setting the proper Hugging Face repository.
47
+ See [the model list](https://github.com/kyutai-labs/moshi?tab=readme-ov-file#models) for a reference of the available models.
48
+
49
+ Accessing a server that is not localhost via http may cause issues with using
50
+ the microphone in the web UI (in some browsers this is only allowed using
51
+ https).
52
+
53
+ A local client is also available, as
54
+ ```bash
55
+ python -m moshi.client [--url URL_TO_GRADIO]
56
+ ```
57
+ However note, that unlike the web browser, this client is barebone. It does not perform any echo cancellation,
58
+ nor does it try to compensate for a growing lag by skipping frames.
59
+
60
+
61
+ ## API
62
+
63
+ You can use programmatically the Mimi/Moshi as follows:
64
+ ```python
65
+ from huggingface_hub import hf_hub_download
66
+ import torch
67
+
68
+ from moshi.models import loaders, LMGen
69
+
70
+ mimi_weight = hf_hub_download(loaders.DEFAULT_REPO, loaders.MIMI_NAME)
71
+ mimi = loaders.get_mimi(mimi_weight, device='cpu')
72
+ mimi.set_num_codebooks(8) # up to 32 for mimi, but limited to 8 for moshi.
73
+
74
+ wav = torch.randn(1, 1, 24000 * 10) # should be [B, C=1, T]
75
+ with torch.no_grad():
76
+ codes = mimi.encode(wav) # [B, K = 8, T]
77
+ decoded = mimi.decode(codes)
78
+
79
+ # Supports streaming too.
80
+ frame_size = int(mimi.sample_rate / mimi.frame_rate)
81
+ all_codes = []
82
+ with mimi.streaming(batch_size=1):
83
+ for offset in range(0, wav.shape[-1], frame_size):
84
+ frame = wav[:, :, offset: offset + frame_size]
85
+ codes = mimi.encode(frame)
86
+ assert codes.shape[-1] == 1, codes.shape
87
+ all_codes.append(codes)
88
+ # Now if you have a GPU around.
89
+ mimi.cuda()
90
+ moshi_weight = hf_hub_download(loaders.DEFAULT_REPO, loaders.MOSHI_NAME)
91
+ moshi = loaders.get_moshi_lm(moshi_weight, device='cuda')
92
+ lm_gen = LMGen(moshi, temp=0.8, temp_text=0.7) # this handles sampling params etc.
93
+ out_wav_chunks = []
94
+ # Now we will stream over both Moshi I/O, and decode on the fly with Mimi.
95
+ with torch.no_grad(), lm_gen.streaming(1), mimi.streaming(1):
96
+ for idx, code in enumerate(all_codes):
97
+ tokens_out = lm_gen.step(code.cuda())
98
+ # tokens_out is [B, 1 + 8, 1], with tokens_out[:, 1] representing the text token.
99
+ if tokens_out is not None:
100
+ wav_chunk = mimi.decode(tokens_out[:, 1:])
101
+ out_wav_chunks.append(wav_chunk)
102
+ print(idx, end='\r')
103
+ out_wav = torch.cat(out_wav_chunks, dim=-1)
104
+ ```
105
+
106
+ ## Development
107
+
108
+ If you wish to install from a clone of this repository, maybe to further develop Moshi, you can do the following:
109
+ ```bash
110
+ # From the current folder (e.g. `moshi/`)
111
+ pip install -e '.[dev]'
112
+ pre-commit install
113
+ ```
114
+
115
+ Once locally installed, Mimi can be tested with the following command, from **the root** of the repository,
116
+ ```bash
117
+ wget https://github.com/metavoiceio/metavoice-src/raw/main/assets/bria.mp3
118
+ python scripts/mimi_streaming_test.py
119
+
120
+ ```
121
+
122
+ Similary, Moshi can be tested (with a GPU) with
123
+ ```bash
124
+ python scripts/moshi_benchmark.py
125
+ ```
126
+
127
+
128
+ ## License
129
+
130
+ The present code is provided under the MIT license.
131
+ Note that parts of this code is based on [AudioCraft](https://github.com/facebookresearch/audiocraft), released under
132
+ the MIT license.
133
+
134
+ ## Citation
135
+
136
+ If you use either Mimi or Moshi, please cite the following paper,
137
+
138
+ ```
139
+ @techreport{kyutai2024moshi,
140
+ author = {Alexandre D\'efossez and Laurent Mazar\'e and Manu Orsini and Am\'elie Royer and
141
+ Patrick P\'erez and Herv\'e J\'egou and Edouard Grave and Neil Zeghidour},
142
+ title = {Moshi: a speech-text foundation model for real-time dialogue},
143
+ institution = {Kyutai},
144
+ year={2024},
145
+ month={September},
146
+ url={http://kyutai.org/Moshi.pdf},
147
+ }
148
+ ```
149
+
150
+ [moshi]: https://kyutai.org/Moshi.pdf
151
+ [main_repo]: https://github.com/kyutai-labs/moshi
moshi/__init__.py ADDED
@@ -0,0 +1,18 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Kyutai, all rights reserved.
2
+ # This source code is licensed under the license found in the
3
+ # LICENSE file in the root directory of this source tree.
4
+
5
+ """
6
+ moshi is the inference codebase for Kyutai audio generation models.
7
+
8
+ The code has been adapted from Audiocraft, see LICENSE.audiocraft
9
+ Copyright (c) Meta Platforms, Inc. and affiliates.
10
+ """
11
+
12
+ # flake8: noqa
13
+ from . import utils
14
+ from . import modules
15
+ from . import models
16
+ from . import quantization
17
+
18
+ __version__ = "0.1.0"
moshi/client.py ADDED
@@ -0,0 +1,196 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Kyutai, all rights reserved.
2
+ # This source code is licensed under the license found in the
3
+ # LICENSE file in the root directory of this source tree.
4
+ """Client for the Moshi server."""
5
+
6
+ import argparse
7
+ import asyncio
8
+ import queue
9
+ import sys
10
+
11
+ import aiohttp
12
+ import numpy as np
13
+ import sphn
14
+ import sounddevice as sd
15
+
16
+ from .client_utils import AnyPrinter, Printer, RawPrinter
17
+
18
+
19
+ class Connection:
20
+ def __init__(
21
+ self,
22
+ printer: AnyPrinter,
23
+ websocket: aiohttp.ClientWebSocketResponse,
24
+ sample_rate: float = 24000,
25
+ channels: int = 1,
26
+ frame_size: int = 1920,
27
+ ) -> None:
28
+ self.printer = printer
29
+ self.websocket = websocket
30
+ self.sample_rate = sample_rate
31
+ self.frame_size = frame_size
32
+ self.channels = channels
33
+
34
+ self._done = False
35
+ self._in_stream = sd.InputStream(
36
+ samplerate=sample_rate,
37
+ channels=channels,
38
+ blocksize=self.frame_size,
39
+ callback=self._on_audio_input,
40
+ )
41
+
42
+ self._out_stream = sd.OutputStream(
43
+ samplerate=sample_rate,
44
+ channels=channels,
45
+ blocksize=frame_size,
46
+ callback=self._on_audio_output,
47
+ )
48
+ self._opus_writer = sphn.OpusStreamWriter(sample_rate)
49
+ self._opus_reader = sphn.OpusStreamReader(sample_rate)
50
+ self._output_queue = queue.Queue()
51
+
52
+ async def _queue_loop(self) -> None:
53
+ while True:
54
+ if self._done:
55
+ return
56
+ await asyncio.sleep(0.001)
57
+ msg = self._opus_writer.read_bytes()
58
+ if len(msg) > 0:
59
+ try:
60
+ await self.websocket.send_bytes(b"\x01" + msg)
61
+ except Exception as e:
62
+ print(e)
63
+ self._lost_connection()
64
+ return
65
+
66
+ async def _decoder_loop(self) -> None:
67
+ all_pcm_data = None
68
+ while True:
69
+ if self._done:
70
+ return
71
+ await asyncio.sleep(0.001)
72
+ pcm = self._opus_reader.read_pcm()
73
+ if all_pcm_data is None:
74
+ all_pcm_data = pcm
75
+ else:
76
+ all_pcm_data = np.concatenate((all_pcm_data, pcm))
77
+ while all_pcm_data.shape[-1] >= self.frame_size:
78
+ self._output_queue.put(all_pcm_data[: self.frame_size])
79
+ all_pcm_data = np.array(all_pcm_data[self.frame_size :])
80
+
81
+ async def _recv_loop(self) -> None:
82
+ try:
83
+ async for message in self.websocket:
84
+ if message.type == aiohttp.WSMsgType.CLOSED:
85
+ self.printer.log("info", "Connection closed")
86
+ break
87
+ elif message.type == aiohttp.WSMsgType.ERROR:
88
+ self.printer.log("error", f"{self.websocket.exception()}")
89
+ break
90
+ elif message.type != aiohttp.WSMsgType.BINARY:
91
+ self.printer.log("error", f"received from server: {message.type}")
92
+ continue
93
+ message = message.data
94
+ if not isinstance(message, bytes):
95
+ self.printer.log(
96
+ "warning", f"unsupported message type {type(message)}"
97
+ )
98
+ continue
99
+ if len(message) == 0:
100
+ self.printer.log("warning", "empty message")
101
+ continue
102
+ kind = message[0]
103
+ if kind == 1: # audio
104
+ payload = message[1:]
105
+ self._opus_reader.append_bytes(payload)
106
+ self.printer.print_pending()
107
+ elif kind == 2: # text
108
+ payload = message[1:]
109
+ self.printer.print_token(payload.decode())
110
+ else:
111
+ self.printer.log("warning", f"unknown message kind {kind}")
112
+ except Exception as e:
113
+ print(e)
114
+ self._lost_connection()
115
+ return
116
+
117
+ def _lost_connection(self) -> None:
118
+ if not self._done:
119
+ self.printer.log("error", "Lost connection with the server!")
120
+ self._done = True
121
+
122
+ def _on_audio_input(self, in_data, frames, time_, status) -> None:
123
+ assert in_data.shape == (self.frame_size, self.channels), in_data.shape
124
+ self._opus_writer.append_pcm(in_data[:, 0])
125
+
126
+ def _on_audio_output(self, out_data, frames, time_, status) -> None:
127
+ assert out_data.shape == (self.frame_size, self.channels), out_data.shape
128
+ try:
129
+ pcm_data = self._output_queue.get(block=False)
130
+ # TODO: handle other shapes by using some form of fifo/ring buffer.
131
+ assert pcm_data.shape == (self.frame_size,), pcm_data.shape
132
+ out_data[:, 0] = pcm_data
133
+ except queue.Empty:
134
+ out_data.fill(0)
135
+ self.printer.print_lag()
136
+
137
+ async def run(self) -> None:
138
+ with self._in_stream, self._out_stream:
139
+ await asyncio.gather(
140
+ self._recv_loop(), self._decoder_loop(), self._queue_loop()
141
+ )
142
+
143
+
144
+ async def run(printer: AnyPrinter, args):
145
+ if args.url is None:
146
+ proto = "ws"
147
+ if args.https:
148
+ proto += "s"
149
+ uri = f"{proto}://{args.host}:{args.port}/api/chat"
150
+ else:
151
+ proto = "wss"
152
+ if '://' in args.url:
153
+ proto, without_proto = args.url.split('://', 1)
154
+ if proto in ['ws', 'http']:
155
+ proto = "ws"
156
+ elif proto in ['wss', 'https']:
157
+ proto = "wss"
158
+ else:
159
+ printer.log("error", "The provided URL {args.url} seems to contain a protocol but it is unknown.")
160
+ sys.exit(1)
161
+ else:
162
+ without_proto = args.url
163
+ uri = f"{proto}://{without_proto}/api/chat"
164
+
165
+ printer.log("info", "Connecting to {uri}.")
166
+ async with aiohttp.ClientSession() as session:
167
+ async with session.ws_connect(uri) as ws:
168
+ printer.log("info", "connected!")
169
+ printer.print_header()
170
+ connection = Connection(printer, ws)
171
+ await connection.run()
172
+
173
+
174
+ def main():
175
+ parser = argparse.ArgumentParser("client_opus")
176
+ parser.add_argument("--host", default="localhost", type=str, help="Hostname to connect to.")
177
+ parser.add_argument("--port", default=8998, type=int, help="Port to connect to.")
178
+ parser.add_argument("--https", action='store_true',
179
+ help="Set this flag for using a https connection.")
180
+ parser.add_argument("--url", type=str, help='Provides directly a URL, e.g. to a gradio tunnel.')
181
+ args = parser.parse_args()
182
+ printer: AnyPrinter
183
+
184
+ if sys.stdout.isatty():
185
+ printer = Printer()
186
+ else:
187
+ printer = RawPrinter()
188
+ try:
189
+ asyncio.run(run(printer, args))
190
+ except KeyboardInterrupt:
191
+ printer.log("warning", "Interrupting, exiting connection.")
192
+ printer.log("info", "All done!")
193
+
194
+
195
+ if __name__ == "__main__":
196
+ main()
moshi/client_utils.py ADDED
@@ -0,0 +1,211 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Kyutai, all rights reserved.
2
+ # This source code is licensed under the license found in the
3
+ # LICENSE file in the root directory of this source tree.
4
+ """Utilities for the command line client, in particular for handling interactions with the terminal.
5
+ """
6
+
7
+ from dataclasses import dataclass
8
+ import sys
9
+
10
+
11
+ def colorize(text, color):
12
+ code = f"\033[{color}m"
13
+ restore = "\033[0m"
14
+ return "".join([code, text, restore])
15
+
16
+
17
+ def make_log(level: str, msg: str) -> str:
18
+ if level == "warning":
19
+ prefix = colorize("[Warn]", "1;31")
20
+ elif level == "info":
21
+ prefix = colorize("[Info]", "1;34")
22
+ elif level == "error":
23
+ prefix = colorize("[Err ]", "1;31")
24
+ else:
25
+ raise ValueError(f"Unknown level {level}")
26
+ return prefix + " " + msg
27
+
28
+
29
+ class RawPrinter:
30
+ def __init__(self, stream=sys.stdout, err_stream=sys.stderr):
31
+ self.stream = stream
32
+ self.err_stream = err_stream
33
+
34
+ def print_header(self):
35
+ pass
36
+
37
+ def print_token(self, token: str):
38
+ self.stream.write(token)
39
+ self.stream.flush()
40
+
41
+ def log(self, level: str, msg: str):
42
+ print(f"{level.capitalize()}: {msg}", file=self.err_stream)
43
+
44
+ def print_lag(self):
45
+ self.err_stream.write(colorize(" [LAG]", "31"))
46
+ self.err_stream.flush()
47
+
48
+ def print_pending(self):
49
+ pass
50
+
51
+
52
+ @dataclass
53
+ class LineEntry:
54
+ msg: str
55
+ color: str | None = None
56
+
57
+ def render(self):
58
+ if self.color is None:
59
+ return self.msg
60
+ else:
61
+ return colorize(self.msg, self.color)
62
+
63
+ def __len__(self):
64
+ return len(self.msg)
65
+
66
+
67
+ class Line:
68
+ def __init__(self, stream):
69
+ self.stream = stream
70
+ self._line: list[LineEntry] = []
71
+ self._has_padding: bool = False
72
+ self._max_line_length = 0
73
+
74
+ def __bool__(self):
75
+ return bool(self._line)
76
+
77
+ def __len__(self):
78
+ return sum(len(entry) for entry in self._line)
79
+
80
+ def add(self, msg: str, color: str | None = None) -> int:
81
+ entry = LineEntry(msg, color)
82
+ return self._add(entry)
83
+
84
+ def _add(self, entry: LineEntry) -> int:
85
+ if self._has_padding:
86
+ self.erase(count=0)
87
+ self._line.append(entry)
88
+ self.stream.write(entry.render())
89
+ self._max_line_length = max(self._max_line_length, len(self))
90
+ return len(entry)
91
+
92
+ def erase(self, count: int = 1):
93
+ if count:
94
+ entries = list(self._line[:-count])
95
+ else:
96
+ entries = list(self._line)
97
+ self._line.clear()
98
+ self.stream.write("\r")
99
+ for entry in entries:
100
+ self._line.append(entry)
101
+ self.stream.write(entry.render())
102
+
103
+ self._has_padding = False
104
+
105
+ def newline(self):
106
+ missing = self._max_line_length - len(self)
107
+ if missing > 0:
108
+ self.stream.write(" " * missing)
109
+ self.stream.write("\n")
110
+ self._line.clear()
111
+ self._max_line_length = 0
112
+ self._has_padding = False
113
+
114
+ def flush(self):
115
+ missing = self._max_line_length - len(self)
116
+ if missing > 0:
117
+ self.stream.write(" " * missing)
118
+ self._has_padding = True
119
+ self.stream.flush()
120
+
121
+
122
+ class Printer:
123
+ def __init__(self, max_cols: int = 80, stream=sys.stdout, err_stream=sys.stderr):
124
+ self.max_cols = max_cols
125
+ self.line = Line(stream)
126
+ self.stream = stream
127
+ self.err_stream = err_stream
128
+ self._pending_count = 0
129
+ self._pending_printed = False
130
+
131
+ def print_header(self):
132
+ self.line.add(" " + "-" * (self.max_cols) + " ")
133
+ self.line.newline()
134
+ self.line.flush()
135
+ self.line.add("| ")
136
+
137
+ def _remove_pending(self) -> bool:
138
+ if self._pending_printed:
139
+ self._pending_printed = False
140
+ self.line.erase(1)
141
+ return True
142
+ return False
143
+
144
+ def print_token(self, token: str, color: str | None = None):
145
+ self._remove_pending()
146
+ remaining = self.max_cols - len(self.line)
147
+ if len(token) <= remaining:
148
+ self.line.add(token, color)
149
+ else:
150
+ end = " " * remaining + " |"
151
+ if token.startswith(" "):
152
+ token = token.lstrip()
153
+ self.line.add(end)
154
+ self.line.newline()
155
+ self.line.add("| ")
156
+ self.line.add(token, color)
157
+ else:
158
+ assert color is None
159
+ erase_count = None
160
+ cumulated = ""
161
+ for idx, entry in enumerate(self.line._line[::-1]):
162
+ if entry.color:
163
+ # probably a LAG message
164
+ erase_count = idx
165
+ break
166
+ if entry.msg.startswith(" "):
167
+ erase_count = idx + 1
168
+ cumulated = entry.msg + cumulated
169
+ break
170
+ if erase_count is not None:
171
+ if erase_count > 0:
172
+ self.line.erase(erase_count)
173
+ remaining = self.max_cols - len(self.line)
174
+ end = " " * remaining + " |"
175
+ self.line.add(end)
176
+ self.line.newline()
177
+ self.line.add("| ")
178
+ token = cumulated.lstrip() + token
179
+ self.line.add(token)
180
+ else:
181
+ self.line.add(token[:remaining])
182
+ self.line.add(" |")
183
+ self.line.newline()
184
+ self.line.add("| ")
185
+ self.line.add(token[remaining:])
186
+ self.line.flush()
187
+
188
+ def log(self, level: str, msg: str):
189
+ msg = make_log(level, msg)
190
+ self._remove_pending()
191
+ if self.line:
192
+ self.line.newline()
193
+ self.line.flush()
194
+ print(msg, file=self.err_stream)
195
+ self.err_stream.flush()
196
+
197
+ def print_lag(self):
198
+ self.print_token(" [LAG]", "31")
199
+
200
+ def print_pending(self):
201
+ chars = ["|", "/", "-", "\\"]
202
+ count = int(self._pending_count / 5)
203
+ char = chars[count % len(chars)]
204
+ colors = ["32", "33", "31"]
205
+ self._remove_pending()
206
+ self.line.add(char, colors[count % len(colors)])
207
+ self._pending_printed = True
208
+ self._pending_count += 1
209
+
210
+
211
+ AnyPrinter = Printer | RawPrinter
moshi/models/__init__.py ADDED
@@ -0,0 +1,14 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Kyutai, all rights reserved.
2
+ # This source code is licensed under the license found in the
3
+ # LICENSE file in the root directory of this source tree.
4
+ """
5
+ Models for the compression model Moshi,
6
+ """
7
+
8
+ # flake8: noqa
9
+ from .compression import (
10
+ CompressionModel,
11
+ MimiModel,
12
+ )
13
+ from .lm import LMModel, LMGen
14
+ from .loaders import get_mimi, get_moshi_lm
moshi/models/compression.py ADDED
@@ -0,0 +1,474 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Kyutai, all rights reserved.
2
+ # This source code is licensed under the license found in the
3
+ # LICENSE file in the root directory of this source tree.
4
+
5
+ # Part of this file is adapted from encodec.py in https://github.com/facebookresearch/audiocraft
6
+ # released under the following license.
7
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
8
+ # All rights reserved.
9
+ #
10
+ # This source code is licensed under the license found in the
11
+ # LICENSE file in the root directory of this source tree.
12
+ """Compression models or wrapper around existing models. In particular, provides the implementation
13
+ for Mimi. Also defines the main interface that a model must follow to be usable as an audio tokenizer.
14
+ """
15
+
16
+ from abc import abstractmethod
17
+ from contextlib import nullcontext
18
+ from dataclasses import dataclass
19
+ import logging
20
+ import typing as tp
21
+
22
+ import torch
23
+ from torch import nn
24
+
25
+
26
+ from ..quantization import (
27
+ QuantizedResult,
28
+ BaseQuantizer,
29
+ SplitResidualVectorQuantizer,
30
+ ResidualVectorQuantizer,
31
+ )
32
+ from ..modules.resample import ConvDownsample1d, ConvTrUpsample1d
33
+ from ..modules.streaming import StreamingModule, State
34
+ from ..utils.compile import no_compile, CUDAGraphed
35
+
36
+
37
+ logger = logging.getLogger()
38
+
39
+
40
+ class CompressionModel(StreamingModule[State]):
41
+ """Base API for all compression model that aim at being used as audio tokenizers
42
+ with a language model.
43
+ """
44
+
45
+ @abstractmethod
46
+ def forward(self, x: torch.Tensor) -> QuantizedResult: ...
47
+
48
+ @abstractmethod
49
+ def encode(self, x: torch.Tensor) -> torch.Tensor:
50
+ """See `MimiModel.encode`."""
51
+ ...
52
+
53
+ @abstractmethod
54
+ def decode(self, codes: torch.Tensor) -> torch.Tensor:
55
+ """See `MimiModel.decode`."""
56
+ ...
57
+
58
+ @abstractmethod
59
+ def decode_latent(self, codes: torch.Tensor) -> torch.Tensor:
60
+ """Decode from the discrete codes to continuous latent space."""
61
+ ...
62
+
63
+ @property
64
+ @abstractmethod
65
+ def channels(self) -> int: ...
66
+
67
+ @property
68
+ @abstractmethod
69
+ def frame_rate(self) -> float: ...
70
+
71
+ @property
72
+ @abstractmethod
73
+ def sample_rate(self) -> int: ...
74
+
75
+ @property
76
+ @abstractmethod
77
+ def cardinality(self) -> int: ...
78
+
79
+ @property
80
+ @abstractmethod
81
+ def num_codebooks(self) -> int: ...
82
+
83
+ @property
84
+ @abstractmethod
85
+ def total_codebooks(self) -> int: ...
86
+
87
+ @abstractmethod
88
+ def set_num_codebooks(self, n: int):
89
+ """Set the active number of codebooks used by the quantizer."""
90
+ ...
91
+
92
+
93
+ @dataclass
94
+ class _MimiState:
95
+ graphed_tr_enc: CUDAGraphed | None
96
+ graphed_tr_dec: CUDAGraphed | None
97
+
98
+ def reset(self):
99
+ pass
100
+
101
+
102
+ class MimiModel(CompressionModel[_MimiState]):
103
+ """Mimi model operating on the raw waveform.
104
+
105
+ Args:
106
+ encoder (nn.Module): Encoder network.
107
+ decoder (nn.Module): Decoder network.
108
+ quantizer (qt.BaseQuantizer): Quantizer network.
109
+ frame_rate (float): Final frame rate of the quantized representatiopn.
110
+ encoder_frame_rate (float): frame rate of the encoder model. Note that if `frame_rate != encopder_frame_rate`,
111
+ the latent will be resampled linearly to match the desired `frame_rate` before and after quantization.
112
+ sample_rate (int): Audio sample rate.
113
+ channels (int): Number of audio channels.
114
+ causal (bool): Whether to use a causal version of the model.
115
+ encoder_transformer (nn.Module or None): optional transformer for the encoder.
116
+ decoder_transformer (nn.Module or None): optional transformer for the decoder.
117
+ resample_method (str): method to use for resampling the latent space before the quantizer.
118
+ upsample_channel_wise_bug (bool): controls whether the upsampling is channel wise.
119
+ Defaults to true to reproduce bug in original implementation.
120
+ freeze_encoder: whether to freeze the encoder weights.
121
+ freeze_quantizer: whether to freeze the quantizer weights.
122
+ freeze_quantizer_level: If positive, freeze the quantizer up to this level.
123
+ torch_compile_encoder_decoder (bool): if True, uses torch.compile on the encoder / decoder.
124
+ Deactivated by default for training as this is incompatible at the moment with weight norm.
125
+ See https://github.com/pytorch/pytorch/issues/121902
126
+ Also this seems to work well with 2.2.0, but completely fail with 2.4.0.
127
+ """
128
+
129
+ def __init__(
130
+ self,
131
+ encoder: nn.Module,
132
+ decoder: nn.Module,
133
+ quantizer: BaseQuantizer,
134
+ frame_rate: float,
135
+ encoder_frame_rate: float,
136
+ sample_rate: int,
137
+ channels: int,
138
+ causal: bool = False,
139
+ encoder_transformer: tp.Optional[nn.Module] = None,
140
+ decoder_transformer: tp.Optional[nn.Module] = None,
141
+ resample_method: str = "interpolate",
142
+ upsample_channel_wise_bug: bool = True,
143
+ freeze_encoder: bool = False,
144
+ freeze_quantizer: bool = False,
145
+ freeze_quantizer_level: int = -1,
146
+ torch_compile_encoder_decoder: bool = False,
147
+ ):
148
+ super().__init__()
149
+ self.encoder = encoder
150
+ self.decoder = decoder
151
+ self.encoder_transformer = encoder_transformer
152
+ self.decoder_transformer = decoder_transformer
153
+ self.quantizer = quantizer
154
+ self._frame_rate = frame_rate
155
+ self._sample_rate = sample_rate
156
+ self._channels = channels
157
+ self.encoder_frame_rate = encoder_frame_rate
158
+ self.torch_compile_encoder_decoder = torch_compile_encoder_decoder
159
+
160
+ if freeze_encoder:
161
+ for p in self.encoder.parameters():
162
+ p.requires_grad = False
163
+ if self.encoder_transformer is not None:
164
+ for p in self.encoder_transformer.parameters():
165
+ p.requires_grad = False
166
+ for name, p in self.quantizer.named_parameters():
167
+ if name.endswith("input_proj.weight"):
168
+ p.requires_grad = False
169
+ if freeze_quantizer:
170
+ self.quantizer.ema_frozen_(True)
171
+ self.freeze_quantizer = freeze_quantizer
172
+ self.freeze_quantizer_level = (
173
+ freeze_quantizer_level
174
+ if freeze_quantizer_level > 0
175
+ else self.quantizer.num_codebooks
176
+ )
177
+
178
+ # We will need the dimension for the resampling. In general the encoder will be a SeanetEncoder
179
+ # which exposes a `dimension` attribute.
180
+ dimension = encoder.dimension
181
+ assert isinstance(
182
+ dimension, int
183
+ ), f"Dimension should be int, got {dimension} of type {type(dimension)}."
184
+ self.dimension = dimension
185
+
186
+ assert resample_method in [
187
+ "interpolate",
188
+ "conv",
189
+ "avg_pool",
190
+ ], f"Invalid resample_method {resample_method}"
191
+ self.resample_method = resample_method
192
+ if encoder_frame_rate != frame_rate:
193
+ assert not (
194
+ causal and resample_method == "interpolate"
195
+ ), "Cannot interpolate with causal model."
196
+ if resample_method in ["conv", "avg_pool"]:
197
+ assert (
198
+ self.encoder_frame_rate > self.frame_rate
199
+ ), "Cannot upsample with conv."
200
+ downsample_stride = self.encoder_frame_rate / self.frame_rate
201
+ assert downsample_stride == int(
202
+ downsample_stride
203
+ ), f"Only integer strides are supported, got {downsample_stride}"
204
+ learnt = resample_method == "conv"
205
+ self.downsample = ConvDownsample1d(
206
+ int(downsample_stride),
207
+ dimension=dimension,
208
+ learnt=learnt,
209
+ causal=causal,
210
+ )
211
+ if freeze_encoder:
212
+ for p in self.downsample.parameters():
213
+ p.requires_grad = False
214
+ self.upsample = ConvTrUpsample1d(
215
+ int(downsample_stride),
216
+ dimension=dimension,
217
+ learnt=learnt,
218
+ causal=causal,
219
+ channel_wise=upsample_channel_wise_bug,
220
+ )
221
+
222
+ def _init_streaming_state(self, batch_size: int) -> _MimiState:
223
+ device = next(self.parameters()).device
224
+ disable = device.type != 'cuda'
225
+ graphed_tr_dec = None
226
+ graphed_tr_enc = None
227
+ if self.encoder_transformer is not None:
228
+ graphed_tr_enc = CUDAGraphed(self.encoder_transformer, disable=disable)
229
+ if self.decoder_transformer is not None:
230
+ graphed_tr_dec = CUDAGraphed(self.decoder_transformer, disable=disable)
231
+ return _MimiState(graphed_tr_enc, graphed_tr_dec)
232
+
233
+ @property
234
+ def channels(self) -> int:
235
+ return self._channels
236
+
237
+ @property
238
+ def frame_rate(self) -> float:
239
+ return self._frame_rate
240
+
241
+ @property
242
+ def sample_rate(self) -> int:
243
+ return self._sample_rate
244
+
245
+ @property
246
+ def total_codebooks(self):
247
+ """Total number of quantizer codebooks available."""
248
+ return self.quantizer.total_codebooks
249
+
250
+ @property
251
+ def num_codebooks(self):
252
+ """Active number of codebooks used by the quantizer."""
253
+ return self.quantizer.num_codebooks
254
+
255
+ def set_num_codebooks(self, n: int):
256
+ """Set the active number of codebooks used by the quantizer."""
257
+ self.quantizer.set_num_codebooks(n)
258
+
259
+ @property
260
+ def cardinality(self):
261
+ """Cardinality of each codebook."""
262
+ return self.quantizer.cardinality
263
+
264
+ def _to_framerate(self, x: torch.Tensor):
265
+ # Convert from the encoder frame rate to the overall framerate.
266
+ _, _, length = x.shape
267
+ frame_rate = self.encoder_frame_rate
268
+ new_frame_rate = self.frame_rate
269
+ if frame_rate == new_frame_rate:
270
+ return x
271
+ if self.resample_method == "interpolate":
272
+ target_length = int(length * new_frame_rate / frame_rate)
273
+ return nn.functional.interpolate(x, size=target_length, mode="linear")
274
+ else:
275
+ return self.downsample(x)
276
+
277
+ def _to_encoder_framerate(self, x: torch.Tensor):
278
+ # Convert from overall framerate to the encoder frame rate.
279
+ _, _, length = x.shape
280
+ frame_rate = self.encoder_frame_rate
281
+ new_frame_rate = self.frame_rate
282
+ if frame_rate == new_frame_rate:
283
+ return x
284
+ if self.resample_method == "interpolate":
285
+ target_length = int(length * new_frame_rate / frame_rate)
286
+ return nn.functional.interpolate(x, size=target_length, mode="linear")
287
+ else:
288
+ return self.upsample(x)
289
+
290
+ @property
291
+ def _context_for_encoder_decoder(self):
292
+ if self.torch_compile_encoder_decoder:
293
+ return nullcontext()
294
+ else:
295
+ return no_compile()
296
+
297
+ def forward(self, x: torch.Tensor) -> QuantizedResult:
298
+ assert x.dim() == 3
299
+ length = x.shape[-1]
300
+ extra_metrics: tp.Dict[str, torch.Tensor] = {}
301
+
302
+ if self.freeze_quantizer:
303
+ if isinstance(self.quantizer, SplitResidualVectorQuantizer):
304
+ self.quantizer.rvq_first.eval()
305
+ for i in range(
306
+ self.freeze_quantizer_level - self.quantizer.n_q_semantic
307
+ ):
308
+ self.quantizer.rvq_rest.vq.layers[i].eval()
309
+ elif isinstance(self.quantizer, ResidualVectorQuantizer):
310
+ for i in range(self.freeze_quantizer_level):
311
+ self.quantizer.vq.layers[i].eval()
312
+ else:
313
+ raise ValueError(f"Unsupported quantizer type {type(self.quantizer)}")
314
+
315
+ with self._context_for_encoder_decoder:
316
+ emb = self.encoder(x)
317
+ if self.encoder_transformer is not None:
318
+ (emb,) = self.encoder_transformer(emb)
319
+ emb = self._to_framerate(emb)
320
+ expected_length = self.frame_rate * length / self.sample_rate
321
+ # Checking that we have the proper length given the advertised frame rate.
322
+ assert abs(emb.shape[-1] - expected_length) < 1, (
323
+ emb.shape[-1],
324
+ expected_length,
325
+ )
326
+
327
+ q_res = self.quantizer(emb, self.frame_rate)
328
+ emb = q_res.x
329
+ emb = self._to_encoder_framerate(emb)
330
+ if self.decoder_transformer is not None:
331
+ (emb,) = self.decoder_transformer(emb)
332
+
333
+ with self._context_for_encoder_decoder:
334
+ out = self.decoder(emb)
335
+
336
+ # remove extra padding added by the encoder and decoder
337
+ assert out.shape[-1] >= length, (out.shape[-1], length)
338
+ out = out[..., :length]
339
+
340
+ q_res.x = out
341
+ q_res.metrics.update(extra_metrics)
342
+ return q_res
343
+
344
+ def _encode_to_unquantized_latent(self, x: torch.Tensor) -> torch.Tensor:
345
+ """Projects a batch of waveforms to unquantized latent space.
346
+
347
+ Args:
348
+ x (torch.Tensor): Float tensor of shape [B, C, T].
349
+
350
+ Returns:
351
+ Unquantized embeddings.
352
+ """
353
+ assert (
354
+ x.dim() == 3
355
+ ), f"CompressionModel._encode_to_unquantized_latent expects audio of shape [B, C, T] but got {x.shape}"
356
+ state = self._streaming_state
357
+ with self._context_for_encoder_decoder:
358
+ emb = self.encoder(x)
359
+ if self.encoder_transformer is not None:
360
+ if state is None:
361
+ (emb,) = self.encoder_transformer(emb)
362
+ else:
363
+ assert state.graphed_tr_enc is not None
364
+ (emb,) = state.graphed_tr_enc(emb)
365
+ emb = self._to_framerate(emb)
366
+ return emb
367
+
368
+ def encode(self, x: torch.Tensor) -> torch.Tensor:
369
+ """Encode the given input tensor to quantized representation.
370
+
371
+ Args:
372
+ x (torch.Tensor): Float tensor of shape [B, C, T]
373
+
374
+ Returns:
375
+ codes (torch.Tensor): an int tensor of shape [B, K, T]
376
+ with K the number of codebooks used and T the timestep.
377
+ """
378
+ emb = self._encode_to_unquantized_latent(x)
379
+ codes = self.quantizer.encode(emb)
380
+ return codes
381
+
382
+ def encode_to_latent(self, x: torch.Tensor, quantize: bool = True) -> torch.Tensor:
383
+ """Projects a batch of waveforms to latent space.
384
+
385
+ Args:
386
+ x (torch.Tensor): Float tensor of shape [B, C, T].
387
+
388
+ Returns:
389
+ Embeddings, either quantized or not.
390
+ """
391
+ emb = self._encode_to_unquantized_latent(x)
392
+ if not quantize:
393
+ return emb
394
+ else:
395
+ codes = self.quantizer.encode(emb)
396
+ return self.decode_latent(codes)
397
+
398
+ def decode(self, codes: torch.Tensor):
399
+ """Decode the given codes to a reconstructed representation.
400
+
401
+ Args:
402
+ codes (torch.Tensor): Int tensor of shape [B, K, T]
403
+
404
+ Returns:
405
+ out (torch.Tensor): Float tensor of shape [B, C, T], the reconstructed audio.
406
+ """
407
+ state = self._streaming_state
408
+ emb = self.decode_latent(codes)
409
+ emb = self._to_encoder_framerate(emb)
410
+ if self.decoder_transformer is not None:
411
+ if state is None:
412
+ (emb,) = self.decoder_transformer(emb)
413
+ else:
414
+ assert state.graphed_tr_dec is not None
415
+ (emb,) = state.graphed_tr_dec(emb)
416
+ with self._context_for_encoder_decoder:
417
+ out = self.decoder(emb)
418
+ # out contains extra padding added by the encoder and decoder
419
+ return out
420
+
421
+ def decode_latent(self, codes: torch.Tensor) -> torch.Tensor:
422
+ """Decode from the discrete codes to continuous latent space."""
423
+ return self.quantizer.decode(codes)
424
+
425
+
426
+ class WrapperCompressionModel(CompressionModel[State]):
427
+ """Base API for CompressionModel wrappers that do not depend on external frameworks."""
428
+
429
+ def __init__(self, model: CompressionModel):
430
+ super().__init__()
431
+ self.model = model
432
+
433
+ def forward(self, x: torch.Tensor) -> QuantizedResult:
434
+ return self.model.forward(x)
435
+
436
+ def encode(self, x: torch.Tensor) -> torch.Tensor:
437
+ return self.model.encode(x)
438
+
439
+ def decode(self, codes: torch.Tensor) -> torch.Tensor:
440
+ return self.model.decode(codes)
441
+
442
+ def decode_latent(self, codes: torch.Tensor) -> torch.Tensor:
443
+ return self.model.decode_latent(codes)
444
+
445
+ def set_num_codebooks(self, n: int):
446
+ self.model.set_num_codebooks(n)
447
+
448
+ @property
449
+ def quantizer(self):
450
+ return self.model.quantizer
451
+
452
+ @property
453
+ def channels(self) -> int:
454
+ return self.model.channels
455
+
456
+ @property
457
+ def frame_rate(self) -> float:
458
+ return self.model.frame_rate
459
+
460
+ @property
461
+ def sample_rate(self) -> int:
462
+ return self.model.sample_rate
463
+
464
+ @property
465
+ def cardinality(self) -> int:
466
+ return self.model.cardinality
467
+
468
+ @property
469
+ def num_codebooks(self) -> int:
470
+ return self.model.num_codebooks
471
+
472
+ @property
473
+ def total_codebooks(self) -> int:
474
+ return self.model.total_codebooks
moshi/models/lm.py ADDED
@@ -0,0 +1,488 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Kyutai, all rights reserved.
2
+ # This source code is licensed under the license found in the
3
+ # LICENSE file in the root directory of this source tree.
4
+
5
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
6
+ # All rights reserved.
7
+ #
8
+ # This source code is licensed under the license found in the
9
+ # LICENSE file in the root directory of this source tree.
10
+
11
+ from dataclasses import dataclass
12
+ from functools import partial
13
+ import logging
14
+ import typing as tp
15
+
16
+ import torch
17
+ from torch import nn
18
+
19
+ from ..utils.sampling import sample_token
20
+ from ..utils.compile import CUDAGraphed
21
+ from ..modules.streaming import StreamingContainer, StreamingModule
22
+ from ..modules.transformer import (
23
+ StreamingTransformer,
24
+ create_norm_fn,
25
+ )
26
+
27
+
28
+ logger = logging.getLogger(__name__)
29
+
30
+
31
+ class ScaledEmbedding(nn.Embedding):
32
+ """Boost learning rate for embeddings (with `scale`).
33
+
34
+ Args:
35
+ norm (bool): if True, uses a layer norm after the embedding.
36
+ zero_idx (int): special value indicating that the output should be exactly 0.
37
+ """
38
+
39
+ def __init__(self, *args, norm: bool = False, zero_idx: int = -1, **kwargs):
40
+ super().__init__(*args, **kwargs)
41
+ self.norm = None
42
+ if norm:
43
+ self.norm = create_norm_fn("layer_norm", self.embedding_dim)
44
+ assert zero_idx < 0, "Please use negative values for the zero_idx."
45
+ self.zero_idx = zero_idx
46
+
47
+ def forward(self, input, *args, **kwargs):
48
+ is_zero = input == self.zero_idx
49
+ zero = torch.zeros(1, dtype=input.dtype, device=input.device)
50
+ input = input.clamp(min=0)
51
+ y = super().forward(input, *args, **kwargs)
52
+ if self.norm is not None:
53
+ y = self.norm(y)
54
+ y = torch.where(is_zero[..., None], zero, y)
55
+ return y
56
+
57
+
58
+ class LMModel(StreamingContainer):
59
+ """Transformer-based language model on multiple streams of codes.
60
+
61
+ Args:
62
+ n_q (int): Number of parallel streams to model as input.
63
+ dep_q (int): Number of parallel streams to model in the depformer.
64
+ card (int): Cardinality, vocabulary size.
65
+ text_card (int): Cardinality of the text vocabulary.
66
+ dim (int): Dimension of the transformer encoder.
67
+ num_heads (int): Number of heads for the transformer encoder.
68
+ hidden_scale (int): Scale for hidden feed forward dimension of the transformer encoder.
69
+ norm (str): Normalization method.
70
+ norm_emb (bool): Whether to normalize embeddings.
71
+ bias_proj (bool): Use bias for output projections.
72
+ depformer_*: params used for the Depformer Transformer, all the other will be shared.
73
+ depformer_multi_linear (bool): if True, uses one linear layer per codebook to project the
74
+ output of the main transformer to the Depformer latent space.
75
+ depformer_dim_feedforward (int| list[int]| None): If None, defaults to hidden_scale * depformer_dim.
76
+ existing_text_padding_id (bool): if True, will use a different token for the initial text token, and
77
+ the text padding token.
78
+ same_initial (bool): if True, uses the same initial tokens for both text and audio mode.
79
+ **kwargs: Additional parameters for the transformer encoder.
80
+ """
81
+
82
+ def __init__(
83
+ self,
84
+ delays: tp.List[int] = [0],
85
+ n_q: int = 8,
86
+ dep_q: int = 8,
87
+ card: int = 1024,
88
+ text_card: int = 32000,
89
+ dim: int = 128,
90
+ num_heads: int = 8,
91
+ hidden_scale: int = 4,
92
+ norm: str = "layer_norm",
93
+ norm_emb: bool = False,
94
+ bias_proj: bool = False,
95
+ depformer_dim: int = 256,
96
+ depformer_dim_feedforward: int | list[int] | None = None,
97
+ depformer_multi_linear: bool = False,
98
+ depformer_weights_per_step: bool = False,
99
+ depformer_pos_emb: str = "sin",
100
+ existing_text_padding_id: tp.Optional[int] = None,
101
+ context: tp.Optional[int] = None,
102
+ device=None,
103
+ dtype=None,
104
+ **kwargs,
105
+ ):
106
+ super().__init__()
107
+ self.n_q = n_q
108
+ self.dep_q = dep_q
109
+ self.card = card
110
+ self.text_card = text_card
111
+ assert len(delays) == self.num_codebooks, "unexpected number of delays"
112
+ self.delays = delays
113
+ self.dim = dim
114
+ self.existing_text_padding_id = existing_text_padding_id
115
+ self.context = context
116
+ kwargs["context"] = context
117
+ EmbeddingFactory = partial(
118
+ ScaledEmbedding,
119
+ norm=norm_emb,
120
+ device=device,
121
+ dtype=dtype,
122
+ zero_idx=self.zero_token_id,
123
+ )
124
+ self.emb = nn.ModuleList(
125
+ [EmbeddingFactory(self.card + 1, dim) for _ in range(n_q)]
126
+ )
127
+ # Text card + padding token (if not in the original tokenizer)
128
+ extra_text = self.existing_text_padding_id is None
129
+ # Unlike for audio, here we authorize the model to output the special token.
130
+ self.text_emb = EmbeddingFactory(text_card + 1, dim)
131
+ self.text_linear = nn.Linear(dim, text_card + extra_text, bias=bias_proj)
132
+ depformer_prefix = "depformer_"
133
+ main_kwargs = {
134
+ k: v for k, v in kwargs.items() if not k.startswith(depformer_prefix)
135
+ }
136
+ self.transformer = StreamingTransformer(
137
+ d_model=dim,
138
+ num_heads=num_heads,
139
+ dim_feedforward=int(hidden_scale * dim),
140
+ norm=norm,
141
+ device=device,
142
+ dtype=dtype,
143
+ **main_kwargs,
144
+ )
145
+ self.out_norm = create_norm_fn(norm, dim)
146
+ self.depformer_multi_linear = depformer_multi_linear
147
+ kwargs_dep = main_kwargs.copy()
148
+ kwargs_dep.update(
149
+ {
150
+ k.removeprefix(depformer_prefix): v
151
+ for k, v in kwargs.items()
152
+ if k.startswith(depformer_prefix)
153
+ }
154
+ )
155
+ kwargs_dep["positional_embedding"] = depformer_pos_emb
156
+ kwargs_dep["context"] = None
157
+ if depformer_weights_per_step:
158
+ kwargs_dep["weights_per_step"] = dep_q
159
+ if depformer_multi_linear:
160
+ # One linear layer per codebook to project different informations from the main model.
161
+ self.depformer_in = nn.ModuleList(
162
+ [nn.Linear(dim, depformer_dim, bias=False) for _ in range(dep_q)]
163
+ )
164
+ else:
165
+ self.depformer_in = nn.ModuleList(
166
+ [nn.Linear(dim, depformer_dim, bias=False)]
167
+ )
168
+ # Only using up to dep_q - 1 because the last codebook is never an input to Depformer.
169
+ self.depformer_emb = nn.ModuleList(
170
+ [EmbeddingFactory(self.card + 1, depformer_dim) for _ in range(dep_q - 1)]
171
+ )
172
+ self.depformer_text_emb = EmbeddingFactory(text_card + 1, depformer_dim)
173
+ if depformer_dim_feedforward is None:
174
+ depformer_dim_feedforward = int(hidden_scale * depformer_dim)
175
+ self.depformer = StreamingTransformer(
176
+ d_model=depformer_dim,
177
+ dim_feedforward=depformer_dim_feedforward,
178
+ norm=norm,
179
+ device=device,
180
+ dtype=dtype,
181
+ **kwargs_dep,
182
+ )
183
+ self.depformer.set_streaming_propagate(False)
184
+ dim = depformer_dim # we will directly apply the next linears to the output of the Depformer.
185
+
186
+ self.linears = nn.ModuleList(
187
+ [nn.Linear(dim, self.card, bias=bias_proj) for _ in range(dep_q)]
188
+ )
189
+
190
+ @property
191
+ def initial_token_id(self) -> int:
192
+ """Token id for the start of sequence (audio)."""
193
+ return self.card
194
+
195
+ @property
196
+ def text_initial_token_id(self) -> int:
197
+ """Token id for the start of sequence (text)."""
198
+ return self.text_card
199
+
200
+ @property
201
+ def text_padding_token_id(self) -> int:
202
+ """Token id for text padding."""
203
+ if self.existing_text_padding_id is None:
204
+ return self.text_card
205
+ else:
206
+ return self.existing_text_padding_id
207
+
208
+ @property
209
+ def end_of_text_padding_id(self) -> int:
210
+ """Token id for optionally marking the last padding step for a word."""
211
+ return 0
212
+
213
+ @property
214
+ def zero_token_id(self) -> int:
215
+ """Special value in the input tokens, indicating that no sampling should
216
+ happen for that value, and no input should be given to the model."""
217
+ return -1
218
+
219
+ @property
220
+ def ungenerated_token_id(self) -> int:
221
+ """Special value that can be provided in the prompt to indicate that this specific
222
+ value should be predicted and sampled. This allows for partial teacher forcing, by generating
223
+ one modality, with the other one fixed.
224
+ """
225
+ return -2
226
+
227
+ @property
228
+ def device(self):
229
+ first_param = next(iter(self.parameters()))
230
+ return first_param.device
231
+
232
+ @property
233
+ def num_codebooks(self) -> int:
234
+ return self.n_q + 1
235
+
236
+ @property
237
+ def num_audio_codebooks(self) -> int:
238
+ return self.n_q
239
+
240
+ @property
241
+ def audio_offset(self) -> int:
242
+ return 1
243
+
244
+ def _get_initial_token(self) -> torch.Tensor:
245
+ # Returns the initial token that will be fed to the model to predict the very first timestep.
246
+ # The output shape will be [B, K, 1].
247
+ device = next(iter(self.parameters())).device
248
+ zero = torch.full(
249
+ [1, 1, 1], self.zero_token_id, device=device, dtype=torch.long
250
+ )
251
+ special = torch.full_like(zero, self.initial_token_id)
252
+
253
+ text_special = torch.full_like(zero, self.text_initial_token_id)
254
+ audio_token = special
255
+ text_token = text_special
256
+ audio_token = audio_token.expand(-1, self.num_audio_codebooks, -1)
257
+ token = torch.cat([text_token, audio_token], dim=1)
258
+ return token
259
+
260
+ def forward_text(
261
+ self,
262
+ sequence: torch.Tensor,
263
+ ) -> tuple[torch.Tensor, torch.Tensor]:
264
+ B, K, S = sequence.shape
265
+ assert (
266
+ K == self.num_codebooks
267
+ ), f"Sequence shape {sequence.shape} must match the number of codebooks."
268
+ input_sequence = sequence
269
+ input_ = None
270
+ for cb_index in range(self.num_audio_codebooks):
271
+ audio_emb = self.emb[cb_index](
272
+ input_sequence[:, cb_index + self.audio_offset]
273
+ )
274
+ input_ = audio_emb if input_ is None else input_ + audio_emb
275
+ text_emb = self.text_emb(input_sequence[:, 0])
276
+ input_ = text_emb if input_ is None else input_ + text_emb
277
+ transformer_out = self.transformer(input_)
278
+
279
+ if self.out_norm:
280
+ transformer_out = self.out_norm(transformer_out)
281
+ assert isinstance(transformer_out, torch.Tensor)
282
+ text_logits = self.text_linear(transformer_out)
283
+ text_logits = text_logits[:, None]
284
+ return transformer_out, text_logits
285
+
286
+ def forward_depformer(
287
+ self,
288
+ depformer_cb_index: int,
289
+ sequence: torch.Tensor,
290
+ transformer_out: torch.Tensor,
291
+ ) -> torch.Tensor:
292
+ B, K, S = sequence.shape
293
+ assert (
294
+ K == 1
295
+ ), f"Codebooks for Depformer streaming should be passed 1 by 1, got {K}."
296
+ assert (
297
+ S == 1
298
+ ), f"Steps for Depformer streaming should be passed 1 by 1, got {S}."
299
+ assert (
300
+ transformer_out.shape[1] == 1
301
+ ), "Transformer out should be a for a single step."
302
+ last_token_input: tp.Optional[torch.Tensor] = None
303
+ depformer_input = transformer_out
304
+ if self.depformer_multi_linear:
305
+ depformer_input = self.depformer_in[depformer_cb_index](depformer_input)
306
+ else:
307
+ depformer_input = self.depformer_in[0](depformer_input)
308
+ if depformer_cb_index == 0:
309
+ last_token_input = self.depformer_text_emb(sequence[:, 0])
310
+ else:
311
+ last_token_input = self.depformer_emb[depformer_cb_index - 1](
312
+ sequence[:, 0]
313
+ )
314
+ depformer_input = depformer_input + last_token_input
315
+ assert depformer_input.shape[1] == 1
316
+ # depformer_input is [B, 1, depformer_dim].
317
+ # The streaming state of the depformer ensures that the proper layer is run.
318
+ dep_output = self.depformer(depformer_input)
319
+ logits = self.linears[depformer_cb_index](dep_output)
320
+ logits = logits[:, None]
321
+ assert logits.dim() == 4, logits.shape # [B, Ka, S, card]
322
+ return logits
323
+
324
+
325
+ @dataclass
326
+ class _LMGenState:
327
+ cache: torch.Tensor
328
+ initial: torch.Tensor
329
+ graphed_main: CUDAGraphed
330
+ graphed_depth: CUDAGraphed
331
+ offset: int = 0
332
+
333
+ def reset(self):
334
+ self.offset = 0
335
+
336
+
337
+ class LMGen(StreamingModule[_LMGenState]):
338
+ def __init__(
339
+ self,
340
+ lm_model: LMModel,
341
+ use_sampling: bool = True,
342
+ temp: float = 0.8,
343
+ temp_text: float = 0.7,
344
+ top_k: int = 250,
345
+ top_k_text: int = 25,
346
+ check: bool = False,
347
+ ):
348
+ assert not lm_model.training, "generation shouldn't be used in training mode."
349
+ super().__init__()
350
+
351
+ self.lm_model = lm_model
352
+ self.use_sampling = use_sampling
353
+ self.temp = temp
354
+ self.temp_text = temp_text
355
+ self.top_k = top_k
356
+ self.top_k_text = top_k_text
357
+ self.check = check
358
+ self.max_delay = max(
359
+ lm_model.delays
360
+ ) # with delays, we need to generate a few more time steps.
361
+ self.delays_cuda = torch.tensor(
362
+ lm_model.delays, device=lm_model.device, dtype=torch.long
363
+ )
364
+
365
+ def _init_streaming_state(self, batch_size: int) -> _LMGenState:
366
+ lm_model = self.lm_model
367
+ initial = lm_model._get_initial_token()
368
+ cache = torch.full(
369
+ (batch_size, self.lm_model.num_codebooks, self.max_delay + 2),
370
+ lm_model.ungenerated_token_id,
371
+ device=lm_model.device,
372
+ dtype=torch.long,
373
+ )
374
+
375
+ disable = lm_model.device.type != 'cuda'
376
+ graphed_main = CUDAGraphed(lm_model.forward_text, disable=disable)
377
+ graphed_depth = CUDAGraphed(self.depformer_step, disable=disable)
378
+
379
+ return _LMGenState(cache, initial, graphed_main, graphed_depth)
380
+
381
+ @torch.no_grad()
382
+ def step(self, input_tokens: torch.Tensor) -> torch.Tensor | None:
383
+ state = self._streaming_state
384
+ if state is None:
385
+ raise RuntimeError(
386
+ "You should wrap those calls with a `with lm_gen.streaming(): ...`."
387
+ )
388
+ lm_model = self.lm_model
389
+
390
+ assert input_tokens.dim() == 3, "Shape should be [B, K, T]."
391
+ B, Ki, S = input_tokens.shape
392
+ assert S == 1, "Only support being given steps one by one."
393
+ needed_tokens = lm_model.num_codebooks - lm_model.dep_q - 1
394
+ assert (
395
+ Ki == needed_tokens
396
+ ), f"We expect {needed_tokens} tokens from the user stream, got {Ki}."
397
+
398
+ CT = state.cache.shape[2]
399
+
400
+ for q_other in range(input_tokens.shape[1]):
401
+ k = lm_model.dep_q + 1 + q_other
402
+ delay = lm_model.delays[k]
403
+ write_position = (state.offset + delay) % CT
404
+ state.cache[:, k, write_position : write_position + 1] = input_tokens[
405
+ :, q_other
406
+ ]
407
+
408
+ position = state.offset % CT
409
+ for k, delay in enumerate(lm_model.delays):
410
+ # Only for the very beginning, we extend the initial token for the acoustic
411
+ # token that are delayed, and thus have no good value to take.
412
+ if state.offset <= delay:
413
+ state.cache[:, k, position] = state.initial[:, k, 0]
414
+ input_ = state.cache[:, :, position : position + 1]
415
+
416
+ if self.check:
417
+ # Check that we are not feeding in any value that is not generated yet.
418
+ assert not (input_ == lm_model.ungenerated_token_id).any(), (
419
+ state.offset,
420
+ input_,
421
+ )
422
+ assert (input_[:, lm_model.audio_offset :] <= lm_model.card).all(), input_
423
+ assert (input_[:, :1] <= lm_model.text_card).all()
424
+
425
+ transformer_out, text_logits = state.graphed_main(input_)
426
+ # Shape of text_logits should be [B, K_text=1, T=1, Card_text]
427
+ text_token = sample_token(
428
+ text_logits.float(),
429
+ self.use_sampling,
430
+ self.temp_text,
431
+ self.top_k_text,
432
+ )
433
+ assert text_token.dim() == 3, text_token.shape
434
+ assert text_token.shape[2] == 1
435
+ assert text_token.shape[1] == 1, "Only one text stream supported."
436
+ text_token = text_token[:, 0, 0] # shape is [B]
437
+ audio_tokens = state.graphed_depth(text_token, transformer_out)
438
+
439
+ # ensure we don't overwrite prompt tokens, we only write over ungenerated tokens
440
+ state.offset += 1
441
+ position = state.offset % CT
442
+ state.cache[:, 0, position] = text_token
443
+ state.cache[:, 1 : lm_model.dep_q + 1, position] = audio_tokens
444
+
445
+ if state.offset <= self.max_delay:
446
+ return None
447
+ B = state.cache.shape[0]
448
+ gen_delays_cuda = self.delays_cuda[: lm_model.dep_q + 1]
449
+ index = (
450
+ ((state.offset - self.max_delay + gen_delays_cuda) % CT)
451
+ .view(1, -1, 1)
452
+ .expand(B, -1, 1)
453
+ )
454
+ out = state.cache.gather(dim=2, index=index)
455
+ return out
456
+
457
+ def depformer_step(
458
+ self,
459
+ text_token: torch.Tensor,
460
+ transformer_out: torch.Tensor,
461
+ ) -> torch.Tensor:
462
+ (B,) = text_token.shape
463
+ prev_token = text_token
464
+ lm_model = self.lm_model
465
+ depformer_tokens: list[torch.Tensor] = []
466
+ assert not lm_model.depformer.is_streaming
467
+ with lm_model.depformer.streaming(B):
468
+ for cb_index in range(lm_model.dep_q):
469
+ input_ = prev_token[:, None, None]
470
+ logits = lm_model.forward_depformer(cb_index, input_, transformer_out)
471
+ next_token = sample_token(
472
+ logits.float(),
473
+ self.use_sampling,
474
+ self.temp,
475
+ self.top_k,
476
+ )
477
+ assert next_token.shape == (B, 1, 1)
478
+ next_token = next_token[:, 0, 0] # shape is B
479
+ depformer_tokens.append(next_token)
480
+ prev_token = next_token
481
+
482
+ assert len(depformer_tokens) == lm_model.dep_q, (
483
+ len(depformer_tokens),
484
+ lm_model.dep_q,
485
+ )
486
+ out = torch.stack(depformer_tokens, dim=1)
487
+ assert out.shape == (B, lm_model.dep_q), out.shape
488
+ return out
moshi/models/loaders.py ADDED
@@ -0,0 +1,159 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Kyutai, all rights reserved.
2
+ # This source code is licensed under the license found in the
3
+ # LICENSE file in the root directory of this source tree.
4
+ """Retrieves the pretrained models for Moshi and Mimi."""
5
+ from pathlib import Path
6
+
7
+ from safetensors.torch import load_model
8
+ import torch
9
+
10
+ from .compression import MimiModel
11
+ from .lm import LMModel
12
+ from ..modules import SEANetEncoder, SEANetDecoder, transformer
13
+ from ..quantization import SplitResidualVectorQuantizer
14
+
15
+ SAMPLE_RATE = 24000
16
+ FRAME_RATE = 12.5
17
+
18
+ TEXT_TOKENIZER_NAME = 'tokenizer_spm_32k_3.model'
19
+ MOSHI_NAME = 'model.safetensors'
20
+ MIMI_NAME = 'tokenizer-e351c8d8-checkpoint125.safetensors'
21
+ DEFAULT_REPO = 'kyutai/moshiko-pytorch-bf16'
22
+
23
+
24
+ _seanet_kwargs = {
25
+ "channels": 1,
26
+ "dimension": 512,
27
+ "causal": True,
28
+ "n_filters": 64,
29
+ "n_residual_layers": 1,
30
+ "activation": "ELU",
31
+ "compress": 2,
32
+ "dilation_base": 2,
33
+ "disable_norm_outer_blocks": 0,
34
+ "kernel_size": 7,
35
+ "residual_kernel_size": 3,
36
+ "last_kernel_size": 3,
37
+ # We train using weight_norm but then the weights are pre-processed for inference so
38
+ # that we can use a normal convolution.
39
+ "norm": "none",
40
+ "pad_mode": "constant",
41
+ "ratios": [8, 6, 5, 4],
42
+ "true_skip": True,
43
+ }
44
+ _quantizer_kwargs = {
45
+ "dimension": 256,
46
+ "n_q": 32,
47
+ "bins": 2048,
48
+ "input_dimension": _seanet_kwargs["dimension"],
49
+ "output_dimension": _seanet_kwargs["dimension"],
50
+ }
51
+ _transformer_kwargs = {
52
+ "d_model": _seanet_kwargs["dimension"],
53
+ "num_heads": 8,
54
+ "num_layers": 8,
55
+ "causal": True,
56
+ "layer_scale": 0.01,
57
+ "context": 250,
58
+ "conv_layout": True,
59
+ "max_period": 10000,
60
+ "gating": "none",
61
+ "norm": "layer_norm",
62
+ "positional_embedding": "rope",
63
+ "dim_feedforward": 2048,
64
+ "input_dimension": _seanet_kwargs["dimension"],
65
+ "output_dimensions": [_seanet_kwargs["dimension"]],
66
+ }
67
+
68
+ _lm_kwargs = {
69
+ "dim": 4096,
70
+ "text_card": 32000,
71
+ "existing_text_padding_id": 3,
72
+ "n_q": 16,
73
+ "dep_q": 8,
74
+ "card": _quantizer_kwargs["bins"],
75
+ "num_heads": 32,
76
+ "num_layers": 32,
77
+ "hidden_scale": 4.125,
78
+ "causal": True,
79
+ "layer_scale": None,
80
+ "context": 3000,
81
+ "max_period": 10000,
82
+ "gating": "silu",
83
+ "norm": "rms_norm_f32",
84
+ "positional_embedding": "rope",
85
+ "depformer_dim": 1024,
86
+ "depformer_dim_feedforward": int(4.125 * 1024),
87
+ "depformer_num_heads": 16,
88
+ "depformer_num_layers": 6,
89
+ "depformer_causal": True,
90
+ "depformer_layer_scale": None,
91
+ "depformer_multi_linear": True,
92
+ "depformer_context": 8,
93
+ "depformer_max_period": 10000,
94
+ "depformer_gating": "silu",
95
+ "depformer_pos_emb": "none",
96
+ "depformer_weights_per_step": True,
97
+ "delays": [0, 0, 1, 1, 1, 1, 1, 1, 1, 0, 1, 1, 1, 1, 1, 1, 1],
98
+ }
99
+
100
+
101
+ def _is_safetensors(path: Path | str) -> bool:
102
+ return Path(path).suffix in (".safetensors", ".sft", ".sfts")
103
+
104
+
105
+ def get_mimi(filename: str | Path,
106
+ device: torch.device | str = 'cpu') -> MimiModel:
107
+ """Return a pretrained Mimi model."""
108
+ encoder = SEANetEncoder(**_seanet_kwargs)
109
+ decoder = SEANetDecoder(**_seanet_kwargs)
110
+ encoder_transformer = transformer.ProjectedTransformer(
111
+ device=device, **_transformer_kwargs
112
+ )
113
+ decoder_transformer = transformer.ProjectedTransformer(
114
+ device=device, **_transformer_kwargs
115
+ )
116
+ quantizer = SplitResidualVectorQuantizer(
117
+ **_quantizer_kwargs,
118
+ )
119
+ model = MimiModel(
120
+ encoder,
121
+ decoder,
122
+ quantizer,
123
+ channels=1,
124
+ sample_rate=SAMPLE_RATE,
125
+ frame_rate=FRAME_RATE,
126
+ encoder_frame_rate=SAMPLE_RATE / encoder.hop_length,
127
+ causal=True,
128
+ resample_method="conv",
129
+ encoder_transformer=encoder_transformer,
130
+ decoder_transformer=decoder_transformer,
131
+ ).to(device=device)
132
+ model.eval()
133
+ if _is_safetensors(filename):
134
+ load_model(model, filename)
135
+ else:
136
+ pkg = torch.load(filename, "cpu")
137
+ model.load_state_dict(pkg["model"])
138
+ model.set_num_codebooks(8)
139
+ return model
140
+
141
+
142
+ def get_moshi_lm(filename: str | Path,
143
+ device: torch.device | str = 'cpu') -> LMModel:
144
+ dtype = torch.bfloat16
145
+ model = LMModel(
146
+ device=device,
147
+ dtype=dtype,
148
+ **_lm_kwargs,
149
+ ).to(device=device, dtype=dtype)
150
+ model.eval()
151
+ if _is_safetensors(filename):
152
+ load_model(model, filename)
153
+ else:
154
+ pkg = torch.load(
155
+ filename,
156
+ "cpu",
157
+ )
158
+ model.load_state_dict(pkg["fsdp_best_state"]["model"])
159
+ return model
moshi/modules/__init__.py ADDED
@@ -0,0 +1,23 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Kyutai, all rights reserved.
2
+ # This source code is licensed under the license found in the
3
+ # LICENSE file in the root directory of this source tree.
4
+
5
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
6
+ # All rights reserved.
7
+ #
8
+ # This source code is licensed under the license found in the
9
+ # LICENSE file in the root directory of this source tree.
10
+ """Modules used for building the models."""
11
+
12
+ # flake8: noqa
13
+ from .conv import (
14
+ NormConv1d,
15
+ NormConvTranspose1d,
16
+ StreamingConv1d,
17
+ StreamingConvTranspose1d,
18
+ pad_for_conv1d,
19
+ pad1d,
20
+ unpad1d,
21
+ )
22
+ from .seanet import SEANetEncoder, SEANetDecoder
23
+ from .transformer import StreamingTransformer
moshi/modules/conv.py ADDED
@@ -0,0 +1,329 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Kyutai, all rights reserved.
2
+ # This source code is licensed under the license found in the
3
+ # LICENSE file in the root directory of this source tree.
4
+
5
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
6
+ # All rights reserved.
7
+ #
8
+ # This source code is licensed under the license found in the
9
+ # LICENSE file in the root directory of this source tree.
10
+
11
+ from dataclasses import dataclass
12
+ import math
13
+ import typing as tp
14
+ import warnings
15
+
16
+ import torch
17
+ from torch import nn
18
+ from torch.nn import functional as F
19
+ from torch.nn.utils import weight_norm
20
+
21
+ from .streaming import RawStreamingConv1d, RawStreamingConvTranspose1d, StreamingModule
22
+
23
+
24
+ CONV_NORMALIZATIONS = frozenset(["none", "weight_norm"])
25
+
26
+
27
+ class TransposedLayerNorm(nn.Module):
28
+ """LayerNorm for [B, C, T] inputs."""
29
+
30
+ def __init__(self, **kwargs):
31
+ super().__init__()
32
+ self.layer_norm = nn.LayerNorm(**kwargs)
33
+
34
+ def forward(self, x):
35
+ x = x.transpose(1, 2)
36
+ x = self.layer_norm(x)
37
+ return x.transpose(1, 2)
38
+
39
+
40
+ def apply_parametrization_norm(module: nn.Module, norm: str = "none"):
41
+ assert norm in CONV_NORMALIZATIONS
42
+ if norm == "weight_norm":
43
+ return weight_norm(module)
44
+ else:
45
+ # We already check was in CONV_NORMALIZATION, so any other choice
46
+ # doesn't need reparametrization.
47
+ return module
48
+
49
+
50
+ def get_extra_padding_for_conv1d(
51
+ x: torch.Tensor, kernel_size: int, stride: int, padding_total: int = 0
52
+ ) -> int:
53
+ """See `pad_for_conv1d`."""
54
+ length = x.shape[-1]
55
+ n_frames = (length - kernel_size + padding_total) / stride + 1
56
+ ideal_length = (math.ceil(n_frames) - 1) * stride + (kernel_size - padding_total)
57
+ return ideal_length - length
58
+
59
+
60
+ def pad_for_conv1d(
61
+ x: torch.Tensor, kernel_size: int, stride: int, padding_total: int = 0
62
+ ):
63
+ """Pad for a convolution to make sure that the last window is full.
64
+ Extra padding is added at the end. This is required to ensure that we can rebuild
65
+ an output of the same length, as otherwise, even with padding, some time steps
66
+ might get removed.
67
+ For instance, with total padding = 4, kernel size = 4, stride = 2:
68
+ 0 0 1 2 3 4 5 0 0 # (0s are padding)
69
+ 1 2 3 # (output frames of a convolution, last 0 is never used)
70
+ 0 0 1 2 3 4 5 0 # (output of tr. conv., but pos. 5 is going to get removed as padding)
71
+ 1 2 3 4 # once you removed padding, we are missing one time step !
72
+ """
73
+ extra_padding = get_extra_padding_for_conv1d(x, kernel_size, stride, padding_total)
74
+ return F.pad(x, (0, extra_padding))
75
+
76
+
77
+ def pad1d(
78
+ x: torch.Tensor,
79
+ paddings: tp.Tuple[int, int],
80
+ mode: str = "constant",
81
+ value: float = 0.0,
82
+ ):
83
+ """Tiny wrapper around F.pad, just to allow for reflect padding on small input.
84
+ If this is the case, we insert extra 0 padding to the right before the reflection happen.
85
+ """
86
+ length = x.shape[-1]
87
+ padding_left, padding_right = paddings
88
+ assert padding_left >= 0 and padding_right >= 0, (padding_left, padding_right)
89
+ if mode == "reflect":
90
+ max_pad = max(padding_left, padding_right)
91
+ extra_pad = 0
92
+ if length <= max_pad:
93
+ extra_pad = max_pad - length + 1
94
+ x = F.pad(x, (0, extra_pad))
95
+ padded = F.pad(x, paddings, mode, value)
96
+ end = padded.shape[-1] - extra_pad
97
+ return padded[..., :end]
98
+ else:
99
+ return F.pad(x, paddings, mode, value)
100
+
101
+
102
+ def unpad1d(x: torch.Tensor, paddings: tp.Tuple[int, int]):
103
+ """Remove padding from x, handling properly zero padding. Only for 1d!"""
104
+ padding_left, padding_right = paddings
105
+ assert padding_left >= 0 and padding_right >= 0, (padding_left, padding_right)
106
+ assert (padding_left + padding_right) <= x.shape[-1]
107
+ end = x.shape[-1] - padding_right
108
+ return x[..., padding_left:end]
109
+
110
+
111
+ class NormConv1d(nn.Module):
112
+ """Wrapper around Conv1d and normalization applied to this conv
113
+ to provide a uniform interface across normalization approaches.
114
+ """
115
+
116
+ def __init__(
117
+ self,
118
+ *args,
119
+ causal: bool = False,
120
+ norm: str = "none",
121
+ norm_kwargs: tp.Dict[str, tp.Any] = {},
122
+ **kwargs,
123
+ ):
124
+ super().__init__()
125
+ self.conv = apply_parametrization_norm(
126
+ RawStreamingConv1d(*args, **kwargs), norm
127
+ )
128
+ self.norm_type = norm
129
+
130
+ def forward(self, x):
131
+ x = self.conv(x)
132
+ return x
133
+
134
+
135
+ class NormConvTranspose1d(nn.Module):
136
+ """Wrapper around ConvTranspose1d and normalization applied to this conv
137
+ to provide a uniform interface across normalization approaches.
138
+ """
139
+
140
+ def __init__(
141
+ self,
142
+ *args,
143
+ causal: bool = False,
144
+ norm: str = "none",
145
+ norm_kwargs: tp.Dict[str, tp.Any] = {},
146
+ **kwargs,
147
+ ):
148
+ super().__init__()
149
+ self.convtr = apply_parametrization_norm(
150
+ RawStreamingConvTranspose1d(*args, **kwargs), norm
151
+ )
152
+ self.norm_type = norm
153
+
154
+ def forward(self, x):
155
+ x = self.convtr(x)
156
+ return x
157
+
158
+
159
+ @dataclass
160
+ class _StreamingConv1dState:
161
+ padding_to_add: int
162
+ original_padding_to_add: int
163
+
164
+ def reset(self):
165
+ self.padding_to_add = self.original_padding_to_add
166
+
167
+
168
+ class StreamingConv1d(StreamingModule[_StreamingConv1dState]):
169
+ """Conv1d with some builtin handling of asymmetric or causal padding
170
+ and normalization.
171
+ """
172
+
173
+ def __init__(
174
+ self,
175
+ in_channels: int,
176
+ out_channels: int,
177
+ kernel_size: int,
178
+ stride: int = 1,
179
+ dilation: int = 1,
180
+ groups: int = 1,
181
+ bias: bool = True,
182
+ causal: bool = False,
183
+ norm: str = "none",
184
+ norm_kwargs: tp.Dict[str, tp.Any] = {},
185
+ pad_mode: str = "reflect",
186
+ ):
187
+ super().__init__()
188
+ # warn user on unusual setup between dilation and stride
189
+ if stride > 1 and dilation > 1:
190
+ warnings.warn(
191
+ "StreamingConv1d has been initialized with stride > 1 and dilation > 1"
192
+ f" (kernel_size={kernel_size} stride={stride}, dilation={dilation})."
193
+ )
194
+ self.conv = NormConv1d(
195
+ in_channels,
196
+ out_channels,
197
+ kernel_size,
198
+ stride,
199
+ dilation=dilation,
200
+ groups=groups,
201
+ bias=bias,
202
+ causal=causal,
203
+ norm=norm,
204
+ norm_kwargs=norm_kwargs,
205
+ )
206
+ self.causal = causal
207
+ self.pad_mode = pad_mode
208
+
209
+ @property
210
+ def _stride(self) -> int:
211
+ return self.conv.conv.stride[0]
212
+
213
+ @property
214
+ def _kernel_size(self) -> int:
215
+ return self.conv.conv.kernel_size[0]
216
+
217
+ @property
218
+ def _effective_kernel_size(self) -> int:
219
+ dilation = self.conv.conv.dilation[0]
220
+ return (
221
+ self._kernel_size - 1
222
+ ) * dilation + 1 # effective kernel size with dilations
223
+
224
+ @property
225
+ def _padding_total(self) -> int:
226
+ return self._effective_kernel_size - self._stride
227
+
228
+ def _init_streaming_state(self, batch_size: int) -> _StreamingConv1dState:
229
+ assert self.causal, "streaming is only supported for causal convs"
230
+ return _StreamingConv1dState(self._padding_total, self._padding_total)
231
+
232
+ def forward(self, x):
233
+ B, C, T = x.shape
234
+ padding_total = self._padding_total
235
+ extra_padding = get_extra_padding_for_conv1d(
236
+ x, self._effective_kernel_size, self._stride, padding_total
237
+ )
238
+ state = self._streaming_state
239
+ if state is None:
240
+ if self.causal:
241
+ # Left padding for causal
242
+ x = pad1d(x, (padding_total, extra_padding), mode=self.pad_mode)
243
+ else:
244
+ # Asymmetric padding required for odd strides
245
+ padding_right = padding_total // 2
246
+ padding_left = padding_total - padding_right
247
+ x = pad1d(
248
+ x, (padding_left, padding_right + extra_padding), mode=self.pad_mode
249
+ )
250
+ else:
251
+ if state.padding_to_add > 0 and x.shape[-1] > 0:
252
+ x = pad1d(x, (state.padding_to_add, 0), mode=self.pad_mode)
253
+ state.padding_to_add = 0
254
+ return self.conv(x)
255
+
256
+
257
+ @dataclass
258
+ class _StreamingConvTr1dState:
259
+ pass
260
+
261
+ def reset(self):
262
+ pass
263
+
264
+
265
+ class StreamingConvTranspose1d(StreamingModule[_StreamingConvTr1dState]):
266
+ """ConvTranspose1d with some builtin handling of asymmetric or causal padding
267
+ and normalization.
268
+ """
269
+
270
+ def __init__(
271
+ self,
272
+ in_channels: int,
273
+ out_channels: int,
274
+ kernel_size: int,
275
+ stride: int = 1,
276
+ groups: int = 1,
277
+ bias: bool = True,
278
+ causal: bool = False,
279
+ norm: str = "none",
280
+ trim_right_ratio: float = 1.0,
281
+ norm_kwargs: tp.Dict[str, tp.Any] = {},
282
+ ):
283
+ super().__init__()
284
+ self.convtr = NormConvTranspose1d(
285
+ in_channels,
286
+ out_channels,
287
+ kernel_size,
288
+ stride,
289
+ groups=groups,
290
+ bias=bias,
291
+ causal=causal,
292
+ norm=norm,
293
+ norm_kwargs=norm_kwargs,
294
+ )
295
+ self.causal = causal
296
+ self.trim_right_ratio = trim_right_ratio
297
+ assert (
298
+ self.causal or self.trim_right_ratio == 1.0
299
+ ), "`trim_right_ratio` != 1.0 only makes sense for causal convolutions"
300
+ assert self.trim_right_ratio >= 0.0 and self.trim_right_ratio <= 1.0
301
+
302
+ def _init_streaming_state(self, batch_size: int) -> _StreamingConvTr1dState:
303
+ assert self.causal, "streaming is only supported for causal convtrs"
304
+ return _StreamingConvTr1dState()
305
+
306
+ def forward(self, x):
307
+ kernel_size = self.convtr.convtr.kernel_size[0]
308
+ stride = self.convtr.convtr.stride[0]
309
+ padding_total = kernel_size - stride
310
+
311
+ y = self.convtr(x)
312
+
313
+ if not self.is_streaming:
314
+ # We will only trim fixed padding. Extra padding from `pad_for_conv1d` would be
315
+ # removed at the very end, when keeping only the right length for the output,
316
+ # as removing it here would require also passing the length at the matching layer
317
+ # in the encoder.
318
+ if self.causal:
319
+ # Trim the padding on the right according to the specified ratio
320
+ # if trim_right_ratio = 1.0, trim everything from right
321
+ padding_right = math.ceil(padding_total * self.trim_right_ratio)
322
+ padding_left = padding_total - padding_right
323
+ y = unpad1d(y, (padding_left, padding_right))
324
+ else:
325
+ # Asymmetric padding required for odd strides
326
+ padding_right = padding_total // 2
327
+ padding_left = padding_total - padding_right
328
+ y = unpad1d(y, (padding_left, padding_right))
329
+ return y
moshi/modules/gating.py ADDED
@@ -0,0 +1,82 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Kyutai, all rights reserved.
2
+ # This source code is licensed under the license found in the
3
+ # LICENSE file in the root directory of this source tree.
4
+
5
+ import torch
6
+ from torch import nn
7
+ from torch.nn import functional as F
8
+
9
+ from ..utils.compile import torch_compile_lazy
10
+
11
+
12
+ @torch_compile_lazy
13
+ def gating_forward_kernel(
14
+ weight_in: torch.Tensor, weight_out: torch.Tensor, activation, x: torch.Tensor
15
+ ):
16
+ x = F.linear(x, weight_in)
17
+ B, T, _ = x.shape
18
+ x = x.view(B, T, 2, -1)
19
+ x = activation(x[..., 0, :]) * x[..., 1, :]
20
+ x = F.linear(x, weight_out)
21
+ return x
22
+
23
+
24
+ class ActivationGating(nn.Module):
25
+ """
26
+ Gating FFN layer, using the given activation.
27
+ Args:
28
+ dim (int): dimension of the input and output of the transformer.
29
+ activation (any callable Tensor to Tensor): activation function to use.
30
+ **factory_kwargs: other kwargs passed to the linear layer, in particular device and dtype.
31
+ """
32
+
33
+ _fsdp_final = True
34
+
35
+ def __init__(self, dim: int, dim_feedforward: int, activation, **factory_kwargs):
36
+ super().__init__()
37
+ # We should have 8 d^2 param, instead we will have
38
+ # 2 * h * d + h * d = 3 h * d = 8 d^2
39
+ # so h = 8 d / 3 but following Hervé's advice we use 21 / 8 as an approx.
40
+ if dim_feedforward == 4 * dim:
41
+ hidden = (21 * dim) // 8
42
+ else:
43
+ hidden = (2 * dim_feedforward) // 3
44
+ self.linear_in = nn.Linear(dim, 2 * hidden, bias=False, **factory_kwargs)
45
+ self.linear_out = nn.Linear(hidden, dim, bias=False, **factory_kwargs)
46
+ self.activation = activation
47
+
48
+ def forward(self, x: torch.Tensor):
49
+ return gating_forward_kernel(
50
+ self.linear_in.weight, self.linear_out.weight, self.activation, x
51
+ )
52
+
53
+
54
+ def _get_activation(name: str):
55
+ if name in ["sigmoid", "tanh", "relu"]:
56
+ return getattr(torch, name)
57
+ elif name in ["leaky_relu", "elu", "gelu", "silu", "mish", "softsign"]:
58
+ return getattr(torch.nn.functional, name)
59
+ elif name == "identity":
60
+ return torch.nn.Identity()
61
+ else:
62
+ raise ValueError(f"Unknown activation {name}")
63
+
64
+
65
+ def _make_gating(
66
+ name: str, dim: int, dim_feedforward: int, **factory_kwargs
67
+ ) -> nn.Module:
68
+ return ActivationGating(
69
+ dim, dim_feedforward, _get_activation(name), **factory_kwargs
70
+ )
71
+
72
+
73
+ def make_gating(
74
+ name: str, dim: int, dim_feedforward: int, **factory_kwargs
75
+ ) -> nn.Module:
76
+ gating = _make_gating(name, dim, dim_feedforward, **factory_kwargs)
77
+ max_params = 2 * dim * dim_feedforward
78
+ params = sum(p.numel() for p in gating.parameters())
79
+ assert (
80
+ params <= max_params
81
+ ), f"{name} gating has {params} params, max is {max_params}"
82
+ return gating
moshi/modules/resample.py ADDED
@@ -0,0 +1,119 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Kyutai, all rights reserved.
2
+ # This source code is licensed under the license found in the
3
+ # LICENSE file in the root directory of this source tree.
4
+
5
+ import typing as tp
6
+
7
+ from einops import rearrange
8
+ import torch
9
+ from torch import nn
10
+
11
+ from .conv import StreamingConv1d, StreamingConvTranspose1d
12
+
13
+
14
+ class ConvDownsample1d(nn.Module):
15
+ """
16
+ Downsampling by some integer amount `stride` using convolutions
17
+ with a kernel size of twice the stride.
18
+ If `causal` is True, the output uses a causal convolution.
19
+ """
20
+
21
+ def __init__(
22
+ self,
23
+ stride: int,
24
+ dimension: tp.Optional[int] = None,
25
+ causal: bool = False,
26
+ learnt: bool = False,
27
+ channel_wise: bool = False,
28
+ ):
29
+ super().__init__()
30
+ self.learnt = learnt
31
+ self.channel_wise = channel_wise
32
+ groups = 1
33
+ if learnt:
34
+ assert dimension is not None, "Dimension required for learnt convolutions."
35
+ in_channels = dimension
36
+ out_channels = dimension
37
+ if channel_wise:
38
+ groups = dimension
39
+ else:
40
+ in_channels = 1
41
+ out_channels = 1
42
+
43
+ self.conv = StreamingConv1d(
44
+ in_channels,
45
+ out_channels,
46
+ kernel_size=2 * stride,
47
+ stride=stride,
48
+ causal=causal,
49
+ groups=groups,
50
+ bias=False,
51
+ pad_mode="replicate",
52
+ )
53
+ if not learnt:
54
+ actual_conv = self.conv.conv.conv
55
+ actual_conv.weight.requires_grad_(False)
56
+ actual_conv.weight.data.fill_(1.0 / (2 * stride))
57
+
58
+ def forward(self, x: torch.Tensor):
59
+ batch_size = len(x)
60
+ if not self.learnt:
61
+ x = rearrange(x, "b c t -> (b c) () t")
62
+ y = self.conv(x)
63
+ if not self.learnt:
64
+ y = rearrange(y, "(b c) () t -> b c t", b=batch_size)
65
+ return y
66
+
67
+
68
+ class ConvTrUpsample1d(nn.Module):
69
+ """
70
+ Upsample by some integer amount `stride` using transposed convolutions.
71
+ """
72
+
73
+ def __init__(
74
+ self,
75
+ stride: int,
76
+ dimension: tp.Optional[int] = None,
77
+ causal: bool = False,
78
+ learnt: bool = False,
79
+ channel_wise: bool = False,
80
+ ):
81
+ super().__init__()
82
+ self.learnt = learnt
83
+ self.channel_wise = channel_wise
84
+ groups = 1
85
+ if learnt:
86
+ assert dimension is not None, "Dimension required for learnt convolutions."
87
+ in_channels = dimension
88
+ out_channels = dimension
89
+ if channel_wise:
90
+ groups = dimension
91
+ else:
92
+ in_channels = 1
93
+ out_channels = 1
94
+
95
+ self.convtr = StreamingConvTranspose1d(
96
+ in_channels,
97
+ out_channels,
98
+ kernel_size=2 * stride,
99
+ stride=stride,
100
+ causal=causal,
101
+ groups=groups,
102
+ bias=False,
103
+ )
104
+ if not learnt:
105
+ actual_convtr = self.convtr.convtr.convtr
106
+ actual_convtr.weight.requires_grad_(False)
107
+ actual_convtr.weight.data.fill_(1.0)
108
+
109
+ def forward(self, x: torch.Tensor):
110
+ batch_size = len(x)
111
+ if not self.learnt:
112
+ x = rearrange(x, "b c t -> (b c) () t")
113
+ y = self.convtr(x)
114
+ if not self.learnt:
115
+ x_for_normalization = torch.ones_like(x[:1])
116
+ normalization = self.convtr(x_for_normalization)
117
+ y = y / normalization
118
+ y = rearrange(y, "(b c) () t -> b c t", b=batch_size)
119
+ return y
moshi/modules/rope.py ADDED
@@ -0,0 +1,90 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Kyutai, all rights reserved.
2
+ # This source code is licensed under the license found in the
3
+ # LICENSE file in the root directory of this source tree.
4
+
5
+ from torch import nn
6
+ import math
7
+ import torch
8
+ from ..utils.compile import torch_compile_lazy
9
+
10
+
11
+ @torch_compile_lazy
12
+ def apply_rope(
13
+ q: torch.Tensor,
14
+ k: torch.Tensor,
15
+ offset: torch.Tensor,
16
+ max_period: float = 10_000,
17
+ time_before_heads: bool = False,
18
+ ):
19
+ """
20
+ Args:
21
+ q (torch.Tensor): queries, shape `[B, T, H, D]`.
22
+ k (torch.Tensor): keys, shape `[B, T, H, D]`.
23
+ offset (int): current offset, e.g. when streaming.
24
+ max_period (float): maximum period for the cos and sin.
25
+ time_before_heads (bool): if True, expected [B, T, H, D], else [B, H, T ,D]
26
+ """
27
+
28
+ if time_before_heads:
29
+ B, T, H, D = q.shape
30
+ else:
31
+ B, H, T, D = q.shape
32
+ assert k.shape == q.shape
33
+ assert D > 0
34
+ assert D % 2 == 0
35
+ assert max_period > 0
36
+
37
+ ds = torch.arange(D // 2, device=q.device, dtype=torch.float32)
38
+ freqs = torch.exp(ds * (-math.log(max_period) * 2 / D))
39
+ ts = offset.float() + torch.arange(T, device=q.device, dtype=torch.float32)
40
+ if time_before_heads:
41
+ ts = ts.view(-1, 1, 1)
42
+ else:
43
+ ts = ts.view(1, -1, 1)
44
+
45
+ dims = q.shape[:-1]
46
+ q = q.view(*dims, D // 2, 2)
47
+ k = k.view(*dims, D // 2, 2)
48
+
49
+ # convention is `r` suffix is real part, `i` is imaginary.
50
+ qr = q[..., 0].float()
51
+ qi = q[..., 1].float()
52
+
53
+ kr = k[..., 0].float()
54
+ ki = k[..., 1].float()
55
+
56
+ rotr = torch.cos(freqs * ts)
57
+ roti = torch.sin(freqs * ts)
58
+ qor = qr * rotr - qi * roti
59
+ qoi = qr * roti + qi * rotr
60
+
61
+ kor = kr * rotr - ki * roti
62
+ koi = kr * roti + ki * rotr
63
+
64
+ dtype = q.dtype
65
+ qo = torch.stack([qor.to(dtype), qoi.to(dtype)], dim=-1)
66
+ ko = torch.stack([kor.to(dtype), koi.to(dtype)], dim=-1)
67
+
68
+ return qo.view(*dims, D), ko.view(*dims, D)
69
+
70
+
71
+ class RotaryEmbedding(nn.Module):
72
+ """Rotary positional embedding (RoPE) from [Su et al 2022](https://arxiv.org/abs/2104.09864).
73
+
74
+ Args:
75
+ max_period (float): Maximum period of the rotation frequencies.
76
+ """
77
+
78
+ def __init__(self, max_period: float = 10000.0):
79
+ super().__init__()
80
+ self.max_period = max_period
81
+
82
+ def forward(
83
+ self,
84
+ q: torch.Tensor,
85
+ k: torch.Tensor,
86
+ offset: torch.Tensor,
87
+ time_before_heads: bool = False,
88
+ ):
89
+ """Apply rope rotation to query or key tensor."""
90
+ return apply_rope(q, k, offset, self.max_period, time_before_heads)
moshi/modules/seanet.py ADDED
@@ -0,0 +1,395 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Kyutai, all rights reserved.
2
+ # This source code is licensed under the license found in the
3
+ # LICENSE file in the root directory of this source tree.
4
+
5
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
6
+ # All rights reserved.
7
+ #
8
+ # This source code is licensed under the license found in the
9
+ # LICENSE file in the root directory of this source tree.
10
+
11
+ import typing as tp
12
+
13
+ import numpy as np
14
+ import torch.nn as nn
15
+
16
+ from .conv import StreamingConv1d, StreamingConvTranspose1d
17
+ from .streaming import StreamingContainer, StreamingAdd
18
+ from ..utils.compile import torch_compile_lazy
19
+
20
+
21
+ class SEANetResnetBlock(StreamingContainer):
22
+ """Residual block from SEANet model.
23
+
24
+ Args:
25
+ dim (int): Dimension of the input/output.
26
+ kernel_sizes (list): List of kernel sizes for the convolutions.
27
+ dilations (list): List of dilations for the convolutions.
28
+ activation (str): Activation function.
29
+ activation_params (dict): Parameters to provide to the activation function.
30
+ norm (str): Normalization method.
31
+ norm_params (dict): Parameters to provide to the underlying normalization used along with the convolution.
32
+ causal (bool): Whether to use fully causal convolution.
33
+ pad_mode (str): Padding mode for the convolutions.
34
+ compress (int): Reduced dimensionality in residual branches (from Demucs v3).
35
+ true_skip (bool): Whether to use true skip connection or a simple
36
+ (streamable) convolution as the skip connection.
37
+ """
38
+
39
+ def __init__(
40
+ self,
41
+ dim: int,
42
+ kernel_sizes: tp.List[int] = [3, 1],
43
+ dilations: tp.List[int] = [1, 1],
44
+ activation: str = "ELU",
45
+ activation_params: dict = {"alpha": 1.0},
46
+ norm: str = "none",
47
+ norm_params: tp.Dict[str, tp.Any] = {},
48
+ causal: bool = False,
49
+ pad_mode: str = "reflect",
50
+ compress: int = 2,
51
+ true_skip: bool = True,
52
+ ):
53
+ super().__init__()
54
+ assert len(kernel_sizes) == len(
55
+ dilations
56
+ ), "Number of kernel sizes should match number of dilations"
57
+ act = getattr(nn, activation)
58
+ hidden = dim // compress
59
+ block = []
60
+ for i, (kernel_size, dilation) in enumerate(zip(kernel_sizes, dilations)):
61
+ in_chs = dim if i == 0 else hidden
62
+ out_chs = dim if i == len(kernel_sizes) - 1 else hidden
63
+ block += [
64
+ act(**activation_params),
65
+ StreamingConv1d(
66
+ in_chs,
67
+ out_chs,
68
+ kernel_size=kernel_size,
69
+ dilation=dilation,
70
+ norm=norm,
71
+ norm_kwargs=norm_params,
72
+ causal=causal,
73
+ pad_mode=pad_mode,
74
+ ),
75
+ ]
76
+ self.block = nn.Sequential(*block)
77
+ self.add = StreamingAdd()
78
+ self.shortcut: nn.Module
79
+ if true_skip:
80
+ self.shortcut = nn.Identity()
81
+ else:
82
+ self.shortcut = StreamingConv1d(
83
+ dim,
84
+ dim,
85
+ kernel_size=1,
86
+ norm=norm,
87
+ norm_kwargs=norm_params,
88
+ causal=causal,
89
+ pad_mode=pad_mode,
90
+ )
91
+
92
+ def forward(self, x):
93
+ u, v = self.shortcut(x), self.block(x)
94
+ return self.add(u, v)
95
+
96
+
97
+ class SEANetEncoder(StreamingContainer):
98
+ """SEANet encoder.
99
+
100
+ Args:
101
+ channels (int): Audio channels.
102
+ dimension (int): Intermediate representation dimension.
103
+ n_filters (int): Base width for the model.
104
+ n_residual_layers (int): nb of residual layers.
105
+ ratios (Sequence[int]): kernel size and stride ratios. The encoder uses downsampling ratios instead of
106
+ upsampling ratios, hence it will use the ratios in the reverse order to the ones specified here
107
+ that must match the decoder order. We use the decoder order as some models may only employ the decoder.
108
+ activation (str): Activation function.
109
+ activation_params (dict): Parameters to provide to the activation function.
110
+ norm (str): Normalization method.
111
+ norm_params (dict): Parameters to provide to the underlying normalization used along with the convolution.
112
+ kernel_size (int): Kernel size for the initial convolution.
113
+ last_kernel_size (int): Kernel size for the initial convolution.
114
+ residual_kernel_size (int): Kernel size for the residual layers.
115
+ dilation_base (int): How much to increase the dilation with each layer.
116
+ causal (bool): Whether to use fully causal convolution.
117
+ pad_mode (str): Padding mode for the convolutions.
118
+ true_skip (bool): Whether to use true skip connection or a simple
119
+ (streamable) convolution as the skip connection in the residual network blocks.
120
+ compress (int): Reduced dimensionality in residual branches (from Demucs v3).
121
+ disable_norm_outer_blocks (int): Number of blocks for which we don't apply norm.
122
+ For the encoder, it corresponds to the N first blocks.
123
+ mask_fn (nn.Module): Optional mask function to apply after convolution layers.
124
+ mask_position (int): Position of the mask function, with mask_position == 0 for the first convolution layer,
125
+ mask_position == 1 for the first conv block, etc.
126
+ """
127
+
128
+ def __init__(
129
+ self,
130
+ channels: int = 1,
131
+ dimension: int = 128,
132
+ n_filters: int = 32,
133
+ n_residual_layers: int = 3,
134
+ ratios: tp.List[int] = [8, 5, 4, 2],
135
+ activation: str = "ELU",
136
+ activation_params: dict = {"alpha": 1.0},
137
+ norm: str = "none",
138
+ norm_params: tp.Dict[str, tp.Any] = {},
139
+ kernel_size: int = 7,
140
+ last_kernel_size: int = 7,
141
+ residual_kernel_size: int = 3,
142
+ dilation_base: int = 2,
143
+ causal: bool = False,
144
+ pad_mode: str = "reflect",
145
+ true_skip: bool = True,
146
+ compress: int = 2,
147
+ disable_norm_outer_blocks: int = 0,
148
+ mask_fn: tp.Optional[nn.Module] = None,
149
+ mask_position: tp.Optional[int] = None,
150
+ ):
151
+ super().__init__()
152
+ self.channels = channels
153
+ self.dimension = dimension
154
+ self.n_filters = n_filters
155
+ self.ratios = list(reversed(ratios))
156
+ del ratios
157
+ self.n_residual_layers = n_residual_layers
158
+ self.hop_length = int(np.prod(self.ratios))
159
+ self.n_blocks = len(self.ratios) + 2 # first and last conv + residual blocks
160
+ self.disable_norm_outer_blocks = disable_norm_outer_blocks
161
+ assert (
162
+ self.disable_norm_outer_blocks >= 0 and self.disable_norm_outer_blocks <= self.n_blocks
163
+ ), (
164
+ "Number of blocks for which to disable norm is invalid."
165
+ "It should be lower or equal to the actual number of blocks in the network and greater or equal to 0."
166
+ )
167
+
168
+ act = getattr(nn, activation)
169
+ mult = 1
170
+ model: tp.List[nn.Module] = [
171
+ StreamingConv1d(
172
+ channels,
173
+ mult * n_filters,
174
+ kernel_size,
175
+ norm="none" if self.disable_norm_outer_blocks >= 1 else norm,
176
+ norm_kwargs=norm_params,
177
+ causal=causal,
178
+ pad_mode=pad_mode,
179
+ )
180
+ ]
181
+ if mask_fn is not None and mask_position == 0:
182
+ model += [mask_fn]
183
+ # Downsample to raw audio scale
184
+ for i, ratio in enumerate(self.ratios):
185
+ block_norm = "none" if self.disable_norm_outer_blocks >= i + 2 else norm
186
+ # Add residual layers
187
+ for j in range(n_residual_layers):
188
+ model += [
189
+ SEANetResnetBlock(
190
+ mult * n_filters,
191
+ kernel_sizes=[residual_kernel_size, 1],
192
+ dilations=[dilation_base**j, 1],
193
+ norm=block_norm,
194
+ norm_params=norm_params,
195
+ activation=activation,
196
+ activation_params=activation_params,
197
+ causal=causal,
198
+ pad_mode=pad_mode,
199
+ compress=compress,
200
+ true_skip=true_skip,
201
+ )
202
+ ]
203
+
204
+ # Add downsampling layers
205
+ model += [
206
+ act(**activation_params),
207
+ StreamingConv1d(
208
+ mult * n_filters,
209
+ mult * n_filters * 2,
210
+ kernel_size=ratio * 2,
211
+ stride=ratio,
212
+ norm=block_norm,
213
+ norm_kwargs=norm_params,
214
+ causal=causal,
215
+ pad_mode=pad_mode,
216
+ ),
217
+ ]
218
+ mult *= 2
219
+ if mask_fn is not None and mask_position == i + 1:
220
+ model += [mask_fn]
221
+
222
+ model += [
223
+ act(**activation_params),
224
+ StreamingConv1d(
225
+ mult * n_filters,
226
+ dimension,
227
+ last_kernel_size,
228
+ norm=(
229
+ "none" if self.disable_norm_outer_blocks == self.n_blocks else norm
230
+ ),
231
+ norm_kwargs=norm_params,
232
+ causal=causal,
233
+ pad_mode=pad_mode,
234
+ ),
235
+ ]
236
+
237
+ self.model = nn.Sequential(*model)
238
+
239
+ @torch_compile_lazy
240
+ def forward(self, x):
241
+ return self.model(x)
242
+
243
+
244
+ class SEANetDecoder(StreamingContainer):
245
+ """SEANet decoder.
246
+
247
+ Args:
248
+ channels (int): Audio channels.
249
+ dimension (int): Intermediate representation dimension.
250
+ n_filters (int): Base width for the model.
251
+ n_residual_layers (int): nb of residual layers.
252
+ ratios (Sequence[int]): kernel size and stride ratios.
253
+ activation (str): Activation function.
254
+ activation_params (dict): Parameters to provide to the activation function.
255
+ final_activation (str): Final activation function after all convolutions.
256
+ final_activation_params (dict): Parameters to provide to the activation function.
257
+ norm (str): Normalization method.
258
+ norm_params (dict): Parameters to provide to the underlying normalization used along with the convolution.
259
+ kernel_size (int): Kernel size for the initial convolution.
260
+ last_kernel_size (int): Kernel size for the initial convolution.
261
+ residual_kernel_size (int): Kernel size for the residual layers.
262
+ dilation_base (int): How much to increase the dilation with each layer.
263
+ causal (bool): Whether to use fully causal convolution.
264
+ pad_mode (str): Padding mode for the convolutions.
265
+ true_skip (bool): Whether to use true skip connection or a simple.
266
+ (streamable) convolution as the skip connection in the residual network blocks.
267
+ compress (int): Reduced dimensionality in residual branches (from Demucs v3).
268
+ disable_norm_outer_blocks (int): Number of blocks for which we don't apply norm.
269
+ For the decoder, it corresponds to the N last blocks.
270
+ trim_right_ratio (float): Ratio for trimming at the right of the transposed convolution under the causal setup.
271
+ If equal to 1.0, it means that all the trimming is done at the right.
272
+ """
273
+
274
+ def __init__(
275
+ self,
276
+ channels: int = 1,
277
+ dimension: int = 128,
278
+ n_filters: int = 32,
279
+ n_residual_layers: int = 3,
280
+ ratios: tp.List[int] = [8, 5, 4, 2],
281
+ activation: str = "ELU",
282
+ activation_params: dict = {"alpha": 1.0},
283
+ final_activation: tp.Optional[str] = None,
284
+ final_activation_params: tp.Optional[dict] = None,
285
+ norm: str = "none",
286
+ norm_params: tp.Dict[str, tp.Any] = {},
287
+ kernel_size: int = 7,
288
+ last_kernel_size: int = 7,
289
+ residual_kernel_size: int = 3,
290
+ dilation_base: int = 2,
291
+ causal: bool = False,
292
+ pad_mode: str = "reflect",
293
+ true_skip: bool = True,
294
+ compress: int = 2,
295
+ disable_norm_outer_blocks: int = 0,
296
+ trim_right_ratio: float = 1.0,
297
+ ):
298
+ super().__init__()
299
+ self.dimension = dimension
300
+ self.channels = channels
301
+ self.n_filters = n_filters
302
+ self.ratios = ratios
303
+ del ratios
304
+ self.n_residual_layers = n_residual_layers
305
+ self.hop_length = int(np.prod(self.ratios))
306
+ self.n_blocks = len(self.ratios) + 2 # first and last conv + residual blocks
307
+ self.disable_norm_outer_blocks = disable_norm_outer_blocks
308
+ assert (
309
+ self.disable_norm_outer_blocks >= 0 and self.disable_norm_outer_blocks <= self.n_blocks
310
+ ), (
311
+ "Number of blocks for which to disable norm is invalid."
312
+ "It should be lower or equal to the actual number of blocks in the network and greater or equal to 0."
313
+ )
314
+
315
+ act = getattr(nn, activation)
316
+ mult = int(2 ** len(self.ratios))
317
+ model: tp.List[nn.Module] = [
318
+ StreamingConv1d(
319
+ dimension,
320
+ mult * n_filters,
321
+ kernel_size,
322
+ norm=(
323
+ "none" if self.disable_norm_outer_blocks == self.n_blocks else norm
324
+ ),
325
+ norm_kwargs=norm_params,
326
+ causal=causal,
327
+ pad_mode=pad_mode,
328
+ )
329
+ ]
330
+
331
+ # Upsample to raw audio scale
332
+ for i, ratio in enumerate(self.ratios):
333
+ block_norm = (
334
+ "none"
335
+ if self.disable_norm_outer_blocks >= self.n_blocks - (i + 1)
336
+ else norm
337
+ )
338
+ # Add upsampling layers
339
+ model += [
340
+ act(**activation_params),
341
+ StreamingConvTranspose1d(
342
+ mult * n_filters,
343
+ mult * n_filters // 2,
344
+ kernel_size=ratio * 2,
345
+ stride=ratio,
346
+ norm=block_norm,
347
+ norm_kwargs=norm_params,
348
+ causal=causal,
349
+ trim_right_ratio=trim_right_ratio,
350
+ ),
351
+ ]
352
+ # Add residual layers
353
+ for j in range(n_residual_layers):
354
+ model += [
355
+ SEANetResnetBlock(
356
+ mult * n_filters // 2,
357
+ kernel_sizes=[residual_kernel_size, 1],
358
+ dilations=[dilation_base**j, 1],
359
+ activation=activation,
360
+ activation_params=activation_params,
361
+ norm=block_norm,
362
+ norm_params=norm_params,
363
+ causal=causal,
364
+ pad_mode=pad_mode,
365
+ compress=compress,
366
+ true_skip=true_skip,
367
+ )
368
+ ]
369
+
370
+ mult //= 2
371
+
372
+ # Add final layers
373
+ model += [
374
+ act(**activation_params),
375
+ StreamingConv1d(
376
+ n_filters,
377
+ channels,
378
+ last_kernel_size,
379
+ norm="none" if self.disable_norm_outer_blocks >= 1 else norm,
380
+ norm_kwargs=norm_params,
381
+ causal=causal,
382
+ pad_mode=pad_mode,
383
+ ),
384
+ ]
385
+ # Add optional final activation to decoder (eg. tanh)
386
+ if final_activation is not None:
387
+ final_act = getattr(nn, final_activation)
388
+ final_activation_params = final_activation_params or {}
389
+ model += [final_act(**final_activation_params)]
390
+ self.model = nn.Sequential(*model)
391
+
392
+ @torch_compile_lazy
393
+ def forward(self, z):
394
+ y = self.model(z)
395
+ return y
moshi/modules/streaming.py ADDED
@@ -0,0 +1,363 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Kyutai, all rights reserved.
2
+ # This source code is licensed under the license found in the
3
+ # LICENSE file in the root directory of this source tree.
4
+
5
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
6
+ # All rights reserved.
7
+ #
8
+ # This source code is licensed under the license found in the
9
+ # LICENSE file in the root directory of this source tree.
10
+
11
+ """
12
+ Streaming module API that should be implemented by all Streaming components,
13
+ """
14
+
15
+ import abc
16
+ from contextlib import contextmanager
17
+ from dataclasses import dataclass
18
+ import itertools
19
+ import math
20
+ import typing as tp
21
+ from torch import nn
22
+ import torch
23
+
24
+
25
+ class Resetable(tp.Protocol):
26
+ def reset(self) -> None:
27
+ pass
28
+
29
+
30
+ State = tp.TypeVar("State", bound=Resetable)
31
+
32
+
33
+ class StreamingModule(abc.ABC, nn.Module, tp.Generic[State]):
34
+ """Common API for streaming components.
35
+
36
+ Each streaming component has a streaming state, which is just a dict[str, Tensor].
37
+ By convention, the first dim of each tensor must be the batch size.
38
+ Don't use dots in the key names, as this would clash with submodules
39
+ (like in state_dict).
40
+
41
+ If `self._is_streaming` is True, the component should use and remember
42
+ the proper state inside `self._streaming_state`.
43
+
44
+ To set a streaming component in streaming state, use
45
+
46
+ with module.streaming():
47
+ ...
48
+
49
+ This will automatically reset the streaming state when exiting the context manager.
50
+ This also automatically propagates to all streaming children module.
51
+
52
+ Some module might also implement the `StreamingModule.flush` method, although
53
+ this one is trickier, as all parents module must be StreamingModule and implement
54
+ it as well for it to work properly. See `StreamingSequential` after.
55
+ """
56
+
57
+ def __init__(self) -> None:
58
+ super().__init__()
59
+ self._streaming_state: State | None = None
60
+ self._streaming_propagate: bool = True
61
+
62
+ @property
63
+ def is_streaming(self):
64
+ return self._streaming_state is not None
65
+
66
+ def set_streaming_propagate(self, streaming_propagate: bool):
67
+ self._streaming_propagate = streaming_propagate
68
+
69
+ def _apply_named_streaming(self, fn: tp.Any):
70
+ def _handle_module(prefix: str, module: nn.Module, recurse: bool = True):
71
+ propagate = True
72
+ if isinstance(module, StreamingModule):
73
+ if module._streaming_propagate:
74
+ fn(prefix, module)
75
+ else:
76
+ propagate = False
77
+ if not recurse:
78
+ return
79
+ if propagate:
80
+ for name, child in module.named_children():
81
+ _handle_module(prefix + "." + name, child)
82
+
83
+ _handle_module("", self, recurse=False)
84
+ for name, child in self.named_children():
85
+ _handle_module(name, child)
86
+
87
+ def _start_streaming(self, batch_size: int):
88
+ def _start_streaming(name: str, module: StreamingModule):
89
+ module._streaming_state = module._init_streaming_state(batch_size)
90
+
91
+ self._apply_named_streaming(_start_streaming)
92
+
93
+ def _stop_streaming(self):
94
+ def _stop_streaming(name: str, module: StreamingModule):
95
+ module._streaming_state = None
96
+
97
+ self._apply_named_streaming(_stop_streaming)
98
+
99
+ @abc.abstractmethod
100
+ def _init_streaming_state(self, batch_size: int) -> State: ...
101
+
102
+ def streaming_forever(self, batch_size: int):
103
+ self._start_streaming(batch_size)
104
+
105
+ @contextmanager
106
+ def streaming(self, batch_size: int):
107
+ """Context manager to enter streaming mode. Reset streaming state on exit."""
108
+
109
+ self._start_streaming(batch_size)
110
+ try:
111
+ yield
112
+ finally:
113
+ self._stop_streaming()
114
+
115
+ def reset_streaming(self):
116
+ """Reset the streaming state."""
117
+
118
+ def _reset(name: str, module: StreamingModule):
119
+ state = module._streaming_state
120
+ if state is None:
121
+ raise ValueError(
122
+ f"Trying to reset streaming, but {name} wasn't streaming."
123
+ )
124
+ state.reset()
125
+
126
+ self._apply_named_streaming(_reset)
127
+
128
+ def get_streaming_state(self) -> dict[str, tp.Any]:
129
+ """Return the complete streaming state, including that of sub-modules."""
130
+ state: dict[str, tp.Any] = {}
131
+
132
+ def _add(name: str, module: StreamingModule):
133
+ state[name] = module._streaming_state
134
+
135
+ self._apply_named_streaming(_add)
136
+ return state
137
+
138
+ def set_streaming_state(self, state: dict[str, tp.Any]):
139
+ """Set the streaming state, including that of sub-modules."""
140
+ state = dict(state)
141
+
142
+ def _set(name: str, module: StreamingModule):
143
+ if name in state:
144
+ module._streaming_state = state[name]
145
+ state.pop(name)
146
+ else:
147
+ raise RuntimeError(f"Expected to find a streaming state for {name}.")
148
+
149
+ self._apply_named_streaming(_set)
150
+ if state:
151
+ raise RuntimeError(f"Some states were not consumed: {list(state.keys())}")
152
+
153
+
154
+ @dataclass
155
+ class _NullState:
156
+ pass
157
+
158
+ def reset(self) -> None:
159
+ pass
160
+
161
+
162
+ class StreamingContainer(StreamingModule[_NullState]):
163
+ def _init_streaming_state(self, batch_size: int) -> _NullState:
164
+ return _NullState()
165
+
166
+
167
+ @dataclass
168
+ class _StreamingAddState:
169
+ previous_x: torch.Tensor | None = None
170
+ previous_y: torch.Tensor | None = None
171
+
172
+ def reset(self):
173
+ self.previous_x = None
174
+ self.previous_y = None
175
+
176
+
177
+ class StreamingAdd(StreamingModule[_StreamingAddState]):
178
+ def _init_streaming_state(self, batch_size: int) -> _StreamingAddState:
179
+ return _StreamingAddState()
180
+
181
+ def forward(self, x: torch.Tensor, y: torch.Tensor):
182
+ if self._streaming_state is None:
183
+ return x + y
184
+ else:
185
+ prev_x = self._streaming_state.previous_x
186
+ prev_y = self._streaming_state.previous_y
187
+ if prev_x is not None:
188
+ x = torch.cat([prev_x, x], dim=-1)
189
+ if prev_y is not None:
190
+ y = torch.cat([prev_y, y], dim=-1)
191
+ m_l = min(x.shape[-1], y.shape[-1])
192
+ self._streaming_state.previous_x = x[..., m_l:]
193
+ self._streaming_state.previous_y = y[..., m_l:]
194
+ return x[..., :m_l] + y[..., :m_l]
195
+
196
+
197
+ @dataclass
198
+ class _StreamingConvState:
199
+ previous: torch.Tensor | None = None
200
+
201
+ def reset(self):
202
+ self.previous = None
203
+
204
+
205
+ class RawStreamingConv1d(nn.Conv1d, StreamingModule[_StreamingConvState]):
206
+ def __init__(self, *args, **kwargs):
207
+ super().__init__(*args, **kwargs)
208
+ assert self.padding[0] == 0, "Padding should be handled outside."
209
+ assert (
210
+ self.stride[0] <= self.kernel_size[0]
211
+ ), "stride must be less than kernel_size."
212
+
213
+ def _init_streaming_state(self, batch_size: int) -> _StreamingConvState:
214
+ return _StreamingConvState()
215
+
216
+ def forward(self, input: torch.Tensor) -> torch.Tensor:
217
+ stride = self.stride[0]
218
+ # Effective kernel size accounting for dilation.
219
+ kernel = (self.kernel_size[0] - 1) * self.dilation[0] + 1
220
+ if self._streaming_state is None:
221
+ return super().forward(input)
222
+ else:
223
+ # Due to the potential overlap, we might have some cache of the previous time steps.
224
+ previous = self._streaming_state.previous
225
+ if previous is not None:
226
+ input = torch.cat([previous, input], dim=-1)
227
+ B, C, T = input.shape
228
+ # We now compute the number of full convolution frames, i.e. the frames
229
+ # that are ready to be computed.
230
+ num_frames = max(0, int(math.floor((T - kernel) / stride) + 1))
231
+ offset = num_frames * stride
232
+ # We will compute `num_frames` outputs, and we are advancing by `stride`
233
+ # for each of the frame, so we know the data before `stride * num_frames`
234
+ # will never be used again.
235
+ self._streaming_state.previous = input[..., offset:]
236
+ if num_frames > 0:
237
+ input_length = (num_frames - 1) * stride + kernel
238
+ out = super().forward(input[..., :input_length])
239
+ else:
240
+ # Not enough data as this point to output some new frames.
241
+ out = torch.empty(
242
+ B, self.out_channels, 0, device=input.device, dtype=input.dtype
243
+ )
244
+ return out
245
+
246
+
247
+ @dataclass
248
+ class _StreamingConvTrState:
249
+ partial: torch.Tensor | None = None
250
+
251
+ def reset(self):
252
+ self.partial = None
253
+
254
+
255
+ class RawStreamingConvTranspose1d(
256
+ nn.ConvTranspose1d, StreamingModule[_StreamingConvTrState]
257
+ ):
258
+ def __init__(self, *args, **kwargs):
259
+ super().__init__(*args, **kwargs)
260
+ assert self.padding[0] == 0, "Padding should be handled outside."
261
+ assert self.dilation[0] == 1, "No dilation for now"
262
+ assert (
263
+ self.stride[0] <= self.kernel_size[0]
264
+ ), "stride must be less than kernel_size."
265
+ assert self.output_padding[0] == 0, "Output padding not supported."
266
+
267
+ def _init_streaming_state(self, batch_size: int) -> _StreamingConvTrState:
268
+ return _StreamingConvTrState()
269
+
270
+ def forward(self, x: torch.Tensor) -> torch.Tensor: # type: ignore
271
+ B, C, T = x.shape
272
+ stride = self.stride[0]
273
+ kernel = self.kernel_size[0]
274
+ if self._streaming_state is None:
275
+ return super().forward(x)
276
+ else:
277
+ if T == 0:
278
+ return torch.empty(
279
+ B, self.out_channels, 0, device=x.device, dtype=x.dtype
280
+ )
281
+ out = super().forward(x)
282
+ OT = out.shape[-1]
283
+ partial = self._streaming_state.partial
284
+ if partial is not None:
285
+ # Due to the potential overlap, the rightmost output of the conv transpose is not
286
+ # ready to be output, as it will receive contributions from the next input frames.
287
+ # Here we recover those `partial` output frames. We know that the first time step
288
+ # of the `partial` tensor corresponds to the first time step of `out` as anything
289
+ # coming before the first time step of `out` would have been already flushed.
290
+ PT = partial.shape[-1]
291
+ if self.bias is not None:
292
+ out[..., :PT] += partial - self.bias[:, None]
293
+ else:
294
+ out[..., :PT] += partial
295
+ # The input is T, the output is S * (T - 1) + K.
296
+ # The offset of the left of the next frame will be S * T
297
+ # so everything between 0 and S * T is ready to be output, and we need
298
+ # to keep in the internal state everything beyond that, i.e. S (T - 1) + K - S T = K - S
299
+ invalid_steps = kernel - stride
300
+ partial = out[..., OT - invalid_steps :]
301
+ out = out[..., : OT - invalid_steps]
302
+ self._streaming_state.partial = partial
303
+ return out
304
+
305
+
306
+ def test():
307
+ torch.manual_seed(1234)
308
+ device = "cpu"
309
+ if torch.cuda.is_available():
310
+ # Avoid the cuda optimizations that would take place on single precision
311
+ # floats for convolutions.
312
+ torch.backends.cudnn.enabled = True
313
+ torch.backends.cudnn.benchmark = False
314
+ torch.backends.cudnn.deterministic = True
315
+ torch.backends.cuda.matmul.allow_tf32 = False
316
+ torch.backends.cudnn.allow_tf32 = False
317
+ device = "cuda:0"
318
+
319
+ kernel_sizes = [1, 3, 4, 8, 15, 16]
320
+ strides = [1, 2, 3, 4, 5, 6, 7, 8, 9]
321
+ chin = 6
322
+ chout = 12
323
+
324
+ for kernel, stride in itertools.product(kernel_sizes, strides):
325
+ if stride > kernel:
326
+ continue
327
+ conv = RawStreamingConv1d(chin, chout, kernel, stride).to(device)
328
+ convtr = RawStreamingConvTranspose1d(chout, chin, kernel, stride).to(device)
329
+
330
+ for length in [4, 8, 32, 54, 65, 128, 1043]:
331
+ print(f"ksize {kernel} strides {stride} len {length}")
332
+ if length < kernel:
333
+ continue
334
+ batch_size = 3
335
+ x = torch.randn(batch_size, chin, length).to(device)
336
+ y = conv(x)
337
+ z = convtr(y)
338
+ for chunk_size in [1, 3, 5, 8]:
339
+ ys = []
340
+ zs = []
341
+ with conv.streaming(batch_size), convtr.streaming(batch_size):
342
+ for offset in range(0, length, chunk_size):
343
+ chunk = x[..., offset : offset + chunk_size]
344
+ ys.append(conv(chunk))
345
+ zs.append(convtr(ys[-1]))
346
+ y_stream = torch.cat(ys, dim=-1)
347
+ z_stream = torch.cat(zs, dim=-1)
348
+ y = y[..., : y_stream.shape[-1]]
349
+ z = z[..., : z_stream.shape[-1]]
350
+ assert y.shape == y_stream.shape, (y.shape, y_stream.shape)
351
+ delta = (y_stream - y).norm() / y.norm()
352
+ assert delta <= 1e-6, delta
353
+ num_frames = int((length - kernel) / stride) + 1
354
+ assert num_frames == y_stream.shape[-1]
355
+
356
+ assert z.shape == z_stream.shape, (z.shape, z_stream.shape)
357
+ delta = (z_stream - z).norm() / z.norm()
358
+ assert delta <= 1e-6, (delta, (z_stream - z).abs().mean(dim=(0, 1)))
359
+
360
+
361
+ if __name__ == "__main__":
362
+ with torch.no_grad():
363
+ test()
moshi/modules/transformer.py ADDED
@@ -0,0 +1,750 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Kyutai, all rights reserved.
2
+ # This source code is licensed under the license found in the
3
+ # LICENSE file in the root directory of this source tree.
4
+
5
+ """
6
+ Transformer model, with streaming support, + CUDA Graphable.
7
+ Optimized for inference.
8
+
9
+ See `StreamingTransformer` for more information.
10
+ """
11
+
12
+ from contextlib import ExitStack
13
+ from dataclasses import dataclass
14
+ import typing as tp
15
+
16
+ from einops import rearrange
17
+ import torch
18
+ import torch.nn as nn
19
+ from torch.nn import functional as F
20
+
21
+ from ..utils.compile import no_compile
22
+ from .gating import make_gating
23
+ from .rope import RotaryEmbedding
24
+ from .streaming import StreamingModule, StreamingContainer
25
+
26
+
27
+ class LayerNormF32(nn.LayerNorm):
28
+ def forward(self, input: torch.Tensor) -> torch.Tensor:
29
+ x_f32 = input.float()
30
+ out_f32 = super().forward(x_f32)
31
+ return out_f32.to(input.dtype)
32
+
33
+
34
+ def _rms_norm(
35
+ x: torch.Tensor,
36
+ alpha: torch.Tensor,
37
+ dtype: tp.Optional[torch.dtype],
38
+ eps: float,
39
+ ):
40
+ assert x.dim() == 3, f"RMSNorm expects 3D inputs but got {x.shape}"
41
+ x_dtype = x.dtype
42
+ if dtype is not None:
43
+ x = x.to(dtype)
44
+ var = eps + torch.mean(x**2, dim=2, keepdim=True)
45
+ y = (x * (alpha.to(var) * torch.rsqrt(var))).to(x_dtype)
46
+ return y
47
+
48
+
49
+ class RMSNorm(nn.Module):
50
+ def __init__(
51
+ self,
52
+ dim: int,
53
+ eps: float = 1e-5,
54
+ dtype: tp.Optional[torch.dtype] = None,
55
+ device=None,
56
+ ):
57
+ super().__init__()
58
+ self.eps = eps
59
+ self.dtype = dtype
60
+ self.alpha = nn.Parameter(
61
+ torch.full((1, 1, dim), 1.0, requires_grad=True, device=device, dtype=dtype)
62
+ )
63
+
64
+ def forward(self, x: torch.Tensor):
65
+ return _rms_norm(x, self.alpha, self.dtype, self.eps)
66
+
67
+
68
+ class LayerScale(nn.Module):
69
+ """Layer scale from [Touvron et al 2021] (https://arxiv.org/pdf/2103.17239.pdf).
70
+ This rescales diagonally the residual outputs close to 0, with a learnt scale.
71
+
72
+ Args:
73
+ channels (int): Number of channels.
74
+ init (float): Initial scale.
75
+ channel_last (bool): If True, expect `[*, C]` shaped tensors, otherwise, `[*, C, T]`.
76
+ device (torch.device or str, optional): Device on which to initialize the module.
77
+ dtype (torch.dtype, optional): dtype to use to initialize the module.
78
+ """
79
+
80
+ def __init__(
81
+ self,
82
+ channels: int,
83
+ init: float = 1e-4,
84
+ channel_last: bool = True,
85
+ device=None,
86
+ dtype=None,
87
+ ):
88
+ super().__init__()
89
+ self.channel_last = channel_last
90
+ self.scale = nn.Parameter(
91
+ torch.full(
92
+ (channels,), init, requires_grad=True, device=device, dtype=dtype
93
+ )
94
+ )
95
+
96
+ def forward(self, x: torch.Tensor):
97
+ if self.channel_last:
98
+ return self.scale * x
99
+ else:
100
+ return self.scale[:, None] * x
101
+
102
+
103
+ def create_norm_fn(norm_type: str, dim: int, **kwargs) -> nn.Module:
104
+ """Create normalization module for transformer encoder layer.
105
+
106
+ Args:
107
+ norm_type (str): Normalization method.
108
+ dim (int): Dimension of the normalized layer.
109
+ **kwargs (dict): Additional parameters for normalization layer.
110
+ Returns:
111
+ nn.Module: Normalization module.
112
+ """
113
+ if norm_type == "layer_norm":
114
+ return nn.LayerNorm(dim, eps=1e-5, **kwargs)
115
+ elif norm_type == "layer_norm_f32":
116
+ kwargs.pop("dtype", None)
117
+ return LayerNormF32(dim, eps=1e-8, **kwargs)
118
+ elif norm_type in {"rms_norm"}:
119
+ return RMSNorm(dim, eps=1e-5, **kwargs)
120
+ elif norm_type in {"rms_norm_f32"}:
121
+ kwargs.pop("dtype", None)
122
+ return RMSNorm(dim, eps=1e-8, dtype=torch.float, **kwargs)
123
+ else:
124
+ raise ValueError(f"Unknown norm type: {norm_type}")
125
+
126
+
127
+ def create_sin_embedding(
128
+ positions: torch.Tensor,
129
+ dim: int,
130
+ max_period: float = 10000,
131
+ dtype: torch.dtype = torch.float32,
132
+ ) -> torch.Tensor:
133
+ """Create sinusoidal positional embedding, with shape `[B, T, C]`.
134
+
135
+ Args:
136
+ positions (torch.Tensor): LongTensor of positions.
137
+ dim (int): Dimension of the embedding.
138
+ max_period (float): Maximum period of the cosine/sine functions.
139
+ dtype (torch.dtype or str): dtype to use to generate the embedding.
140
+ Returns:
141
+ torch.Tensor: Sinusoidal positional embedding.
142
+ """
143
+ # We aim for BTC format
144
+ assert dim % 2 == 0
145
+ half_dim = dim // 2
146
+ positions = positions.to(dtype)
147
+ adim = torch.arange(half_dim, device=positions.device, dtype=dtype).view(1, 1, -1)
148
+ max_period_tensor = torch.full(
149
+ [], max_period, device=positions.device, dtype=dtype
150
+ ) # avoid sync point
151
+ phase = positions / (max_period_tensor ** (adim / (half_dim - 1)))
152
+ return torch.cat([torch.cos(phase), torch.sin(phase)], dim=-1)
153
+
154
+
155
+ def multi_linear(
156
+ num_linear: int,
157
+ weight: torch.Tensor,
158
+ x: torch.Tensor,
159
+ offset: int,
160
+ ):
161
+ """Utility to apply a multi linear layer to the given input. A multi linear layer
162
+ applies a different set of weight for each time step.
163
+
164
+ Args:
165
+ num_linear (int): Number of possible time steps and so number of linears.
166
+ weight (torch.Tensor): Weight tensor, with shape `[num_linear * chout, chin]`.
167
+ x (torch.Tensor): Input tensor, with shape `[B, T, C]`.
168
+ offset (int): offset for the current time step, in particular for decoding, with
169
+ time steps provided one by one.
170
+ """
171
+ B, T, C = x.shape
172
+ ys = []
173
+ chout, chin = weight.shape
174
+ weight = weight.view(num_linear, -1, chin)
175
+ for t in range(T):
176
+ y = F.linear(x[:, t], weight[t + offset])
177
+ ys.append(y)
178
+ out = torch.stack(ys, 1)
179
+ return out
180
+
181
+
182
+ def set_attention_context(model: nn.Module, context: tp.Optional[int] = None) -> None:
183
+ """Deactivates or changes the context span (in time steps) in a model.
184
+ Args:
185
+ model (nn.Module): model over which to look for attentions.
186
+ context (int or None): new temporary context value.
187
+
188
+ ..Note:: this is not a context manager but a plain function changing the context forever.
189
+ Initially, it was a context manager, but that led to interesting bugs when using
190
+ activation checkpointing, with the context being inconsistent between the forward
191
+ and backward.
192
+ """
193
+ for module in model.modules():
194
+ if isinstance(module, StreamingMultiheadAttention):
195
+ module.context = context
196
+
197
+
198
+ class KVCacheResult(tp.NamedTuple):
199
+ keys: torch.Tensor
200
+ values: torch.Tensor
201
+ positions: torch.Tensor
202
+
203
+ @staticmethod
204
+ def from_kv(keys: torch.Tensor, values: torch.Tensor) -> "KVCacheResult":
205
+ B, H, T, D = keys.shape
206
+ assert tuple(values.shape[:-1]) == (B, H, T)
207
+ positions = torch.arange(T, device=keys.device, dtype=torch.long)
208
+ return KVCacheResult(keys, values, positions)
209
+
210
+
211
+ class RingKVCache:
212
+ """Efficient streaming KVCache to be compatible with Cuda Graph.
213
+
214
+ Args:
215
+ batch_size (int): Batch size.
216
+ num_heads (int): Number of heads in the attention.
217
+ dim_per_head (int): Dimension per head.
218
+ device (torch.device): Device on which to initialize the cache.
219
+ dtype (torch.dtype): dtype to use for the cache.
220
+ """
221
+
222
+ def __init__(
223
+ self,
224
+ batch_size: int,
225
+ num_heads: int,
226
+ dim_per_head: int,
227
+ capacity: int,
228
+ device: torch.device = torch.device("cuda"),
229
+ dtype: torch.dtype = torch.bfloat16,
230
+ ):
231
+ self.capacity = capacity
232
+ self.cache = torch.zeros(
233
+ (2, batch_size, num_heads, capacity, dim_per_head),
234
+ device=device,
235
+ dtype=dtype,
236
+ )
237
+ self.end_offset = torch.zeros(1, device=device, dtype=torch.long)
238
+
239
+ def reset(self):
240
+ self.end_offset.zero_()
241
+
242
+ def complete(self, k: torch.Tensor, v: torch.Tensor) -> KVCacheResult:
243
+ assert k.shape[:-1] == v.shape[:-1], (k.shape, v.shape)
244
+ B, H, T, D = k.shape
245
+ indexes = torch.arange(T, device=self.end_offset.device, dtype=self.end_offset.dtype) + self.end_offset
246
+ indexes = indexes % self.capacity
247
+ self.cache[0].index_copy_(2, indexes, k)
248
+ self.cache[1].index_copy_(2, indexes, v)
249
+ self.end_offset.add_(T)
250
+
251
+ keys = self.cache[0]
252
+ values = self.cache[1]
253
+
254
+ indexes = torch.arange(
255
+ self.capacity, device=self.end_offset.device, dtype=torch.long
256
+ )
257
+ invalid = indexes >= self.end_offset
258
+
259
+ end_index = self.end_offset % self.capacity
260
+ delta = indexes - end_index
261
+
262
+ # If last key is for step S, and capacity is C, last key was written at index S % C.
263
+ # then end_offset = S + 1, and end_index = (S + 1) % C.
264
+ # Then for index = (S % C), delta = -1, and the next code gives us:
265
+ # position(index) = (S + 1) - 1 = S, all good.
266
+ # Now the time step at end_offset is actually the oldest in the KVCache, e.g., its
267
+ # position should be (S - self.capacity + 1).
268
+ # The following code gives us:
269
+ # position(index + 1) = S + 1 + 0 - self.capacity.
270
+
271
+ positions = torch.where(
272
+ delta <= 0,
273
+ self.end_offset + delta,
274
+ self.end_offset + delta - self.capacity,
275
+ )
276
+ positions = torch.where(invalid, torch.full_like(positions, -1), positions)
277
+
278
+ return KVCacheResult(keys, values, positions)
279
+
280
+
281
+ @dataclass
282
+ class _MHAState:
283
+ kv_cache: RingKVCache
284
+ offset: torch.Tensor
285
+ offset_cpu: int
286
+
287
+ def reset(self):
288
+ self.kv_cache.reset()
289
+ self.offset.zero_()
290
+ self.offset_cpu = 0
291
+
292
+
293
+ class StreamingMultiheadAttention(StreamingModule[_MHAState]):
294
+ """Similar to `nn.MultiheadAttention` but with support for streaming, causal evaluation.
295
+
296
+ Args:
297
+ embed_dim (int): Dimension to project to.
298
+ num_heads (int): Number of heads.
299
+ causal (bool): Causal mask applied automatically.
300
+ context (int, optional): Number of time steps the attention can access to.
301
+ When causal, can access `context` time steps into the past, and when non causal,
302
+ can access `context // 2` steps in the past, and the same in the future.
303
+ rope (`RotaryEmbedding`, optional): Rope embedding to use.
304
+ weights_per_step (int): use different weights per time step. If non zero, should correspond to the
305
+ number of possible time steps.
306
+ device (torch.device, optional): Device on which to initialize.
307
+ dtype (torch.dtype, optional): dtype to use.
308
+ """
309
+
310
+ _fsdp_final = True
311
+
312
+ def __init__(
313
+ self,
314
+ embed_dim: int,
315
+ num_heads: int,
316
+ causal: bool = False,
317
+ context: tp.Optional[int] = None,
318
+ rope: tp.Optional[RotaryEmbedding] = None,
319
+ weights_per_step: int = 0,
320
+ device=None,
321
+ dtype=None,
322
+ ):
323
+ super().__init__()
324
+ factory_kwargs = {"device": device, "dtype": dtype}
325
+
326
+ self.embed_dim = embed_dim
327
+ self.causal = causal
328
+ self.context = context
329
+ self.rope = rope
330
+ self.num_heads = num_heads
331
+
332
+ out_dim = embed_dim
333
+ out_dim = 3 * embed_dim
334
+ mult = 1
335
+ self.weights_per_step = weights_per_step
336
+ if weights_per_step:
337
+ mult = weights_per_step
338
+ in_proj = nn.Linear(embed_dim, mult * out_dim, bias=False, **factory_kwargs)
339
+ # We try to follow the default PyTorch MHA convention, to easily compare results.
340
+ self.in_proj_weight = in_proj.weight
341
+ self.in_proj_bias = in_proj.bias
342
+ self.out_proj = nn.Linear(
343
+ embed_dim, mult * embed_dim, bias=False, **factory_kwargs
344
+ )
345
+
346
+ def _init_streaming_state(self, batch_size: int) -> _MHAState:
347
+ if self.context is None:
348
+ if self.weights_per_step:
349
+ capacity = self.weights_per_step
350
+ else:
351
+ raise RuntimeError(
352
+ "Cannot create a streaming KVCache without a context to estimate capacity."
353
+ )
354
+ else:
355
+ capacity = self.context
356
+ device = self.in_proj_weight.device
357
+ # TODO: the following estimation will not work great with FSDP.
358
+ dtype = self.in_proj_weight.dtype
359
+ dim_per_head = self.embed_dim // self.num_heads
360
+ kv_cache = RingKVCache(
361
+ batch_size, self.num_heads, dim_per_head, capacity, device, dtype
362
+ )
363
+ return _MHAState(
364
+ kv_cache,
365
+ offset=torch.zeros(1, device=device, dtype=torch.long),
366
+ offset_cpu=0,
367
+ )
368
+
369
+ def _complete_kv(self, k, v) -> KVCacheResult:
370
+ state = self._streaming_state
371
+ if state is None:
372
+ return KVCacheResult.from_kv(k, v)
373
+ else:
374
+ return state.kv_cache.complete(k, v)
375
+
376
+ def forward(self, query: torch.Tensor, key: torch.Tensor, value: torch.Tensor):
377
+ state = self._streaming_state
378
+ T = query.shape[1]
379
+
380
+ if state is None:
381
+ offset = torch.zeros(1, device=query.device, dtype=torch.long)
382
+ offset_cpu = 0
383
+ else:
384
+ assert self.causal, "Streaming only available for causal"
385
+ offset = state.offset
386
+ offset_cpu = state.offset_cpu
387
+
388
+ if self.weights_per_step:
389
+ projected = multi_linear(
390
+ self.weights_per_step, self.in_proj_weight, query, offset_cpu
391
+ )
392
+ else:
393
+ projected = nn.functional.linear(query, self.in_proj_weight)
394
+ q, k, v = rearrange(
395
+ projected, "b t (p h d) -> p b h t d", p=3, h=self.num_heads
396
+ )
397
+
398
+ if self.rope:
399
+ q, k = self.rope(q, k, offset, time_before_heads=False)
400
+
401
+ k, v, pos_k = self._complete_kv(k, v)
402
+ if self.causal:
403
+ pos_k = pos_k.view(1, -1)
404
+ pos_q = offset + torch.arange(T, device=q.device, dtype=torch.long).view(
405
+ -1, 1
406
+ )
407
+ delta = pos_q - pos_k
408
+ attn_bias = (pos_k >= 0) & (delta >= 0)
409
+ if self.context is not None:
410
+ attn_bias = attn_bias & (delta < self.context)
411
+ else:
412
+ attn_bias = None
413
+ x = F.scaled_dot_product_attention(q, k, v, attn_bias, dropout_p=0.0)
414
+
415
+ x = rearrange(x, "b h t d -> b t (h d)")
416
+ if self.weights_per_step:
417
+ x = multi_linear(self.weights_per_step, self.out_proj.weight, x, offset_cpu)
418
+ else:
419
+ x = self.out_proj(x)
420
+ if state is not None:
421
+ state.offset.add_(T)
422
+ state.offset_cpu += T
423
+ return x
424
+
425
+
426
+ @dataclass
427
+ class _LayerState:
428
+ offset_cpu: int
429
+
430
+ def reset(self):
431
+ self.offset_cpu = 0
432
+
433
+
434
+ class StreamingTransformerLayer(StreamingModule[_LayerState]):
435
+ """TransformerLayer with Streaming / Causal support.
436
+
437
+ Args:
438
+ d_model (int): Dimension of the data.
439
+ num_heads (int): Number of heads.
440
+ dim_feedforward (int): Intermediate dimension of FF module.
441
+ causal (bool): Causal mask applied automatically.
442
+ context (int, optional): Receptive field for the causal mask, infinite if None.
443
+ custom (bool): Use custom MHA implementation, for testing / benchmarking.
444
+ rope (`RotaryEmbedding`, optional): Rope embedding to use.
445
+ norm (str): Normalization to use. Currently, only 'layer_norm' is supported.
446
+ layer_scale (float, optional): If not None, LayerScale will be used with the given value as initial scale.
447
+ gating (str): if provided, replaces FFN with special gating, like GLU, GSiGLU etc.
448
+ weights_per_step (int): use different weights per time step. If non zero, should correspond to the
449
+ number of possible time steps.
450
+ skip_self_attn: If true, skips the self attention module and the norm
451
+ device (torch.device, optional): Device on which to initialize.
452
+ dtype (torch.dtype, optional): dtype to use.
453
+ """
454
+
455
+ _fsdp_final = True
456
+
457
+ def __init__(
458
+ self,
459
+ d_model: int,
460
+ num_heads: int,
461
+ dim_feedforward: int | list[int] = 2048,
462
+ causal: bool = False,
463
+ context: tp.Optional[int] = None,
464
+ rope: tp.Optional[RotaryEmbedding] = None,
465
+ norm: str = "layer_norm",
466
+ layer_scale: tp.Optional[float] = None,
467
+ gating: str = "none",
468
+ weights_per_step: int = 0,
469
+ activation=F.gelu,
470
+ skip_self_attn: bool = False,
471
+ device=None,
472
+ dtype=None,
473
+ ):
474
+ super().__init__()
475
+ factory_kwargs = {"device": device, "dtype": dtype}
476
+ # Redefine self_attn to our streaming multi-head attention
477
+ attn_kwargs: tp.Dict[str, tp.Any] = {
478
+ "embed_dim": d_model,
479
+ "num_heads": num_heads,
480
+ }
481
+ if not skip_self_attn:
482
+ self.self_attn: StreamingMultiheadAttention = StreamingMultiheadAttention(
483
+ causal=causal,
484
+ context=context,
485
+ rope=rope,
486
+ weights_per_step=weights_per_step,
487
+ **attn_kwargs, # type: ignore
488
+ **factory_kwargs, # type: ignore
489
+ ) # type: ignore
490
+ self.norm1 = create_norm_fn(norm, d_model, **factory_kwargs)
491
+ self.norm2 = create_norm_fn(norm, d_model, **factory_kwargs)
492
+ # Redefine feedforward layers to expose bias parameter
493
+ self.weights_per_step = weights_per_step
494
+ self.gating: tp.Optional[nn.Module] = None
495
+ self.linear1: tp.Optional[nn.Module] = None
496
+ self.linear2: tp.Optional[nn.Module] = None
497
+ self.activation = activation
498
+ self.skip_self_attn = skip_self_attn
499
+
500
+ if isinstance(dim_feedforward, list):
501
+ assert dim_feedforward
502
+ assert len(dim_feedforward) == weights_per_step, (
503
+ "Length of dim_feedforward must match weights_per_step,"
504
+ f" got {len(dim_feedforward)} != {weights_per_step}"
505
+ )
506
+ if gating == "none":
507
+ assert (
508
+ not weights_per_step
509
+ ), "weights_per_step without gating not supported for now."
510
+ assert not isinstance(
511
+ dim_feedforward, list
512
+ ), "List dim_feedforward without gating not supported for now."
513
+ self.linear1 = nn.Linear(
514
+ d_model, dim_feedforward, bias=False, **factory_kwargs
515
+ )
516
+ self.linear2 = nn.Linear(
517
+ dim_feedforward, d_model, bias=False, **factory_kwargs
518
+ )
519
+ else:
520
+ self.linear1 = None
521
+ self.linear2 = None
522
+ if weights_per_step:
523
+ if isinstance(dim_feedforward, int):
524
+ dim_feedforward = [dim_feedforward] * weights_per_step
525
+ assert isinstance(dim_feedforward, list), dim_feedforward
526
+ self.gating = nn.ModuleList(
527
+ [
528
+ make_gating(gating, d_model, dim, **factory_kwargs)
529
+ for dim in dim_feedforward
530
+ ]
531
+ )
532
+ else:
533
+ assert isinstance(dim_feedforward, int)
534
+ self.gating = make_gating(
535
+ gating, d_model, dim_feedforward, **factory_kwargs
536
+ )
537
+
538
+ self.layer_scale_1: nn.Module
539
+ self.layer_scale_2: nn.Module
540
+ if layer_scale is None:
541
+ self.layer_scale_1 = nn.Identity()
542
+ self.layer_scale_2 = nn.Identity()
543
+ else:
544
+ self.layer_scale_1 = LayerScale(d_model, layer_scale, **factory_kwargs) # type: ignore
545
+ self.layer_scale_2 = LayerScale(d_model, layer_scale, **factory_kwargs) # type: ignore
546
+
547
+ def _init_streaming_state(self, batch_size: int) -> _LayerState:
548
+ return _LayerState(offset_cpu=0)
549
+
550
+ # feed forward block
551
+ def _ff_block(self, x: torch.Tensor) -> torch.Tensor:
552
+ state = self._streaming_state
553
+ offset = 0
554
+ if state is not None:
555
+ offset = state.offset_cpu
556
+ x_orig = x
557
+ x = self.norm2(x)
558
+ if self.gating is None:
559
+ assert self.linear1 is not None
560
+ assert self.linear2 is not None
561
+ update = self.linear2(self.activation(self.linear1(x)))
562
+ else:
563
+ if self.weights_per_step:
564
+ assert isinstance(self.gating, nn.ModuleList)
565
+ B, T, D = x.shape
566
+ ys = []
567
+ for t in range(T):
568
+ y = self.gating[offset + t](x[:, t : t + 1])
569
+ ys.append(y)
570
+ update = torch.cat(ys, dim=1)
571
+ else:
572
+ update = self.gating(x)
573
+ return x_orig + self.layer_scale_2(update)
574
+
575
+ def _sa_block(self, x: torch.Tensor):
576
+ if self.skip_self_attn:
577
+ return x
578
+ x_orig = x
579
+ x = self.norm1(x)
580
+ update = self.self_attn(x, x, x)
581
+ return x_orig + self.layer_scale_1(update)
582
+
583
+ def forward(self, x: torch.Tensor):
584
+ with ExitStack() as stack:
585
+ if x.device.type != 'cuda':
586
+ stack.enter_context(no_compile())
587
+ x = self._sa_block(x)
588
+ x = self._ff_block(x)
589
+ state = self._streaming_state
590
+ if state:
591
+ state.offset_cpu += x.shape[1]
592
+ return x
593
+
594
+
595
+ @dataclass
596
+ class _TransformerState:
597
+ offset: torch.Tensor
598
+
599
+ def reset(self):
600
+ self.offset.zero_()
601
+
602
+
603
+ class StreamingTransformer(StreamingModule[_TransformerState]):
604
+ """Transformer with Streaming / Causal support.
605
+
606
+ Args:
607
+ d_model (int): Dimension of the data.
608
+ num_heads (int): Number of heads.
609
+ dim_feedforward (int): Intermediate dimension of FF module.
610
+ causal (bool): Causal mask applied automatically.
611
+ context (int, optional): Receptive field for the causal mask, infinite if None.
612
+ layer_scale (float, optional): If not None, LayerScale will be used
613
+ with the given value as initial scale.
614
+ positional_embedding (str): Positional embedding strategy (sin, rope, sin_rope, or none).
615
+ max_period (float): Maximum period of the time embedding.
616
+ positional_scale (float): Scale of positional embedding, set to 0 to deactivate.
617
+ layer_class: (subclass of `StreamingTransformerLayer): class to use
618
+ to initialize the layers, allowing further customization outside of AudioCraft.
619
+ device (torch.device, optional): Device on which to initialize.
620
+ dtype (torch.dtype, optional): dtype to use.
621
+ **kwargs: See `StreamingTransformerLayer`.
622
+ """
623
+
624
+ def __init__(
625
+ self,
626
+ d_model: int,
627
+ num_heads: int,
628
+ num_layers: int,
629
+ dim_feedforward: int | list[int] = 2048,
630
+ causal: bool = False,
631
+ context: tp.Optional[int] = None,
632
+ positional_embedding: str = "sin",
633
+ max_period: float = 10_000,
634
+ positional_scale: float = 1.0,
635
+ betas: tp.Optional[tp.Tuple[float, float]] = None,
636
+ layer_class: tp.Type[StreamingTransformerLayer] = StreamingTransformerLayer,
637
+ device=None,
638
+ dtype=None,
639
+ **kwargs,
640
+ ):
641
+ super().__init__()
642
+ assert d_model % num_heads == 0
643
+
644
+ self.positional_embedding = positional_embedding
645
+ self.max_period = max_period
646
+ self.positional_scale = positional_scale
647
+ self.betas = betas
648
+
649
+ assert positional_embedding in {"sin", "rope", "sin_rope", "none"}
650
+ self.rope: tp.Optional[RotaryEmbedding] = None
651
+ if self.positional_embedding in {"rope", "sin_rope"}:
652
+ self.rope = RotaryEmbedding(max_period=max_period)
653
+
654
+ self.layers = nn.ModuleList()
655
+ for _ in range(num_layers):
656
+ self.layers.append(
657
+ layer_class(
658
+ d_model=d_model,
659
+ num_heads=num_heads,
660
+ dim_feedforward=dim_feedforward,
661
+ causal=causal,
662
+ context=context,
663
+ rope=self.rope,
664
+ device=device,
665
+ dtype=dtype,
666
+ **kwargs,
667
+ )
668
+ )
669
+
670
+ def _init_streaming_state(self, batch_size: int) -> _TransformerState:
671
+ device = next(self.parameters()).device
672
+ return _TransformerState(offset=torch.zeros(1, device=device, dtype=torch.long))
673
+
674
+ def forward(self, x: torch.Tensor, *args, **kwargs):
675
+ B, T, C = x.shape
676
+
677
+ state = self._streaming_state
678
+ if state is None:
679
+ offset = torch.zeros(1, dtype=torch.long, device=x.device)
680
+ else:
681
+ offset = state.offset
682
+
683
+ if self.positional_embedding in {"sin", "sin_rope"}:
684
+ positions = torch.arange(T, device=x.device).view(1, -1, 1)
685
+ positions = positions + offset.view(-1, 1, 1)
686
+ pos_emb = create_sin_embedding(
687
+ positions, C, max_period=self.max_period, dtype=x.dtype
688
+ )
689
+ x = x + self.positional_scale * pos_emb
690
+
691
+ for layer in self.layers:
692
+ x = layer(x, *args, **kwargs)
693
+
694
+ if state is not None:
695
+ state.offset.add_(T)
696
+ return x
697
+
698
+
699
+ class ProjectedTransformer(StreamingContainer):
700
+ """Transformer with optional projections of the input and output to different dimensions when needed.
701
+ Supports multiple outputs.
702
+
703
+ Args:
704
+ input_dimension (int): dimension of the input.
705
+ output_dimensions (tuple[int]): dimensions of the outputs.
706
+ d_model (int): inner dimension of the Transformer.
707
+ conv_layout (bool): If True, expects `[B, C, T]` shaped tensors, otherwise, `[B, T, C]`.
708
+ Similarly, the output will have the same layout.
709
+ """
710
+
711
+ def __init__(
712
+ self,
713
+ input_dimension: int,
714
+ output_dimensions: tp.Tuple[int, ...],
715
+ d_model: int,
716
+ *,
717
+ conv_layout: bool = False,
718
+ **kwargs,
719
+ ):
720
+ super().__init__()
721
+ self.transformer = StreamingTransformer(d_model=d_model, **kwargs)
722
+ self.input_dimension = input_dimension
723
+ self.output_dimensions = output_dimensions
724
+ self.conv_layout = conv_layout
725
+ self.input_proj = None
726
+ if d_model != input_dimension:
727
+ self.input_proj = nn.Linear(input_dimension, d_model, bias=False)
728
+
729
+ self.output_projs = nn.ModuleList()
730
+ for output_dimension in output_dimensions:
731
+ if d_model == output_dimension:
732
+ self.output_projs.append(nn.Identity())
733
+ else:
734
+ self.output_projs.append(
735
+ nn.Linear(d_model, output_dimension, bias=False)
736
+ )
737
+
738
+ def forward(self, x, *args, **kwargs):
739
+ if self.conv_layout:
740
+ x = x.transpose(1, 2)
741
+ if self.input_proj is not None:
742
+ x = self.input_proj(x)
743
+ z = self.transformer(x, *args, **kwargs)
744
+ ys = []
745
+ for output_proj in self.output_projs:
746
+ y = output_proj(z)
747
+ if self.conv_layout:
748
+ y = y.transpose(1, 2)
749
+ ys.append(y)
750
+ return ys
moshi/quantization/__init__.py ADDED
@@ -0,0 +1,13 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Kyutai, all rights reserved.
2
+ # This source code is licensed under the license found in the
3
+ # LICENSE file in the root directory of this source tree.
4
+
5
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
6
+ # All rights reserved.
7
+ #
8
+ # This source code is licensed under the license found in the
9
+ # LICENSE file in the root directory of this source tree.
10
+ """RVQ."""
11
+ # flake8: noqa
12
+ from .vq import ResidualVectorQuantizer, SplitResidualVectorQuantizer
13
+ from .base import BaseQuantizer, DummyQuantizer, QuantizedResult
moshi/quantization/base.py ADDED
@@ -0,0 +1,170 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Kyutai, all rights reserved.
2
+ # This source code is licensed under the license found in the
3
+ # LICENSE file in the root directory of this source tree.
4
+
5
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
6
+ # All rights reserved.
7
+ #
8
+ # This source code is licensed under the license found in the
9
+ # LICENSE file in the root directory of this source tree.
10
+
11
+ """
12
+ Base class for all quantizers.
13
+ """
14
+
15
+ from dataclasses import dataclass, field
16
+ import typing as tp
17
+
18
+ import torch
19
+ from torch import nn
20
+
21
+
22
+ @dataclass
23
+ class QuantizedResult:
24
+ x: torch.Tensor
25
+ codes: torch.Tensor
26
+ bandwidth: torch.Tensor # bandwidth in kb/s used, per batch item.
27
+ penalty: tp.Optional[torch.Tensor] = None
28
+ metrics: dict = field(default_factory=dict)
29
+
30
+
31
+ class BaseQuantizer(nn.Module):
32
+ """Base class for quantizers."""
33
+
34
+ def __init__(self):
35
+ super().__init__()
36
+ self._ema_frozen = False
37
+
38
+ def forward(self, x: torch.Tensor, frame_rate: int) -> QuantizedResult:
39
+ """
40
+ Given input tensor x, returns first the quantized (or approximately quantized)
41
+ representation along with quantized codes, bandwidth, and any penalty term for the loss.
42
+ Finally, this returns a dict of metrics to update logging etc.
43
+ Frame rate must be passed so that the bandwidth is properly computed.
44
+ """
45
+ raise NotImplementedError()
46
+
47
+ def encode(self, x: torch.Tensor) -> torch.Tensor:
48
+ """Encode a given input tensor with the specified sample rate at the given bandwidth."""
49
+ raise NotImplementedError()
50
+
51
+ def decode(self, codes: torch.Tensor) -> torch.Tensor:
52
+ """Decode the given codes to the quantized representation."""
53
+ raise NotImplementedError()
54
+
55
+ @property
56
+ def cardinality(self) -> int:
57
+ """Cardinality of each codebook."""
58
+ raise NotImplementedError()
59
+
60
+ @property
61
+ def total_codebooks(self) -> int:
62
+ """Total number of codebooks."""
63
+ raise NotImplementedError()
64
+
65
+ @property
66
+ def num_codebooks(self) -> int:
67
+ """Number of active codebooks."""
68
+ raise NotImplementedError()
69
+
70
+ @property
71
+ def semantic_quantizer(self) -> 'BaseQuantizer':
72
+ """This returns the quantizer that models the first level of the hierarchy (typically semantic).
73
+
74
+ In this case, it's the quantizer itself.
75
+ """
76
+ return self
77
+
78
+ @property
79
+ def acoustic_quantizer(self) -> 'BaseQuantizer':
80
+ """This returns the quantizer that models the higher levels of the hierarchy (typically acoustic).
81
+
82
+ In this case, it's the quantizer itself.
83
+ """
84
+ return self
85
+
86
+ def set_num_codebooks(self, n: int) -> None:
87
+ """Set the number of active codebooks."""
88
+ raise NotImplementedError()
89
+
90
+ @property
91
+ def ema_frozen(self) -> bool:
92
+ """Whether to apply ema to the codebooks."""
93
+ return self._ema_frozen
94
+
95
+ def ema_frozen_(self, ema_frozen: bool) -> None:
96
+ """Set whether ema should be applied to the codebooks."""
97
+ self._ema_frozen = ema_frozen
98
+
99
+
100
+ class DummyQuantizer(BaseQuantizer):
101
+ """Fake quantizer that actually does not perform any quantization."""
102
+
103
+ def __init__(
104
+ self,
105
+ dimension: int,
106
+ input_dimension: tp.Optional[int] = None,
107
+ output_dimension: tp.Optional[int] = None,
108
+ ):
109
+ super().__init__()
110
+ self.dimension = dimension
111
+ self.input_dimension = input_dimension or dimension
112
+ self.output_dimension = output_dimension or dimension
113
+ self.input_proj: torch.nn.Module
114
+ self.output_proj: torch.nn.Module
115
+ if self.input_dimension == self.dimension:
116
+ self.input_proj = torch.nn.Identity()
117
+ else:
118
+ self.input_proj = torch.nn.Conv1d(
119
+ self.input_dimension, self.dimension, 1, bias=False
120
+ )
121
+ if self.input_dimension == self.dimension:
122
+ self.output_proj = torch.nn.Identity()
123
+ else:
124
+ self.output_proj = torch.nn.Conv1d(
125
+ self.dimension, self.output_dimension, 1, bias=False
126
+ )
127
+
128
+ def forward(self, x: torch.Tensor, frame_rate: int):
129
+ q = x.unsqueeze(1)
130
+ x = self.output_proj(self.input_proj(x))
131
+ return QuantizedResult(
132
+ x, q, torch.tensor(q.numel() * 32 * frame_rate / 1000 / len(x)).to(x)
133
+ )
134
+
135
+ def encode(self, x: torch.Tensor) -> torch.Tensor:
136
+ """Encode a given input tensor with the specified sample rate at the given bandwidth.
137
+ In the case of the DummyQuantizer, the codes are actually identical
138
+ to the input and resulting quantized representation as no quantization is done.
139
+ """
140
+ x = self.input_proj(x)
141
+ return x.unsqueeze(1)
142
+
143
+ def decode(self, codes: torch.Tensor) -> torch.Tensor:
144
+ """Decode the given codes to the quantized representation.
145
+ In the case of the DummyQuantizer, the codes are actually identical
146
+ to the input and resulting quantized representation as no quantization is done.
147
+ """
148
+ y = codes.squeeze(1)
149
+ return self.output_proj(y)
150
+
151
+ @property
152
+ def total_codebooks(self):
153
+ """Total number of codebooks."""
154
+ return 1
155
+
156
+ @property
157
+ def num_codebooks(self):
158
+ """Total number of codebooks."""
159
+ return self.total_codebooks
160
+
161
+ def set_num_codebooks(self, n: int):
162
+ """Set the number of active codebooks."""
163
+ raise AttributeError(
164
+ "Cannot override the number of codebooks for the dummy quantizer"
165
+ )
166
+
167
+ @property
168
+ def cardinality(self) -> int:
169
+ """Cardinality of each codebook."""
170
+ return 1
moshi/quantization/core_vq.py ADDED
@@ -0,0 +1,384 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Kyutai, all rights reserved.
2
+ # This source code is licensed under the license found in the
3
+ # LICENSE file in the root directory of this source tree.
4
+
5
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
6
+ # All rights reserved.
7
+ #
8
+ # This source code is licensed under the license found in the
9
+ # LICENSE file in the root directory of this source tree.
10
+
11
+ import typing as tp
12
+
13
+ from einops import rearrange
14
+ import torch
15
+ from torch import nn
16
+ from torch import distributed
17
+ import torch.nn.functional as F
18
+
19
+
20
+ class _CodebookForwardResult(tp.NamedTuple):
21
+ quantized: torch.Tensor
22
+ codes: torch.Tensor
23
+ metrics: tp.Dict[str, torch.Tensor]
24
+
25
+
26
+ class _VQForwardResult(tp.NamedTuple):
27
+ quantized: torch.Tensor
28
+ codes: torch.Tensor
29
+ loss: torch.Tensor
30
+ metrics: tp.Dict[str, torch.Tensor]
31
+
32
+
33
+ def _ema_inplace(moving_avg: torch.Tensor, new: torch.Tensor, decay: float) -> None:
34
+ moving_avg.data.mul_(decay).add_(new, alpha=(1 - decay))
35
+
36
+
37
+ def _uniform_init(*shape: int) -> torch.Tensor:
38
+ t = torch.empty(shape)
39
+ nn.init.kaiming_uniform_(t)
40
+ return t
41
+
42
+
43
+ def _sample_vectors(samples: torch.Tensor, num: int) -> torch.Tensor:
44
+ num_samples, device = samples.shape[0], samples.device
45
+
46
+ if num_samples >= num:
47
+ indices = torch.randperm(num_samples, device=device)[:num]
48
+ else:
49
+ indices = torch.randint(0, num_samples, (num,), device=device)
50
+
51
+ return samples[indices]
52
+
53
+
54
+ def _compute_entropy(usage: torch.Tensor) -> torch.Tensor:
55
+ # Usage is some unnormalized distribution.
56
+ proba = usage / usage.sum()
57
+ p_log_p = torch.where(
58
+ proba == 0, zero_scalar(usage.device), proba * torch.log(proba)
59
+ )
60
+ return -p_log_p.sum()
61
+
62
+
63
+ def _is_distributed() -> bool:
64
+ # Checks if we need to use distributed routines.
65
+ return distributed.is_initialized() and distributed.get_world_size() > 1
66
+
67
+
68
+ def zero_scalar(device) -> torch.Tensor:
69
+ """Returns a 0. value on the given device without introducing a synchronization point."""
70
+ return torch.zeros([1], device=device)[0]
71
+
72
+
73
+ class EuclideanCodebook(nn.Module):
74
+ """Codebook with Euclidean distance.
75
+
76
+ Args:
77
+ dim (int): Dimension.
78
+ codebook_size (int): Codebook size.
79
+ decay (float): Decay for exponential moving average over the codebooks.
80
+ epsilon (float): Epsilon value for numerical stability.
81
+ threshold_usage_ratio (float): Defines the threshold for the cluster usage under which a centroid
82
+ is replaced. This is expressed as a fraction of the usage a centroid would get under
83
+ a uniform distribution, so that it doesn't depend on the batch size etc.
84
+ replaced_usage_ratio (float): When replacing a centroid, use this as an initial centroid usage,
85
+ to avoid the centroid getting replaced too quickly.
86
+ check_unused_every (int): Check for unused centroids every `check_unused_every` iterations.
87
+ This is to avoid too many synchronization points.
88
+
89
+ Buffers:
90
+ cluster_usage (torch.Tensor): EMA of the cluster usage per batch, e.g. this will
91
+ be dependent on the batch size etc.
92
+ embedding_sum (torch.Tensor): EMA of the sum of the assigned points to each cluster.
93
+ In particular, this can be normalized by `cluster_usage` to obtain the
94
+ actual cluster centroids.
95
+ """
96
+
97
+ def __init__(
98
+ self,
99
+ dim: int,
100
+ codebook_size: int,
101
+ decay: float = 0.99,
102
+ epsilon: float = 1e-5,
103
+ threshold_usage_ratio: float = 0.1,
104
+ replaced_usage_ratio: float = 1.0,
105
+ check_unused_every: int = 5,
106
+ ):
107
+ super().__init__()
108
+ self.decay = decay
109
+ embedding = torch.zeros(codebook_size, dim)
110
+
111
+ self.dim = dim
112
+ self.codebook_size = codebook_size
113
+
114
+ self.epsilon = epsilon
115
+ self.threshold_usage_ratio = threshold_usage_ratio
116
+ self.replaced_usage_ratio = replaced_usage_ratio
117
+ self.check_unused_every = check_unused_every
118
+ self._next_unused_check = check_unused_every
119
+
120
+ self.register_buffer("_initialized", torch.tensor([False], dtype=torch.float))
121
+ self.register_buffer("cluster_usage", torch.ones(codebook_size))
122
+ self.register_buffer("embedding_sum", embedding)
123
+ self.register_buffer("_embedding", None, persistent=False)
124
+ self._cached_initialized = False
125
+
126
+ def _load_from_state_dict(self, state_dict, prefix, *args, **kwargs) -> None:
127
+ # Mapping old names to new names
128
+ mappings = {
129
+ "inited": "_initialized",
130
+ "cluster_size": "cluster_usage",
131
+ "embed_avg": "embedding_sum",
132
+ "embed_sum": "embedding_sum",
133
+ }
134
+ for old_name, new_name in mappings.items():
135
+ old_name = prefix + old_name
136
+ if old_name in state_dict:
137
+ value = state_dict.pop(old_name)
138
+ if new_name is not None:
139
+ state_dict[prefix + new_name] = value
140
+ super()._load_from_state_dict(state_dict, prefix, *args, **kwargs)
141
+
142
+ @property
143
+ def embedding(self) -> torch.Tensor:
144
+ if self._embedding is None:
145
+ embedding = (
146
+ self.embedding_sum / self.cluster_usage.clamp(min=self.epsilon)[:, None]
147
+ )
148
+ self.register_buffer("_embedding", embedding, persistent=False)
149
+ return embedding
150
+ return self._embedding
151
+
152
+ def _broadcast_buffers(self) -> None:
153
+ if _is_distributed():
154
+ for buffer in self.buffers():
155
+ distributed.broadcast(buffer, 0)
156
+
157
+ def _replace_expired_codes(self, samples: torch.Tensor, mask: torch.Tensor) -> None:
158
+ # Replaces expired centroids, as indicated by `mask` (a true value indicate the code needs to be replaced).
159
+ # The new codes are sampled from the batch `samples`.
160
+ new_vectors = _sample_vectors(samples, self.codebook_size)
161
+ replace_cluster_usage = (
162
+ self.replaced_usage_ratio * self.cluster_usage.sum() / self.codebook_size
163
+ )
164
+ self.embedding_sum[:] = torch.where(
165
+ mask[:, None], replace_cluster_usage * new_vectors, self.embedding_sum
166
+ )
167
+ self.cluster_usage[:] = torch.where(
168
+ mask, replace_cluster_usage, self.cluster_usage
169
+ )
170
+
171
+ def _reshape_input(self, x: torch.Tensor) -> torch.Tensor:
172
+ # Flattens all the dimensions but the last one, e.g. return a vector of shape `[N, D]`.
173
+ x = rearrange(x, "... d -> (...) d")
174
+ return x
175
+
176
+ def _reshape_codes(self, codes: torch.Tensor, shape: torch.Size) -> torch.Tensor:
177
+ return codes.view(*shape[:-1])
178
+
179
+ def _quantize(self, x: torch.Tensor) -> torch.Tensor:
180
+ # Projects each vector in `x` over the nearest centroid and return its index.
181
+ # `x` should be `[N, D]` with `N` the number of input vectors and `D` the dimension.
182
+ assert x.dim() == 2
183
+ dists = torch.cdist(x[None], self.embedding[None], p=2)[0]
184
+ codes = dists.argmin(dim=-1)
185
+ return codes
186
+
187
+ def encode(self, x: torch.Tensor) -> torch.Tensor:
188
+ """Given a tensor `x` of shape `[*, D]`, returns a tensor of integer codes of shape `[*]`.
189
+ The codes are defined as the indexes of the centroids nearest to each vector in `x`.
190
+ """
191
+ assert x.dtype.is_floating_point, f"Input should be floats, got {x.dtype}"
192
+ shape = x.shape
193
+ x = self._reshape_input(x)
194
+ codes = self._quantize(x)
195
+ codes = self._reshape_codes(codes, shape)
196
+ return codes
197
+
198
+ def decode(self, codes: torch.Tensor) -> torch.Tensor:
199
+ """Given a tensor of codes of shape `[*]`, returns a tensor of shape `[*, D]`,
200
+ corresponding to the centroids associated to each code index.
201
+ """
202
+ assert (
203
+ not codes.dtype.is_floating_point
204
+ ), f"Codes should be integers, got {codes.dtype}"
205
+ quantized = F.embedding(codes, self.embedding)
206
+ return quantized
207
+
208
+ def forward(
209
+ self, x: torch.Tensor, initialize: bool = True
210
+ ) -> _CodebookForwardResult:
211
+ shape = x.shape
212
+ x = self._reshape_input(x)
213
+
214
+ flat_codes = self._quantize(x)
215
+ codes = self._reshape_codes(flat_codes, shape)
216
+ quantized = self.decode(codes)
217
+ metrics: tp.Dict[str, torch.Tensor] = {}
218
+
219
+ return _CodebookForwardResult(quantized, codes, metrics)
220
+
221
+
222
+ class VectorQuantization(nn.Module):
223
+ """Vector quantization implementation.
224
+ Currently supports only euclidean distance.
225
+
226
+ Args:
227
+ dim (int): Dimension
228
+ codebook_size (int): Codebook size
229
+ codebook_dim (int): Codebook dimension. If not defined, uses the specified dimension in dim.
230
+ decay (float): Decay for exponential moving average over the codebooks.
231
+ epsilon (float): Epsilon value for numerical stability.
232
+ threshold_usage_ratio (float): Defines the threshold for the cluster usage under which a centroid
233
+ is replaced. This is expressed as a fraction of the usage a centroid would get under
234
+ a uniform distribution, so that it doesn't depend on the batch size etc.
235
+ replaced_usage_ratio (float): When replacing a centroid, use this as an initial centroid usage,
236
+ to avoid the centroid getting replaced too quickly.
237
+ check_unused_every (int): Check for unused centroids every `check_unused_every` iterations.
238
+ This is to avoid too many synchronization points.
239
+ """
240
+
241
+ def __init__(
242
+ self,
243
+ dim: int,
244
+ codebook_size: int,
245
+ codebook_dim: tp.Optional[int] = None,
246
+ decay: float = 0.99,
247
+ epsilon: float = 1e-5,
248
+ threshold_usage_ratio: float = 0.1,
249
+ **kwargs,
250
+ ):
251
+ super().__init__()
252
+ if codebook_dim is None:
253
+ codebook_dim = dim
254
+
255
+ requires_projection = codebook_dim != dim
256
+ self.project_in = (
257
+ nn.Linear(dim, codebook_dim) if requires_projection else nn.Identity()
258
+ )
259
+ self.project_out = (
260
+ nn.Linear(codebook_dim, dim) if requires_projection else nn.Identity()
261
+ )
262
+ self.epsilon = epsilon
263
+ self._codebook = EuclideanCodebook(
264
+ dim=codebook_dim,
265
+ codebook_size=codebook_size,
266
+ decay=decay,
267
+ epsilon=epsilon,
268
+ threshold_usage_ratio=threshold_usage_ratio,
269
+ **kwargs,
270
+ )
271
+ self.codebook_size = codebook_size
272
+
273
+ @property
274
+ def embedding(self):
275
+ return self._codebook.embedding
276
+
277
+ def _rearrange_input(self, x):
278
+ x = rearrange(x, "b d n -> b n d")
279
+ return x
280
+
281
+ def _rearrange_output(self, quantized):
282
+ quantized = rearrange(quantized, "b n d -> b d n")
283
+ return quantized
284
+
285
+ def encode(self, x: torch.Tensor) -> torch.Tensor:
286
+ """Encodes `x` into discrete integer codes."""
287
+ x = self._rearrange_input(x)
288
+ x = self.project_in(x)
289
+ codes = self._codebook.encode(x)
290
+ return codes
291
+
292
+ def decode(self, codes: torch.Tensor) -> torch.Tensor:
293
+ """Converts integer codes into quantized vectors."""
294
+ quantized = self._codebook.decode(codes)
295
+ quantized = self.project_out(quantized)
296
+ quantized = self._rearrange_output(quantized)
297
+ return quantized
298
+
299
+ def forward(self, x: torch.Tensor, initialize: bool = True) -> _VQForwardResult:
300
+ x = self._rearrange_input(x)
301
+ quantized, codes, metrics = self._codebook(x, initialize=initialize)
302
+
303
+ loss = zero_scalar(x.device)
304
+
305
+ quantized = self.project_out(quantized)
306
+ quantized = self._rearrange_output(quantized)
307
+
308
+ return _VQForwardResult(quantized, codes, loss, metrics)
309
+
310
+
311
+ class ResidualVectorQuantization(nn.Module):
312
+ """Residual vector quantization implementation.
313
+
314
+ Follows Algorithm 1. in https://arxiv.org/pdf/2107.03312.pdf
315
+ """
316
+
317
+ def __init__(self, *, num_quantizers: int, codebook_offset: int, **kwargs):
318
+ super().__init__()
319
+ self.layers = nn.ModuleList(
320
+ [VectorQuantization(**kwargs) for _ in range(num_quantizers)]
321
+ )
322
+ self.codebook_offset = codebook_offset
323
+
324
+ def forward(
325
+ self, x: torch.Tensor, n_q: tp.Optional[int] = None
326
+ ) -> _VQForwardResult:
327
+ """
328
+ Args:
329
+ x (torch.Tensor): input tensor to quantize, of shape `[B, C, T]`.
330
+ n_q (int or None): if provided, number of codebook levels to use in RVQ.
331
+ """
332
+
333
+ quantized_out = zero_scalar(x.device)
334
+ residual = x
335
+
336
+ all_losses = []
337
+ all_codes = []
338
+ all_metrics: tp.Dict[str, torch.Tensor] = {}
339
+
340
+ n_q = n_q or len(self.layers)
341
+ previous_layer_is_initialized = True
342
+
343
+ for i, layer in enumerate(self.layers[:n_q]): # type: ignore
344
+ quantized, codes, loss, metrics = layer(
345
+ residual, initialize=previous_layer_is_initialized
346
+ )
347
+
348
+ quantized = quantized.detach()
349
+ residual = residual - quantized
350
+ quantized_out = quantized_out + quantized
351
+
352
+ all_codes.append(codes)
353
+ all_losses.append(loss)
354
+
355
+ for key, value in metrics.items():
356
+ if key in all_metrics:
357
+ all_metrics[key] += value / n_q
358
+ else:
359
+ all_metrics[key] = value / n_q
360
+ all_metrics[key + f"_{i + self.codebook_offset}"] = value
361
+
362
+ out_losses, out_codes = map(torch.stack, (all_losses, all_codes))
363
+ return _VQForwardResult(quantized_out, out_codes, out_losses, all_metrics)
364
+
365
+ def encode(self, x: torch.Tensor, n_q: tp.Optional[int] = None) -> torch.Tensor:
366
+ """Encodes `x` into discrete integer codes. If `n_q` is provided, only uses the first `n_q` codebook levels."""
367
+ residual = x
368
+ all_indices = []
369
+ n_q = n_q or len(self.layers)
370
+ for layer in self.layers[:n_q]: # type: ignore
371
+ indices = layer.encode(residual)
372
+ quantized = layer.decode(indices)
373
+ residual = residual - quantized
374
+ all_indices.append(indices)
375
+ out_indices = torch.stack(all_indices)
376
+ return out_indices
377
+
378
+ def decode(self, codes: torch.Tensor) -> torch.Tensor:
379
+ """Converts the integer codes into quantized vectors."""
380
+ quantized = zero_scalar(codes.device)
381
+ for idx, layer_codes in enumerate(codes):
382
+ layer = self.layers[idx]
383
+ quantized = quantized + layer.decode(layer_codes)
384
+ return quantized
moshi/quantization/vq.py ADDED
@@ -0,0 +1,340 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Kyutai, all rights reserved.
2
+ # This source code is licensed under the license found in the
3
+ # LICENSE file in the root directory of this source tree.
4
+
5
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
6
+ # All rights reserved.
7
+ #
8
+ # This source code is licensed under the license found in the
9
+ # LICENSE file in the root directory of this source tree.
10
+
11
+ import math
12
+ import typing as tp
13
+
14
+ import torch
15
+
16
+ from .base import BaseQuantizer, QuantizedResult
17
+ from .core_vq import ResidualVectorQuantization
18
+
19
+
20
+ class ResidualVectorQuantizer(BaseQuantizer):
21
+ """Residual Vector Quantizer.
22
+
23
+ Args:
24
+ dimension (int): Dimension of the codebooks.
25
+ input_dimension (None or int): dimension of the input, defaults to `dimension` if not provided.
26
+ output_dimension (None or int): dimension of the output, defaults to `dimension` if not provided.
27
+ n_q (int): Number of vector quantizers used.
28
+ q_dropout (bool): Random quantizer drop out at train time.
29
+ no_quantization_rate (float): Gives the probability of applying no quantization at all
30
+ at train time. The RVQ codebooks will still get the input value to learn the proper codebook.
31
+ bins (int): Codebook size.
32
+ decay (float): Decay for exponential moving average over the codebooks.
33
+ threshold_usage_ratio (float): Defines the threshold for the cluster usage under which a centroid
34
+ is replaced. This is expressed as a fraction of the usage a centroid would get under
35
+ a uniform distribution, so that it doesn't depend on the batch size etc.
36
+ replaced_usage_ratio (float): When replacing a centroid, use this as an initial centroid usage,
37
+ to avoid the centroid getting replaced too quickly.
38
+ codebook_offset (int): Offset to use for the codebook indices. This is useful when using multiple quantizers
39
+ such as in SplitResidualVectorQuantizer.
40
+ force_projection (bool): Whether to force input and output projections even when dimension is constant.
41
+ generator_seed (int or None): seed used to initialize the RNG used for no quantization.
42
+ """
43
+
44
+ def __init__(
45
+ self,
46
+ dimension: int = 128,
47
+ input_dimension: tp.Optional[int] = None,
48
+ output_dimension: tp.Optional[int] = None,
49
+ n_q: int = 8,
50
+ q_dropout: bool = False,
51
+ q_first_only_proba: float = 0.0,
52
+ no_quantization_rate: float = 0.0,
53
+ bins: int = 1024,
54
+ decay: float = 0.99,
55
+ threshold_usage_ratio: float = 0.1,
56
+ replaced_usage_ratio: float = 1.0,
57
+ codebook_offset: int = 0,
58
+ force_projection: bool = False,
59
+ generator_seed: tp.Optional[int] = None,
60
+ ):
61
+ super().__init__()
62
+ self.max_n_q = n_q
63
+ self.n_q = n_q
64
+ self.q_dropout = q_dropout
65
+ self.no_quantization_rate = no_quantization_rate
66
+ self.q_first_only_proba = q_first_only_proba
67
+ self.dimension = dimension
68
+ self.input_dimension = input_dimension or dimension
69
+ self.output_dimension = output_dimension or dimension
70
+ self.bins = bins
71
+ self.decay = decay
72
+ self.input_proj: torch.nn.Module
73
+ self.output_proj: torch.nn.Module
74
+ self.generator = None
75
+ if generator_seed is not None:
76
+ self.generator = torch.Generator(
77
+ device="cuda" if torch.cuda.is_available() else "cpu"
78
+ )
79
+ self.generator.manual_seed(generator_seed)
80
+ if self.input_dimension == self.dimension and not force_projection:
81
+ self.input_proj = torch.nn.Identity()
82
+ else:
83
+ self.input_proj = torch.nn.Conv1d(
84
+ self.input_dimension, self.dimension, 1, bias=False
85
+ )
86
+ if self.output_dimension == self.dimension and not force_projection:
87
+ self.output_proj = torch.nn.Identity()
88
+ else:
89
+ self.output_proj = torch.nn.Conv1d(
90
+ self.dimension, self.output_dimension, 1, bias=False
91
+ )
92
+ self.vq = ResidualVectorQuantization(
93
+ dim=self.dimension,
94
+ codebook_size=self.bins,
95
+ num_quantizers=self.n_q,
96
+ decay=self.decay,
97
+ threshold_usage_ratio=threshold_usage_ratio,
98
+ replaced_usage_ratio=replaced_usage_ratio,
99
+ codebook_offset=codebook_offset,
100
+ )
101
+
102
+ def forward(self, x: torch.Tensor, frame_rate: int):
103
+ """
104
+ Args:
105
+ x (torch.Tensor): Input tensor of shape [B, C, T] with `C` number of channels.
106
+ frame_rate (int): frame rate of the input (e.g `T = frame_rate * duration`), used to compute
107
+ the bandwidth.
108
+
109
+ Returns:
110
+ QuantizedResult: Quantized result with the following attributes:
111
+ - `x` (torch.Tensor): Quantized tensor of shape [B, C, T].
112
+ - `codes` (torch.Tensor): Quantized codes of shape [B, K, T] with `K` number of codebooks.
113
+ - `bw` (torch.Tensor): Bandwidth of the quantized tensor in kbits per second.
114
+ - `penalty` (torch.Tensor): Commitment loss.
115
+ - `metrics` (dict): RVQ metrics, in particular rate of dead code replacement, and entropy.
116
+ """
117
+ n_q = self.n_q
118
+ x = self.input_proj(x)
119
+
120
+ bw_per_q = math.log2(self.bins) * frame_rate / 1000
121
+ quantized, codes, commit_loss, metrics = self.vq(x, n_q=n_q)
122
+ B, _, _ = quantized.shape
123
+ quantized = self.output_proj(quantized)
124
+ codes = codes.transpose(0, 1)
125
+ # codes is [B, K, T], with T frames, K nb of codebooks.
126
+ bw = torch.tensor(n_q * bw_per_q).to(x)
127
+ return QuantizedResult(
128
+ quantized, codes, bw, penalty=torch.mean(commit_loss), metrics=metrics
129
+ )
130
+
131
+ def encode(self, x: torch.Tensor) -> torch.Tensor:
132
+ """Encode a given input tensor with the specified frame rate at the given bandwidth.
133
+ The RVQ encode method sets the appropriate number of quantizer to use
134
+ and returns indices for each quantizer.
135
+ """
136
+ n_q = self.n_q
137
+ if x.shape[-1] == 0:
138
+ return torch.empty((x.shape[0], n_q, 0), device=x.device, dtype=torch.int64)
139
+
140
+ x = self.input_proj(x)
141
+ codes = self.vq.encode(x, n_q=n_q)
142
+ codes = codes.transpose(0, 1)
143
+ # codes is [B, K, T], with T frames, K nb of codebooks.
144
+ return codes
145
+
146
+ def decode(self, codes: torch.Tensor) -> torch.Tensor:
147
+ """Decode the given codes to the quantized representation."""
148
+ # codes is [B, K, T], with T frames, K nb of codebooks, vq.decode expects [K, B, T].
149
+ codes = codes.transpose(0, 1)
150
+ quantized = self.vq.decode(codes)
151
+ quantized = self.output_proj(quantized)
152
+ return quantized
153
+
154
+ @property
155
+ def total_codebooks(self):
156
+ return self.max_n_q
157
+
158
+ @property
159
+ def num_codebooks(self):
160
+ return self.n_q
161
+
162
+ def set_num_codebooks(self, n: int):
163
+ assert n >= 0 and n <= self.max_n_q
164
+ self.n_q = n
165
+
166
+ @property
167
+ def cardinality(self) -> int:
168
+ return self.bins
169
+
170
+
171
+ class SplitResidualVectorQuantizer(BaseQuantizer):
172
+ """Residual Vector Quantizer with separate projections for the first quantizer and the rest.
173
+
174
+ Args:
175
+ n_q (int): Number of residual vector quantizers used.
176
+ n_semantic_q (int): Number of residual vector quantizers used for the semantic quantizer.
177
+ no_quantization_mode (str): if 'true_skip', when doing no quantization, the input will not go
178
+ through the sub quantizers. If `independent`, independent decisions are taken by
179
+ the semantic and acoustic quantizers. If `same` (the default), the same decision is taken by both.
180
+ **kwargs: Arguments to the constructor of `ResidualVectorQuantizer` that are shared between both.
181
+ """
182
+
183
+ def __init__(
184
+ self,
185
+ *,
186
+ n_q: int = 8,
187
+ no_quantization_rate: float = 0.0,
188
+ no_quantization_mode: str = "same",
189
+ n_q_semantic: int = 1,
190
+ **kwargs,
191
+ ):
192
+ super().__init__()
193
+ assert n_q > n_q_semantic, (
194
+ f"Number of quantizers {n_q} must be larger "
195
+ f"than the number of semantic quantizers {n_q_semantic}."
196
+ )
197
+ self.max_n_q = n_q
198
+ self.n_q_semantic = n_q_semantic
199
+ self.n_q_acoustic = n_q - n_q_semantic
200
+ if no_quantization_mode == "true_skip":
201
+ self.no_quantization_rate = no_quantization_rate
202
+ # Setting to zero for the underlying RVQ.
203
+ no_quantization_rate = 0.0
204
+ else:
205
+ self.no_quantization_rate = 0.0
206
+ if no_quantization_mode == "same":
207
+ kwargs["generator_seed"] = 1234
208
+ kwargs["no_quantization_rate"] = no_quantization_rate
209
+ q_dropout = kwargs.pop("q_dropout", False)
210
+ self.rvq_first = ResidualVectorQuantizer(
211
+ n_q=n_q_semantic, force_projection=True, q_dropout=False, **kwargs
212
+ )
213
+ self.rvq_rest = ResidualVectorQuantizer(
214
+ n_q=n_q - n_q_semantic,
215
+ codebook_offset=1,
216
+ force_projection=True,
217
+ q_dropout=q_dropout,
218
+ **kwargs,
219
+ )
220
+ if no_quantization_mode == "true_skip":
221
+ assert self.rvq_first.input_dimension == self.rvq_first.output_dimension
222
+ assert self.rvq_rest.input_dimension == self.rvq_rest.output_dimension
223
+
224
+ def _renorm_and_add(
225
+ self,
226
+ first_val: torch.Tensor,
227
+ rest_val: torch.Tensor,
228
+ n_q_semantic: int,
229
+ n_q_acoustic: int,
230
+ ):
231
+ """Renormalizes values from `rvq_first` and `rvq_rest` and adds them.
232
+
233
+ This allows correcting statistics that are normalized by the number of quantizers. To renormalize, we use the
234
+ number of quantizers that are actually used, e.g. taking into account quantizer dropout.
235
+ """
236
+ n_q = n_q_semantic + n_q_acoustic
237
+ renorm_first_val = first_val * n_q_semantic / n_q
238
+ renorm_rest_val = rest_val * n_q_acoustic / n_q
239
+ return renorm_first_val + renorm_rest_val
240
+
241
+ def forward(self, x: torch.Tensor, frame_rate: int):
242
+ """
243
+ Args:
244
+ x (torch.Tensor): Input tensor of shape [B, C, T] with `C` number of channels.
245
+ frame_rate (int): frame rate of the input (e.g `T = frame_rate * duration`), used to compute
246
+ the bandwidth.
247
+
248
+ Returns:
249
+ QuantizedResult: Quantized result with the following attributes:
250
+ - `x` (torch.Tensor): Quantized tensor of shape [B, C, T].
251
+ - `codes` (torch.Tensor): Quantized codes of shape [B, K, T] with `K` number of codebooks.
252
+ - `bw` (torch.Tensor): Bandwidth of the quantized tensor in kbits per second.
253
+ - `penalty` (torch.Tensor): Commitment loss.
254
+ - `metrics` (dict): RVQ metrics, in particular rate of dead code replacement, and entropy.
255
+ """
256
+ semantic_result = self.rvq_first(x, frame_rate)
257
+ if self.n_q == self.n_q_semantic:
258
+ return semantic_result
259
+ acoustic_result = self.rvq_rest(x, frame_rate)
260
+ full_quantized_emb = semantic_result.x + acoustic_result.x
261
+ full_quantized_codes = torch.cat(
262
+ [semantic_result.codes, acoustic_result.codes], dim=1
263
+ )
264
+ # This is the actual number of quantizers used, e.g. taking into account quantizer dropout.
265
+ n_q_semantic = semantic_result.codes.shape[1]
266
+ n_q_acoustic = acoustic_result.codes.shape[1]
267
+ full_quantized_bandwidth = semantic_result.bandwidth + acoustic_result.bandwidth
268
+ full_quantized_penalty = self._renorm_and_add(
269
+ semantic_result.penalty, acoustic_result.penalty, n_q_semantic, n_q_acoustic
270
+ )
271
+ full_quantized_metrics = semantic_result.metrics
272
+ for key, value in acoustic_result.metrics.items():
273
+ if key in full_quantized_metrics:
274
+ full_quantized_metrics[key] = self._renorm_and_add(
275
+ full_quantized_metrics[key], value, n_q_semantic, n_q_acoustic
276
+ )
277
+ else:
278
+ full_quantized_metrics[key] = value
279
+ return QuantizedResult(
280
+ full_quantized_emb,
281
+ full_quantized_codes,
282
+ full_quantized_bandwidth,
283
+ penalty=full_quantized_penalty,
284
+ metrics=full_quantized_metrics,
285
+ )
286
+
287
+ def encode(self, x: torch.Tensor) -> torch.Tensor:
288
+ """Encode a given input tensor with the specified frame rate at the given bandwidth.
289
+ The RVQ encode method sets the appropriate number of quantizer to use
290
+ and returns indices for each quantizer.
291
+ """
292
+ codes = self.rvq_first.encode(x)
293
+ if self.n_q > self.n_q_semantic:
294
+ acoustic_codes = self.rvq_rest.encode(x)
295
+ codes = torch.cat([codes, acoustic_codes], dim=1)
296
+ # codes is [B, K, T], with T frames, K nb of codebooks.
297
+ return codes
298
+
299
+ def decode(self, codes: torch.Tensor) -> torch.Tensor:
300
+ """Decode the given codes to the quantized representation."""
301
+ # codes is [B, K, T], with T frames, K nb of codebooks.
302
+ quantized = self.rvq_first.decode(codes[:, : self.n_q_semantic])
303
+ if codes.shape[1] > self.n_q_semantic:
304
+ quantized += self.rvq_rest.decode(codes[:, self.n_q_semantic :])
305
+ return quantized
306
+
307
+ @property
308
+ def total_codebooks(self):
309
+ return self.rvq_first.max_n_q + self.rvq_rest.max_n_q
310
+
311
+ @property
312
+ def num_codebooks(self):
313
+ return self.rvq_first.num_codebooks + self.rvq_rest.num_codebooks
314
+
315
+ @property
316
+ def n_q(self):
317
+ return self.rvq_first.n_q + self.rvq_rest.n_q
318
+
319
+ @property
320
+ def dimension(self):
321
+ return self.rvq_first.dimension
322
+
323
+ @property
324
+ def semantic_quantizer(self) -> ResidualVectorQuantizer:
325
+ """This returns the quantizer that models the first level of the hierarchy (typically semantic)."""
326
+ return self.rvq_first
327
+
328
+ @property
329
+ def acoustic_quantizer(self) -> ResidualVectorQuantizer:
330
+ """This returns the quantizer that models the higher levels of the hierarchy (typically acoustic)."""
331
+ return self.rvq_rest
332
+
333
+ def set_num_codebooks(self, n: int):
334
+ assert n >= self.n_q_semantic and n <= self.total_codebooks
335
+ self.rvq_rest.set_num_codebooks(n - self.n_q_semantic)
336
+
337
+ @property
338
+ def cardinality(self) -> int:
339
+ assert self.rvq_rest.cardinality == self.rvq_first.cardinality
340
+ return self.rvq_first.cardinality
moshi/server.py ADDED
@@ -0,0 +1,256 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Kyutai, all rights reserved.
2
+ # This source code is licensed under the license found in the
3
+ # LICENSE file in the root directory of this source tree.
4
+
5
+ import argparse
6
+ import asyncio
7
+ from dataclasses import dataclass
8
+ import random
9
+ import os
10
+ from pathlib import Path
11
+ import tarfile
12
+ import time
13
+ import secrets
14
+ import sys
15
+
16
+ import aiohttp
17
+ from aiohttp import web
18
+ from huggingface_hub import hf_hub_download
19
+ import numpy as np
20
+ import sentencepiece
21
+ import sphn
22
+ import torch
23
+
24
+
25
+ from .client_utils import make_log
26
+ from .models import loaders, MimiModel, LMModel, LMGen
27
+
28
+
29
+ def log(level: str, msg: str):
30
+ print(make_log(level, msg))
31
+
32
+
33
+ def seed_all(seed):
34
+ torch.manual_seed(seed)
35
+ if torch.cuda.is_available():
36
+ torch.cuda.manual_seed(seed)
37
+ torch.cuda.manual_seed_all(seed) # for multi-GPU setups
38
+ random.seed(seed)
39
+ np.random.seed(seed)
40
+ torch.backends.cudnn.deterministic = False
41
+ torch.backends.cudnn.benchmark = False
42
+
43
+
44
+ @dataclass
45
+ class ServerState:
46
+ mimi: MimiModel
47
+ text_tokenizer: sentencepiece.SentencePieceProcessor
48
+ lm_gen: LMGen
49
+ lock: asyncio.Lock
50
+
51
+ def __init__(self, mimi: MimiModel, text_tokenizer: sentencepiece.SentencePieceProcessor,
52
+ lm: LMModel, device: str | torch.device):
53
+ self.mimi = mimi
54
+ self.text_tokenizer = text_tokenizer
55
+ self.lm_gen = LMGen(lm)
56
+
57
+ self.device = device
58
+ self.frame_size = int(self.mimi.sample_rate / self.mimi.frame_rate)
59
+ self.lock = asyncio.Lock()
60
+
61
+ self.mimi.streaming_forever(1)
62
+ self.lm_gen.streaming_forever(1)
63
+
64
+ def warmup(self):
65
+ for chunk in range(4):
66
+ chunk = torch.zeros(1, 1, self.frame_size, dtype=torch.float32, device=self.device)
67
+ codes = self.mimi.encode(chunk)
68
+ for c in range(codes.shape[-1]):
69
+ tokens = self.lm_gen.step(codes[:, :, c: c + 1])
70
+ if tokens is None:
71
+ continue
72
+ _ = self.mimi.decode(tokens[:, 1:])
73
+ torch.cuda.synchronize()
74
+
75
+ async def handle_chat(self, request):
76
+ ws = web.WebSocketResponse()
77
+ await ws.prepare(request)
78
+
79
+ async def recv_loop():
80
+ nonlocal close
81
+ try:
82
+ async for message in ws:
83
+ if message.type == aiohttp.WSMsgType.ERROR:
84
+ log("error", f"{ws.exception()}")
85
+ break
86
+ elif message.type == aiohttp.WSMsgType.CLOSED:
87
+ break
88
+ elif message.type != aiohttp.WSMsgType.BINARY:
89
+ log("error", f"unexpected message type {message.type}")
90
+ continue
91
+ message = message.data
92
+ if not isinstance(message, bytes):
93
+ log("error", f"unsupported message type {type(message)}")
94
+ continue
95
+ if len(message) == 0:
96
+ log("warning", "empty message")
97
+ continue
98
+ kind = message[0]
99
+ if kind == 1: # audio
100
+ payload = message[1:]
101
+ opus_reader.append_bytes(payload)
102
+ else:
103
+ log("warning", f"unknown message kind {kind}")
104
+ finally:
105
+ close = True
106
+ log("info", "connection closed")
107
+
108
+ async def opus_loop():
109
+ all_pcm_data = None
110
+
111
+ while True:
112
+ if close:
113
+ return
114
+ await asyncio.sleep(0.001)
115
+ pcm = opus_reader.read_pcm()
116
+ if pcm.shape[-1] == 0:
117
+ continue
118
+ if all_pcm_data is None:
119
+ all_pcm_data = pcm
120
+ else:
121
+ all_pcm_data = np.concatenate((all_pcm_data, pcm))
122
+ while all_pcm_data.shape[-1] >= self.frame_size:
123
+ be = time.time()
124
+ chunk = all_pcm_data[: self.frame_size]
125
+ all_pcm_data = all_pcm_data[self.frame_size:]
126
+ chunk = torch.from_numpy(chunk)
127
+ chunk = chunk.to(device=self.device)[None, None]
128
+ codes = self.mimi.encode(chunk)
129
+ for c in range(codes.shape[-1]):
130
+ tokens = self.lm_gen.step(codes[:, :, c: c + 1])
131
+ if tokens is None:
132
+ continue
133
+ assert tokens.shape[1] == self.lm_gen.lm_model.dep_q + 1
134
+ main_pcm = self.mimi.decode(tokens[:, 1:])
135
+ main_pcm = main_pcm.cpu()
136
+ opus_writer.append_pcm(main_pcm[0, 0].numpy())
137
+ text_token = tokens[0, 0, 0].item()
138
+ if text_token not in (0, 3):
139
+ _text = self.text_tokenizer.id_to_piece(text_token) # type: ignore
140
+ _text = _text.replace("▁", " ")
141
+ msg = b"\x02" + bytes(_text, encoding="utf8")
142
+ log("info", f"text token '{_text}'")
143
+ await ws.send_bytes(msg)
144
+ log("info", f"frame handled in {1000 * (time.time() - be):.1f}ms")
145
+
146
+ async def send_loop():
147
+ while True:
148
+ if close:
149
+ return
150
+ await asyncio.sleep(0.001)
151
+ msg = opus_writer.read_bytes()
152
+ if len(msg) > 0:
153
+ await ws.send_bytes(b"\x01" + msg)
154
+
155
+ log("info", "accepted connection")
156
+ close = False
157
+ async with self.lock:
158
+ opus_writer = sphn.OpusStreamWriter(self.mimi.sample_rate)
159
+ opus_reader = sphn.OpusStreamReader(self.mimi.sample_rate)
160
+ self.mimi.reset_streaming()
161
+ self.lm_gen.reset_streaming()
162
+ # Send the handshake.
163
+ await ws.send_bytes(b"\x00")
164
+ await asyncio.gather(opus_loop(), recv_loop(), send_loop())
165
+ log("info", "done with connection")
166
+ return ws
167
+
168
+
169
+ def main():
170
+ parser = argparse.ArgumentParser()
171
+ parser.add_argument("--host", default="localhost", type=str)
172
+ parser.add_argument("--port", default=8998, type=int)
173
+ parser.add_argument("--static", type=str)
174
+ parser.add_argument("--gradio-tunnel", action='store_true', help='Activate a gradio tunnel.')
175
+ parser.add_argument("--gradio-tunnel-token",
176
+ help='Provide a custom (secret) token here to keep getting the same URL.')
177
+
178
+ parser.add_argument("--tokenizer", type=str, help="Path to a local tokenizer file.")
179
+ parser.add_argument("--moshi-weight", type=str, help="Path to a local checkpoint file for Moshi.")
180
+ parser.add_argument("--mimi-weight", type=str, help="Path to a local checkpoint file for Mimi.")
181
+ parser.add_argument("--hf-repo", type=str, default=loaders.DEFAULT_REPO,
182
+ help="HF repo to look into, defaults Moshiko. "
183
+ "Use this to select a different pre-trained model.")
184
+ parser.add_argument("--device", type=str, default="cuda", help="Device on which to run, defaults to 'cuda'.")
185
+
186
+ args = parser.parse_args()
187
+ seed_all(42424242)
188
+
189
+ setup_tunnel = None
190
+ tunnel_token = ''
191
+ if args.gradio_tunnel:
192
+ try:
193
+ from gradio import networking # type: ignore
194
+ except ImportError:
195
+ log("error", "Cannot find gradio which is required to activate a tunnel. "
196
+ "Please install with `pip install gradio`.")
197
+ sys.exit(1)
198
+ setup_tunnel = networking.setup_tunnel
199
+ if args.gradio_tunnel_token is None:
200
+ tunnel_token = secrets.token_urlsafe(32)
201
+ else:
202
+ tunnel_token = args.gradio_tunnel_token
203
+
204
+ log("info", "loading mimi")
205
+ if args.mimi_weight is None:
206
+ args.mimi_weight = hf_hub_download(args.hf_repo, loaders.MIMI_NAME)
207
+ mimi = loaders.get_mimi(args.mimi_weight, args.device)
208
+ log("info", "mimi loaded")
209
+
210
+ if args.tokenizer is None:
211
+ args.tokenizer = hf_hub_download(args.hf_repo, loaders.TEXT_TOKENIZER_NAME)
212
+ text_tokenizer = sentencepiece.SentencePieceProcessor(args.tokenizer) # type: ignore
213
+
214
+ log("info", "loading moshi")
215
+ if args.moshi_weight is None:
216
+ args.moshi_weight = hf_hub_download(args.hf_repo, loaders.MOSHI_NAME)
217
+ lm = loaders.get_moshi_lm(args.moshi_weight, args.device)
218
+ log("info", "moshi loaded")
219
+
220
+ state = ServerState(mimi, text_tokenizer, lm, args.device)
221
+ log("info", "warming up the model")
222
+ state.warmup()
223
+ app = web.Application()
224
+ app.router.add_get("/api/chat", state.handle_chat)
225
+ static_path: None | str = None
226
+ if args.static is None:
227
+ log("info", "retrieving the static content")
228
+ dist_tgz = hf_hub_download("kyutai/moshi-artifacts", "dist.tgz")
229
+ dist_tgz = Path(dist_tgz)
230
+ dist = dist_tgz.parent / "dist"
231
+ if not dist.exists():
232
+ with tarfile.open(dist_tgz, "r:gz") as tar:
233
+ tar.extractall(path=dist_tgz.parent)
234
+ static_path = str(dist)
235
+ elif args.static != "none":
236
+ # When set to the "none" string, we don't serve any static content.
237
+ static_path = args.static
238
+ if static_path is not None:
239
+ async def handle_root(_):
240
+ return web.FileResponse(os.path.join(static_path, "index.html"))
241
+
242
+ log("info", f"serving static content from {static_path}")
243
+ app.router.add_get("/", handle_root)
244
+ app.router.add_static(
245
+ "/", path=static_path, follow_symlinks=True, name="static"
246
+ )
247
+ log("info", f"Access the Web UI directly at http://{args.host}:{args.port}")
248
+ if setup_tunnel is not None:
249
+ tunnel = setup_tunnel('localhost', args.port, tunnel_token, None)
250
+ log("info", f"Tunnel started, if executing on a remote GPU, you can use {tunnel}.")
251
+ log("info", "Note that this tunnel goes through the US and you might experience high latency in Europe.")
252
+ web.run_app(app, port=args.port)
253
+
254
+
255
+ with torch.no_grad():
256
+ main()
moshi/utils/__init__.py ADDED
@@ -0,0 +1,10 @@
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Kyutai, all rights reserved.
2
+ # This source code is licensed under the license found in the
3
+ # LICENSE file in the root directory of this source tree.
4
+
5
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
6
+ # All rights reserved.
7
+ #
8
+ # This source code is licensed under the license found in the
9
+ # LICENSE file in the root directory of this source tree.
10
+ """Utilities."""
moshi/utils/autocast.py ADDED
@@ -0,0 +1,45 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Kyutai, all rights reserved.
2
+ # This source code is licensed under the license found in the
3
+ # LICENSE file in the root directory of this source tree.
4
+
5
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
6
+ # All rights reserved.
7
+ #
8
+ # This source code is licensed under the license found in the
9
+ # LICENSE file in the root directory of this source tree.
10
+
11
+ import torch
12
+
13
+
14
+ class TorchAutocast:
15
+ """TorchAutocast utility class.
16
+ Allows you to enable and disable autocast. This is specially useful
17
+ when dealing with different architectures and clusters with different
18
+ levels of support.
19
+
20
+ Args:
21
+ enabled (bool): Whether to enable torch.autocast or not.
22
+ args: Additional args for torch.autocast.
23
+ kwargs: Additional kwargs for torch.autocast
24
+ """
25
+
26
+ def __init__(self, enabled: bool, *args, **kwargs):
27
+ self.autocast = torch.autocast(*args, **kwargs) if enabled else None
28
+
29
+ def __enter__(self):
30
+ if self.autocast is None:
31
+ return
32
+ try:
33
+ self.autocast.__enter__()
34
+ except RuntimeError:
35
+ device = self.autocast.device
36
+ dtype = self.autocast.fast_dtype
37
+ raise RuntimeError(
38
+ f"There was an error autocasting with dtype={dtype} device={device}\n"
39
+ "If you are on the FAIR Cluster, you might need to use autocast_dtype=float16"
40
+ )
41
+
42
+ def __exit__(self, *args, **kwargs):
43
+ if self.autocast is None:
44
+ return
45
+ self.autocast.__exit__(*args, **kwargs)
moshi/utils/compile.py ADDED
@@ -0,0 +1,284 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Kyutai, all rights reserved.
2
+ # This source code is licensed under the license found in the
3
+ # LICENSE file in the root directory of this source tree.
4
+
5
+ """
6
+ Provides some extra utilities around torch compile, in particular with a way
7
+ to fully deactivate it easily with a context manager.
8
+ Provides a simple activation checkpointing that is compatible with FSDP and torch compile.
9
+ Finally, provides some utilities for CUDA graphing functions.
10
+ """
11
+ from contextlib import contextmanager
12
+ from functools import wraps
13
+ import inspect
14
+ import os
15
+ import typing as tp
16
+
17
+ import torch
18
+ from torch import cuda
19
+
20
+
21
+ _compile_disabled: bool = False
22
+
23
+
24
+ @contextmanager
25
+ def no_compile():
26
+ """Disable torch.compile locally. Now Pytorch 2.4 provides a function to do that."""
27
+ global _compile_disabled
28
+
29
+ prev_disabled = _compile_disabled
30
+ _compile_disabled = True
31
+ try:
32
+ yield
33
+ finally:
34
+ _compile_disabled = prev_disabled
35
+
36
+
37
+ def torch_compile_lazy(fun):
38
+ """torch.compile creates a huge pool of processes, even when not using the function at all,
39
+ e.g. with Dora. This can polute stderr when doing CTRL+C. So we do it in a lazy way.
40
+ """
41
+ if os.environ.get("NO_TORCH_COMPILE"):
42
+ return fun
43
+ fun_compiled = None
44
+
45
+ @wraps(fun)
46
+ def _wrapped(*args, **kwargs):
47
+ nonlocal fun_compiled
48
+ if _compile_disabled:
49
+ return fun(*args, **kwargs)
50
+ if fun_compiled is None:
51
+ fun_compiled = torch.compile(fun)
52
+ return fun_compiled(*args, **kwargs)
53
+
54
+ return _wrapped
55
+
56
+
57
+ class Checkpoint(torch.autograd.Function):
58
+ @staticmethod
59
+ def forward(ctx, function, *args) -> tp.Any:
60
+ to_save = []
61
+ ctx.others = []
62
+ ctx.function = function
63
+ # Sources will indicate whether the arg in position N is
64
+ # a tensor stored in ctx.save_for_backward, or inside ctx.others.
65
+ ctx.sources = []
66
+ new_args = []
67
+ for arg in args:
68
+ if isinstance(arg, torch.Tensor):
69
+ to_save.append(arg)
70
+ ctx.sources.append("tensor")
71
+ new_args.append(arg.detach())
72
+ else:
73
+ ctx.sources.append("other")
74
+ ctx.others.append(arg)
75
+ new_args.append(arg)
76
+ ctx.save_for_backward(*to_save)
77
+ # During the forward, we just make a pass with no gradient computed.
78
+ with torch.no_grad():
79
+ res = function(*new_args)
80
+ return res
81
+
82
+ @staticmethod
83
+ def backward(ctx, *grads) -> tp.Tuple[tp.Optional[torch.Tensor], ...]:
84
+ pseudo_tensors = []
85
+ with torch.set_grad_enabled(True):
86
+ # We create leaf tensors to collect the output gradients.
87
+ # We call them pseudo_tensors because they are pretending to be the input
88
+ # to `function` but are not directly
89
+ for tensor in ctx.saved_tensors:
90
+ pseudo_tensor = tensor.detach()
91
+ pseudo_tensor.requires_grad_(True)
92
+ pseudo_tensors.append(pseudo_tensor)
93
+ pseudo_tensors_copy = list(pseudo_tensors)
94
+ args = []
95
+ for source in ctx.sources:
96
+ if source == "other":
97
+ args.append(ctx.others.pop(0))
98
+ else:
99
+ assert source == "tensor"
100
+ args.append(pseudo_tensors_copy.pop(0))
101
+ res = ctx.function(*args)
102
+ # The second forward with grad computation allows us to connect the input leaf tensors
103
+ # inside pseudo_tensors, to the outputs of the function called.
104
+ if not isinstance(res, tuple):
105
+ res = (res,)
106
+ # Now we just ask Torch to compute the derivative of `res` given the gradient coming from above
107
+ # `grads`. The computed gradient will end up into the `pseudo_tensors` grad attributes.
108
+ torch.autograd.backward(res, grads)
109
+ out: tp.List[tp.Optional[torch.Tensor]] = [None]
110
+ for source in ctx.sources:
111
+ # We still need to output `None` values for non tensor parameters.
112
+ if source == "other":
113
+ out.append(None)
114
+ else:
115
+ assert source == "tensor"
116
+ out.append(pseudo_tensors.pop(0).grad)
117
+ return tuple(out)
118
+
119
+
120
+ def simple_checkpoint(module: torch.nn.Module, *args, **kwargs):
121
+ """Custom implementation of checkpointing in PyTorch as the builtin implementation is broken
122
+ when using torch compile. Only supports wrapping a `nn.Module` with a forward with no `*args` or `**kwargs`.
123
+
124
+ https://github.com/pytorch/pytorch/issues/97436.
125
+ Should be resolved in nightlies, but it is quite fun and simple to code it ourselves.
126
+ """
127
+ if hasattr(module, "_fsdp_wrapped_module"):
128
+ module_for_sig = module._fsdp_wrapped_module
129
+ else:
130
+ module_for_sig = module
131
+ sig = inspect.signature(module_for_sig.forward)
132
+ # We first flatten all arguments to use only *args, to make things easier and because
133
+ # torch.autograd.Function has weird support for kwargs.
134
+ bounded = sig.bind(*args, **kwargs)
135
+ new_args = []
136
+ for name, param in sig.parameters.items():
137
+ if param.kind in {
138
+ inspect.Parameter.VAR_POSITIONAL,
139
+ inspect.Parameter.VAR_KEYWORD,
140
+ }:
141
+ raise RuntimeError("simple_checkpoint doesn't support var args.")
142
+ if name not in bounded.arguments:
143
+ break
144
+ new_args.append(bounded.arguments[name])
145
+ return Checkpoint.apply(module, *new_args)
146
+
147
+
148
+ _in_cuda_graph = False
149
+ _disable_cuda_graph = False
150
+
151
+
152
+ def in_cuda_graph() -> bool:
153
+ """Indicate whether we are in a function that is CUDA Graphed (or will be soon)."""
154
+ return _in_cuda_graph
155
+
156
+
157
+ @contextmanager
158
+ def _set_in_cuda_graph():
159
+ global _in_cuda_graph
160
+ assert not _in_cuda_graph
161
+ _in_cuda_graph = True
162
+ try:
163
+ yield
164
+ finally:
165
+ _in_cuda_graph = False
166
+
167
+
168
+ def _is_cuda_graph_enabled() -> bool:
169
+ if _disable_cuda_graph:
170
+ return False
171
+ no_cuda_graph = os.environ.get("NO_CUDA_GRAPH", "")
172
+ if no_cuda_graph.lower() not in {"0", "no", "n", ""}:
173
+ return False
174
+ return True
175
+
176
+
177
+ @contextmanager
178
+ def no_cuda_graph():
179
+ """Deactivate CUDA Graphing for all the calls in this context manager."""
180
+ global _disable_cuda_graph
181
+ old_value = _disable_cuda_graph
182
+ _disable_cuda_graph = True
183
+ try:
184
+ yield
185
+ finally:
186
+ _disable_cuda_graph = old_value
187
+
188
+
189
+ class CUDAGraphed:
190
+ """Allow simple CUDA Graphing of a function.
191
+
192
+ Args:
193
+ func: callable, taking any number of arguments. Its tensors arguments should
194
+ be top level args, not nested in structures (tuples, dicts, etc). Keyword
195
+ arguments are NOT supported for simplicity.
196
+ warmup_steps: how many call to make normally before CUDA Graphing. In particular, this
197
+ allows torch.compiled functions to get properly compiled.
198
+ disabled: if True, just call the func directly, useful to quickly deactivate on CPU.
199
+ """
200
+
201
+ def __init__(self, func: tp.Callable, warmup_steps: int = 1, disable: bool = False):
202
+ self.func = func
203
+ self.warmup_steps = warmup_steps
204
+ self.disable = disable
205
+ self._graph: cuda.CUDAGraph | None = None
206
+ self._output: tuple | None = None
207
+ self._args: tuple | None = None
208
+
209
+ def reset(self, warmup_steps: int = 0) -> None:
210
+ """Reset the state, meaning the next call we get CUDA Graphed again. Useful if some
211
+ shapes have changed, or external state (e.g. KVCache) has changed."""
212
+ self.warmup_steps = warmup_steps
213
+ self._graph = None
214
+ self._output = None
215
+ self._args = None
216
+
217
+ def __call__(self, *args, **kwargs) -> tp.Any:
218
+ if kwargs:
219
+ raise RuntimeError("Named arguments not supported for now.")
220
+ if self.disable or not _is_cuda_graph_enabled() or in_cuda_graph():
221
+ return self.func(*args, **kwargs)
222
+
223
+ def _clone_tensors(args: tuple) -> tuple:
224
+ out: list = []
225
+ for arg in args:
226
+ if isinstance(arg, torch.Tensor):
227
+ arg = arg.clone()
228
+ out.append(arg)
229
+ return tuple(out)
230
+
231
+ def _match_values_copy_tensors(args: tuple, target_args: tuple) -> None:
232
+ if len(args) != len(target_args):
233
+ raise ValueError(
234
+ f"Expected {len(target_args)}, but got {args} for CUDA Graphed function."
235
+ )
236
+ for idx, (source, target) in enumerate(zip(args, target_args)):
237
+ if isinstance(target, torch.Tensor):
238
+ if not isinstance(source, torch.Tensor):
239
+ raise ValueError(
240
+ f"Argument #{idx} was a tensor, and is no longer (now {source})."
241
+ )
242
+ if source.shape != target.shape:
243
+ raise ValueError(
244
+ f"Argument #{idx} had shape {target.shape}, but got shae {source.shape}"
245
+ )
246
+ target.copy_(source)
247
+ else:
248
+ if isinstance(source, torch.Tensor):
249
+ raise ValueError(
250
+ f"Argument #{idx} was not a tensor {target}, but is now one."
251
+ )
252
+ if source is not target and source != target:
253
+ raise ValueError(
254
+ f"Argument #{idx} changed value from {target} to {source}."
255
+ )
256
+
257
+ with _set_in_cuda_graph():
258
+ # Prevent any one under us to try and CUDA Graph things.
259
+ if self._graph is None:
260
+ if self.warmup_steps <= 0:
261
+ self._graph = cuda.CUDAGraph()
262
+ # Making a copy just to ensure those are not used else where.
263
+ self._args = _clone_tensors(args)
264
+ with cuda.graph(self._graph):
265
+ self._output = self.func(*self._args)
266
+ # At this point nothing really happened, so we have to make it run for real.
267
+ self._graph.replay()
268
+ return self._output
269
+ else:
270
+ self.warmup_steps -= 1
271
+ return self.func(*args)
272
+ else:
273
+ assert self._args is not None
274
+ assert self._output is not None
275
+ _match_values_copy_tensors(args, self._args)
276
+ self._graph.replay()
277
+ return self._output
278
+
279
+
280
+ def cuda_graph(func: tp.Callable, warmup_steps: int = 1):
281
+ """Just calls `CUDAGraphed` on the given function."""
282
+ if not _is_cuda_graph_enabled():
283
+ return func
284
+ return CUDAGraphed(func, warmup_steps)
moshi/utils/sampling.py ADDED
@@ -0,0 +1,126 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Kyutai, all rights reserved.
2
+ # This source code is licensed under the license found in the
3
+ # LICENSE file in the root directory of this source tree.
4
+
5
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
6
+ # All rights reserved.
7
+ #
8
+ # This source code is licensed under the license found in the
9
+ # LICENSE file in the root directory of this source tree.
10
+
11
+
12
+ import torch
13
+
14
+
15
+ def multinomial(
16
+ input: torch.Tensor, num_samples: int, replacement=False, *, generator=None
17
+ ):
18
+ """torch.multinomial with arbitrary number of dimensions, and number of candidates on the last dimension.
19
+
20
+ Args:
21
+ input (torch.Tensor): The input tensor containing probabilities.
22
+ num_samples (int): Number of samples to draw.
23
+ replacement (bool): Whether to draw with replacement or not.
24
+ Keywords args:
25
+ generator (torch.Generator): A pseudorandom number generator for sampling.
26
+ Returns:
27
+ torch.Tensor: Last dimension contains num_samples indices
28
+ sampled from the multinomial probability distribution
29
+ located in the last dimension of tensor input.
30
+ """
31
+ input_ = input.reshape(-1, input.shape[-1])
32
+ # We should probably be able to remove this once the following PR has landed:
33
+ # https://github.com/pytorch/pytorch/pull/134818/files
34
+ # In the meantime, we specialize the case no-replacement, nsamples=1 so as not
35
+ # to have a synchronization point.
36
+ if replacement or num_samples != 1:
37
+ output_ = torch.multinomial(
38
+ input_,
39
+ num_samples=num_samples,
40
+ replacement=replacement,
41
+ generator=generator,
42
+ )
43
+ else:
44
+ q = torch.empty_like(input_).exponential_(1, generator=generator)
45
+ q = input_ / q
46
+ output_ = q.argmax(dim=-1, keepdim=True)
47
+ output = output_.reshape(*list(input.shape[:-1]), -1)
48
+ return output
49
+
50
+
51
+ def sample_top_k(probs: torch.Tensor, k: int) -> torch.Tensor:
52
+ """Sample next token from top K values along the last dimension of the input probs tensor.
53
+
54
+ Args:
55
+ probs (torch.Tensor): Input probabilities with token candidates on the last dimension.
56
+ k (int): The k in “top-k”.
57
+ Returns:
58
+ torch.Tensor: Sampled tokens.
59
+ """
60
+ probs, indices = torch.topk(probs, k, dim=-1)
61
+ next_token = multinomial(probs, num_samples=1)
62
+ next_token = indices.gather(-1, next_token)
63
+ return next_token
64
+
65
+
66
+ def sample_top_p(probs: torch.Tensor, p: float) -> torch.Tensor:
67
+ """Sample next token from top P probabilities along the last dimension of the input probs tensor.
68
+
69
+ Args:
70
+ probs (torch.Tensor): Input probabilities with token candidates on the last dimension.
71
+ p (int): The p in “top-p”.
72
+ Returns:
73
+ torch.Tensor: Sampled tokens.
74
+ """
75
+ probs_sort, probs_idx = torch.sort(probs, dim=-1, descending=True)
76
+ probs_sum = torch.cumsum(probs_sort, dim=-1)
77
+ mask = probs_sum - probs_sort > p
78
+ probs_sort *= (~mask).float()
79
+ probs_sort.div_(probs_sort.sum(dim=-1, keepdim=True))
80
+ next_token = multinomial(probs_sort, num_samples=1)
81
+ next_token = torch.gather(probs_idx, -1, next_token)
82
+ return next_token
83
+
84
+
85
+ def sample_token(
86
+ logits: torch.Tensor,
87
+ use_sampling: bool = False,
88
+ temp: float = 1.0,
89
+ top_k: int = 0,
90
+ top_p: float = 0.0,
91
+ ) -> torch.Tensor:
92
+ """Given logits of shape [*, Card], returns a LongTensor of shape [*]."""
93
+ # Apply softmax for sampling if temp > 0. Else, do greedy sampling to avoid zero division error.
94
+ if use_sampling and temp > 0.0:
95
+ probs = torch.softmax(logits / temp, dim=-1)
96
+ if top_p > 0.0:
97
+ next_token = sample_top_p(probs, p=top_p)
98
+ elif top_k > 0:
99
+ next_token = sample_top_k(probs, k=top_k)
100
+ else:
101
+ next_token = multinomial(probs, num_samples=1)
102
+ else:
103
+ next_token = torch.argmax(logits, dim=-1, keepdim=True)
104
+ assert next_token.shape[-1] == 1
105
+ return next_token[..., 0]
106
+
107
+
108
+ if __name__ == "__main__":
109
+ torch.manual_seed(1234)
110
+ device = "cpu"
111
+ if torch.cuda.is_available():
112
+ torch.backends.cuda.matmul.allow_tf32 = False
113
+ torch.backends.cudnn.allow_tf32 = False
114
+ device = "cuda:0"
115
+
116
+ ps = torch.tensor([5.0, 2.0, 12.0, 6.0, 8.0, 1.0, 0.0, 4.0], device=device)
117
+ cnts = torch.zeros(ps.shape, dtype=torch.long, device=device)
118
+ total_samples = 1000
119
+ for _ in range(total_samples):
120
+ vs = multinomial(ps, num_samples=1, replacement=False)
121
+ cnts[vs] += 1
122
+ diff = cnts / cnts.sum() - ps / ps.sum()
123
+ max_diff = diff.abs().max().cpu().item()
124
+ print(ps / ps.sum())
125
+ print(cnts / cnts.sum())
126
+ assert max_diff < 1.5e-2
pyproject.toml ADDED
@@ -0,0 +1,33 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ [project]
2
+ name = "moshi"
3
+ requires-python = ">= 3.10"
4
+ description = "Moshi is moshi"
5
+ dependencies = [
6
+ "numpy >= 1.26, < 2.2",
7
+ "safetensors >= 0.4.0, < 0.5",
8
+ "huggingface-hub >= 0.24, < 0.25",
9
+ "einops == 0.7",
10
+ "sentencepiece == 0.2",
11
+ "sounddevice == 0.5",
12
+ "sphn >= 0.1.4",
13
+ "torch >= 2.2.0, < 2.5",
14
+ "aiohttp>=3.10.5, <3.11",
15
+ ]
16
+ authors = [{name="Laurent Mazaré", email="laurent@kyutai.org"}]
17
+ maintainers = [{name="Laurent Mazaré", email="laurent@kyutai.org"}]
18
+ license = {text = "MIT"}
19
+ dynamic = ["version"]
20
+
21
+ [tool.setuptools.dynamic]
22
+ version = {attr = "moshi.__version__"}
23
+
24
+ [build-system]
25
+ requires = ["setuptools"]
26
+ build-backend = "setuptools.build_meta"
27
+
28
+ [project.optional-dependencies]
29
+ dev = [
30
+ "pyright",
31
+ "flake8",
32
+ "pre-commit",
33
+ ]
requirements.txt ADDED
@@ -0,0 +1,10 @@
 
 
 
 
 
 
 
 
 
 
 
1
+ einops==0.7.0
2
+ safetensors=0.4.4
3
+ sentencepiece==0.2.0
4
+ sounddevice==0.5.0
5
+ soundfile==0.12.1
6
+ sphn==0.1.4
7
+ torch==2.2.0
8
+ numpy==1.26.4
9
+ aiohttp>=3.10.5, <3.11
10
+ huggingface-hub==0.24.6
setup.cfg ADDED
@@ -0,0 +1,10 @@
 
 
 
 
 
 
 
 
 
 
 
1
+ [pep8]
2
+ max-line-length = 120
3
+
4
+ [flake8]
5
+ max-line-length = 120
6
+ ignore = E203,E704
7
+ exclude =
8
+ dist
9
+ build
10
+