klemenk commited on
Commit
f8b510b
·
verified ·
1 Parent(s): 78a362d

Upload WavCoch random-init model (WavCochV8192CausalConfig)

Browse files
README.md ADDED
@@ -0,0 +1,47 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ ---
2
+ license: apache-2.0
3
+ tags:
4
+ - audio
5
+ - speech
6
+ - tokenizer
7
+ - vocoder
8
+ - wavcoch
9
+ library_name: transformers
10
+ ---
11
+
12
+ # WavCochCausalV8192-vocoder-randinit
13
+
14
+ **WavCoch** is a causal waveform-to-cochleagram tokenizer by **Greta Tuckute** and **Klemen Kotar**.
15
+
16
+ This repository contains a freshly initialized `WavCochV8192CausalConfig` model with a bundled random-initialized vocoder. The weights are random and have not been trained from a checkpoint.
17
+
18
+ ## Model Details
19
+
20
+ | Parameter | Value |
21
+ |-----------|-------|
22
+ | Parameters | ~24.42M |
23
+ | Window Size | 1001 |
24
+ | Hop Length | 80 |
25
+ | Encoder Dim | 512 |
26
+ | Vocabulary Size | 8192 |
27
+ | Includes Vocoder | True |
28
+
29
+ ## Usage
30
+
31
+ ```python
32
+ from transformers import AutoModel
33
+
34
+ wavcoch = AutoModel.from_pretrained(
35
+ "TuKoResearch/WavCochCausalV8192-vocoder-randinit",
36
+ trust_remote_code=True,
37
+ )
38
+
39
+ codes = wavcoch.quantize(waveform_tensor)
40
+ coch = wavcoch.decode(codes)
41
+
42
+ audio = wavcoch.decode_audio(codes)
43
+ ```
44
+
45
+ ## Notes
46
+
47
+ This repo includes a bundled vocoder and supports `decode_audio(...)` for end-to-end waveform synthesis.
config.json ADDED
@@ -0,0 +1,71 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "model_type": "wavcoch",
3
+ "architectures": [
4
+ "WavCoch"
5
+ ],
6
+ "auto_map": {
7
+ "AutoConfig": "configuration_wavcoch.WavCochConfig",
8
+ "AutoModel": "modeling_wavcoch.WavCoch"
9
+ },
10
+ "torch_dtype": "float32",
11
+ "transformers_version": "4.40.0",
12
+ "sample_rate": 16000,
13
+ "causal_pad_mode": "repeat",
14
+ "out_channels": 211,
15
+ "has_vocoder": true,
16
+ "vocoder_upsample_rates": [
17
+ 5,
18
+ 4,
19
+ 2,
20
+ 2
21
+ ],
22
+ "vocoder_upsample_kernel_sizes": [
23
+ 10,
24
+ 8,
25
+ 4,
26
+ 4
27
+ ],
28
+ "vocoder_upsample_initial_channel": 512,
29
+ "vocoder_resblock": "1",
30
+ "vocoder_resblock_kernel_sizes": [
31
+ 11,
32
+ 7,
33
+ 3
34
+ ],
35
+ "vocoder_resblock_dilation_sizes": [
36
+ [
37
+ 1,
38
+ 3,
39
+ 5
40
+ ],
41
+ [
42
+ 1,
43
+ 3,
44
+ 5
45
+ ],
46
+ [
47
+ 1,
48
+ 3,
49
+ 5
50
+ ]
51
+ ],
52
+ "window_size": 1001,
53
+ "window_padding": 1000,
54
+ "hop_length": 80,
55
+ "causal_convs": true,
56
+ "encoder_layers": 8,
57
+ "encoder_dim": 512,
58
+ "encoder_kernel_size": 3,
59
+ "decoder_layers": 8,
60
+ "decoder_dim": 512,
61
+ "decoder_kernel_size": 9,
62
+ "quantizer": "FSQ",
63
+ "channels": [
64
+ 8,
65
+ 8,
66
+ 8,
67
+ 4,
68
+ 4
69
+ ],
70
+ "vocab_size": 8192
71
+ }
configuration_wavcoch.py ADDED
@@ -0,0 +1,73 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ WavCoch configuration for Hugging Face Transformers.
3
+ """
4
+
5
+ from transformers import PretrainedConfig
6
+
7
+
8
+ class WavCochConfig(PretrainedConfig):
9
+ """Configuration class for WavCoch checkpoints with optional vocoder."""
10
+
11
+ model_type = "wavcoch"
12
+
13
+ def __init__(
14
+ self,
15
+ window_size: int = 1001,
16
+ window_padding: int = 1000,
17
+ hop_length: int = 80,
18
+ out_channels: int = 211,
19
+ causal_convs: bool = True,
20
+ causal_pad_mode: str = "repeat",
21
+ encoder_layers: int = 8,
22
+ encoder_dim: int = 512,
23
+ encoder_kernel_size: int = 3,
24
+ decoder_layers: int = 8,
25
+ decoder_dim: int = 512,
26
+ decoder_kernel_size: int = 9,
27
+ quantizer: str = "FSQ",
28
+ channels=None,
29
+ vocab_size: int = None,
30
+ sample_rate: int = 16000,
31
+ has_vocoder: bool = False,
32
+ vocoder_upsample_rates=None,
33
+ vocoder_upsample_kernel_sizes=None,
34
+ vocoder_upsample_initial_channel: int = 512,
35
+ vocoder_resblock: str = "1",
36
+ vocoder_resblock_kernel_sizes=None,
37
+ vocoder_resblock_dilation_sizes=None,
38
+ **kwargs,
39
+ ):
40
+ channels = list(channels or [8, 8, 8, 4, 4])
41
+ if vocab_size is None:
42
+ vocab_size = 1
43
+ for level in channels:
44
+ vocab_size *= int(level)
45
+
46
+ self.window_size = int(window_size)
47
+ self.window_padding = int(window_padding)
48
+ self.hop_length = int(hop_length)
49
+ self.out_channels = int(out_channels)
50
+ self.causal_convs = bool(causal_convs)
51
+ self.causal_pad_mode = str(causal_pad_mode)
52
+ self.encoder_layers = int(encoder_layers)
53
+ self.encoder_dim = int(encoder_dim)
54
+ self.encoder_kernel_size = int(encoder_kernel_size)
55
+ self.decoder_layers = int(decoder_layers)
56
+ self.decoder_dim = int(decoder_dim)
57
+ self.decoder_kernel_size = int(decoder_kernel_size)
58
+ self.quantizer = str(quantizer)
59
+ self.channels = channels
60
+ self.vocab_size = int(vocab_size)
61
+ self.sample_rate = int(sample_rate)
62
+
63
+ self.has_vocoder = bool(has_vocoder)
64
+ self.vocoder_upsample_rates = list(vocoder_upsample_rates or [5, 4, 2, 2])
65
+ self.vocoder_upsample_kernel_sizes = list(vocoder_upsample_kernel_sizes or [10, 8, 4, 4])
66
+ self.vocoder_upsample_initial_channel = int(vocoder_upsample_initial_channel)
67
+ self.vocoder_resblock = str(vocoder_resblock)
68
+ self.vocoder_resblock_kernel_sizes = list(vocoder_resblock_kernel_sizes or [11, 7, 3])
69
+ self.vocoder_resblock_dilation_sizes = [
70
+ list(d) for d in (vocoder_resblock_dilation_sizes or [[1, 3, 5], [1, 3, 5], [1, 3, 5]])
71
+ ]
72
+
73
+ super().__init__(**kwargs)
configure_wavcoch.py ADDED
@@ -0,0 +1,8 @@
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Backward-compatible import shim for older WavCoch repos.
3
+ """
4
+
5
+ from .configuration_wavcoch import WavCochConfig
6
+
7
+
8
+ __all__ = ["WavCochConfig"]
model.safetensors ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:487d0a8c2dba58367919fe898e3a1812bcaf931b5ef5b97792d7c3ac8f4de15a
3
+ size 97726648
modeling_wavcoch.py ADDED
@@ -0,0 +1,583 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ WavCoch model for Hugging Face Transformers.
3
+
4
+ This implementation is self-contained so HF-hosted WavCoch checkpoints do not
5
+ depend on the local auristream package or vector_quantize_pytorch.
6
+ """
7
+
8
+ import math
9
+ import os
10
+ from typing import List, Optional
11
+
12
+ os.environ.setdefault("USE_TORCH_XLA", "0")
13
+
14
+ import torch
15
+ import torch.nn as nn
16
+ import torch.nn.functional as F
17
+ from torch.nn import Conv1d, ConvTranspose1d
18
+ from torch.nn.utils import remove_weight_norm
19
+ try:
20
+ from torch.nn.utils.parametrizations import weight_norm
21
+ except ImportError: # pragma: no cover - older PyTorch compatibility
22
+ from torch.nn.utils import weight_norm
23
+
24
+ from transformers import PreTrainedModel
25
+ try:
26
+ from transformers.tokenization_utils_base import BatchEncoding
27
+ except ImportError: # pragma: no cover - compatibility with older Transformers
28
+ from transformers.tokenization_utils import BatchEncoding
29
+ import transformers.modeling_utils as transformers_modeling_utils
30
+ import transformers.utils.import_utils as transformers_import_utils
31
+
32
+ transformers_import_utils.is_torch_xla_available = lambda *args, **kwargs: False
33
+ transformers_modeling_utils.is_torch_xla_available = lambda *args, **kwargs: False
34
+
35
+ try:
36
+ from .configuration_wavcoch import WavCochConfig
37
+ except ImportError: # pragma: no cover - compatibility with older repos
38
+ from .configure_wavcoch import WavCochConfig
39
+
40
+
41
+ class CausalConv1d(nn.Module):
42
+ """1D causal convolution with left-only padding."""
43
+
44
+ def __init__(
45
+ self,
46
+ in_channels: int,
47
+ out_channels: int,
48
+ kernel_size: int,
49
+ stride: int = 1,
50
+ dilation: int = 1,
51
+ bias: bool = True,
52
+ groups: int = 1,
53
+ pad_mode: str = "repeat",
54
+ constant_value: float = 0.0,
55
+ ):
56
+ super().__init__()
57
+ left_pad = dilation * (kernel_size - 1)
58
+ if pad_mode == "repeat":
59
+ self.pad = nn.ReplicationPad1d((left_pad, 0))
60
+ elif pad_mode == "constant":
61
+ self.pad = nn.ConstantPad1d((left_pad, 0), constant_value)
62
+ else:
63
+ raise ValueError(f"Unsupported pad_mode: {pad_mode}")
64
+ self.conv = nn.Conv1d(
65
+ in_channels,
66
+ out_channels,
67
+ kernel_size=kernel_size,
68
+ stride=stride,
69
+ padding=0,
70
+ dilation=dilation,
71
+ groups=groups,
72
+ bias=bias,
73
+ )
74
+
75
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
76
+ return self.conv(self.pad(x))
77
+
78
+
79
+ class FSQ(nn.Module):
80
+ """Finite Scalar Quantization with the subset of functionality needed for inference."""
81
+
82
+ def __init__(self, levels: List[int], dim: int):
83
+ super().__init__()
84
+ if not levels:
85
+ raise ValueError("FSQ levels must be non-empty")
86
+
87
+ self.levels = [int(level) for level in levels]
88
+ self.codebook_dim = len(self.levels)
89
+ self.dim = int(dim)
90
+
91
+ level_tensor = torch.tensor(self.levels, dtype=torch.int32)
92
+ basis = torch.cumprod(torch.tensor([1] + self.levels[:-1], dtype=torch.int32), dim=0)
93
+ self.register_buffer("_levels", level_tensor, persistent=False)
94
+ self.register_buffer("_basis", basis, persistent=False)
95
+
96
+ if self.dim != self.codebook_dim:
97
+ self.project_in = nn.Linear(self.dim, self.codebook_dim)
98
+ self.project_out = nn.Linear(self.codebook_dim, self.dim)
99
+ else:
100
+ self.project_in = nn.Identity()
101
+ self.project_out = nn.Identity()
102
+
103
+ def _refresh_level_buffers(self, device: Optional[torch.device] = None):
104
+ level_values = [int(level) for level in self.levels]
105
+ if device is None:
106
+ if isinstance(self.project_in, nn.Linear):
107
+ device = self.project_in.weight.device
108
+ elif isinstance(self.project_out, nn.Linear):
109
+ device = self.project_out.weight.device
110
+ else:
111
+ device = self._levels.device
112
+
113
+ self._levels = torch.tensor(level_values, dtype=torch.int32, device=device)
114
+ self._basis = torch.cumprod(
115
+ torch.tensor([1] + level_values[:-1], dtype=torch.int32, device=device),
116
+ dim=0,
117
+ )
118
+
119
+ def bound(self, z: torch.Tensor, eps: float = 1e-3) -> torch.Tensor:
120
+ levels = self._levels.to(dtype=z.dtype, device=z.device)
121
+ half_l = (levels - 1) * (1 + eps) / 2
122
+ offset = torch.where(
123
+ (self._levels % 2).to(device=z.device) == 0,
124
+ torch.tensor(0.5, device=z.device, dtype=z.dtype),
125
+ torch.tensor(0.0, device=z.device, dtype=z.dtype),
126
+ )
127
+ shift = (offset / half_l).atanh()
128
+ return (z + shift).tanh() * half_l - offset
129
+
130
+ def _scale_and_shift(self, zhat_normalized: torch.Tensor) -> torch.Tensor:
131
+ half_width = (self._levels // 2).to(dtype=zhat_normalized.dtype, device=zhat_normalized.device)
132
+ return (zhat_normalized * half_width) + half_width
133
+
134
+ def _scale_and_shift_inverse(self, zhat: torch.Tensor) -> torch.Tensor:
135
+ half_width = (self._levels // 2).to(dtype=zhat.dtype, device=zhat.device)
136
+ return (zhat - half_width) / half_width
137
+
138
+ def quantize_values(self, z: torch.Tensor) -> torch.Tensor:
139
+ self._refresh_level_buffers(device=z.device)
140
+ half_width = (self._levels // 2).to(dtype=z.dtype, device=z.device)
141
+ return self.bound(z).round() / half_width
142
+
143
+ def codes_to_indices(self, zhat: torch.Tensor) -> torch.Tensor:
144
+ self._refresh_level_buffers(device=zhat.device)
145
+ zhat = self._scale_and_shift(zhat)
146
+ basis = self._basis.to(device=zhat.device, dtype=zhat.dtype)
147
+ return (zhat * basis).sum(dim=-1).to(torch.int32)
148
+
149
+ def indices_to_level_indices(self, indices: torch.Tensor) -> torch.Tensor:
150
+ self._refresh_level_buffers(device=indices.device)
151
+ indices = indices.unsqueeze(-1)
152
+ levels = self._levels.to(device=indices.device)
153
+ basis = self._basis.to(device=indices.device)
154
+ return (indices // basis) % levels
155
+
156
+ def indices_to_codes(self, indices: torch.Tensor) -> torch.Tensor:
157
+ self._refresh_level_buffers(device=indices.device)
158
+ level_indices = self.indices_to_level_indices(indices)
159
+ codes = self._scale_and_shift_inverse(level_indices.to(dtype=torch.float32))
160
+ return self.project_out(codes)
161
+
162
+ def forward(self, z: torch.Tensor):
163
+ orig_dtype = z.dtype
164
+ z = self.project_in(z.to(torch.float32))
165
+ q = self.quantize_values(z)
166
+ indices = self.codes_to_indices(q)
167
+ out = self.project_out(q).to(orig_dtype)
168
+ return out, indices.long()
169
+
170
+
171
+ LRELU_SLOPE = 0.1
172
+
173
+
174
+ def get_padding(kernel_size: int, dilation: int = 1) -> int:
175
+ return int((kernel_size * dilation - dilation) / 2)
176
+
177
+
178
+ def init_weights(module, mean: float = 0.0, std: float = 0.01):
179
+ classname = module.__class__.__name__
180
+ if classname.find("Conv") != -1 and hasattr(module, "weight"):
181
+ module.weight.data.normal_(mean, std)
182
+
183
+
184
+ class ResBlock1(nn.Module):
185
+ __constants__ = ["lrelu_slope"]
186
+
187
+ def __init__(self, channels: int, kernel_size: int = 3, dilation=(1, 3, 5)):
188
+ super().__init__()
189
+ self.lrelu_slope = LRELU_SLOPE
190
+
191
+ ch = channels
192
+ ks = kernel_size
193
+ self.convs1 = nn.Sequential(
194
+ weight_norm(Conv1d(ch, ch, ks, 1, get_padding(ks, dilation[0]), dilation[0])),
195
+ weight_norm(Conv1d(ch, ch, ks, 1, get_padding(ks, dilation[1]), dilation[1])),
196
+ weight_norm(Conv1d(ch, ch, ks, 1, get_padding(ks, dilation[2]), dilation[2])),
197
+ )
198
+ self.convs2 = nn.Sequential(
199
+ weight_norm(Conv1d(ch, ch, ks, 1, get_padding(ks, 1))),
200
+ weight_norm(Conv1d(ch, ch, ks, 1, get_padding(ks, 1))),
201
+ weight_norm(Conv1d(ch, ch, ks, 1, get_padding(ks, 1))),
202
+ )
203
+ self.convs1.apply(init_weights)
204
+ self.convs2.apply(init_weights)
205
+
206
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
207
+ for conv1, conv2 in zip(self.convs1, self.convs2):
208
+ xt = F.leaky_relu(x, self.lrelu_slope)
209
+ xt = conv1(xt)
210
+ xt = F.leaky_relu(xt, self.lrelu_slope)
211
+ xt = conv2(xt)
212
+ x = xt + x
213
+ return x
214
+
215
+ def remove_weight_norm(self):
216
+ for layer in self.convs1:
217
+ remove_weight_norm(layer)
218
+ for layer in self.convs2:
219
+ remove_weight_norm(layer)
220
+
221
+
222
+ class ResBlock2(nn.Module):
223
+ __constants__ = ["lrelu_slope"]
224
+
225
+ def __init__(self, channels: int, kernel_size: int = 3, dilation=(1, 3)):
226
+ super().__init__()
227
+ self.lrelu_slope = LRELU_SLOPE
228
+
229
+ ch = channels
230
+ ks = kernel_size
231
+ self.convs = nn.ModuleList(
232
+ [
233
+ weight_norm(Conv1d(ch, ch, ks, 1, get_padding(kernel_size, dilation[0]), dilation[0])),
234
+ weight_norm(Conv1d(ch, ch, ks, 1, get_padding(kernel_size, dilation[1]), dilation[1])),
235
+ ]
236
+ )
237
+ self.convs.apply(init_weights)
238
+
239
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
240
+ for conv in self.convs:
241
+ xt = F.leaky_relu(x, self.lrelu_slope)
242
+ xt = conv(xt)
243
+ x = xt + x
244
+ return x
245
+
246
+ def remove_weight_norm(self):
247
+ for layer in self.convs:
248
+ remove_weight_norm(layer)
249
+
250
+
251
+ class Generator(nn.Module):
252
+ __constants__ = ["lrelu_slope", "num_kernels", "num_upsamples"]
253
+
254
+ def __init__(
255
+ self,
256
+ out_channels: int = 211,
257
+ upsample_rates=None,
258
+ upsample_kernel_sizes=None,
259
+ upsample_initial_channel: int = 512,
260
+ resblock: str = "1",
261
+ resblock_kernel_sizes=None,
262
+ resblock_dilation_sizes=None,
263
+ ):
264
+ super().__init__()
265
+ upsample_rates = list(upsample_rates or [5, 4, 2, 2])
266
+ upsample_kernel_sizes = list(upsample_kernel_sizes or [10, 8, 4, 4])
267
+ resblock_kernel_sizes = list(resblock_kernel_sizes or [11, 7, 3])
268
+ resblock_dilation_sizes = [list(d) for d in (resblock_dilation_sizes or [[1, 3, 5], [1, 3, 5], [1, 3, 5]])]
269
+
270
+ self.num_kernels = len(resblock_kernel_sizes)
271
+ self.num_upsamples = len(upsample_rates)
272
+ self.lrelu_slope = LRELU_SLOPE
273
+
274
+ self.conv_pre = weight_norm(Conv1d(out_channels, upsample_initial_channel, 7, 1, padding=3))
275
+ resblock_cls = ResBlock1 if resblock == "1" else ResBlock2
276
+
277
+ ups = []
278
+ for i, (rate, kernel) in enumerate(zip(upsample_rates, upsample_kernel_sizes)):
279
+ ups.append(
280
+ weight_norm(
281
+ ConvTranspose1d(
282
+ upsample_initial_channel // (2 ** i),
283
+ upsample_initial_channel // (2 ** (i + 1)),
284
+ kernel,
285
+ rate,
286
+ padding=(kernel - rate) // 2,
287
+ )
288
+ )
289
+ )
290
+ self.ups = nn.Sequential(*ups)
291
+
292
+ resblocks = []
293
+ for i in range(len(self.ups)):
294
+ ch = upsample_initial_channel // (2 ** (i + 1))
295
+ resblocks.append(
296
+ nn.Sequential(
297
+ *[
298
+ resblock_cls(ch, kernel, dilation)
299
+ for kernel, dilation in zip(resblock_kernel_sizes, resblock_dilation_sizes)
300
+ ]
301
+ )
302
+ )
303
+ self.resblocks = nn.Sequential(*resblocks)
304
+
305
+ self.conv_post = weight_norm(Conv1d(ch, 1, 17, 1, padding=0))
306
+ self.ups.apply(init_weights)
307
+ self.conv_post.apply(init_weights)
308
+
309
+ def load_state_dict(self, state_dict, strict: bool = True):
310
+ new_state_dict = {}
311
+ for key, value in state_dict.items():
312
+ new_key = key
313
+ if "resblocks" in key:
314
+ parts = key.split(".")
315
+ if len(parts) == 5:
316
+ layer = int(parts[1])
317
+ new_key = f"resblocks.{layer // 3}.{layer % 3}.{'.'.join(parts[2:])}"
318
+ new_state_dict[new_key] = value
319
+
320
+ current_state = self.state_dict()
321
+ for key, value in list(new_state_dict.items()):
322
+ if key not in current_state:
323
+ continue
324
+ len_diff = value.dim() - current_state[key].dim()
325
+ if len_diff == -1:
326
+ new_state_dict[key] = value.unsqueeze(-1)
327
+ elif len_diff == 1:
328
+ new_state_dict[key] = value.squeeze(-1)
329
+
330
+ super().load_state_dict(new_state_dict, strict=strict)
331
+
332
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
333
+ x = self.conv_pre(x.permute(0, 2, 1))
334
+
335
+ for upsample_layer, resblock_group in zip(self.ups, self.resblocks):
336
+ x = F.leaky_relu(x, self.lrelu_slope)
337
+ x = upsample_layer(x)
338
+ xs = 0
339
+ for resblock in resblock_group:
340
+ xs = xs + resblock(x)
341
+ x = xs / self.num_kernels
342
+
343
+ x = F.leaky_relu(x)
344
+ x = self.conv_post(x)
345
+ return torch.tanh(x)
346
+
347
+ def remove_weight_norm(self):
348
+ for layer in self.ups:
349
+ remove_weight_norm(layer)
350
+ for group in self.resblocks:
351
+ for block in group:
352
+ block.remove_weight_norm()
353
+ remove_weight_norm(self.conv_pre)
354
+ remove_weight_norm(self.conv_post)
355
+
356
+
357
+ class WavCoch(PreTrainedModel):
358
+ """Causal waveform-to-cochleagram tokenizer with optional vocoder."""
359
+
360
+ config_class = WavCochConfig
361
+ main_input_name = "wav"
362
+
363
+ def __init__(self, config: WavCochConfig):
364
+ super().__init__(config)
365
+ self.config = config
366
+
367
+ self.N = int(config.window_size)
368
+ self.hop_length = int(config.hop_length)
369
+ self.window_padding = int(getattr(config, "window_padding", self.N - self.hop_length))
370
+ self.causal_convs = bool(getattr(config, "causal_convs", True))
371
+ self.causal_pad_mode = getattr(config, "causal_pad_mode", "repeat")
372
+
373
+ out_bins = self.N // 2 + 1
374
+ self.conv_real_filters = nn.Conv1d(1, out_bins, kernel_size=self.N, stride=self.hop_length)
375
+ self.conv_imag_filters = nn.Conv1d(1, out_bins, kernel_size=self.N, stride=self.hop_length)
376
+ self._initialize_conv_filters()
377
+
378
+ self.encoder = self._build_conv_stack(
379
+ in_channels=out_bins,
380
+ out_channels=config.encoder_dim,
381
+ num_layers=config.encoder_layers,
382
+ kernel_size=config.encoder_kernel_size,
383
+ causal=self.causal_convs,
384
+ )
385
+ self.quantizer = FSQ(levels=list(config.channels), dim=config.encoder_dim)
386
+ self.decoder = self._build_conv_stack(
387
+ in_channels=config.decoder_dim,
388
+ out_channels=config.out_channels,
389
+ num_layers=config.decoder_layers,
390
+ kernel_size=config.decoder_kernel_size,
391
+ causal=self.causal_convs,
392
+ )
393
+
394
+ self.has_vocoder = bool(getattr(config, "has_vocoder", False))
395
+ if self.has_vocoder:
396
+ if int(config.out_channels) != 211:
397
+ raise ValueError("Bundled vocoder currently expects 211 cochleagram channels")
398
+ self.vocoder = Generator(
399
+ out_channels=config.out_channels,
400
+ upsample_rates=config.vocoder_upsample_rates,
401
+ upsample_kernel_sizes=config.vocoder_upsample_kernel_sizes,
402
+ upsample_initial_channel=config.vocoder_upsample_initial_channel,
403
+ resblock=config.vocoder_resblock,
404
+ resblock_kernel_sizes=config.vocoder_resblock_kernel_sizes,
405
+ resblock_dilation_sizes=config.vocoder_resblock_dilation_sizes,
406
+ )
407
+ else:
408
+ self.vocoder = None
409
+
410
+ self._vocab_size = int(config.vocab_size)
411
+ self.post_init()
412
+
413
+ def _build_conv_stack(
414
+ self,
415
+ in_channels: int,
416
+ out_channels: int,
417
+ num_layers: int,
418
+ kernel_size: int,
419
+ causal: bool,
420
+ ) -> nn.Sequential:
421
+ layers = []
422
+ for layer_idx in range(int(num_layers)):
423
+ input_channels = in_channels if layer_idx == 0 else out_channels
424
+ if causal:
425
+ conv = CausalConv1d(
426
+ input_channels,
427
+ out_channels,
428
+ kernel_size=kernel_size,
429
+ stride=1,
430
+ pad_mode=self.causal_pad_mode,
431
+ )
432
+ else:
433
+ conv = nn.Conv1d(
434
+ input_channels,
435
+ out_channels,
436
+ kernel_size=kernel_size,
437
+ stride=1,
438
+ padding=kernel_size // 2,
439
+ )
440
+ layers.extend([conv, nn.ReLU()])
441
+ return nn.Sequential(*layers)
442
+
443
+ def _compute_twiddle_factors(self):
444
+ n = torch.arange(self.N, dtype=torch.float32).unsqueeze(1)
445
+ k = torch.arange(self.N, dtype=torch.float32).unsqueeze(0)
446
+ angles = -2.0 * math.pi * n * k / float(self.N)
447
+ return torch.cos(angles), torch.sin(angles)
448
+
449
+ def _initialize_conv_filters(self):
450
+ with torch.no_grad():
451
+ cos_matrix, sin_matrix = self._compute_twiddle_factors()
452
+ cos_matrix = cos_matrix[: self.N // 2 + 1, :]
453
+ sin_matrix = sin_matrix[: self.N // 2 + 1, :]
454
+ window = torch.hann_window(self.N, periodic=True).view(1, 1, -1)
455
+ real_weights = (cos_matrix.unsqueeze(1) * window).to(dtype=self.conv_real_filters.weight.dtype)
456
+ imag_weights = (sin_matrix.unsqueeze(1) * window).to(dtype=self.conv_imag_filters.weight.dtype)
457
+ self.conv_real_filters.weight.copy_(real_weights)
458
+ self.conv_imag_filters.weight.copy_(imag_weights)
459
+
460
+ for param in self.conv_real_filters.parameters():
461
+ param.requires_grad_(False)
462
+ for param in self.conv_imag_filters.parameters():
463
+ param.requires_grad_(False)
464
+
465
+ def _normalize_sample_rate(self, sample_rate: Optional[int], sampling_rate: Optional[int]) -> int:
466
+ if sample_rate is not None and sampling_rate is not None and sample_rate != sampling_rate:
467
+ raise ValueError(f"sample_rate ({sample_rate}) and sampling_rate ({sampling_rate}) conflict")
468
+ resolved = int(sample_rate or sampling_rate or self.config.sample_rate)
469
+ if resolved != int(self.config.sample_rate):
470
+ raise ValueError(
471
+ f"WavCoch expects {self.config.sample_rate} Hz audio, but received {resolved} Hz"
472
+ )
473
+ return resolved
474
+
475
+ def _prepare_wav_batch(self, wav) -> torch.Tensor:
476
+ if isinstance(wav, list):
477
+ wav = [item if isinstance(item, torch.Tensor) else torch.tensor(item) for item in wav]
478
+ normalized = []
479
+ for item in wav:
480
+ if item.ndim == 1:
481
+ normalized.append(item)
482
+ elif item.ndim == 2 and 1 in item.shape:
483
+ normalized.append(item.reshape(-1))
484
+ else:
485
+ raise ValueError(f"Unexpected list element shape {tuple(item.shape)}")
486
+ wav = torch.nn.utils.rnn.pad_sequence(normalized, batch_first=True).unsqueeze(1)
487
+ elif isinstance(wav, torch.Tensor):
488
+ if wav.ndim == 1:
489
+ wav = wav.unsqueeze(0).unsqueeze(0)
490
+ elif wav.ndim == 2:
491
+ wav = wav.unsqueeze(1)
492
+ elif wav.ndim != 3:
493
+ raise ValueError(f"Unexpected tensor shape {tuple(wav.shape)}, expected 1D, 2D or 3D")
494
+ else:
495
+ raise TypeError(f"Unsupported input type: {type(wav)}")
496
+
497
+ return wav.to(dtype=torch.float32)
498
+
499
+ @property
500
+ def vocab_size(self) -> int:
501
+ return self._vocab_size
502
+
503
+ def forward(
504
+ self,
505
+ wav: torch.Tensor,
506
+ coch: Optional[torch.Tensor] = None,
507
+ return_tensors: str = "pt",
508
+ sample_rate: Optional[int] = None,
509
+ sampling_rate: Optional[int] = None,
510
+ pad: bool = True,
511
+ ):
512
+ del return_tensors # unused, kept for tokenizer-like API compatibility
513
+ self._normalize_sample_rate(sample_rate, sampling_rate)
514
+
515
+ wav = self._prepare_wav_batch(wav)
516
+ if coch is None:
517
+ codes = self.quantize(wav, pad=pad)
518
+ return BatchEncoding({"input_values": codes, "input_ids": codes})
519
+
520
+ if pad:
521
+ wav = F.pad(wav, (self.window_padding, 0), mode="constant", value=0.0)
522
+ with torch.no_grad():
523
+ real_part = self.conv_real_filters(wav)
524
+ imag_part = self.conv_imag_filters(wav)
525
+
526
+ x = real_part + imag_part
527
+ x = self.encoder(x).permute(0, 2, 1)
528
+ quantized, _ = self.quantizer(x)
529
+ pred_coch = self.decoder(quantized.permute(0, 2, 1)).permute(0, 2, 1)
530
+ loss = F.l1_loss(pred_coch, coch)
531
+ return pred_coch, loss, None
532
+
533
+ @torch.no_grad()
534
+ def quantize(self, wav: torch.Tensor, pad: bool = True) -> torch.Tensor:
535
+ wav = self._prepare_wav_batch(wav)
536
+ if pad:
537
+ wav = F.pad(wav, (self.window_padding, 0), mode="constant", value=0.0)
538
+
539
+ real_part = self.conv_real_filters(wav)
540
+ imag_part = self.conv_imag_filters(wav)
541
+ x = real_part + imag_part
542
+ x = self.encoder(x).permute(0, 2, 1)
543
+ _, indices = self.quantizer(x)
544
+ return indices.long()
545
+
546
+ @torch.no_grad()
547
+ def decode(self, indices: torch.Tensor) -> torch.Tensor:
548
+ if indices.ndim == 1:
549
+ indices = indices.unsqueeze(0)
550
+ emb = self.quantizer.indices_to_codes(indices.long())
551
+ return self.decoder(emb.permute(0, 2, 1)).permute(0, 2, 1)
552
+
553
+ @torch.no_grad()
554
+ def wav2coch(self, wav: torch.Tensor, pad: bool = True) -> torch.Tensor:
555
+ wav = self._prepare_wav_batch(wav)
556
+ if pad:
557
+ wav = F.pad(wav, (self.window_padding, 0), mode="constant", value=0.0)
558
+
559
+ real_part = self.conv_real_filters(wav)
560
+ imag_part = self.conv_imag_filters(wav)
561
+ x = real_part + imag_part
562
+ x = self.encoder(x).permute(0, 2, 1)
563
+ quantized, _ = self.quantizer(x)
564
+ return self.decoder(quantized.permute(0, 2, 1)).permute(0, 2, 1)
565
+
566
+ @torch.no_grad()
567
+ def vocode(self, coch: torch.Tensor) -> torch.Tensor:
568
+ if self.vocoder is None:
569
+ raise ValueError("This WavCoch checkpoint does not include a bundled vocoder")
570
+
571
+ if coch.ndim == 2:
572
+ coch = coch.unsqueeze(0)
573
+ elif coch.ndim != 3:
574
+ raise ValueError(f"Unexpected cochleagram shape {tuple(coch.shape)}")
575
+
576
+ if coch.shape[-1] != self.config.out_channels and coch.shape[1] == self.config.out_channels:
577
+ coch = coch.transpose(1, 2)
578
+
579
+ return self.vocoder(coch)
580
+
581
+ @torch.no_grad()
582
+ def decode_audio(self, indices: torch.Tensor) -> torch.Tensor:
583
+ return self.vocode(self.decode(indices))