Daniel Rothmann commited on
Commit ·
fad9fad
1
Parent(s): a2c97d7
WIP audio decoder
Browse files- KanadeDecoder.mlpackage/Data/com.apple.CoreML/model.mlmodel +3 -0
- KanadeDecoder.mlpackage/Data/com.apple.CoreML/weights/weight.bin +3 -0
- KanadeDecoder.mlpackage/Manifest.json +18 -0
- PlaprePicoDecode.mlpackage/Data/com.apple.CoreML/model.mlmodel +2 -2
- PlaprePicoDecode.mlpackage/Manifest.json +8 -8
- Vocoder.mlpackage/Data/com.apple.CoreML/model.mlmodel +3 -0
- Vocoder.mlpackage/Data/com.apple.CoreML/weights/weight.bin +3 -0
- Vocoder.mlpackage/Manifest.json +18 -0
- scripts/convert.py +9 -3
- scripts/convert_kanade.py +711 -0
- scripts/model_wrapper.py +7 -2
- scripts/test_generate.py +250 -0
KanadeDecoder.mlpackage/Data/com.apple.CoreML/model.mlmodel
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:1a72aeec4e105b9d593a721317e9fce1ca7783e21293e82f898d810c6bf1c1fe
|
| 3 |
+
size 178115
|
KanadeDecoder.mlpackage/Data/com.apple.CoreML/weights/weight.bin
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:1d2922387d7a2ef3f41db7a069ca9be2d313250137841ffbe8ab7b912bddd96a
|
| 3 |
+
size 364866112
|
KanadeDecoder.mlpackage/Manifest.json
ADDED
|
@@ -0,0 +1,18 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
{
|
| 2 |
+
"fileFormatVersion": "1.0.0",
|
| 3 |
+
"itemInfoEntries": {
|
| 4 |
+
"3D07005F-6244-406D-9DD3-91CF5F26CCAE": {
|
| 5 |
+
"author": "com.apple.CoreML",
|
| 6 |
+
"description": "CoreML Model Specification",
|
| 7 |
+
"name": "model.mlmodel",
|
| 8 |
+
"path": "com.apple.CoreML/model.mlmodel"
|
| 9 |
+
},
|
| 10 |
+
"FD090485-11AF-465F-8569-E149E7086201": {
|
| 11 |
+
"author": "com.apple.CoreML",
|
| 12 |
+
"description": "CoreML Model Weights",
|
| 13 |
+
"name": "weights",
|
| 14 |
+
"path": "com.apple.CoreML/weights"
|
| 15 |
+
}
|
| 16 |
+
},
|
| 17 |
+
"rootModelIdentifier": "3D07005F-6244-406D-9DD3-91CF5F26CCAE"
|
| 18 |
+
}
|
PlaprePicoDecode.mlpackage/Data/com.apple.CoreML/model.mlmodel
CHANGED
|
@@ -1,3 +1,3 @@
|
|
| 1 |
version https://git-lfs.github.com/spec/v1
|
| 2 |
-
oid sha256:
|
| 3 |
-
size
|
|
|
|
| 1 |
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:0cbc60dac941edc9fbe212c52c4a96677e6ca547d575bdf461d19695e35de86a
|
| 3 |
+
size 579443
|
PlaprePicoDecode.mlpackage/Manifest.json
CHANGED
|
@@ -1,18 +1,18 @@
|
|
| 1 |
{
|
| 2 |
"fileFormatVersion": "1.0.0",
|
| 3 |
"itemInfoEntries": {
|
| 4 |
-
"
|
| 5 |
-
"author": "com.apple.CoreML",
|
| 6 |
-
"description": "CoreML Model Weights",
|
| 7 |
-
"name": "weights",
|
| 8 |
-
"path": "com.apple.CoreML/weights"
|
| 9 |
-
},
|
| 10 |
-
"E087C383-13E2-4E2C-B87A-990925041088": {
|
| 11 |
"author": "com.apple.CoreML",
|
| 12 |
"description": "CoreML Model Specification",
|
| 13 |
"name": "model.mlmodel",
|
| 14 |
"path": "com.apple.CoreML/model.mlmodel"
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 15 |
}
|
| 16 |
},
|
| 17 |
-
"rootModelIdentifier": "
|
| 18 |
}
|
|
|
|
| 1 |
{
|
| 2 |
"fileFormatVersion": "1.0.0",
|
| 3 |
"itemInfoEntries": {
|
| 4 |
+
"668A6F00-934D-4D44-9C27-7881268451D9": {
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 5 |
"author": "com.apple.CoreML",
|
| 6 |
"description": "CoreML Model Specification",
|
| 7 |
"name": "model.mlmodel",
|
| 8 |
"path": "com.apple.CoreML/model.mlmodel"
|
| 9 |
+
},
|
| 10 |
+
"BA90DDE9-E076-4B65-A23D-91E3BFAD284D": {
|
| 11 |
+
"author": "com.apple.CoreML",
|
| 12 |
+
"description": "CoreML Model Weights",
|
| 13 |
+
"name": "weights",
|
| 14 |
+
"path": "com.apple.CoreML/weights"
|
| 15 |
}
|
| 16 |
},
|
| 17 |
+
"rootModelIdentifier": "668A6F00-934D-4D44-9C27-7881268451D9"
|
| 18 |
}
|
Vocoder.mlpackage/Data/com.apple.CoreML/model.mlmodel
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:88536c7f82ce5963c40ab46ab192452ddd1af731ecd4e08a40ea827fc544fbb6
|
| 3 |
+
size 1298694
|
Vocoder.mlpackage/Data/com.apple.CoreML/weights/weight.bin
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:2f1b0ee1106eb66c74b00639159b27c910123caa778ffb2b7b4ece2eb88a180c
|
| 3 |
+
size 85215120
|
Vocoder.mlpackage/Manifest.json
ADDED
|
@@ -0,0 +1,18 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
{
|
| 2 |
+
"fileFormatVersion": "1.0.0",
|
| 3 |
+
"itemInfoEntries": {
|
| 4 |
+
"1865D6B1-DF08-4C5C-8B25-53058EF04D75": {
|
| 5 |
+
"author": "com.apple.CoreML",
|
| 6 |
+
"description": "CoreML Model Specification",
|
| 7 |
+
"name": "model.mlmodel",
|
| 8 |
+
"path": "com.apple.CoreML/model.mlmodel"
|
| 9 |
+
},
|
| 10 |
+
"6D12622B-E675-4537-9163-574EA27CA0C1": {
|
| 11 |
+
"author": "com.apple.CoreML",
|
| 12 |
+
"description": "CoreML Model Weights",
|
| 13 |
+
"name": "weights",
|
| 14 |
+
"path": "com.apple.CoreML/weights"
|
| 15 |
+
}
|
| 16 |
+
},
|
| 17 |
+
"rootModelIdentifier": "1865D6B1-DF08-4C5C-8B25-53058EF04D75"
|
| 18 |
+
}
|
scripts/convert.py
CHANGED
|
@@ -203,16 +203,17 @@ def convert_decode(model: PlaprePicoDecode, output_dir: Path):
|
|
| 203 |
causal_mask = torch.full((1, 1, 1, MAX_CONTEXT), float("-inf"), dtype=torch.float16)
|
| 204 |
causal_mask[0, 0, 0, :PREFILL_SEQ_LEN] = 0.0
|
| 205 |
|
| 206 |
-
# Pre-sliced RoPE for a single position (caller computes these)
|
| 207 |
cos = torch.zeros(1, 1, 1, HEAD_DIM, dtype=torch.float16)
|
| 208 |
sin = torch.zeros(1, 1, 1, HEAD_DIM, dtype=torch.float16)
|
| 209 |
|
| 210 |
-
# One-hot position mask for cache update (caller builds this)
|
| 211 |
update_mask = torch.zeros(1, 1, MAX_CONTEXT, 1, dtype=torch.float16)
|
| 212 |
update_mask[0, 0, PREFILL_SEQ_LEN, 0] = 1.0
|
| 213 |
|
|
|
|
|
|
|
|
|
|
| 214 |
with torch.no_grad():
|
| 215 |
-
traced = torch.jit.trace(model, (input_ids, causal_mask, cos, sin, update_mask))
|
| 216 |
|
| 217 |
print("Converting decode to CoreML...")
|
| 218 |
mlmodel = ct.convert(
|
|
@@ -231,6 +232,11 @@ def convert_decode(model: PlaprePicoDecode, output_dir: Path):
|
|
| 231 |
shape=(1, 1, MAX_CONTEXT, 1),
|
| 232 |
dtype=np.float16,
|
| 233 |
),
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 234 |
],
|
| 235 |
outputs=[ct.TensorType(name="logits", dtype=np.float16)],
|
| 236 |
states=build_kv_cache_states(),
|
|
|
|
| 203 |
causal_mask = torch.full((1, 1, 1, MAX_CONTEXT), float("-inf"), dtype=torch.float16)
|
| 204 |
causal_mask[0, 0, 0, :PREFILL_SEQ_LEN] = 0.0
|
| 205 |
|
|
|
|
| 206 |
cos = torch.zeros(1, 1, 1, HEAD_DIM, dtype=torch.float16)
|
| 207 |
sin = torch.zeros(1, 1, 1, HEAD_DIM, dtype=torch.float16)
|
| 208 |
|
|
|
|
| 209 |
update_mask = torch.zeros(1, 1, MAX_CONTEXT, 1, dtype=torch.float16)
|
| 210 |
update_mask[0, 0, PREFILL_SEQ_LEN, 0] = 1.0
|
| 211 |
|
| 212 |
+
# Pre-projected speaker hidden: (1, 1, HIDDEN_SIZE) — zeros for non-speaker steps
|
| 213 |
+
speaker_hidden = torch.zeros(1, 1, HIDDEN_SIZE, dtype=torch.float16)
|
| 214 |
+
|
| 215 |
with torch.no_grad():
|
| 216 |
+
traced = torch.jit.trace(model, (input_ids, causal_mask, cos, sin, update_mask, speaker_hidden))
|
| 217 |
|
| 218 |
print("Converting decode to CoreML...")
|
| 219 |
mlmodel = ct.convert(
|
|
|
|
| 232 |
shape=(1, 1, MAX_CONTEXT, 1),
|
| 233 |
dtype=np.float16,
|
| 234 |
),
|
| 235 |
+
ct.TensorType(
|
| 236 |
+
name="speaker_hidden",
|
| 237 |
+
shape=(1, 1, HIDDEN_SIZE),
|
| 238 |
+
dtype=np.float16,
|
| 239 |
+
),
|
| 240 |
],
|
| 241 |
outputs=[ct.TensorType(name="logits", dtype=np.float16)],
|
| 242 |
states=build_kv_cache_states(),
|
scripts/convert_kanade.py
ADDED
|
@@ -0,0 +1,711 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#!/usr/bin/env python3
|
| 2 |
+
"""
|
| 3 |
+
Convert Kanade decoder and HiFT vocoder to CoreML.
|
| 4 |
+
|
| 5 |
+
These are non-autoregressive models (single forward pass), so conversion
|
| 6 |
+
is simpler than the LLM — no KV cache or StateType needed.
|
| 7 |
+
|
| 8 |
+
Two models are produced:
|
| 9 |
+
- KanadeDecoder.mlpackage: audio token indices + speaker embedding → mel spectrogram
|
| 10 |
+
- HiFTVocoder.mlpackage: mel spectrogram → PCM waveform
|
| 11 |
+
|
| 12 |
+
Usage:
|
| 13 |
+
python scripts/convert_kanade.py [--output-dir PATH] [--num-tokens 100]
|
| 14 |
+
"""
|
| 15 |
+
|
| 16 |
+
import argparse
|
| 17 |
+
from pathlib import Path
|
| 18 |
+
|
| 19 |
+
import numpy as np
|
| 20 |
+
import torch
|
| 21 |
+
import torch.nn as nn
|
| 22 |
+
import torch.nn.functional as F
|
| 23 |
+
import coremltools as ct
|
| 24 |
+
from kanade_tokenizer import KanadeModel, load_vocoder
|
| 25 |
+
import kanade_tokenizer.module.transformer as kanade_transformer
|
| 26 |
+
|
| 27 |
+
|
| 28 |
+
# ── Monkey-patch Kanade's complex RoPE with real-valued version ───────────
|
| 29 |
+
|
| 30 |
+
def _apply_rotary_emb_real(x: torch.Tensor, freqs_cis: torch.Tensor) -> torch.Tensor:
|
| 31 |
+
"""Real-valued RoPE replacement for Kanade's complex-number version.
|
| 32 |
+
Converts complex freqs_cis to cos/sin and applies split-half rotation.
|
| 33 |
+
"""
|
| 34 |
+
# freqs_cis is complex: (seq_len, head_dim/2)
|
| 35 |
+
cos = freqs_cis.real # (seq_len, head_dim/2)
|
| 36 |
+
sin = freqs_cis.imag
|
| 37 |
+
# Broadcast to match x shape: (bsz, seq_len, n_heads, head_dim)
|
| 38 |
+
# x has head_dim, cos/sin have head_dim/2 — need to double them
|
| 39 |
+
cos = torch.cat([cos, cos], dim=-1) # (seq_len, head_dim)
|
| 40 |
+
sin = torch.cat([sin, sin], dim=-1)
|
| 41 |
+
# Reshape for broadcast: (1, seq_len, 1, head_dim)
|
| 42 |
+
cos = cos.unsqueeze(0).unsqueeze(2)
|
| 43 |
+
sin = sin.unsqueeze(0).unsqueeze(2)
|
| 44 |
+
# Split-half rotation
|
| 45 |
+
half = x.shape[-1] // 2
|
| 46 |
+
x1 = x[..., :half]
|
| 47 |
+
x2 = x[..., half:]
|
| 48 |
+
rotated = torch.cat((-x2, x1), dim=-1)
|
| 49 |
+
return (x * cos + rotated * sin).type_as(x)
|
| 50 |
+
|
| 51 |
+
|
| 52 |
+
def _apply_rotary_emb_precomputed(x: torch.Tensor, freqs_cos_sin: torch.Tensor) -> torch.Tensor:
|
| 53 |
+
"""Real-valued RoPE using precomputed cos/sin stored as (seq_len, head_dim).
|
| 54 |
+
head_dim is always 64, hardcoded to avoid dynamic size ops.
|
| 55 |
+
"""
|
| 56 |
+
cos = freqs_cos_sin[..., :32]
|
| 57 |
+
sin = freqs_cos_sin[..., 32:]
|
| 58 |
+
cos = torch.cat([cos, cos], dim=-1)
|
| 59 |
+
sin = torch.cat([sin, sin], dim=-1)
|
| 60 |
+
cos = cos.unsqueeze(0).unsqueeze(2)
|
| 61 |
+
sin = sin.unsqueeze(0).unsqueeze(2)
|
| 62 |
+
x1 = x[..., :32]
|
| 63 |
+
x2 = x[..., 32:]
|
| 64 |
+
rotated = torch.cat((-x2, x1), dim=-1)
|
| 65 |
+
return (x * cos + rotated * sin).type_as(x)
|
| 66 |
+
|
| 67 |
+
|
| 68 |
+
def _patched_attention_forward_v2(self, x, freqs_cis, mask, return_kv=False):
|
| 69 |
+
"""Attention forward with real-valued RoPE and explicit matmul."""
|
| 70 |
+
bsz, seqlen, _ = x.shape
|
| 71 |
+
xq, xk, xv = self.wq(x), self.wk(x), self.wv(x)
|
| 72 |
+
|
| 73 |
+
xq = xq.view(bsz, seqlen, self.n_heads, self.head_dim)
|
| 74 |
+
xk = xk.view(bsz, seqlen, self.n_heads, self.head_dim)
|
| 75 |
+
xv = xv.view(bsz, seqlen, self.n_heads, self.head_dim)
|
| 76 |
+
|
| 77 |
+
if freqs_cis is not None:
|
| 78 |
+
xq = _apply_rotary_emb_precomputed(xq, freqs_cis[:seqlen])
|
| 79 |
+
xk = _apply_rotary_emb_precomputed(xk, freqs_cis[:seqlen])
|
| 80 |
+
|
| 81 |
+
xq = xq.transpose(1, 2)
|
| 82 |
+
xk = xk.transpose(1, 2)
|
| 83 |
+
xv = xv.transpose(1, 2)
|
| 84 |
+
|
| 85 |
+
attn_weights = torch.matmul(xq, xk.transpose(2, 3)) * self.scale
|
| 86 |
+
if mask is not None:
|
| 87 |
+
attn_weights = attn_weights + mask
|
| 88 |
+
if self.causal:
|
| 89 |
+
causal_mask = torch.triu(
|
| 90 |
+
torch.full((seqlen, seqlen), float("-inf"), device=x.device), diagonal=1
|
| 91 |
+
)
|
| 92 |
+
attn_weights = attn_weights + causal_mask
|
| 93 |
+
attn_weights = F.softmax(attn_weights, dim=-1, dtype=torch.float32).to(xq.dtype)
|
| 94 |
+
output = torch.matmul(attn_weights, xv)
|
| 95 |
+
|
| 96 |
+
# 12 heads * 64 head_dim = 768
|
| 97 |
+
output = output.transpose(1, 2).contiguous().reshape(bsz, seqlen, 768)
|
| 98 |
+
output = self.wo(output)
|
| 99 |
+
|
| 100 |
+
if return_kv:
|
| 101 |
+
return output, (xk, xv)
|
| 102 |
+
return output
|
| 103 |
+
|
| 104 |
+
|
| 105 |
+
def _convert_freqs_cis_to_real(transformer_module):
|
| 106 |
+
"""Replace complex freqs_cis buffer with real-valued cos/sin concatenation."""
|
| 107 |
+
if hasattr(transformer_module, 'freqs_cis') and transformer_module.freqs_cis is not None:
|
| 108 |
+
fc = transformer_module.freqs_cis # (max_len, head_dim/2) complex
|
| 109 |
+
cos = fc.real.float() # (max_len, head_dim/2)
|
| 110 |
+
sin = fc.imag.float()
|
| 111 |
+
real_freqs = torch.cat([cos, sin], dim=-1) # (max_len, head_dim)
|
| 112 |
+
# Replace the buffer
|
| 113 |
+
del transformer_module.freqs_cis
|
| 114 |
+
transformer_module.register_buffer('freqs_cis', real_freqs)
|
| 115 |
+
|
| 116 |
+
|
| 117 |
+
def patch_kanade_for_coreml(kanade: KanadeModel):
|
| 118 |
+
"""Apply monkey-patches to make Kanade traceable by coremltools."""
|
| 119 |
+
kanade_transformer.Attention.forward = _patched_attention_forward_v2
|
| 120 |
+
# Convert complex freqs_cis to real in all transformers
|
| 121 |
+
for name, module in kanade.named_modules():
|
| 122 |
+
if isinstance(module, kanade_transformer.Transformer):
|
| 123 |
+
_convert_freqs_cis_to_real(module)
|
| 124 |
+
|
| 125 |
+
|
| 126 |
+
class KanadeDecoderWrapper(nn.Module):
|
| 127 |
+
"""Wraps Kanade's decode pipeline for tracing.
|
| 128 |
+
|
| 129 |
+
Pipeline: token indices → quantizer decode → mel_prenet → upsample →
|
| 130 |
+
mel_decoder (conditioned on speaker) → mel_postnet → mel
|
| 131 |
+
"""
|
| 132 |
+
|
| 133 |
+
def __init__(self, kanade: KanadeModel, num_tokens: int):
|
| 134 |
+
super().__init__()
|
| 135 |
+
self.local_quantizer = kanade.local_quantizer
|
| 136 |
+
self.mel_prenet = kanade.mel_prenet
|
| 137 |
+
self.mel_conv_upsample = kanade.mel_conv_upsample
|
| 138 |
+
self.mel_decoder = kanade.mel_decoder
|
| 139 |
+
self.mel_postnet = kanade.mel_postnet
|
| 140 |
+
self.num_tokens = num_tokens
|
| 141 |
+
# Precompute mel_length for this token count
|
| 142 |
+
self.mel_length = kanade._calculate_target_mel_length(
|
| 143 |
+
kanade._calculate_original_audio_length(num_tokens)
|
| 144 |
+
)
|
| 145 |
+
|
| 146 |
+
def forward(
|
| 147 |
+
self,
|
| 148 |
+
token_indices: torch.Tensor,
|
| 149 |
+
speaker_embedding: torch.Tensor,
|
| 150 |
+
) -> torch.Tensor:
|
| 151 |
+
"""
|
| 152 |
+
Args:
|
| 153 |
+
token_indices: (num_tokens,) int32 — Kanade codebook indices (0-12799)
|
| 154 |
+
speaker_embedding: (1, 128) float32 — speaker embedding
|
| 155 |
+
|
| 156 |
+
Returns:
|
| 157 |
+
mel: (1, 80, mel_length) float32
|
| 158 |
+
"""
|
| 159 |
+
# Quantizer decode: indices → content embedding
|
| 160 |
+
content_emb = self.local_quantizer.decode(token_indices) # (num_tokens, 768)
|
| 161 |
+
content_emb = content_emb.unsqueeze(0) # (1, num_tokens, 768)
|
| 162 |
+
|
| 163 |
+
# Mel prenet (transformer)
|
| 164 |
+
local_latent = self.mel_prenet(content_emb)
|
| 165 |
+
|
| 166 |
+
# Upsample to mel length
|
| 167 |
+
if self.mel_conv_upsample is not None:
|
| 168 |
+
local_latent = self.mel_conv_upsample(
|
| 169 |
+
local_latent.transpose(1, 2)
|
| 170 |
+
).transpose(1, 2)
|
| 171 |
+
local_latent = F.interpolate(
|
| 172 |
+
local_latent.transpose(1, 2), size=self.mel_length, mode="nearest"
|
| 173 |
+
).transpose(1, 2)
|
| 174 |
+
|
| 175 |
+
# Mel decoder (conditioned on speaker)
|
| 176 |
+
mel = self.mel_decoder(local_latent, condition=speaker_embedding.unsqueeze(1))
|
| 177 |
+
mel = mel.transpose(1, 2) # (1, 80, mel_length)
|
| 178 |
+
|
| 179 |
+
# Postnet
|
| 180 |
+
mel = self.mel_postnet(mel)
|
| 181 |
+
return mel
|
| 182 |
+
|
| 183 |
+
|
| 184 |
+
class FullVocoderWrapper(nn.Module):
|
| 185 |
+
"""Complete mel → waveform pipeline: F0 prediction + source gen + HiFT decode + iSTFT.
|
| 186 |
+
|
| 187 |
+
Noise is replaced with zeros for deterministic tracing.
|
| 188 |
+
"""
|
| 189 |
+
|
| 190 |
+
def __init__(self, vocoder, num_stft_frames: int):
|
| 191 |
+
super().__init__()
|
| 192 |
+
self.vocoder = vocoder
|
| 193 |
+
self.num_stft_frames = num_stft_frames
|
| 194 |
+
n_fft = vocoder.istft_n_fft # 16
|
| 195 |
+
hop_len = vocoder.istft_hop_len # 4
|
| 196 |
+
|
| 197 |
+
# iDFT basis
|
| 198 |
+
n = torch.arange(n_fft, dtype=torch.float32)
|
| 199 |
+
k = torch.arange(n_fft, dtype=torch.float32)
|
| 200 |
+
angles = 2.0 * torch.pi * n.unsqueeze(1) * k.unsqueeze(0) / n_fft
|
| 201 |
+
self.register_buffer("idft_cos", torch.cos(angles) / n_fft)
|
| 202 |
+
self.register_buffer("idft_sin", torch.sin(angles) / n_fft)
|
| 203 |
+
self.register_buffer("window", vocoder.stft_window.clone())
|
| 204 |
+
|
| 205 |
+
# Source generation constants
|
| 206 |
+
self.sampling_rate = vocoder.m_source.l_sin_gen.sampling_rate
|
| 207 |
+
self.harmonic_num = vocoder.m_source.l_sin_gen.harmonic_num # 8
|
| 208 |
+
self.sine_amp = vocoder.m_source.l_sin_gen.sine_amp # 0.1
|
| 209 |
+
self.upsample_scale = vocoder.m_source.l_sin_gen.upsample_scale # 480
|
| 210 |
+
|
| 211 |
+
# Harmonic multipliers: [1, 2, ..., 9]
|
| 212 |
+
self.register_buffer(
|
| 213 |
+
"harmonic_muls",
|
| 214 |
+
torch.arange(1, self.harmonic_num + 2, dtype=torch.float32),
|
| 215 |
+
)
|
| 216 |
+
|
| 217 |
+
# l_linear and l_tanh from m_source
|
| 218 |
+
self.source_linear = vocoder.m_source.l_linear
|
| 219 |
+
self.source_tanh = vocoder.m_source.l_tanh
|
| 220 |
+
|
| 221 |
+
self.n_fft = n_fft
|
| 222 |
+
self.hop_len = hop_len
|
| 223 |
+
self.n_fft_half = n_fft // 2 + 1
|
| 224 |
+
|
| 225 |
+
def _generate_source(self, f0: torch.Tensor) -> torch.Tensor:
|
| 226 |
+
"""f0: (1, mel_length) → source_stft: (1, 18, stft_frames)"""
|
| 227 |
+
# Upsample f0: (1, mel_length) → (1, 1, mel_length) → nearest → (1, 1, audio_length)
|
| 228 |
+
f0_up = F.interpolate(
|
| 229 |
+
f0.unsqueeze(1), scale_factor=float(self.upsample_scale), mode="nearest"
|
| 230 |
+
).squeeze(1) # (1, audio_length)
|
| 231 |
+
|
| 232 |
+
# Generate harmonics: f0 * [1..9]
|
| 233 |
+
# f0_up: (1, L) → (1, L, 1) * (9,) → (1, L, 9)
|
| 234 |
+
fn = f0_up.unsqueeze(-1) * self.harmonic_muls.unsqueeze(0).unsqueeze(0)
|
| 235 |
+
|
| 236 |
+
# Phase accumulation: cumsum(f/sr) * 2pi
|
| 237 |
+
rad = (fn / self.sampling_rate) # instantaneous frequency in cycles per sample
|
| 238 |
+
phase = torch.cumsum(rad, dim=1) * 2.0 * torch.pi # (1, L, 9)
|
| 239 |
+
|
| 240 |
+
# Sine waves
|
| 241 |
+
sines = torch.sin(phase) * self.sine_amp # (1, L, 9)
|
| 242 |
+
|
| 243 |
+
# UV mask (voiced/unvoiced)
|
| 244 |
+
uv = (f0_up > 0).float().unsqueeze(-1) # (1, L, 1)
|
| 245 |
+
|
| 246 |
+
# Apply UV (no noise — zeros instead of randn for tracing)
|
| 247 |
+
sines = sines * uv # (1, L, 9)
|
| 248 |
+
|
| 249 |
+
# l_linear + tanh: (1, L, 9) → linear → (1, L, 1) → tanh
|
| 250 |
+
source = self.source_tanh(self.source_linear(sines)) # (1, L, 1)
|
| 251 |
+
source = source.squeeze(-1) # (1, L)
|
| 252 |
+
|
| 253 |
+
# Manual STFT (torch.stft/unfold not CoreML-compatible)
|
| 254 |
+
# n_fft=16, hop=4. With center padding, we get num_stft_frames frames.
|
| 255 |
+
# Pad source: reflect pad n_fft//2 on each side
|
| 256 |
+
padded = F.pad(source, (self.n_fft // 2, self.n_fft // 2), mode="reflect")
|
| 257 |
+
# padded: (1, L + n_fft) where L = audio_length
|
| 258 |
+
|
| 259 |
+
# Extract overlapping frames using conv1d with identity kernel
|
| 260 |
+
# This replaces unfold: conv1d with (n_fft, 1, n_fft) identity kernel, stride=hop
|
| 261 |
+
# Equivalent to: frames[i] = padded[i*hop : i*hop + n_fft]
|
| 262 |
+
eye_kernel = torch.eye(self.n_fft, dtype=source.dtype, device=source.device).unsqueeze(1)
|
| 263 |
+
# padded: (1, L+16) → (1, 1, L+16) for conv1d
|
| 264 |
+
frames = F.conv1d(padded.unsqueeze(1), eye_kernel, stride=self.hop_len)
|
| 265 |
+
# frames: (1, 16, num_frames)
|
| 266 |
+
frames = frames * self.window.unsqueeze(0).unsqueeze(-1) # window each frame
|
| 267 |
+
# Transpose to (1, num_frames, 16) for matmul
|
| 268 |
+
frames = frames.transpose(1, 2)
|
| 269 |
+
|
| 270 |
+
# DFT via matmul
|
| 271 |
+
dft_cos = self.idft_cos[:self.n_fft_half, :] * self.n_fft # undo 1/N normalization
|
| 272 |
+
dft_sin = self.idft_sin[:self.n_fft_half, :] * self.n_fft
|
| 273 |
+
s_real = torch.matmul(frames, dft_cos.T) # (1, NF, 9)
|
| 274 |
+
s_imag = -torch.matmul(frames, dft_sin.T) # (1, NF, 9)
|
| 275 |
+
source_stft = torch.cat([s_real.transpose(1, 2), s_imag.transpose(1, 2)], dim=1)
|
| 276 |
+
return source_stft
|
| 277 |
+
|
| 278 |
+
def _istft_overlap_add(self, x: torch.Tensor) -> torch.Tensor:
|
| 279 |
+
"""x: (1, 18, num_frames) conv_post output → waveform (1, samples)"""
|
| 280 |
+
magnitude = torch.exp(x[:, :self.n_fft_half, :])
|
| 281 |
+
phase = torch.sin(x[:, self.n_fft_half:, :])
|
| 282 |
+
|
| 283 |
+
real_half = magnitude * torch.cos(phase)
|
| 284 |
+
imag_half = magnitude * torch.sin(phase)
|
| 285 |
+
|
| 286 |
+
real_mirror = torch.flip(real_half[:, 1:self.n_fft_half - 1, :], dims=[1])
|
| 287 |
+
imag_mirror = -torch.flip(imag_half[:, 1:self.n_fft_half - 1, :], dims=[1])
|
| 288 |
+
real_full = torch.cat([real_half, real_mirror], dim=1)
|
| 289 |
+
imag_full = torch.cat([imag_half, imag_mirror], dim=1)
|
| 290 |
+
|
| 291 |
+
real_t = real_full.transpose(1, 2)
|
| 292 |
+
imag_t = imag_full.transpose(1, 2)
|
| 293 |
+
segments = torch.matmul(real_t, self.idft_cos.T) - torch.matmul(imag_t, self.idft_sin.T)
|
| 294 |
+
|
| 295 |
+
NF = self.num_stft_frames
|
| 296 |
+
segments = segments * self.window.unsqueeze(0).unsqueeze(0)
|
| 297 |
+
seg = segments.squeeze(0)
|
| 298 |
+
seg_chunks = seg.reshape(NF, 4, 4)
|
| 299 |
+
|
| 300 |
+
b0 = seg_chunks[:, 0, :].reshape(-1)
|
| 301 |
+
b1 = seg_chunks[:, 1, :].reshape(-1)
|
| 302 |
+
b2 = seg_chunks[:, 2, :].reshape(-1)
|
| 303 |
+
b3 = seg_chunks[:, 3, :].reshape(-1)
|
| 304 |
+
|
| 305 |
+
F4 = NF * 4
|
| 306 |
+
padded_samples = NF * 4 + 12
|
| 307 |
+
output = torch.zeros(padded_samples)
|
| 308 |
+
output[0:F4] = output[0:F4] + b0
|
| 309 |
+
output[4:F4 + 4] = output[4:F4 + 4] + b1
|
| 310 |
+
output[8:F4 + 8] = output[8:F4 + 8] + b2
|
| 311 |
+
output[12:F4 + 12] = output[12:F4 + 12] + b3
|
| 312 |
+
|
| 313 |
+
win_sq = self.window * self.window
|
| 314 |
+
win_chunks = win_sq.reshape(4, 4)
|
| 315 |
+
w0 = win_chunks[0].repeat(NF)
|
| 316 |
+
w1 = win_chunks[1].repeat(NF)
|
| 317 |
+
w2 = win_chunks[2].repeat(NF)
|
| 318 |
+
w3 = win_chunks[3].repeat(NF)
|
| 319 |
+
|
| 320 |
+
wnorm = torch.zeros(padded_samples)
|
| 321 |
+
wnorm[0:F4] = wnorm[0:F4] + w0
|
| 322 |
+
wnorm[4:F4 + 4] = wnorm[4:F4 + 4] + w1
|
| 323 |
+
wnorm[8:F4 + 8] = wnorm[8:F4 + 8] + w2
|
| 324 |
+
wnorm[12:F4 + 12] = wnorm[12:F4 + 12] + w3
|
| 325 |
+
|
| 326 |
+
output = output / (wnorm + 1e-8)
|
| 327 |
+
pad = 8
|
| 328 |
+
trimmed_len = (NF - 1) * 4
|
| 329 |
+
output = output[pad:pad + trimmed_len]
|
| 330 |
+
output = torch.clamp(output, -0.99, 0.99)
|
| 331 |
+
return output.unsqueeze(0)
|
| 332 |
+
|
| 333 |
+
def forward(self, mel: torch.Tensor) -> torch.Tensor:
|
| 334 |
+
"""mel: (1, 80, T) → waveform: (1, samples)"""
|
| 335 |
+
# F0 prediction
|
| 336 |
+
f0 = self.vocoder.f0_predictor(mel) # (1, T)
|
| 337 |
+
|
| 338 |
+
# Source generation
|
| 339 |
+
source_stft = self._generate_source(f0)
|
| 340 |
+
|
| 341 |
+
# HiFT decode
|
| 342 |
+
x = self.vocoder.conv_pre(mel)
|
| 343 |
+
for i in range(self.vocoder.num_upsamples):
|
| 344 |
+
x = F.leaky_relu(x, self.vocoder.lrelu_slope)
|
| 345 |
+
x = self.vocoder.ups[i](x)
|
| 346 |
+
if i == self.vocoder.num_upsamples - 1:
|
| 347 |
+
x = self.vocoder.reflection_pad(x)
|
| 348 |
+
si = self.vocoder.source_downs[i](source_stft)
|
| 349 |
+
si = self.vocoder.source_resblocks[i](si)
|
| 350 |
+
x = x + si
|
| 351 |
+
xs = None
|
| 352 |
+
for j in range(self.vocoder.num_kernels):
|
| 353 |
+
if xs is None:
|
| 354 |
+
xs = self.vocoder.resblocks[i * self.vocoder.num_kernels + j](x)
|
| 355 |
+
else:
|
| 356 |
+
xs += self.vocoder.resblocks[i * self.vocoder.num_kernels + j](x)
|
| 357 |
+
x = xs / self.vocoder.num_kernels
|
| 358 |
+
|
| 359 |
+
x = F.leaky_relu(x)
|
| 360 |
+
x = self.vocoder.conv_post(x)
|
| 361 |
+
|
| 362 |
+
return self._istft_overlap_add(x)
|
| 363 |
+
|
| 364 |
+
|
| 365 |
+
class F0PredictorWrapper(nn.Module):
|
| 366 |
+
"""Wraps HiFT's f0 predictor: mel → f0."""
|
| 367 |
+
|
| 368 |
+
def __init__(self, vocoder):
|
| 369 |
+
super().__init__()
|
| 370 |
+
self.f0_predictor = vocoder.f0_predictor
|
| 371 |
+
|
| 372 |
+
def forward(self, mel: torch.Tensor) -> torch.Tensor:
|
| 373 |
+
"""mel: (1, 80, T) → f0: (1, 1, T)"""
|
| 374 |
+
return self.f0_predictor(mel)
|
| 375 |
+
|
| 376 |
+
|
| 377 |
+
class HiFTDecodeWrapper(nn.Module):
|
| 378 |
+
"""Wraps HiFT's decode stage: mel + source_stft → waveform.
|
| 379 |
+
|
| 380 |
+
Includes a manual iSTFT implementation using matmul with a precomputed
|
| 381 |
+
DFT basis matrix, so the entire pipeline runs inside CoreML.
|
| 382 |
+
"""
|
| 383 |
+
|
| 384 |
+
def __init__(self, vocoder, num_stft_frames: int):
|
| 385 |
+
super().__init__()
|
| 386 |
+
self.vocoder = vocoder
|
| 387 |
+
self.num_stft_frames = num_stft_frames # hardcoded for tracing
|
| 388 |
+
n_fft = vocoder.istft_n_fft # 16
|
| 389 |
+
hop_len = vocoder.istft_hop_len # 4
|
| 390 |
+
|
| 391 |
+
# Precompute DFT basis for iSTFT: (n_fft, n_fft) real-valued IDFT matrix
|
| 392 |
+
# X[k] = sum_n x[n] * exp(j*2pi*n*k/N) → x[n] = (1/N) * sum_k X[k] * exp(j*2pi*n*k/N)
|
| 393 |
+
n = torch.arange(n_fft, dtype=torch.float32)
|
| 394 |
+
k = torch.arange(n_fft, dtype=torch.float32)
|
| 395 |
+
angles = 2.0 * torch.pi * n.unsqueeze(1) * k.unsqueeze(0) / n_fft # (n_fft, n_fft)
|
| 396 |
+
# cos/sin basis for real/imag parts
|
| 397 |
+
self.register_buffer("idft_cos", torch.cos(angles) / n_fft) # (n_fft, n_fft)
|
| 398 |
+
self.register_buffer("idft_sin", torch.sin(angles) / n_fft) # (n_fft, n_fft)
|
| 399 |
+
|
| 400 |
+
# Window for overlap-add
|
| 401 |
+
self.register_buffer("window", vocoder.stft_window.clone())
|
| 402 |
+
self.n_fft = n_fft
|
| 403 |
+
self.hop_len = hop_len
|
| 404 |
+
self.n_fft_half = n_fft // 2 + 1 # 9
|
| 405 |
+
|
| 406 |
+
def forward(self, mel: torch.Tensor, source_stft: torch.Tensor) -> torch.Tensor:
|
| 407 |
+
"""
|
| 408 |
+
Args:
|
| 409 |
+
mel: (1, 80, T) float32
|
| 410 |
+
source_stft: (1, 18, T') float32 — real+imag STFT of source signal
|
| 411 |
+
|
| 412 |
+
Returns:
|
| 413 |
+
waveform: (1, samples) float32
|
| 414 |
+
"""
|
| 415 |
+
x = self.vocoder.conv_pre(mel)
|
| 416 |
+
for i in range(self.vocoder.num_upsamples):
|
| 417 |
+
x = F.leaky_relu(x, self.vocoder.lrelu_slope)
|
| 418 |
+
x = self.vocoder.ups[i](x)
|
| 419 |
+
if i == self.vocoder.num_upsamples - 1:
|
| 420 |
+
x = self.vocoder.reflection_pad(x)
|
| 421 |
+
|
| 422 |
+
si = self.vocoder.source_downs[i](source_stft)
|
| 423 |
+
si = self.vocoder.source_resblocks[i](si)
|
| 424 |
+
x = x + si
|
| 425 |
+
|
| 426 |
+
xs = None
|
| 427 |
+
for j in range(self.vocoder.num_kernels):
|
| 428 |
+
if xs is None:
|
| 429 |
+
xs = self.vocoder.resblocks[i * self.vocoder.num_kernels + j](x)
|
| 430 |
+
else:
|
| 431 |
+
xs += self.vocoder.resblocks[i * self.vocoder.num_kernels + j](x)
|
| 432 |
+
x = xs / self.vocoder.num_kernels
|
| 433 |
+
|
| 434 |
+
x = F.leaky_relu(x)
|
| 435 |
+
x = self.vocoder.conv_post(x) # (1, 18, num_frames)
|
| 436 |
+
|
| 437 |
+
# Split into magnitude and phase
|
| 438 |
+
magnitude = torch.exp(x[:, :self.n_fft_half, :]) # (1, 9, num_frames)
|
| 439 |
+
phase = torch.sin(x[:, self.n_fft_half:, :]) # (1, 9, num_frames)
|
| 440 |
+
|
| 441 |
+
# Convert to real/imag
|
| 442 |
+
real_half = magnitude * torch.cos(phase) # (1, 9, num_frames)
|
| 443 |
+
imag_half = magnitude * torch.sin(phase)
|
| 444 |
+
|
| 445 |
+
# Mirror to full spectrum (Hermitian symmetry)
|
| 446 |
+
# real: [r0, r1, ..., r8, r7, r6, ..., r1]
|
| 447 |
+
# imag: [i0, i1, ..., i8, -i7, -i6, ..., -i1]
|
| 448 |
+
real_mirror = torch.flip(real_half[:, 1:self.n_fft_half - 1, :], dims=[1])
|
| 449 |
+
imag_mirror = -torch.flip(imag_half[:, 1:self.n_fft_half - 1, :], dims=[1])
|
| 450 |
+
real_full = torch.cat([real_half, real_mirror], dim=1) # (1, 16, num_frames)
|
| 451 |
+
imag_full = torch.cat([imag_half, imag_mirror], dim=1) # (1, 16, num_frames)
|
| 452 |
+
|
| 453 |
+
# iDFT via matmul: output[n] = sum_k (real[k]*cos[n,k] - imag[k]*sin[n,k])
|
| 454 |
+
# (1, 16, num_frames) → transpose to (1, num_frames, 16) → matmul with (16, 16)
|
| 455 |
+
real_t = real_full.transpose(1, 2) # (1, num_frames, 16)
|
| 456 |
+
imag_t = imag_full.transpose(1, 2)
|
| 457 |
+
# segments[n] = sum_k real[k]*cos[n,k] - imag[k]*sin[n,k]
|
| 458 |
+
# = real_t @ idft_cos.T - imag_t @ idft_sin.T
|
| 459 |
+
# But idft_cos is (n_fft, n_fft) where idft_cos[n,k] = cos(2pi*n*k/N)/N
|
| 460 |
+
# We want segments[frame, n] = sum_k (real[frame,k] * idft_cos[n,k] - imag[frame,k] * idft_sin[n,k])
|
| 461 |
+
# = (real_t @ idft_cos^T - imag_t @ idft_sin^T)[frame, n]
|
| 462 |
+
segments = torch.matmul(real_t, self.idft_cos.T) - torch.matmul(imag_t, self.idft_sin.T)
|
| 463 |
+
# segments: (1, num_frames, 16)
|
| 464 |
+
|
| 465 |
+
# Overlap-add with window
|
| 466 |
+
# n_fft=16, hop=4, so overlap ratio = 4 (each sample covered by 4 frames)
|
| 467 |
+
NF = self.num_stft_frames # hardcoded constant for tracing
|
| 468 |
+
segments = segments * self.window.unsqueeze(0).unsqueeze(0) # (1, NF, 16)
|
| 469 |
+
seg = segments.squeeze(0) # (NF, 16)
|
| 470 |
+
|
| 471 |
+
# Reshape each 16-sample segment into 4 chunks of 4 (hop_len) samples
|
| 472 |
+
# seg: (F, 16) → (F, 4, 4)
|
| 473 |
+
seg_chunks = seg.reshape(NF, 4, 4) # (F, 4_blocks, 4_samples)
|
| 474 |
+
|
| 475 |
+
# Block b of frame f lands at output position (f + b) * hop_len
|
| 476 |
+
# Rearrange so block b from all frames is contiguous:
|
| 477 |
+
# chunk_b[f] = seg_chunks[f, b, :] lands at output[(f+b)*4 : (f+b)*4 + 4]
|
| 478 |
+
# = output index f*4 + b*4 ... but shifted by b frames
|
| 479 |
+
# Equivalently: for block b, we have F values that go to positions b, b+1, ..., b+F-1
|
| 480 |
+
# in units of hop_len
|
| 481 |
+
|
| 482 |
+
# For each sub-block offset (0..3), create a flat array and add shifted
|
| 483 |
+
# Using static slicing only — no dynamic indexing
|
| 484 |
+
padded_samples = NF * 4 + 12 # (NF-1)*4 + 16
|
| 485 |
+
# Actually: (num_frames - 1) * 4 + 16 = num_frames * 4 + 12
|
| 486 |
+
|
| 487 |
+
# Each sub-block b contributes F chunks of 4 samples, placed at positions
|
| 488 |
+
# starting from b*4 with stride 4 between frames.
|
| 489 |
+
# block_b = seg_chunks[:, b, :].reshape(-1) → F*4 contiguous values
|
| 490 |
+
# These go to output[b*4 : b*4 + F*4]
|
| 491 |
+
b0 = seg_chunks[:, 0, :].reshape(-1) # (F*4,) → output[0 : F*4]
|
| 492 |
+
b1 = seg_chunks[:, 1, :].reshape(-1) # (F*4,) → output[4 : F*4 + 4]
|
| 493 |
+
b2 = seg_chunks[:, 2, :].reshape(-1) # (F*4,) → output[8 : F*4 + 8]
|
| 494 |
+
b3 = seg_chunks[:, 3, :].reshape(-1) # (F*4,) → output[12 : F*4 + 12]
|
| 495 |
+
|
| 496 |
+
F4 = NF * 4
|
| 497 |
+
output = torch.zeros(padded_samples)
|
| 498 |
+
output[0:F4] = output[0:F4] + b0
|
| 499 |
+
output[4:F4 + 4] = output[4:F4 + 4] + b1
|
| 500 |
+
output[8:F4 + 8] = output[8:F4 + 8] + b2
|
| 501 |
+
output[12:F4 + 12] = output[12:F4 + 12] + b3
|
| 502 |
+
|
| 503 |
+
# Window normalization — same structure
|
| 504 |
+
win_sq = self.window * self.window # (16,)
|
| 505 |
+
win_chunks = win_sq.reshape(4, 4) # (4_blocks, 4_samples)
|
| 506 |
+
w0 = win_chunks[0].repeat(NF)
|
| 507 |
+
w1 = win_chunks[1].repeat(NF)
|
| 508 |
+
w2 = win_chunks[2].repeat(NF)
|
| 509 |
+
w3 = win_chunks[3].repeat(NF)
|
| 510 |
+
|
| 511 |
+
wnorm = torch.zeros(padded_samples)
|
| 512 |
+
wnorm[0:F4] = wnorm[0:F4] + w0
|
| 513 |
+
wnorm[4:F4 + 4] = wnorm[4:F4 + 4] + w1
|
| 514 |
+
wnorm[8:F4 + 8] = wnorm[8:F4 + 8] + w2
|
| 515 |
+
wnorm[12:F4 + 12] = wnorm[12:F4 + 12] + w3
|
| 516 |
+
|
| 517 |
+
output = output / (wnorm + 1e-8)
|
| 518 |
+
|
| 519 |
+
# Trim center padding: n_fft//2 = 8 from start
|
| 520 |
+
pad = 8
|
| 521 |
+
trimmed_len = (NF - 1) * 4 # expected output length
|
| 522 |
+
output = output[pad:pad + trimmed_len]
|
| 523 |
+
output = torch.clamp(output, -0.99, 0.99)
|
| 524 |
+
return output.unsqueeze(0) # (1, samples)
|
| 525 |
+
|
| 526 |
+
|
| 527 |
+
def convert_kanade_decoder(kanade: KanadeModel, num_tokens: int, output_dir: Path):
|
| 528 |
+
"""Convert Kanade decoder to CoreML."""
|
| 529 |
+
wrapper = KanadeDecoderWrapper(kanade, num_tokens).eval().float()
|
| 530 |
+
print(f"Tracing Kanade decoder (num_tokens={num_tokens}, mel_length={wrapper.mel_length})...")
|
| 531 |
+
|
| 532 |
+
token_indices = torch.arange(num_tokens, dtype=torch.int32)
|
| 533 |
+
speaker_embedding = torch.randn(1, 128, dtype=torch.float32)
|
| 534 |
+
|
| 535 |
+
with torch.no_grad():
|
| 536 |
+
# Test forward
|
| 537 |
+
mel = wrapper(token_indices, speaker_embedding)
|
| 538 |
+
print(f" Output mel shape: {mel.shape}")
|
| 539 |
+
|
| 540 |
+
traced = torch.jit.trace(wrapper, (token_indices, speaker_embedding))
|
| 541 |
+
|
| 542 |
+
print("Converting Kanade decoder to CoreML...")
|
| 543 |
+
mlmodel = ct.convert(
|
| 544 |
+
traced,
|
| 545 |
+
inputs=[
|
| 546 |
+
ct.TensorType(name="token_indices", shape=(num_tokens,), dtype=np.int32),
|
| 547 |
+
ct.TensorType(name="speaker_embedding", shape=(1, 128), dtype=np.float32),
|
| 548 |
+
],
|
| 549 |
+
outputs=[ct.TensorType(name="mel", dtype=np.float32)],
|
| 550 |
+
compute_precision=ct.precision.FLOAT32,
|
| 551 |
+
minimum_deployment_target=ct.target.iOS17,
|
| 552 |
+
)
|
| 553 |
+
|
| 554 |
+
out_path = output_dir / "KanadeDecoder.mlpackage"
|
| 555 |
+
mlmodel.save(str(out_path))
|
| 556 |
+
print(f"Saved Kanade decoder to {out_path}")
|
| 557 |
+
|
| 558 |
+
|
| 559 |
+
def convert_f0_predictor(vocoder, mel_length: int, output_dir: Path):
|
| 560 |
+
"""Convert HiFT f0 predictor to CoreML."""
|
| 561 |
+
wrapper = F0PredictorWrapper(vocoder).eval().float()
|
| 562 |
+
print(f"Tracing F0 predictor (mel_length={mel_length})...")
|
| 563 |
+
|
| 564 |
+
mel = torch.randn(1, 80, mel_length, dtype=torch.float32)
|
| 565 |
+
|
| 566 |
+
with torch.no_grad():
|
| 567 |
+
f0 = wrapper(mel)
|
| 568 |
+
print(f" Output f0 shape: {f0.shape}")
|
| 569 |
+
traced = torch.jit.trace(wrapper, (mel,))
|
| 570 |
+
|
| 571 |
+
print("Converting F0 predictor to CoreML...")
|
| 572 |
+
mlmodel = ct.convert(
|
| 573 |
+
traced,
|
| 574 |
+
inputs=[
|
| 575 |
+
ct.TensorType(name="mel", shape=(1, 80, mel_length), dtype=np.float32),
|
| 576 |
+
],
|
| 577 |
+
outputs=[ct.TensorType(name="f0", dtype=np.float32)],
|
| 578 |
+
compute_precision=ct.precision.FLOAT32,
|
| 579 |
+
minimum_deployment_target=ct.target.iOS17,
|
| 580 |
+
)
|
| 581 |
+
|
| 582 |
+
out_path = output_dir / "F0Predictor.mlpackage"
|
| 583 |
+
mlmodel.save(str(out_path))
|
| 584 |
+
print(f"Saved F0 predictor to {out_path}")
|
| 585 |
+
|
| 586 |
+
|
| 587 |
+
def convert_hift_decode(vocoder, mel_length: int, output_dir: Path):
|
| 588 |
+
"""Convert HiFT decode stage to CoreML.
|
| 589 |
+
|
| 590 |
+
Source signal STFT must be computed externally (Swift side).
|
| 591 |
+
"""
|
| 592 |
+
# Compute source_stft shape: run f0 predictor + source module to get it
|
| 593 |
+
mel = torch.randn(1, 80, mel_length, dtype=torch.float32)
|
| 594 |
+
with torch.no_grad():
|
| 595 |
+
f0 = vocoder.f0_predictor(mel)
|
| 596 |
+
s = vocoder.f0_upsamp(f0[:, None]).transpose(1, 2)
|
| 597 |
+
s, _, _ = vocoder.m_source(s)
|
| 598 |
+
s = s.transpose(1, 2)
|
| 599 |
+
s_stft_real, s_stft_imag = vocoder._stft(s.squeeze(1))
|
| 600 |
+
source_stft = torch.cat([s_stft_real, s_stft_imag], dim=1)
|
| 601 |
+
num_stft_frames = source_stft.shape[2]
|
| 602 |
+
print(f" Source STFT shape: {source_stft.shape} ({num_stft_frames} frames)")
|
| 603 |
+
|
| 604 |
+
wrapper = HiFTDecodeWrapper(vocoder, num_stft_frames).eval().float()
|
| 605 |
+
|
| 606 |
+
print(f"Tracing HiFT decode (mel_length={mel_length})...")
|
| 607 |
+
with torch.no_grad():
|
| 608 |
+
waveform = wrapper(mel, source_stft)
|
| 609 |
+
print(f" Output waveform shape: {waveform.shape}")
|
| 610 |
+
traced = torch.jit.trace(wrapper, (mel, source_stft))
|
| 611 |
+
|
| 612 |
+
print("Converting HiFT decode to CoreML...")
|
| 613 |
+
source_stft_channels = source_stft.shape[1]
|
| 614 |
+
source_stft_time = source_stft.shape[2]
|
| 615 |
+
mlmodel = ct.convert(
|
| 616 |
+
traced,
|
| 617 |
+
inputs=[
|
| 618 |
+
ct.TensorType(name="mel", shape=(1, 80, mel_length), dtype=np.float32),
|
| 619 |
+
ct.TensorType(
|
| 620 |
+
name="source_stft",
|
| 621 |
+
shape=(1, source_stft_channels, source_stft_time),
|
| 622 |
+
dtype=np.float32,
|
| 623 |
+
),
|
| 624 |
+
],
|
| 625 |
+
outputs=[ct.TensorType(name="waveform", dtype=np.float32)],
|
| 626 |
+
compute_precision=ct.precision.FLOAT32,
|
| 627 |
+
minimum_deployment_target=ct.target.iOS17,
|
| 628 |
+
)
|
| 629 |
+
|
| 630 |
+
out_path = output_dir / "HiFTDecode.mlpackage"
|
| 631 |
+
mlmodel.save(str(out_path))
|
| 632 |
+
print(f"Saved HiFT decode to {out_path}")
|
| 633 |
+
|
| 634 |
+
|
| 635 |
+
def main():
|
| 636 |
+
parser = argparse.ArgumentParser(description="Convert Kanade + HiFT to CoreML")
|
| 637 |
+
parser.add_argument(
|
| 638 |
+
"--output-dir", type=str,
|
| 639 |
+
default=str(Path(__file__).parent.parent),
|
| 640 |
+
help="Output directory",
|
| 641 |
+
)
|
| 642 |
+
parser.add_argument(
|
| 643 |
+
"--num-tokens", type=int, default=100,
|
| 644 |
+
help="Fixed number of audio tokens (determines mel length)",
|
| 645 |
+
)
|
| 646 |
+
args = parser.parse_args()
|
| 647 |
+
output_dir = Path(args.output_dir)
|
| 648 |
+
output_dir.mkdir(parents=True, exist_ok=True)
|
| 649 |
+
|
| 650 |
+
print("Loading Kanade model...")
|
| 651 |
+
kanade = KanadeModel.from_pretrained("frothywater/kanade-25hz-clean").eval().float()
|
| 652 |
+
patch_kanade_for_coreml(kanade)
|
| 653 |
+
vocoder = load_vocoder(kanade.config.vocoder_name).eval().float()
|
| 654 |
+
|
| 655 |
+
# Compute mel_length for this token count
|
| 656 |
+
mel_length = kanade._calculate_target_mel_length(
|
| 657 |
+
kanade._calculate_original_audio_length(args.num_tokens)
|
| 658 |
+
)
|
| 659 |
+
|
| 660 |
+
print(f"\n=== Converting Kanade decoder ===")
|
| 661 |
+
convert_kanade_decoder(kanade, args.num_tokens, output_dir)
|
| 662 |
+
|
| 663 |
+
print(f"\n=== Converting full vocoder (mel → waveform) ===")
|
| 664 |
+
convert_full_vocoder(vocoder, mel_length, output_dir)
|
| 665 |
+
|
| 666 |
+
print("\nDone!")
|
| 667 |
+
print(f" KanadeDecoder: {args.num_tokens} tokens → mel (80, {mel_length})")
|
| 668 |
+
print(f" Vocoder: mel (80, {mel_length}) → waveform")
|
| 669 |
+
|
| 670 |
+
|
| 671 |
+
def convert_full_vocoder(vocoder, mel_length: int, output_dir: Path):
|
| 672 |
+
"""Convert complete mel→waveform vocoder to CoreML."""
|
| 673 |
+
# Get num_stft_frames by running a dummy forward
|
| 674 |
+
mel = torch.randn(1, 80, mel_length, dtype=torch.float32)
|
| 675 |
+
with torch.no_grad():
|
| 676 |
+
f0 = vocoder.f0_predictor(mel)
|
| 677 |
+
s = vocoder.f0_upsamp(f0[:, None]).transpose(1, 2)
|
| 678 |
+
s, _, _ = vocoder.m_source(s)
|
| 679 |
+
s = s.transpose(1, 2)
|
| 680 |
+
sr, si = vocoder._stft(s.squeeze(1))
|
| 681 |
+
num_stft_frames = sr.shape[2]
|
| 682 |
+
print(f" STFT frames: {num_stft_frames}")
|
| 683 |
+
|
| 684 |
+
wrapper = FullVocoderWrapper(vocoder, num_stft_frames).eval().float()
|
| 685 |
+
|
| 686 |
+
print(f"Tracing full vocoder (mel_length={mel_length})...")
|
| 687 |
+
# Replace randn_like with zeros for tracing
|
| 688 |
+
orig_randn = torch.randn_like
|
| 689 |
+
torch.randn_like = lambda x, **kw: torch.zeros_like(x)
|
| 690 |
+
with torch.no_grad():
|
| 691 |
+
wav = wrapper(mel)
|
| 692 |
+
print(f" Output waveform: {wav.shape}")
|
| 693 |
+
traced = torch.jit.trace(wrapper, (mel,))
|
| 694 |
+
torch.randn_like = orig_randn
|
| 695 |
+
|
| 696 |
+
print("Converting full vocoder to CoreML...")
|
| 697 |
+
mlmodel = ct.convert(
|
| 698 |
+
traced,
|
| 699 |
+
inputs=[ct.TensorType(name="mel", shape=(1, 80, mel_length), dtype=np.float32)],
|
| 700 |
+
outputs=[ct.TensorType(name="waveform", dtype=np.float32)],
|
| 701 |
+
compute_precision=ct.precision.FLOAT32,
|
| 702 |
+
minimum_deployment_target=ct.target.iOS17,
|
| 703 |
+
)
|
| 704 |
+
|
| 705 |
+
out_path = output_dir / "Vocoder.mlpackage"
|
| 706 |
+
mlmodel.save(str(out_path))
|
| 707 |
+
print(f"Saved vocoder to {out_path}")
|
| 708 |
+
|
| 709 |
+
|
| 710 |
+
if __name__ == "__main__":
|
| 711 |
+
main()
|
scripts/model_wrapper.py
CHANGED
|
@@ -118,8 +118,8 @@ class PlaprePicoPrefill(nn.Module):
|
|
| 118 |
class PlaprePicoDecode(nn.Module):
|
| 119 |
"""Generates one token at a time using the KV cache.
|
| 120 |
|
| 121 |
-
|
| 122 |
-
|
| 123 |
|
| 124 |
Inputs:
|
| 125 |
input_ids: (1, 1) int32
|
|
@@ -127,6 +127,7 @@ class PlaprePicoDecode(nn.Module):
|
|
| 127 |
cos: (1, 1, 1, 64) float16 — RoPE cos for current position
|
| 128 |
sin: (1, 1, 1, 64) float16 — RoPE sin for current position
|
| 129 |
update_mask: (1, 1, 2048, 1) float16 — one-hot at current position
|
|
|
|
| 130 |
|
| 131 |
State buffers:
|
| 132 |
k_cache_0..29, v_cache_0..29: (1, 3, 2048, 64) float16
|
|
@@ -174,8 +175,12 @@ class PlaprePicoDecode(nn.Module):
|
|
| 174 |
cos: torch.Tensor,
|
| 175 |
sin: torch.Tensor,
|
| 176 |
update_mask: torch.Tensor,
|
|
|
|
| 177 |
) -> torch.Tensor:
|
| 178 |
hidden = self.embed_tokens(input_ids) # (1, 1, 576)
|
|
|
|
|
|
|
|
|
|
| 179 |
|
| 180 |
for i, layer in enumerate(self.layers):
|
| 181 |
k_cache = getattr(self, f"k_cache_{i}")
|
|
|
|
| 118 |
class PlaprePicoDecode(nn.Module):
|
| 119 |
"""Generates one token at a time using the KV cache.
|
| 120 |
|
| 121 |
+
Also used for token-by-token prefill. For the speaker token (position 0),
|
| 122 |
+
pass a non-zero speaker_hidden to replace the token embedding.
|
| 123 |
|
| 124 |
Inputs:
|
| 125 |
input_ids: (1, 1) int32
|
|
|
|
| 127 |
cos: (1, 1, 1, 64) float16 — RoPE cos for current position
|
| 128 |
sin: (1, 1, 1, 64) float16 — RoPE sin for current position
|
| 129 |
update_mask: (1, 1, 2048, 1) float16 — one-hot at current position
|
| 130 |
+
speaker_hidden: (1, 1, 576) float16 — pre-projected speaker embedding, or zeros
|
| 131 |
|
| 132 |
State buffers:
|
| 133 |
k_cache_0..29, v_cache_0..29: (1, 3, 2048, 64) float16
|
|
|
|
| 175 |
cos: torch.Tensor,
|
| 176 |
sin: torch.Tensor,
|
| 177 |
update_mask: torch.Tensor,
|
| 178 |
+
speaker_hidden: torch.Tensor,
|
| 179 |
) -> torch.Tensor:
|
| 180 |
hidden = self.embed_tokens(input_ids) # (1, 1, 576)
|
| 181 |
+
# Speaker conditioning: caller passes pre-projected (1,1,576) for position 0,
|
| 182 |
+
# zeros for all other positions. Additive — zeros are a no-op.
|
| 183 |
+
hidden = hidden + speaker_hidden
|
| 184 |
|
| 185 |
for i, layer in enumerate(self.layers):
|
| 186 |
k_cache = getattr(self, f"k_cache_{i}")
|
scripts/test_generate.py
ADDED
|
@@ -0,0 +1,250 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#!/usr/bin/env python3
|
| 2 |
+
"""
|
| 3 |
+
End-to-end test: generate Danish speech using our custom PyTorch wrappers
|
| 4 |
+
(the same code converted to CoreML), decode with Kanade, save as WAV.
|
| 5 |
+
|
| 6 |
+
Usage:
|
| 7 |
+
python scripts/test_generate.py [--text "Hej verden"] [--speaker tor] [--output test.wav]
|
| 8 |
+
"""
|
| 9 |
+
|
| 10 |
+
import argparse
|
| 11 |
+
import json
|
| 12 |
+
import sys
|
| 13 |
+
from pathlib import Path
|
| 14 |
+
|
| 15 |
+
import numpy as np
|
| 16 |
+
import torch
|
| 17 |
+
import torch.nn.functional as F
|
| 18 |
+
import soundfile as sf
|
| 19 |
+
|
| 20 |
+
sys.path.insert(0, str(Path(__file__).parent))
|
| 21 |
+
|
| 22 |
+
from attention import precompute_rope_frequencies
|
| 23 |
+
from model_wrapper import (
|
| 24 |
+
PlaprePicoPrefill,
|
| 25 |
+
PlaprePicoDecode,
|
| 26 |
+
NUM_LAYERS,
|
| 27 |
+
MAX_CONTEXT,
|
| 28 |
+
HEAD_DIM,
|
| 29 |
+
PREFILL_SEQ_LEN,
|
| 30 |
+
SPEAKER_DIM,
|
| 31 |
+
)
|
| 32 |
+
from convert import load_weights, populate_weights
|
| 33 |
+
|
| 34 |
+
AUDIO_TOKEN_OFFSET = 8002
|
| 35 |
+
AUDIO_MARKER_TOKEN = 8001
|
| 36 |
+
TEXT_MARKER_TOKEN = 8000
|
| 37 |
+
EOS_TOKEN = 2
|
| 38 |
+
|
| 39 |
+
|
| 40 |
+
def load_speaker(speakers_path: Path, name: str) -> torch.Tensor:
|
| 41 |
+
with open(speakers_path) as f:
|
| 42 |
+
speakers = json.load(f)
|
| 43 |
+
if name not in speakers:
|
| 44 |
+
raise ValueError(f"Speaker '{name}' not found. Available: {list(speakers.keys())}")
|
| 45 |
+
return torch.tensor(speakers[name], dtype=torch.float16).unsqueeze(0)
|
| 46 |
+
|
| 47 |
+
|
| 48 |
+
def sample(logits: torch.Tensor, temperature: float, top_k: int, top_p: float) -> int:
|
| 49 |
+
if temperature <= 0:
|
| 50 |
+
return int(logits.argmax())
|
| 51 |
+
logits = logits.float() / temperature
|
| 52 |
+
if top_k > 0:
|
| 53 |
+
topv, topi = torch.topk(logits, top_k)
|
| 54 |
+
logits_filtered = torch.full_like(logits, float("-inf"))
|
| 55 |
+
logits_filtered.scatter_(0, topi, topv)
|
| 56 |
+
else:
|
| 57 |
+
logits_filtered = logits
|
| 58 |
+
probs = F.softmax(logits_filtered, dim=-1)
|
| 59 |
+
sorted_probs, sorted_idx = torch.sort(probs, descending=True)
|
| 60 |
+
cumsum = torch.cumsum(sorted_probs, dim=0)
|
| 61 |
+
mask = cumsum - sorted_probs > top_p
|
| 62 |
+
sorted_probs[mask] = 0
|
| 63 |
+
sorted_probs /= sorted_probs.sum()
|
| 64 |
+
idx = torch.multinomial(sorted_probs, 1)
|
| 65 |
+
return int(sorted_idx[idx])
|
| 66 |
+
|
| 67 |
+
|
| 68 |
+
def generate(
|
| 69 |
+
prefill_model: PlaprePicoPrefill,
|
| 70 |
+
decode_model: PlaprePicoDecode,
|
| 71 |
+
text: str,
|
| 72 |
+
speaker_embedding: torch.Tensor,
|
| 73 |
+
tokenizer_path: Path,
|
| 74 |
+
max_tokens: int,
|
| 75 |
+
temperature: float,
|
| 76 |
+
top_k: int,
|
| 77 |
+
top_p: float,
|
| 78 |
+
) -> list[int]:
|
| 79 |
+
from tokenizers import Tokenizer
|
| 80 |
+
tokenizer = Tokenizer.from_file(str(tokenizer_path))
|
| 81 |
+
token_ids = tokenizer.encode(text).ids
|
| 82 |
+
|
| 83 |
+
# Plapre format: [placeholder, <text>, tokens..., <audio>]
|
| 84 |
+
# Position 0 placeholder gets replaced by speaker_proj output
|
| 85 |
+
input_ids_list = [EOS_TOKEN] + [TEXT_MARKER_TOKEN] + token_ids + [AUDIO_MARKER_TOKEN]
|
| 86 |
+
input_len = len(input_ids_list)
|
| 87 |
+
print(f"Input ({input_len} tokens): {input_ids_list}")
|
| 88 |
+
|
| 89 |
+
# Pad to prefill length
|
| 90 |
+
padded_ids = torch.full((1, PREFILL_SEQ_LEN), EOS_TOKEN, dtype=torch.int32)
|
| 91 |
+
for i, tid in enumerate(input_ids_list):
|
| 92 |
+
padded_ids[0, i] = tid
|
| 93 |
+
|
| 94 |
+
# Causal mask: only real tokens (0..input_len-1) attend
|
| 95 |
+
causal_mask = torch.full(
|
| 96 |
+
(1, 1, PREFILL_SEQ_LEN, MAX_CONTEXT), float("-inf"), dtype=torch.float16
|
| 97 |
+
)
|
| 98 |
+
for i in range(input_len):
|
| 99 |
+
causal_mask[0, 0, i, :i + 1] = 0.0
|
| 100 |
+
|
| 101 |
+
# === Prefill ===
|
| 102 |
+
# We can't get logits at an arbitrary position from the wrapper (it returns pos -1).
|
| 103 |
+
# So run the layers manually to read logits at input_len - 1.
|
| 104 |
+
print("Running prefill...")
|
| 105 |
+
with torch.no_grad():
|
| 106 |
+
hidden = prefill_model.embed_tokens(padded_ids)
|
| 107 |
+
spk = prefill_model.speaker_proj(speaker_embedding).unsqueeze(1)
|
| 108 |
+
hidden = torch.cat([spk, hidden[:, 1:, :]], dim=1)
|
| 109 |
+
|
| 110 |
+
cos = prefill_model.rope_cos
|
| 111 |
+
sin = prefill_model.rope_sin
|
| 112 |
+
|
| 113 |
+
for i, layer in enumerate(prefill_model.layers):
|
| 114 |
+
k_cache = getattr(prefill_model, f"k_cache_{i}")
|
| 115 |
+
v_cache = getattr(prefill_model, f"v_cache_{i}")
|
| 116 |
+
hidden, k_new, v_new = layer(hidden, cos, sin, causal_mask, k_cache, v_cache)
|
| 117 |
+
# Update caches on the model so decode can copy them
|
| 118 |
+
setattr(prefill_model, f"k_cache_{i}", k_new)
|
| 119 |
+
setattr(prefill_model, f"v_cache_{i}", v_new)
|
| 120 |
+
|
| 121 |
+
hidden = prefill_model.norm(hidden)
|
| 122 |
+
logits = F.linear(hidden[0, input_len - 1, :], prefill_model.embed_tokens.weight)
|
| 123 |
+
|
| 124 |
+
generated = []
|
| 125 |
+
next_token = sample(logits, temperature, top_k, top_p)
|
| 126 |
+
generated.append(next_token)
|
| 127 |
+
print(f" Token 0: {next_token}")
|
| 128 |
+
|
| 129 |
+
# === Copy KV cache to decode model ===
|
| 130 |
+
with torch.no_grad():
|
| 131 |
+
for i in range(NUM_LAYERS):
|
| 132 |
+
getattr(decode_model, f"k_cache_{i}").copy_(getattr(prefill_model, f"k_cache_{i}"))
|
| 133 |
+
getattr(decode_model, f"v_cache_{i}").copy_(getattr(prefill_model, f"v_cache_{i}"))
|
| 134 |
+
|
| 135 |
+
# === Decode loop ===
|
| 136 |
+
cos_full, sin_full = precompute_rope_frequencies(HEAD_DIM, MAX_CONTEXT, 100000.0)
|
| 137 |
+
cos_full = cos_full.half()
|
| 138 |
+
sin_full = sin_full.half()
|
| 139 |
+
|
| 140 |
+
print("Decoding...")
|
| 141 |
+
for step in range(1, max_tokens):
|
| 142 |
+
pos = input_len + step - 1
|
| 143 |
+
|
| 144 |
+
decode_ids = torch.tensor([[next_token]], dtype=torch.int32)
|
| 145 |
+
decode_mask = torch.full((1, 1, 1, MAX_CONTEXT), float("-inf"), dtype=torch.float16)
|
| 146 |
+
decode_mask[0, 0, 0, :pos + 1] = 0.0
|
| 147 |
+
pos_cos = cos_full[:, :, pos:pos + 1, :]
|
| 148 |
+
pos_sin = sin_full[:, :, pos:pos + 1, :]
|
| 149 |
+
update_mask = torch.zeros(1, 1, MAX_CONTEXT, 1, dtype=torch.float16)
|
| 150 |
+
update_mask[0, 0, pos, 0] = 1.0
|
| 151 |
+
|
| 152 |
+
with torch.no_grad():
|
| 153 |
+
logits = decode_model(decode_ids, decode_mask, pos_cos, pos_sin, update_mask)
|
| 154 |
+
|
| 155 |
+
next_token = sample(logits[0, 0], temperature, top_k, top_p)
|
| 156 |
+
generated.append(next_token)
|
| 157 |
+
|
| 158 |
+
if next_token == EOS_TOKEN:
|
| 159 |
+
print(f" EOS at step {step}")
|
| 160 |
+
break
|
| 161 |
+
if step % 25 == 0:
|
| 162 |
+
print(f" Step {step}: ({step / 25:.1f}s of audio)")
|
| 163 |
+
|
| 164 |
+
return generated
|
| 165 |
+
|
| 166 |
+
|
| 167 |
+
def decode_audio(tokens: list[int], speaker_embedding: torch.Tensor) -> np.ndarray:
|
| 168 |
+
from kanade_tokenizer import KanadeModel, load_vocoder, vocode
|
| 169 |
+
|
| 170 |
+
audio_tokens = [t for t in tokens if AUDIO_TOKEN_OFFSET <= t <= 20801]
|
| 171 |
+
if not audio_tokens:
|
| 172 |
+
raise ValueError("No audio tokens generated!")
|
| 173 |
+
|
| 174 |
+
kanade_indices = torch.tensor([t - AUDIO_TOKEN_OFFSET for t in audio_tokens])
|
| 175 |
+
print(f"Decoding {len(kanade_indices)} audio tokens ({len(kanade_indices) / 25:.1f}s)...")
|
| 176 |
+
|
| 177 |
+
model = KanadeModel.from_pretrained("frothywater/kanade-25hz-clean").eval()
|
| 178 |
+
vocoder = load_vocoder(model.config.vocoder_name)
|
| 179 |
+
|
| 180 |
+
with torch.no_grad():
|
| 181 |
+
spk = speaker_embedding.squeeze(0).float()
|
| 182 |
+
mel = model.decode(global_embedding=spk, content_token_indices=kanade_indices)
|
| 183 |
+
waveform = vocode(vocoder, mel.unsqueeze(0))
|
| 184 |
+
|
| 185 |
+
return waveform.squeeze().cpu().numpy()
|
| 186 |
+
|
| 187 |
+
|
| 188 |
+
def main():
|
| 189 |
+
parser = argparse.ArgumentParser(description="Generate Danish speech (custom model)")
|
| 190 |
+
parser.add_argument("--text", type=str, default="Hej, mit navn er Daniel.")
|
| 191 |
+
parser.add_argument("--speaker", type=str, default="tor")
|
| 192 |
+
parser.add_argument("--output", type=str, default="test.wav")
|
| 193 |
+
parser.add_argument("--max-tokens", type=int, default=500)
|
| 194 |
+
parser.add_argument("--temperature", type=float, default=0.8)
|
| 195 |
+
parser.add_argument("--top-k", type=int, default=50)
|
| 196 |
+
parser.add_argument("--top-p", type=float, default=0.95)
|
| 197 |
+
parser.add_argument("--model-dir", type=str, default=None)
|
| 198 |
+
args = parser.parse_args()
|
| 199 |
+
|
| 200 |
+
if args.model_dir:
|
| 201 |
+
model_dir = Path(args.model_dir)
|
| 202 |
+
else:
|
| 203 |
+
cache = Path.home() / ".cache/huggingface/hub/models--syvai--plapre-pico"
|
| 204 |
+
snapshots = cache / "snapshots"
|
| 205 |
+
if snapshots.exists():
|
| 206 |
+
model_dir = next(snapshots.iterdir())
|
| 207 |
+
else:
|
| 208 |
+
from huggingface_hub import snapshot_download
|
| 209 |
+
model_dir = Path(snapshot_download("syvai/plapre-pico"))
|
| 210 |
+
|
| 211 |
+
repo_root = Path(__file__).parent.parent
|
| 212 |
+
|
| 213 |
+
speakers_path = repo_root / "speakers.json"
|
| 214 |
+
if not speakers_path.exists():
|
| 215 |
+
speakers_path = model_dir / "speakers.json"
|
| 216 |
+
speaker_embedding = load_speaker(speakers_path, args.speaker)
|
| 217 |
+
print(f"Speaker: {args.speaker}")
|
| 218 |
+
|
| 219 |
+
tokenizer_path = repo_root / "tokenizer.json"
|
| 220 |
+
if not tokenizer_path.exists():
|
| 221 |
+
tokenizer_path = model_dir / "tokenizer.json"
|
| 222 |
+
|
| 223 |
+
# Load weights into our custom models
|
| 224 |
+
weights = load_weights(model_dir)
|
| 225 |
+
|
| 226 |
+
prefill = PlaprePicoPrefill()
|
| 227 |
+
populate_weights(prefill, weights, is_prefill=True)
|
| 228 |
+
prefill = prefill.half().eval()
|
| 229 |
+
|
| 230 |
+
decode = PlaprePicoDecode()
|
| 231 |
+
populate_weights(decode, weights, is_prefill=False)
|
| 232 |
+
decode = decode.half().eval()
|
| 233 |
+
|
| 234 |
+
# Generate
|
| 235 |
+
tokens = generate(
|
| 236 |
+
prefill, decode, args.text, speaker_embedding, tokenizer_path,
|
| 237 |
+
args.max_tokens, args.temperature, args.top_k, args.top_p,
|
| 238 |
+
)
|
| 239 |
+
|
| 240 |
+
audio_count = sum(1 for t in tokens if AUDIO_TOKEN_OFFSET <= t <= 20801)
|
| 241 |
+
print(f"\nGenerated {len(tokens)} tokens: {audio_count} audio ({audio_count / 25:.1f}s)")
|
| 242 |
+
print(f"First 20: {tokens[:20]}")
|
| 243 |
+
|
| 244 |
+
waveform = decode_audio(tokens, speaker_embedding)
|
| 245 |
+
sf.write(args.output, waveform, 24000)
|
| 246 |
+
print(f"Saved {len(waveform) / 24000:.1f}s audio to {args.output}")
|
| 247 |
+
|
| 248 |
+
|
| 249 |
+
if __name__ == "__main__":
|
| 250 |
+
main()
|