Daniel Rothmann commited on
Commit
fad9fad
·
1 Parent(s): a2c97d7

WIP audio decoder

Browse files
KanadeDecoder.mlpackage/Data/com.apple.CoreML/model.mlmodel ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:1a72aeec4e105b9d593a721317e9fce1ca7783e21293e82f898d810c6bf1c1fe
3
+ size 178115
KanadeDecoder.mlpackage/Data/com.apple.CoreML/weights/weight.bin ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:1d2922387d7a2ef3f41db7a069ca9be2d313250137841ffbe8ab7b912bddd96a
3
+ size 364866112
KanadeDecoder.mlpackage/Manifest.json ADDED
@@ -0,0 +1,18 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "fileFormatVersion": "1.0.0",
3
+ "itemInfoEntries": {
4
+ "3D07005F-6244-406D-9DD3-91CF5F26CCAE": {
5
+ "author": "com.apple.CoreML",
6
+ "description": "CoreML Model Specification",
7
+ "name": "model.mlmodel",
8
+ "path": "com.apple.CoreML/model.mlmodel"
9
+ },
10
+ "FD090485-11AF-465F-8569-E149E7086201": {
11
+ "author": "com.apple.CoreML",
12
+ "description": "CoreML Model Weights",
13
+ "name": "weights",
14
+ "path": "com.apple.CoreML/weights"
15
+ }
16
+ },
17
+ "rootModelIdentifier": "3D07005F-6244-406D-9DD3-91CF5F26CCAE"
18
+ }
PlaprePicoDecode.mlpackage/Data/com.apple.CoreML/model.mlmodel CHANGED
@@ -1,3 +1,3 @@
1
  version https://git-lfs.github.com/spec/v1
2
- oid sha256:3204fd09a746e1814b5a6803aaf73fc910d7723710ea6a3eeac7dd5970a77341
3
- size 579193
 
1
  version https://git-lfs.github.com/spec/v1
2
+ oid sha256:0cbc60dac941edc9fbe212c52c4a96677e6ca547d575bdf461d19695e35de86a
3
+ size 579443
PlaprePicoDecode.mlpackage/Manifest.json CHANGED
@@ -1,18 +1,18 @@
1
  {
2
  "fileFormatVersion": "1.0.0",
3
  "itemInfoEntries": {
4
- "D9ED4ABB-3CF3-496D-8858-06948CEBC48F": {
5
- "author": "com.apple.CoreML",
6
- "description": "CoreML Model Weights",
7
- "name": "weights",
8
- "path": "com.apple.CoreML/weights"
9
- },
10
- "E087C383-13E2-4E2C-B87A-990925041088": {
11
  "author": "com.apple.CoreML",
12
  "description": "CoreML Model Specification",
13
  "name": "model.mlmodel",
14
  "path": "com.apple.CoreML/model.mlmodel"
 
 
 
 
 
 
15
  }
16
  },
17
- "rootModelIdentifier": "E087C383-13E2-4E2C-B87A-990925041088"
18
  }
 
1
  {
2
  "fileFormatVersion": "1.0.0",
3
  "itemInfoEntries": {
4
+ "668A6F00-934D-4D44-9C27-7881268451D9": {
 
 
 
 
 
 
5
  "author": "com.apple.CoreML",
6
  "description": "CoreML Model Specification",
7
  "name": "model.mlmodel",
8
  "path": "com.apple.CoreML/model.mlmodel"
9
+ },
10
+ "BA90DDE9-E076-4B65-A23D-91E3BFAD284D": {
11
+ "author": "com.apple.CoreML",
12
+ "description": "CoreML Model Weights",
13
+ "name": "weights",
14
+ "path": "com.apple.CoreML/weights"
15
  }
16
  },
17
+ "rootModelIdentifier": "668A6F00-934D-4D44-9C27-7881268451D9"
18
  }
Vocoder.mlpackage/Data/com.apple.CoreML/model.mlmodel ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:88536c7f82ce5963c40ab46ab192452ddd1af731ecd4e08a40ea827fc544fbb6
3
+ size 1298694
Vocoder.mlpackage/Data/com.apple.CoreML/weights/weight.bin ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:2f1b0ee1106eb66c74b00639159b27c910123caa778ffb2b7b4ece2eb88a180c
3
+ size 85215120
Vocoder.mlpackage/Manifest.json ADDED
@@ -0,0 +1,18 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "fileFormatVersion": "1.0.0",
3
+ "itemInfoEntries": {
4
+ "1865D6B1-DF08-4C5C-8B25-53058EF04D75": {
5
+ "author": "com.apple.CoreML",
6
+ "description": "CoreML Model Specification",
7
+ "name": "model.mlmodel",
8
+ "path": "com.apple.CoreML/model.mlmodel"
9
+ },
10
+ "6D12622B-E675-4537-9163-574EA27CA0C1": {
11
+ "author": "com.apple.CoreML",
12
+ "description": "CoreML Model Weights",
13
+ "name": "weights",
14
+ "path": "com.apple.CoreML/weights"
15
+ }
16
+ },
17
+ "rootModelIdentifier": "1865D6B1-DF08-4C5C-8B25-53058EF04D75"
18
+ }
scripts/convert.py CHANGED
@@ -203,16 +203,17 @@ def convert_decode(model: PlaprePicoDecode, output_dir: Path):
203
  causal_mask = torch.full((1, 1, 1, MAX_CONTEXT), float("-inf"), dtype=torch.float16)
204
  causal_mask[0, 0, 0, :PREFILL_SEQ_LEN] = 0.0
205
 
206
- # Pre-sliced RoPE for a single position (caller computes these)
207
  cos = torch.zeros(1, 1, 1, HEAD_DIM, dtype=torch.float16)
208
  sin = torch.zeros(1, 1, 1, HEAD_DIM, dtype=torch.float16)
209
 
210
- # One-hot position mask for cache update (caller builds this)
211
  update_mask = torch.zeros(1, 1, MAX_CONTEXT, 1, dtype=torch.float16)
212
  update_mask[0, 0, PREFILL_SEQ_LEN, 0] = 1.0
213
 
 
 
 
214
  with torch.no_grad():
215
- traced = torch.jit.trace(model, (input_ids, causal_mask, cos, sin, update_mask))
216
 
217
  print("Converting decode to CoreML...")
218
  mlmodel = ct.convert(
@@ -231,6 +232,11 @@ def convert_decode(model: PlaprePicoDecode, output_dir: Path):
231
  shape=(1, 1, MAX_CONTEXT, 1),
232
  dtype=np.float16,
233
  ),
 
 
 
 
 
234
  ],
235
  outputs=[ct.TensorType(name="logits", dtype=np.float16)],
236
  states=build_kv_cache_states(),
 
203
  causal_mask = torch.full((1, 1, 1, MAX_CONTEXT), float("-inf"), dtype=torch.float16)
204
  causal_mask[0, 0, 0, :PREFILL_SEQ_LEN] = 0.0
205
 
 
206
  cos = torch.zeros(1, 1, 1, HEAD_DIM, dtype=torch.float16)
207
  sin = torch.zeros(1, 1, 1, HEAD_DIM, dtype=torch.float16)
208
 
 
209
  update_mask = torch.zeros(1, 1, MAX_CONTEXT, 1, dtype=torch.float16)
210
  update_mask[0, 0, PREFILL_SEQ_LEN, 0] = 1.0
211
 
212
+ # Pre-projected speaker hidden: (1, 1, HIDDEN_SIZE) — zeros for non-speaker steps
213
+ speaker_hidden = torch.zeros(1, 1, HIDDEN_SIZE, dtype=torch.float16)
214
+
215
  with torch.no_grad():
216
+ traced = torch.jit.trace(model, (input_ids, causal_mask, cos, sin, update_mask, speaker_hidden))
217
 
218
  print("Converting decode to CoreML...")
219
  mlmodel = ct.convert(
 
232
  shape=(1, 1, MAX_CONTEXT, 1),
233
  dtype=np.float16,
234
  ),
235
+ ct.TensorType(
236
+ name="speaker_hidden",
237
+ shape=(1, 1, HIDDEN_SIZE),
238
+ dtype=np.float16,
239
+ ),
240
  ],
241
  outputs=[ct.TensorType(name="logits", dtype=np.float16)],
242
  states=build_kv_cache_states(),
scripts/convert_kanade.py ADDED
@@ -0,0 +1,711 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python3
2
+ """
3
+ Convert Kanade decoder and HiFT vocoder to CoreML.
4
+
5
+ These are non-autoregressive models (single forward pass), so conversion
6
+ is simpler than the LLM — no KV cache or StateType needed.
7
+
8
+ Two models are produced:
9
+ - KanadeDecoder.mlpackage: audio token indices + speaker embedding → mel spectrogram
10
+ - HiFTVocoder.mlpackage: mel spectrogram → PCM waveform
11
+
12
+ Usage:
13
+ python scripts/convert_kanade.py [--output-dir PATH] [--num-tokens 100]
14
+ """
15
+
16
+ import argparse
17
+ from pathlib import Path
18
+
19
+ import numpy as np
20
+ import torch
21
+ import torch.nn as nn
22
+ import torch.nn.functional as F
23
+ import coremltools as ct
24
+ from kanade_tokenizer import KanadeModel, load_vocoder
25
+ import kanade_tokenizer.module.transformer as kanade_transformer
26
+
27
+
28
+ # ── Monkey-patch Kanade's complex RoPE with real-valued version ───────────
29
+
30
+ def _apply_rotary_emb_real(x: torch.Tensor, freqs_cis: torch.Tensor) -> torch.Tensor:
31
+ """Real-valued RoPE replacement for Kanade's complex-number version.
32
+ Converts complex freqs_cis to cos/sin and applies split-half rotation.
33
+ """
34
+ # freqs_cis is complex: (seq_len, head_dim/2)
35
+ cos = freqs_cis.real # (seq_len, head_dim/2)
36
+ sin = freqs_cis.imag
37
+ # Broadcast to match x shape: (bsz, seq_len, n_heads, head_dim)
38
+ # x has head_dim, cos/sin have head_dim/2 — need to double them
39
+ cos = torch.cat([cos, cos], dim=-1) # (seq_len, head_dim)
40
+ sin = torch.cat([sin, sin], dim=-1)
41
+ # Reshape for broadcast: (1, seq_len, 1, head_dim)
42
+ cos = cos.unsqueeze(0).unsqueeze(2)
43
+ sin = sin.unsqueeze(0).unsqueeze(2)
44
+ # Split-half rotation
45
+ half = x.shape[-1] // 2
46
+ x1 = x[..., :half]
47
+ x2 = x[..., half:]
48
+ rotated = torch.cat((-x2, x1), dim=-1)
49
+ return (x * cos + rotated * sin).type_as(x)
50
+
51
+
52
+ def _apply_rotary_emb_precomputed(x: torch.Tensor, freqs_cos_sin: torch.Tensor) -> torch.Tensor:
53
+ """Real-valued RoPE using precomputed cos/sin stored as (seq_len, head_dim).
54
+ head_dim is always 64, hardcoded to avoid dynamic size ops.
55
+ """
56
+ cos = freqs_cos_sin[..., :32]
57
+ sin = freqs_cos_sin[..., 32:]
58
+ cos = torch.cat([cos, cos], dim=-1)
59
+ sin = torch.cat([sin, sin], dim=-1)
60
+ cos = cos.unsqueeze(0).unsqueeze(2)
61
+ sin = sin.unsqueeze(0).unsqueeze(2)
62
+ x1 = x[..., :32]
63
+ x2 = x[..., 32:]
64
+ rotated = torch.cat((-x2, x1), dim=-1)
65
+ return (x * cos + rotated * sin).type_as(x)
66
+
67
+
68
+ def _patched_attention_forward_v2(self, x, freqs_cis, mask, return_kv=False):
69
+ """Attention forward with real-valued RoPE and explicit matmul."""
70
+ bsz, seqlen, _ = x.shape
71
+ xq, xk, xv = self.wq(x), self.wk(x), self.wv(x)
72
+
73
+ xq = xq.view(bsz, seqlen, self.n_heads, self.head_dim)
74
+ xk = xk.view(bsz, seqlen, self.n_heads, self.head_dim)
75
+ xv = xv.view(bsz, seqlen, self.n_heads, self.head_dim)
76
+
77
+ if freqs_cis is not None:
78
+ xq = _apply_rotary_emb_precomputed(xq, freqs_cis[:seqlen])
79
+ xk = _apply_rotary_emb_precomputed(xk, freqs_cis[:seqlen])
80
+
81
+ xq = xq.transpose(1, 2)
82
+ xk = xk.transpose(1, 2)
83
+ xv = xv.transpose(1, 2)
84
+
85
+ attn_weights = torch.matmul(xq, xk.transpose(2, 3)) * self.scale
86
+ if mask is not None:
87
+ attn_weights = attn_weights + mask
88
+ if self.causal:
89
+ causal_mask = torch.triu(
90
+ torch.full((seqlen, seqlen), float("-inf"), device=x.device), diagonal=1
91
+ )
92
+ attn_weights = attn_weights + causal_mask
93
+ attn_weights = F.softmax(attn_weights, dim=-1, dtype=torch.float32).to(xq.dtype)
94
+ output = torch.matmul(attn_weights, xv)
95
+
96
+ # 12 heads * 64 head_dim = 768
97
+ output = output.transpose(1, 2).contiguous().reshape(bsz, seqlen, 768)
98
+ output = self.wo(output)
99
+
100
+ if return_kv:
101
+ return output, (xk, xv)
102
+ return output
103
+
104
+
105
+ def _convert_freqs_cis_to_real(transformer_module):
106
+ """Replace complex freqs_cis buffer with real-valued cos/sin concatenation."""
107
+ if hasattr(transformer_module, 'freqs_cis') and transformer_module.freqs_cis is not None:
108
+ fc = transformer_module.freqs_cis # (max_len, head_dim/2) complex
109
+ cos = fc.real.float() # (max_len, head_dim/2)
110
+ sin = fc.imag.float()
111
+ real_freqs = torch.cat([cos, sin], dim=-1) # (max_len, head_dim)
112
+ # Replace the buffer
113
+ del transformer_module.freqs_cis
114
+ transformer_module.register_buffer('freqs_cis', real_freqs)
115
+
116
+
117
+ def patch_kanade_for_coreml(kanade: KanadeModel):
118
+ """Apply monkey-patches to make Kanade traceable by coremltools."""
119
+ kanade_transformer.Attention.forward = _patched_attention_forward_v2
120
+ # Convert complex freqs_cis to real in all transformers
121
+ for name, module in kanade.named_modules():
122
+ if isinstance(module, kanade_transformer.Transformer):
123
+ _convert_freqs_cis_to_real(module)
124
+
125
+
126
+ class KanadeDecoderWrapper(nn.Module):
127
+ """Wraps Kanade's decode pipeline for tracing.
128
+
129
+ Pipeline: token indices → quantizer decode → mel_prenet → upsample →
130
+ mel_decoder (conditioned on speaker) → mel_postnet → mel
131
+ """
132
+
133
+ def __init__(self, kanade: KanadeModel, num_tokens: int):
134
+ super().__init__()
135
+ self.local_quantizer = kanade.local_quantizer
136
+ self.mel_prenet = kanade.mel_prenet
137
+ self.mel_conv_upsample = kanade.mel_conv_upsample
138
+ self.mel_decoder = kanade.mel_decoder
139
+ self.mel_postnet = kanade.mel_postnet
140
+ self.num_tokens = num_tokens
141
+ # Precompute mel_length for this token count
142
+ self.mel_length = kanade._calculate_target_mel_length(
143
+ kanade._calculate_original_audio_length(num_tokens)
144
+ )
145
+
146
+ def forward(
147
+ self,
148
+ token_indices: torch.Tensor,
149
+ speaker_embedding: torch.Tensor,
150
+ ) -> torch.Tensor:
151
+ """
152
+ Args:
153
+ token_indices: (num_tokens,) int32 — Kanade codebook indices (0-12799)
154
+ speaker_embedding: (1, 128) float32 — speaker embedding
155
+
156
+ Returns:
157
+ mel: (1, 80, mel_length) float32
158
+ """
159
+ # Quantizer decode: indices → content embedding
160
+ content_emb = self.local_quantizer.decode(token_indices) # (num_tokens, 768)
161
+ content_emb = content_emb.unsqueeze(0) # (1, num_tokens, 768)
162
+
163
+ # Mel prenet (transformer)
164
+ local_latent = self.mel_prenet(content_emb)
165
+
166
+ # Upsample to mel length
167
+ if self.mel_conv_upsample is not None:
168
+ local_latent = self.mel_conv_upsample(
169
+ local_latent.transpose(1, 2)
170
+ ).transpose(1, 2)
171
+ local_latent = F.interpolate(
172
+ local_latent.transpose(1, 2), size=self.mel_length, mode="nearest"
173
+ ).transpose(1, 2)
174
+
175
+ # Mel decoder (conditioned on speaker)
176
+ mel = self.mel_decoder(local_latent, condition=speaker_embedding.unsqueeze(1))
177
+ mel = mel.transpose(1, 2) # (1, 80, mel_length)
178
+
179
+ # Postnet
180
+ mel = self.mel_postnet(mel)
181
+ return mel
182
+
183
+
184
+ class FullVocoderWrapper(nn.Module):
185
+ """Complete mel → waveform pipeline: F0 prediction + source gen + HiFT decode + iSTFT.
186
+
187
+ Noise is replaced with zeros for deterministic tracing.
188
+ """
189
+
190
+ def __init__(self, vocoder, num_stft_frames: int):
191
+ super().__init__()
192
+ self.vocoder = vocoder
193
+ self.num_stft_frames = num_stft_frames
194
+ n_fft = vocoder.istft_n_fft # 16
195
+ hop_len = vocoder.istft_hop_len # 4
196
+
197
+ # iDFT basis
198
+ n = torch.arange(n_fft, dtype=torch.float32)
199
+ k = torch.arange(n_fft, dtype=torch.float32)
200
+ angles = 2.0 * torch.pi * n.unsqueeze(1) * k.unsqueeze(0) / n_fft
201
+ self.register_buffer("idft_cos", torch.cos(angles) / n_fft)
202
+ self.register_buffer("idft_sin", torch.sin(angles) / n_fft)
203
+ self.register_buffer("window", vocoder.stft_window.clone())
204
+
205
+ # Source generation constants
206
+ self.sampling_rate = vocoder.m_source.l_sin_gen.sampling_rate
207
+ self.harmonic_num = vocoder.m_source.l_sin_gen.harmonic_num # 8
208
+ self.sine_amp = vocoder.m_source.l_sin_gen.sine_amp # 0.1
209
+ self.upsample_scale = vocoder.m_source.l_sin_gen.upsample_scale # 480
210
+
211
+ # Harmonic multipliers: [1, 2, ..., 9]
212
+ self.register_buffer(
213
+ "harmonic_muls",
214
+ torch.arange(1, self.harmonic_num + 2, dtype=torch.float32),
215
+ )
216
+
217
+ # l_linear and l_tanh from m_source
218
+ self.source_linear = vocoder.m_source.l_linear
219
+ self.source_tanh = vocoder.m_source.l_tanh
220
+
221
+ self.n_fft = n_fft
222
+ self.hop_len = hop_len
223
+ self.n_fft_half = n_fft // 2 + 1
224
+
225
+ def _generate_source(self, f0: torch.Tensor) -> torch.Tensor:
226
+ """f0: (1, mel_length) → source_stft: (1, 18, stft_frames)"""
227
+ # Upsample f0: (1, mel_length) → (1, 1, mel_length) → nearest → (1, 1, audio_length)
228
+ f0_up = F.interpolate(
229
+ f0.unsqueeze(1), scale_factor=float(self.upsample_scale), mode="nearest"
230
+ ).squeeze(1) # (1, audio_length)
231
+
232
+ # Generate harmonics: f0 * [1..9]
233
+ # f0_up: (1, L) → (1, L, 1) * (9,) → (1, L, 9)
234
+ fn = f0_up.unsqueeze(-1) * self.harmonic_muls.unsqueeze(0).unsqueeze(0)
235
+
236
+ # Phase accumulation: cumsum(f/sr) * 2pi
237
+ rad = (fn / self.sampling_rate) # instantaneous frequency in cycles per sample
238
+ phase = torch.cumsum(rad, dim=1) * 2.0 * torch.pi # (1, L, 9)
239
+
240
+ # Sine waves
241
+ sines = torch.sin(phase) * self.sine_amp # (1, L, 9)
242
+
243
+ # UV mask (voiced/unvoiced)
244
+ uv = (f0_up > 0).float().unsqueeze(-1) # (1, L, 1)
245
+
246
+ # Apply UV (no noise — zeros instead of randn for tracing)
247
+ sines = sines * uv # (1, L, 9)
248
+
249
+ # l_linear + tanh: (1, L, 9) → linear → (1, L, 1) → tanh
250
+ source = self.source_tanh(self.source_linear(sines)) # (1, L, 1)
251
+ source = source.squeeze(-1) # (1, L)
252
+
253
+ # Manual STFT (torch.stft/unfold not CoreML-compatible)
254
+ # n_fft=16, hop=4. With center padding, we get num_stft_frames frames.
255
+ # Pad source: reflect pad n_fft//2 on each side
256
+ padded = F.pad(source, (self.n_fft // 2, self.n_fft // 2), mode="reflect")
257
+ # padded: (1, L + n_fft) where L = audio_length
258
+
259
+ # Extract overlapping frames using conv1d with identity kernel
260
+ # This replaces unfold: conv1d with (n_fft, 1, n_fft) identity kernel, stride=hop
261
+ # Equivalent to: frames[i] = padded[i*hop : i*hop + n_fft]
262
+ eye_kernel = torch.eye(self.n_fft, dtype=source.dtype, device=source.device).unsqueeze(1)
263
+ # padded: (1, L+16) → (1, 1, L+16) for conv1d
264
+ frames = F.conv1d(padded.unsqueeze(1), eye_kernel, stride=self.hop_len)
265
+ # frames: (1, 16, num_frames)
266
+ frames = frames * self.window.unsqueeze(0).unsqueeze(-1) # window each frame
267
+ # Transpose to (1, num_frames, 16) for matmul
268
+ frames = frames.transpose(1, 2)
269
+
270
+ # DFT via matmul
271
+ dft_cos = self.idft_cos[:self.n_fft_half, :] * self.n_fft # undo 1/N normalization
272
+ dft_sin = self.idft_sin[:self.n_fft_half, :] * self.n_fft
273
+ s_real = torch.matmul(frames, dft_cos.T) # (1, NF, 9)
274
+ s_imag = -torch.matmul(frames, dft_sin.T) # (1, NF, 9)
275
+ source_stft = torch.cat([s_real.transpose(1, 2), s_imag.transpose(1, 2)], dim=1)
276
+ return source_stft
277
+
278
+ def _istft_overlap_add(self, x: torch.Tensor) -> torch.Tensor:
279
+ """x: (1, 18, num_frames) conv_post output → waveform (1, samples)"""
280
+ magnitude = torch.exp(x[:, :self.n_fft_half, :])
281
+ phase = torch.sin(x[:, self.n_fft_half:, :])
282
+
283
+ real_half = magnitude * torch.cos(phase)
284
+ imag_half = magnitude * torch.sin(phase)
285
+
286
+ real_mirror = torch.flip(real_half[:, 1:self.n_fft_half - 1, :], dims=[1])
287
+ imag_mirror = -torch.flip(imag_half[:, 1:self.n_fft_half - 1, :], dims=[1])
288
+ real_full = torch.cat([real_half, real_mirror], dim=1)
289
+ imag_full = torch.cat([imag_half, imag_mirror], dim=1)
290
+
291
+ real_t = real_full.transpose(1, 2)
292
+ imag_t = imag_full.transpose(1, 2)
293
+ segments = torch.matmul(real_t, self.idft_cos.T) - torch.matmul(imag_t, self.idft_sin.T)
294
+
295
+ NF = self.num_stft_frames
296
+ segments = segments * self.window.unsqueeze(0).unsqueeze(0)
297
+ seg = segments.squeeze(0)
298
+ seg_chunks = seg.reshape(NF, 4, 4)
299
+
300
+ b0 = seg_chunks[:, 0, :].reshape(-1)
301
+ b1 = seg_chunks[:, 1, :].reshape(-1)
302
+ b2 = seg_chunks[:, 2, :].reshape(-1)
303
+ b3 = seg_chunks[:, 3, :].reshape(-1)
304
+
305
+ F4 = NF * 4
306
+ padded_samples = NF * 4 + 12
307
+ output = torch.zeros(padded_samples)
308
+ output[0:F4] = output[0:F4] + b0
309
+ output[4:F4 + 4] = output[4:F4 + 4] + b1
310
+ output[8:F4 + 8] = output[8:F4 + 8] + b2
311
+ output[12:F4 + 12] = output[12:F4 + 12] + b3
312
+
313
+ win_sq = self.window * self.window
314
+ win_chunks = win_sq.reshape(4, 4)
315
+ w0 = win_chunks[0].repeat(NF)
316
+ w1 = win_chunks[1].repeat(NF)
317
+ w2 = win_chunks[2].repeat(NF)
318
+ w3 = win_chunks[3].repeat(NF)
319
+
320
+ wnorm = torch.zeros(padded_samples)
321
+ wnorm[0:F4] = wnorm[0:F4] + w0
322
+ wnorm[4:F4 + 4] = wnorm[4:F4 + 4] + w1
323
+ wnorm[8:F4 + 8] = wnorm[8:F4 + 8] + w2
324
+ wnorm[12:F4 + 12] = wnorm[12:F4 + 12] + w3
325
+
326
+ output = output / (wnorm + 1e-8)
327
+ pad = 8
328
+ trimmed_len = (NF - 1) * 4
329
+ output = output[pad:pad + trimmed_len]
330
+ output = torch.clamp(output, -0.99, 0.99)
331
+ return output.unsqueeze(0)
332
+
333
+ def forward(self, mel: torch.Tensor) -> torch.Tensor:
334
+ """mel: (1, 80, T) → waveform: (1, samples)"""
335
+ # F0 prediction
336
+ f0 = self.vocoder.f0_predictor(mel) # (1, T)
337
+
338
+ # Source generation
339
+ source_stft = self._generate_source(f0)
340
+
341
+ # HiFT decode
342
+ x = self.vocoder.conv_pre(mel)
343
+ for i in range(self.vocoder.num_upsamples):
344
+ x = F.leaky_relu(x, self.vocoder.lrelu_slope)
345
+ x = self.vocoder.ups[i](x)
346
+ if i == self.vocoder.num_upsamples - 1:
347
+ x = self.vocoder.reflection_pad(x)
348
+ si = self.vocoder.source_downs[i](source_stft)
349
+ si = self.vocoder.source_resblocks[i](si)
350
+ x = x + si
351
+ xs = None
352
+ for j in range(self.vocoder.num_kernels):
353
+ if xs is None:
354
+ xs = self.vocoder.resblocks[i * self.vocoder.num_kernels + j](x)
355
+ else:
356
+ xs += self.vocoder.resblocks[i * self.vocoder.num_kernels + j](x)
357
+ x = xs / self.vocoder.num_kernels
358
+
359
+ x = F.leaky_relu(x)
360
+ x = self.vocoder.conv_post(x)
361
+
362
+ return self._istft_overlap_add(x)
363
+
364
+
365
+ class F0PredictorWrapper(nn.Module):
366
+ """Wraps HiFT's f0 predictor: mel → f0."""
367
+
368
+ def __init__(self, vocoder):
369
+ super().__init__()
370
+ self.f0_predictor = vocoder.f0_predictor
371
+
372
+ def forward(self, mel: torch.Tensor) -> torch.Tensor:
373
+ """mel: (1, 80, T) → f0: (1, 1, T)"""
374
+ return self.f0_predictor(mel)
375
+
376
+
377
+ class HiFTDecodeWrapper(nn.Module):
378
+ """Wraps HiFT's decode stage: mel + source_stft → waveform.
379
+
380
+ Includes a manual iSTFT implementation using matmul with a precomputed
381
+ DFT basis matrix, so the entire pipeline runs inside CoreML.
382
+ """
383
+
384
+ def __init__(self, vocoder, num_stft_frames: int):
385
+ super().__init__()
386
+ self.vocoder = vocoder
387
+ self.num_stft_frames = num_stft_frames # hardcoded for tracing
388
+ n_fft = vocoder.istft_n_fft # 16
389
+ hop_len = vocoder.istft_hop_len # 4
390
+
391
+ # Precompute DFT basis for iSTFT: (n_fft, n_fft) real-valued IDFT matrix
392
+ # X[k] = sum_n x[n] * exp(j*2pi*n*k/N) → x[n] = (1/N) * sum_k X[k] * exp(j*2pi*n*k/N)
393
+ n = torch.arange(n_fft, dtype=torch.float32)
394
+ k = torch.arange(n_fft, dtype=torch.float32)
395
+ angles = 2.0 * torch.pi * n.unsqueeze(1) * k.unsqueeze(0) / n_fft # (n_fft, n_fft)
396
+ # cos/sin basis for real/imag parts
397
+ self.register_buffer("idft_cos", torch.cos(angles) / n_fft) # (n_fft, n_fft)
398
+ self.register_buffer("idft_sin", torch.sin(angles) / n_fft) # (n_fft, n_fft)
399
+
400
+ # Window for overlap-add
401
+ self.register_buffer("window", vocoder.stft_window.clone())
402
+ self.n_fft = n_fft
403
+ self.hop_len = hop_len
404
+ self.n_fft_half = n_fft // 2 + 1 # 9
405
+
406
+ def forward(self, mel: torch.Tensor, source_stft: torch.Tensor) -> torch.Tensor:
407
+ """
408
+ Args:
409
+ mel: (1, 80, T) float32
410
+ source_stft: (1, 18, T') float32 — real+imag STFT of source signal
411
+
412
+ Returns:
413
+ waveform: (1, samples) float32
414
+ """
415
+ x = self.vocoder.conv_pre(mel)
416
+ for i in range(self.vocoder.num_upsamples):
417
+ x = F.leaky_relu(x, self.vocoder.lrelu_slope)
418
+ x = self.vocoder.ups[i](x)
419
+ if i == self.vocoder.num_upsamples - 1:
420
+ x = self.vocoder.reflection_pad(x)
421
+
422
+ si = self.vocoder.source_downs[i](source_stft)
423
+ si = self.vocoder.source_resblocks[i](si)
424
+ x = x + si
425
+
426
+ xs = None
427
+ for j in range(self.vocoder.num_kernels):
428
+ if xs is None:
429
+ xs = self.vocoder.resblocks[i * self.vocoder.num_kernels + j](x)
430
+ else:
431
+ xs += self.vocoder.resblocks[i * self.vocoder.num_kernels + j](x)
432
+ x = xs / self.vocoder.num_kernels
433
+
434
+ x = F.leaky_relu(x)
435
+ x = self.vocoder.conv_post(x) # (1, 18, num_frames)
436
+
437
+ # Split into magnitude and phase
438
+ magnitude = torch.exp(x[:, :self.n_fft_half, :]) # (1, 9, num_frames)
439
+ phase = torch.sin(x[:, self.n_fft_half:, :]) # (1, 9, num_frames)
440
+
441
+ # Convert to real/imag
442
+ real_half = magnitude * torch.cos(phase) # (1, 9, num_frames)
443
+ imag_half = magnitude * torch.sin(phase)
444
+
445
+ # Mirror to full spectrum (Hermitian symmetry)
446
+ # real: [r0, r1, ..., r8, r7, r6, ..., r1]
447
+ # imag: [i0, i1, ..., i8, -i7, -i6, ..., -i1]
448
+ real_mirror = torch.flip(real_half[:, 1:self.n_fft_half - 1, :], dims=[1])
449
+ imag_mirror = -torch.flip(imag_half[:, 1:self.n_fft_half - 1, :], dims=[1])
450
+ real_full = torch.cat([real_half, real_mirror], dim=1) # (1, 16, num_frames)
451
+ imag_full = torch.cat([imag_half, imag_mirror], dim=1) # (1, 16, num_frames)
452
+
453
+ # iDFT via matmul: output[n] = sum_k (real[k]*cos[n,k] - imag[k]*sin[n,k])
454
+ # (1, 16, num_frames) → transpose to (1, num_frames, 16) → matmul with (16, 16)
455
+ real_t = real_full.transpose(1, 2) # (1, num_frames, 16)
456
+ imag_t = imag_full.transpose(1, 2)
457
+ # segments[n] = sum_k real[k]*cos[n,k] - imag[k]*sin[n,k]
458
+ # = real_t @ idft_cos.T - imag_t @ idft_sin.T
459
+ # But idft_cos is (n_fft, n_fft) where idft_cos[n,k] = cos(2pi*n*k/N)/N
460
+ # We want segments[frame, n] = sum_k (real[frame,k] * idft_cos[n,k] - imag[frame,k] * idft_sin[n,k])
461
+ # = (real_t @ idft_cos^T - imag_t @ idft_sin^T)[frame, n]
462
+ segments = torch.matmul(real_t, self.idft_cos.T) - torch.matmul(imag_t, self.idft_sin.T)
463
+ # segments: (1, num_frames, 16)
464
+
465
+ # Overlap-add with window
466
+ # n_fft=16, hop=4, so overlap ratio = 4 (each sample covered by 4 frames)
467
+ NF = self.num_stft_frames # hardcoded constant for tracing
468
+ segments = segments * self.window.unsqueeze(0).unsqueeze(0) # (1, NF, 16)
469
+ seg = segments.squeeze(0) # (NF, 16)
470
+
471
+ # Reshape each 16-sample segment into 4 chunks of 4 (hop_len) samples
472
+ # seg: (F, 16) → (F, 4, 4)
473
+ seg_chunks = seg.reshape(NF, 4, 4) # (F, 4_blocks, 4_samples)
474
+
475
+ # Block b of frame f lands at output position (f + b) * hop_len
476
+ # Rearrange so block b from all frames is contiguous:
477
+ # chunk_b[f] = seg_chunks[f, b, :] lands at output[(f+b)*4 : (f+b)*4 + 4]
478
+ # = output index f*4 + b*4 ... but shifted by b frames
479
+ # Equivalently: for block b, we have F values that go to positions b, b+1, ..., b+F-1
480
+ # in units of hop_len
481
+
482
+ # For each sub-block offset (0..3), create a flat array and add shifted
483
+ # Using static slicing only — no dynamic indexing
484
+ padded_samples = NF * 4 + 12 # (NF-1)*4 + 16
485
+ # Actually: (num_frames - 1) * 4 + 16 = num_frames * 4 + 12
486
+
487
+ # Each sub-block b contributes F chunks of 4 samples, placed at positions
488
+ # starting from b*4 with stride 4 between frames.
489
+ # block_b = seg_chunks[:, b, :].reshape(-1) → F*4 contiguous values
490
+ # These go to output[b*4 : b*4 + F*4]
491
+ b0 = seg_chunks[:, 0, :].reshape(-1) # (F*4,) → output[0 : F*4]
492
+ b1 = seg_chunks[:, 1, :].reshape(-1) # (F*4,) → output[4 : F*4 + 4]
493
+ b2 = seg_chunks[:, 2, :].reshape(-1) # (F*4,) → output[8 : F*4 + 8]
494
+ b3 = seg_chunks[:, 3, :].reshape(-1) # (F*4,) → output[12 : F*4 + 12]
495
+
496
+ F4 = NF * 4
497
+ output = torch.zeros(padded_samples)
498
+ output[0:F4] = output[0:F4] + b0
499
+ output[4:F4 + 4] = output[4:F4 + 4] + b1
500
+ output[8:F4 + 8] = output[8:F4 + 8] + b2
501
+ output[12:F4 + 12] = output[12:F4 + 12] + b3
502
+
503
+ # Window normalization — same structure
504
+ win_sq = self.window * self.window # (16,)
505
+ win_chunks = win_sq.reshape(4, 4) # (4_blocks, 4_samples)
506
+ w0 = win_chunks[0].repeat(NF)
507
+ w1 = win_chunks[1].repeat(NF)
508
+ w2 = win_chunks[2].repeat(NF)
509
+ w3 = win_chunks[3].repeat(NF)
510
+
511
+ wnorm = torch.zeros(padded_samples)
512
+ wnorm[0:F4] = wnorm[0:F4] + w0
513
+ wnorm[4:F4 + 4] = wnorm[4:F4 + 4] + w1
514
+ wnorm[8:F4 + 8] = wnorm[8:F4 + 8] + w2
515
+ wnorm[12:F4 + 12] = wnorm[12:F4 + 12] + w3
516
+
517
+ output = output / (wnorm + 1e-8)
518
+
519
+ # Trim center padding: n_fft//2 = 8 from start
520
+ pad = 8
521
+ trimmed_len = (NF - 1) * 4 # expected output length
522
+ output = output[pad:pad + trimmed_len]
523
+ output = torch.clamp(output, -0.99, 0.99)
524
+ return output.unsqueeze(0) # (1, samples)
525
+
526
+
527
+ def convert_kanade_decoder(kanade: KanadeModel, num_tokens: int, output_dir: Path):
528
+ """Convert Kanade decoder to CoreML."""
529
+ wrapper = KanadeDecoderWrapper(kanade, num_tokens).eval().float()
530
+ print(f"Tracing Kanade decoder (num_tokens={num_tokens}, mel_length={wrapper.mel_length})...")
531
+
532
+ token_indices = torch.arange(num_tokens, dtype=torch.int32)
533
+ speaker_embedding = torch.randn(1, 128, dtype=torch.float32)
534
+
535
+ with torch.no_grad():
536
+ # Test forward
537
+ mel = wrapper(token_indices, speaker_embedding)
538
+ print(f" Output mel shape: {mel.shape}")
539
+
540
+ traced = torch.jit.trace(wrapper, (token_indices, speaker_embedding))
541
+
542
+ print("Converting Kanade decoder to CoreML...")
543
+ mlmodel = ct.convert(
544
+ traced,
545
+ inputs=[
546
+ ct.TensorType(name="token_indices", shape=(num_tokens,), dtype=np.int32),
547
+ ct.TensorType(name="speaker_embedding", shape=(1, 128), dtype=np.float32),
548
+ ],
549
+ outputs=[ct.TensorType(name="mel", dtype=np.float32)],
550
+ compute_precision=ct.precision.FLOAT32,
551
+ minimum_deployment_target=ct.target.iOS17,
552
+ )
553
+
554
+ out_path = output_dir / "KanadeDecoder.mlpackage"
555
+ mlmodel.save(str(out_path))
556
+ print(f"Saved Kanade decoder to {out_path}")
557
+
558
+
559
+ def convert_f0_predictor(vocoder, mel_length: int, output_dir: Path):
560
+ """Convert HiFT f0 predictor to CoreML."""
561
+ wrapper = F0PredictorWrapper(vocoder).eval().float()
562
+ print(f"Tracing F0 predictor (mel_length={mel_length})...")
563
+
564
+ mel = torch.randn(1, 80, mel_length, dtype=torch.float32)
565
+
566
+ with torch.no_grad():
567
+ f0 = wrapper(mel)
568
+ print(f" Output f0 shape: {f0.shape}")
569
+ traced = torch.jit.trace(wrapper, (mel,))
570
+
571
+ print("Converting F0 predictor to CoreML...")
572
+ mlmodel = ct.convert(
573
+ traced,
574
+ inputs=[
575
+ ct.TensorType(name="mel", shape=(1, 80, mel_length), dtype=np.float32),
576
+ ],
577
+ outputs=[ct.TensorType(name="f0", dtype=np.float32)],
578
+ compute_precision=ct.precision.FLOAT32,
579
+ minimum_deployment_target=ct.target.iOS17,
580
+ )
581
+
582
+ out_path = output_dir / "F0Predictor.mlpackage"
583
+ mlmodel.save(str(out_path))
584
+ print(f"Saved F0 predictor to {out_path}")
585
+
586
+
587
+ def convert_hift_decode(vocoder, mel_length: int, output_dir: Path):
588
+ """Convert HiFT decode stage to CoreML.
589
+
590
+ Source signal STFT must be computed externally (Swift side).
591
+ """
592
+ # Compute source_stft shape: run f0 predictor + source module to get it
593
+ mel = torch.randn(1, 80, mel_length, dtype=torch.float32)
594
+ with torch.no_grad():
595
+ f0 = vocoder.f0_predictor(mel)
596
+ s = vocoder.f0_upsamp(f0[:, None]).transpose(1, 2)
597
+ s, _, _ = vocoder.m_source(s)
598
+ s = s.transpose(1, 2)
599
+ s_stft_real, s_stft_imag = vocoder._stft(s.squeeze(1))
600
+ source_stft = torch.cat([s_stft_real, s_stft_imag], dim=1)
601
+ num_stft_frames = source_stft.shape[2]
602
+ print(f" Source STFT shape: {source_stft.shape} ({num_stft_frames} frames)")
603
+
604
+ wrapper = HiFTDecodeWrapper(vocoder, num_stft_frames).eval().float()
605
+
606
+ print(f"Tracing HiFT decode (mel_length={mel_length})...")
607
+ with torch.no_grad():
608
+ waveform = wrapper(mel, source_stft)
609
+ print(f" Output waveform shape: {waveform.shape}")
610
+ traced = torch.jit.trace(wrapper, (mel, source_stft))
611
+
612
+ print("Converting HiFT decode to CoreML...")
613
+ source_stft_channels = source_stft.shape[1]
614
+ source_stft_time = source_stft.shape[2]
615
+ mlmodel = ct.convert(
616
+ traced,
617
+ inputs=[
618
+ ct.TensorType(name="mel", shape=(1, 80, mel_length), dtype=np.float32),
619
+ ct.TensorType(
620
+ name="source_stft",
621
+ shape=(1, source_stft_channels, source_stft_time),
622
+ dtype=np.float32,
623
+ ),
624
+ ],
625
+ outputs=[ct.TensorType(name="waveform", dtype=np.float32)],
626
+ compute_precision=ct.precision.FLOAT32,
627
+ minimum_deployment_target=ct.target.iOS17,
628
+ )
629
+
630
+ out_path = output_dir / "HiFTDecode.mlpackage"
631
+ mlmodel.save(str(out_path))
632
+ print(f"Saved HiFT decode to {out_path}")
633
+
634
+
635
+ def main():
636
+ parser = argparse.ArgumentParser(description="Convert Kanade + HiFT to CoreML")
637
+ parser.add_argument(
638
+ "--output-dir", type=str,
639
+ default=str(Path(__file__).parent.parent),
640
+ help="Output directory",
641
+ )
642
+ parser.add_argument(
643
+ "--num-tokens", type=int, default=100,
644
+ help="Fixed number of audio tokens (determines mel length)",
645
+ )
646
+ args = parser.parse_args()
647
+ output_dir = Path(args.output_dir)
648
+ output_dir.mkdir(parents=True, exist_ok=True)
649
+
650
+ print("Loading Kanade model...")
651
+ kanade = KanadeModel.from_pretrained("frothywater/kanade-25hz-clean").eval().float()
652
+ patch_kanade_for_coreml(kanade)
653
+ vocoder = load_vocoder(kanade.config.vocoder_name).eval().float()
654
+
655
+ # Compute mel_length for this token count
656
+ mel_length = kanade._calculate_target_mel_length(
657
+ kanade._calculate_original_audio_length(args.num_tokens)
658
+ )
659
+
660
+ print(f"\n=== Converting Kanade decoder ===")
661
+ convert_kanade_decoder(kanade, args.num_tokens, output_dir)
662
+
663
+ print(f"\n=== Converting full vocoder (mel → waveform) ===")
664
+ convert_full_vocoder(vocoder, mel_length, output_dir)
665
+
666
+ print("\nDone!")
667
+ print(f" KanadeDecoder: {args.num_tokens} tokens → mel (80, {mel_length})")
668
+ print(f" Vocoder: mel (80, {mel_length}) → waveform")
669
+
670
+
671
+ def convert_full_vocoder(vocoder, mel_length: int, output_dir: Path):
672
+ """Convert complete mel→waveform vocoder to CoreML."""
673
+ # Get num_stft_frames by running a dummy forward
674
+ mel = torch.randn(1, 80, mel_length, dtype=torch.float32)
675
+ with torch.no_grad():
676
+ f0 = vocoder.f0_predictor(mel)
677
+ s = vocoder.f0_upsamp(f0[:, None]).transpose(1, 2)
678
+ s, _, _ = vocoder.m_source(s)
679
+ s = s.transpose(1, 2)
680
+ sr, si = vocoder._stft(s.squeeze(1))
681
+ num_stft_frames = sr.shape[2]
682
+ print(f" STFT frames: {num_stft_frames}")
683
+
684
+ wrapper = FullVocoderWrapper(vocoder, num_stft_frames).eval().float()
685
+
686
+ print(f"Tracing full vocoder (mel_length={mel_length})...")
687
+ # Replace randn_like with zeros for tracing
688
+ orig_randn = torch.randn_like
689
+ torch.randn_like = lambda x, **kw: torch.zeros_like(x)
690
+ with torch.no_grad():
691
+ wav = wrapper(mel)
692
+ print(f" Output waveform: {wav.shape}")
693
+ traced = torch.jit.trace(wrapper, (mel,))
694
+ torch.randn_like = orig_randn
695
+
696
+ print("Converting full vocoder to CoreML...")
697
+ mlmodel = ct.convert(
698
+ traced,
699
+ inputs=[ct.TensorType(name="mel", shape=(1, 80, mel_length), dtype=np.float32)],
700
+ outputs=[ct.TensorType(name="waveform", dtype=np.float32)],
701
+ compute_precision=ct.precision.FLOAT32,
702
+ minimum_deployment_target=ct.target.iOS17,
703
+ )
704
+
705
+ out_path = output_dir / "Vocoder.mlpackage"
706
+ mlmodel.save(str(out_path))
707
+ print(f"Saved vocoder to {out_path}")
708
+
709
+
710
+ if __name__ == "__main__":
711
+ main()
scripts/model_wrapper.py CHANGED
@@ -118,8 +118,8 @@ class PlaprePicoPrefill(nn.Module):
118
  class PlaprePicoDecode(nn.Module):
119
  """Generates one token at a time using the KV cache.
120
 
121
- Position encoding is handled externally: caller provides pre-sliced cos/sin
122
- for the current position, and a one-hot update_mask for cache writing.
123
 
124
  Inputs:
125
  input_ids: (1, 1) int32
@@ -127,6 +127,7 @@ class PlaprePicoDecode(nn.Module):
127
  cos: (1, 1, 1, 64) float16 — RoPE cos for current position
128
  sin: (1, 1, 1, 64) float16 — RoPE sin for current position
129
  update_mask: (1, 1, 2048, 1) float16 — one-hot at current position
 
130
 
131
  State buffers:
132
  k_cache_0..29, v_cache_0..29: (1, 3, 2048, 64) float16
@@ -174,8 +175,12 @@ class PlaprePicoDecode(nn.Module):
174
  cos: torch.Tensor,
175
  sin: torch.Tensor,
176
  update_mask: torch.Tensor,
 
177
  ) -> torch.Tensor:
178
  hidden = self.embed_tokens(input_ids) # (1, 1, 576)
 
 
 
179
 
180
  for i, layer in enumerate(self.layers):
181
  k_cache = getattr(self, f"k_cache_{i}")
 
118
  class PlaprePicoDecode(nn.Module):
119
  """Generates one token at a time using the KV cache.
120
 
121
+ Also used for token-by-token prefill. For the speaker token (position 0),
122
+ pass a non-zero speaker_hidden to replace the token embedding.
123
 
124
  Inputs:
125
  input_ids: (1, 1) int32
 
127
  cos: (1, 1, 1, 64) float16 — RoPE cos for current position
128
  sin: (1, 1, 1, 64) float16 — RoPE sin for current position
129
  update_mask: (1, 1, 2048, 1) float16 — one-hot at current position
130
+ speaker_hidden: (1, 1, 576) float16 — pre-projected speaker embedding, or zeros
131
 
132
  State buffers:
133
  k_cache_0..29, v_cache_0..29: (1, 3, 2048, 64) float16
 
175
  cos: torch.Tensor,
176
  sin: torch.Tensor,
177
  update_mask: torch.Tensor,
178
+ speaker_hidden: torch.Tensor,
179
  ) -> torch.Tensor:
180
  hidden = self.embed_tokens(input_ids) # (1, 1, 576)
181
+ # Speaker conditioning: caller passes pre-projected (1,1,576) for position 0,
182
+ # zeros for all other positions. Additive — zeros are a no-op.
183
+ hidden = hidden + speaker_hidden
184
 
185
  for i, layer in enumerate(self.layers):
186
  k_cache = getattr(self, f"k_cache_{i}")
scripts/test_generate.py ADDED
@@ -0,0 +1,250 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python3
2
+ """
3
+ End-to-end test: generate Danish speech using our custom PyTorch wrappers
4
+ (the same code converted to CoreML), decode with Kanade, save as WAV.
5
+
6
+ Usage:
7
+ python scripts/test_generate.py [--text "Hej verden"] [--speaker tor] [--output test.wav]
8
+ """
9
+
10
+ import argparse
11
+ import json
12
+ import sys
13
+ from pathlib import Path
14
+
15
+ import numpy as np
16
+ import torch
17
+ import torch.nn.functional as F
18
+ import soundfile as sf
19
+
20
+ sys.path.insert(0, str(Path(__file__).parent))
21
+
22
+ from attention import precompute_rope_frequencies
23
+ from model_wrapper import (
24
+ PlaprePicoPrefill,
25
+ PlaprePicoDecode,
26
+ NUM_LAYERS,
27
+ MAX_CONTEXT,
28
+ HEAD_DIM,
29
+ PREFILL_SEQ_LEN,
30
+ SPEAKER_DIM,
31
+ )
32
+ from convert import load_weights, populate_weights
33
+
34
+ AUDIO_TOKEN_OFFSET = 8002
35
+ AUDIO_MARKER_TOKEN = 8001
36
+ TEXT_MARKER_TOKEN = 8000
37
+ EOS_TOKEN = 2
38
+
39
+
40
+ def load_speaker(speakers_path: Path, name: str) -> torch.Tensor:
41
+ with open(speakers_path) as f:
42
+ speakers = json.load(f)
43
+ if name not in speakers:
44
+ raise ValueError(f"Speaker '{name}' not found. Available: {list(speakers.keys())}")
45
+ return torch.tensor(speakers[name], dtype=torch.float16).unsqueeze(0)
46
+
47
+
48
+ def sample(logits: torch.Tensor, temperature: float, top_k: int, top_p: float) -> int:
49
+ if temperature <= 0:
50
+ return int(logits.argmax())
51
+ logits = logits.float() / temperature
52
+ if top_k > 0:
53
+ topv, topi = torch.topk(logits, top_k)
54
+ logits_filtered = torch.full_like(logits, float("-inf"))
55
+ logits_filtered.scatter_(0, topi, topv)
56
+ else:
57
+ logits_filtered = logits
58
+ probs = F.softmax(logits_filtered, dim=-1)
59
+ sorted_probs, sorted_idx = torch.sort(probs, descending=True)
60
+ cumsum = torch.cumsum(sorted_probs, dim=0)
61
+ mask = cumsum - sorted_probs > top_p
62
+ sorted_probs[mask] = 0
63
+ sorted_probs /= sorted_probs.sum()
64
+ idx = torch.multinomial(sorted_probs, 1)
65
+ return int(sorted_idx[idx])
66
+
67
+
68
+ def generate(
69
+ prefill_model: PlaprePicoPrefill,
70
+ decode_model: PlaprePicoDecode,
71
+ text: str,
72
+ speaker_embedding: torch.Tensor,
73
+ tokenizer_path: Path,
74
+ max_tokens: int,
75
+ temperature: float,
76
+ top_k: int,
77
+ top_p: float,
78
+ ) -> list[int]:
79
+ from tokenizers import Tokenizer
80
+ tokenizer = Tokenizer.from_file(str(tokenizer_path))
81
+ token_ids = tokenizer.encode(text).ids
82
+
83
+ # Plapre format: [placeholder, <text>, tokens..., <audio>]
84
+ # Position 0 placeholder gets replaced by speaker_proj output
85
+ input_ids_list = [EOS_TOKEN] + [TEXT_MARKER_TOKEN] + token_ids + [AUDIO_MARKER_TOKEN]
86
+ input_len = len(input_ids_list)
87
+ print(f"Input ({input_len} tokens): {input_ids_list}")
88
+
89
+ # Pad to prefill length
90
+ padded_ids = torch.full((1, PREFILL_SEQ_LEN), EOS_TOKEN, dtype=torch.int32)
91
+ for i, tid in enumerate(input_ids_list):
92
+ padded_ids[0, i] = tid
93
+
94
+ # Causal mask: only real tokens (0..input_len-1) attend
95
+ causal_mask = torch.full(
96
+ (1, 1, PREFILL_SEQ_LEN, MAX_CONTEXT), float("-inf"), dtype=torch.float16
97
+ )
98
+ for i in range(input_len):
99
+ causal_mask[0, 0, i, :i + 1] = 0.0
100
+
101
+ # === Prefill ===
102
+ # We can't get logits at an arbitrary position from the wrapper (it returns pos -1).
103
+ # So run the layers manually to read logits at input_len - 1.
104
+ print("Running prefill...")
105
+ with torch.no_grad():
106
+ hidden = prefill_model.embed_tokens(padded_ids)
107
+ spk = prefill_model.speaker_proj(speaker_embedding).unsqueeze(1)
108
+ hidden = torch.cat([spk, hidden[:, 1:, :]], dim=1)
109
+
110
+ cos = prefill_model.rope_cos
111
+ sin = prefill_model.rope_sin
112
+
113
+ for i, layer in enumerate(prefill_model.layers):
114
+ k_cache = getattr(prefill_model, f"k_cache_{i}")
115
+ v_cache = getattr(prefill_model, f"v_cache_{i}")
116
+ hidden, k_new, v_new = layer(hidden, cos, sin, causal_mask, k_cache, v_cache)
117
+ # Update caches on the model so decode can copy them
118
+ setattr(prefill_model, f"k_cache_{i}", k_new)
119
+ setattr(prefill_model, f"v_cache_{i}", v_new)
120
+
121
+ hidden = prefill_model.norm(hidden)
122
+ logits = F.linear(hidden[0, input_len - 1, :], prefill_model.embed_tokens.weight)
123
+
124
+ generated = []
125
+ next_token = sample(logits, temperature, top_k, top_p)
126
+ generated.append(next_token)
127
+ print(f" Token 0: {next_token}")
128
+
129
+ # === Copy KV cache to decode model ===
130
+ with torch.no_grad():
131
+ for i in range(NUM_LAYERS):
132
+ getattr(decode_model, f"k_cache_{i}").copy_(getattr(prefill_model, f"k_cache_{i}"))
133
+ getattr(decode_model, f"v_cache_{i}").copy_(getattr(prefill_model, f"v_cache_{i}"))
134
+
135
+ # === Decode loop ===
136
+ cos_full, sin_full = precompute_rope_frequencies(HEAD_DIM, MAX_CONTEXT, 100000.0)
137
+ cos_full = cos_full.half()
138
+ sin_full = sin_full.half()
139
+
140
+ print("Decoding...")
141
+ for step in range(1, max_tokens):
142
+ pos = input_len + step - 1
143
+
144
+ decode_ids = torch.tensor([[next_token]], dtype=torch.int32)
145
+ decode_mask = torch.full((1, 1, 1, MAX_CONTEXT), float("-inf"), dtype=torch.float16)
146
+ decode_mask[0, 0, 0, :pos + 1] = 0.0
147
+ pos_cos = cos_full[:, :, pos:pos + 1, :]
148
+ pos_sin = sin_full[:, :, pos:pos + 1, :]
149
+ update_mask = torch.zeros(1, 1, MAX_CONTEXT, 1, dtype=torch.float16)
150
+ update_mask[0, 0, pos, 0] = 1.0
151
+
152
+ with torch.no_grad():
153
+ logits = decode_model(decode_ids, decode_mask, pos_cos, pos_sin, update_mask)
154
+
155
+ next_token = sample(logits[0, 0], temperature, top_k, top_p)
156
+ generated.append(next_token)
157
+
158
+ if next_token == EOS_TOKEN:
159
+ print(f" EOS at step {step}")
160
+ break
161
+ if step % 25 == 0:
162
+ print(f" Step {step}: ({step / 25:.1f}s of audio)")
163
+
164
+ return generated
165
+
166
+
167
+ def decode_audio(tokens: list[int], speaker_embedding: torch.Tensor) -> np.ndarray:
168
+ from kanade_tokenizer import KanadeModel, load_vocoder, vocode
169
+
170
+ audio_tokens = [t for t in tokens if AUDIO_TOKEN_OFFSET <= t <= 20801]
171
+ if not audio_tokens:
172
+ raise ValueError("No audio tokens generated!")
173
+
174
+ kanade_indices = torch.tensor([t - AUDIO_TOKEN_OFFSET for t in audio_tokens])
175
+ print(f"Decoding {len(kanade_indices)} audio tokens ({len(kanade_indices) / 25:.1f}s)...")
176
+
177
+ model = KanadeModel.from_pretrained("frothywater/kanade-25hz-clean").eval()
178
+ vocoder = load_vocoder(model.config.vocoder_name)
179
+
180
+ with torch.no_grad():
181
+ spk = speaker_embedding.squeeze(0).float()
182
+ mel = model.decode(global_embedding=spk, content_token_indices=kanade_indices)
183
+ waveform = vocode(vocoder, mel.unsqueeze(0))
184
+
185
+ return waveform.squeeze().cpu().numpy()
186
+
187
+
188
+ def main():
189
+ parser = argparse.ArgumentParser(description="Generate Danish speech (custom model)")
190
+ parser.add_argument("--text", type=str, default="Hej, mit navn er Daniel.")
191
+ parser.add_argument("--speaker", type=str, default="tor")
192
+ parser.add_argument("--output", type=str, default="test.wav")
193
+ parser.add_argument("--max-tokens", type=int, default=500)
194
+ parser.add_argument("--temperature", type=float, default=0.8)
195
+ parser.add_argument("--top-k", type=int, default=50)
196
+ parser.add_argument("--top-p", type=float, default=0.95)
197
+ parser.add_argument("--model-dir", type=str, default=None)
198
+ args = parser.parse_args()
199
+
200
+ if args.model_dir:
201
+ model_dir = Path(args.model_dir)
202
+ else:
203
+ cache = Path.home() / ".cache/huggingface/hub/models--syvai--plapre-pico"
204
+ snapshots = cache / "snapshots"
205
+ if snapshots.exists():
206
+ model_dir = next(snapshots.iterdir())
207
+ else:
208
+ from huggingface_hub import snapshot_download
209
+ model_dir = Path(snapshot_download("syvai/plapre-pico"))
210
+
211
+ repo_root = Path(__file__).parent.parent
212
+
213
+ speakers_path = repo_root / "speakers.json"
214
+ if not speakers_path.exists():
215
+ speakers_path = model_dir / "speakers.json"
216
+ speaker_embedding = load_speaker(speakers_path, args.speaker)
217
+ print(f"Speaker: {args.speaker}")
218
+
219
+ tokenizer_path = repo_root / "tokenizer.json"
220
+ if not tokenizer_path.exists():
221
+ tokenizer_path = model_dir / "tokenizer.json"
222
+
223
+ # Load weights into our custom models
224
+ weights = load_weights(model_dir)
225
+
226
+ prefill = PlaprePicoPrefill()
227
+ populate_weights(prefill, weights, is_prefill=True)
228
+ prefill = prefill.half().eval()
229
+
230
+ decode = PlaprePicoDecode()
231
+ populate_weights(decode, weights, is_prefill=False)
232
+ decode = decode.half().eval()
233
+
234
+ # Generate
235
+ tokens = generate(
236
+ prefill, decode, args.text, speaker_embedding, tokenizer_path,
237
+ args.max_tokens, args.temperature, args.top_k, args.top_p,
238
+ )
239
+
240
+ audio_count = sum(1 for t in tokens if AUDIO_TOKEN_OFFSET <= t <= 20801)
241
+ print(f"\nGenerated {len(tokens)} tokens: {audio_count} audio ({audio_count / 25:.1f}s)")
242
+ print(f"First 20: {tokens[:20]}")
243
+
244
+ waveform = decode_audio(tokens, speaker_embedding)
245
+ sf.write(args.output, waveform, 24000)
246
+ print(f"Saved {len(waveform) / 24000:.1f}s audio to {args.output}")
247
+
248
+
249
+ if __name__ == "__main__":
250
+ main()