Daniel Rothmann commited on
Commit
c6a22f5
·
1 Parent(s): 95c6137

Add swift test CLI

Browse files
.gitignore CHANGED
@@ -3,3 +3,4 @@
3
  __pycache__
4
  test_data
5
  **.wav
 
 
3
  __pycache__
4
  test_data
5
  **.wav
6
+ swift-cli/.build
swift-cli/Package.resolved ADDED
@@ -0,0 +1,104 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "pins" : [
3
+ {
4
+ "identity" : "eventsource",
5
+ "kind" : "remoteSourceControl",
6
+ "location" : "https://github.com/mattt/EventSource.git",
7
+ "state" : {
8
+ "revision" : "a3a85a85214caf642abaa96ae664e4c772a59f6e",
9
+ "version" : "1.4.1"
10
+ }
11
+ },
12
+ {
13
+ "identity" : "swift-asn1",
14
+ "kind" : "remoteSourceControl",
15
+ "location" : "https://github.com/apple/swift-asn1.git",
16
+ "state" : {
17
+ "revision" : "9f542610331815e29cc3821d3b6f488db8715517",
18
+ "version" : "1.6.0"
19
+ }
20
+ },
21
+ {
22
+ "identity" : "swift-atomics",
23
+ "kind" : "remoteSourceControl",
24
+ "location" : "https://github.com/apple/swift-atomics.git",
25
+ "state" : {
26
+ "revision" : "b601256eab081c0f92f059e12818ac1d4f178ff7",
27
+ "version" : "1.3.0"
28
+ }
29
+ },
30
+ {
31
+ "identity" : "swift-collections",
32
+ "kind" : "remoteSourceControl",
33
+ "location" : "https://github.com/apple/swift-collections.git",
34
+ "state" : {
35
+ "revision" : "6675bc0ff86e61436e615df6fc5174e043e57924",
36
+ "version" : "1.4.1"
37
+ }
38
+ },
39
+ {
40
+ "identity" : "swift-crypto",
41
+ "kind" : "remoteSourceControl",
42
+ "location" : "https://github.com/apple/swift-crypto.git",
43
+ "state" : {
44
+ "revision" : "bb4ba815dab96d4edc1e0b86d7b9acf9ff973a84",
45
+ "version" : "4.3.1"
46
+ }
47
+ },
48
+ {
49
+ "identity" : "swift-huggingface",
50
+ "kind" : "remoteSourceControl",
51
+ "location" : "https://github.com/huggingface/swift-huggingface.git",
52
+ "state" : {
53
+ "revision" : "b721959445b617d0bf03910b2b4aced345fd93bf",
54
+ "version" : "0.9.0"
55
+ }
56
+ },
57
+ {
58
+ "identity" : "swift-jinja",
59
+ "kind" : "remoteSourceControl",
60
+ "location" : "https://github.com/huggingface/swift-jinja.git",
61
+ "state" : {
62
+ "revision" : "0aeefadec459ce8e11a333769950fb86183aca43",
63
+ "version" : "2.3.5"
64
+ }
65
+ },
66
+ {
67
+ "identity" : "swift-nio",
68
+ "kind" : "remoteSourceControl",
69
+ "location" : "https://github.com/apple/swift-nio.git",
70
+ "state" : {
71
+ "revision" : "558f24a4647193b5a0e2104031b71c55d31ff83a",
72
+ "version" : "2.97.1"
73
+ }
74
+ },
75
+ {
76
+ "identity" : "swift-system",
77
+ "kind" : "remoteSourceControl",
78
+ "location" : "https://github.com/apple/swift-system.git",
79
+ "state" : {
80
+ "revision" : "7c6ad0fc39d0763e0b699210e4124afd5041c5df",
81
+ "version" : "1.6.4"
82
+ }
83
+ },
84
+ {
85
+ "identity" : "swift-transformers",
86
+ "kind" : "remoteSourceControl",
87
+ "location" : "https://github.com/huggingface/swift-transformers",
88
+ "state" : {
89
+ "revision" : "b38443e44d93eca770f2eb68e2a4d0fa100f9aa2",
90
+ "version" : "1.3.0"
91
+ }
92
+ },
93
+ {
94
+ "identity" : "yyjson",
95
+ "kind" : "remoteSourceControl",
96
+ "location" : "https://github.com/ibireme/yyjson.git",
97
+ "state" : {
98
+ "revision" : "8b4a38dc994a110abaec8a400615567bd996105f",
99
+ "version" : "0.12.0"
100
+ }
101
+ }
102
+ ],
103
+ "version" : 2
104
+ }
swift-cli/Package.swift ADDED
@@ -0,0 +1,19 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ // swift-tools-version: 5.9
2
+ import PackageDescription
3
+
4
+ let package = Package(
5
+ name: "plapre-cli",
6
+ platforms: [.macOS("15.0")],
7
+ dependencies: [
8
+ .package(url: "https://github.com/huggingface/swift-transformers", from: "1.3.0"),
9
+ ],
10
+ targets: [
11
+ .executableTarget(
12
+ name: "plapre-cli",
13
+ dependencies: [
14
+ .product(name: "Tokenizers", package: "swift-transformers"),
15
+ ],
16
+ path: "Sources"
17
+ ),
18
+ ]
19
+ )
swift-cli/Sources/main.swift ADDED
@@ -0,0 +1,714 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import Foundation
2
+ import CoreML
3
+ import Accelerate
4
+ import Tokenizers
5
+
6
+ // MARK: - Constants
7
+
8
+ let sampleRate: Int = 24000
9
+ let prefillSeqLen = 512
10
+ let maxContext = 2048
11
+ let headDim = 64
12
+ let numKvHeads = 3
13
+ let speakerDim = 128
14
+ let audioTokenOffset = 8002
15
+ let audioMarkerToken: Int32 = 8001
16
+ let textMarkerToken: Int32 = 8000
17
+ let eosToken: Int32 = 0 // <eos> is token 0 in plapre tokenizer
18
+ let vocabSize = 20802
19
+
20
+ // HiFT source generation parameters
21
+ let hiftNfft = 16
22
+ let hiftHopLen = 4
23
+ let hiftSamplingRate: Float = 24000.0
24
+ let hiftHarmonicNum = 8
25
+ let hiftSineAmp: Float = 0.1
26
+ let hiftNoiseStd: Float = 0.003
27
+ let hiftUpsampleScale = 480
28
+ let hiftWindow: [Float] = [0.0, 0.03806023, 0.14644662, 0.30865827, 0.5, 0.6913417, 0.85355341, 0.96193975, 1.0, 0.96193975, 0.85355341, 0.6913417, 0.5, 0.30865827, 0.14644662, 0.03806023]
29
+ // l_linear: 9 harmonics → 1, then tanh
30
+ let sourceLinearWeight: [Float] = [-0.27458203, -0.27744064, 0.07214482, 0.12596518, 0.02788151, 0.00307915, 0.01020926, -0.01141518, -0.01324173]
31
+ let sourceLinearBias: Float = 7.7338242e-05
32
+
33
+ // MARK: - Model paths
34
+
35
+ let repoRoot = URL(fileURLWithPath: #filePath)
36
+ .deletingLastPathComponent() // Sources
37
+ .deletingLastPathComponent() // swift-cli
38
+ .deletingLastPathComponent() // repo root
39
+
40
+ func modelURL(_ name: String) -> URL {
41
+ repoRoot.appendingPathComponent("\(name).mlpackage")
42
+ }
43
+
44
+ // MARK: - RoPE tables
45
+
46
+ func loadRopeTable(_ name: String) -> [Float] {
47
+ let url = repoRoot.appendingPathComponent(name)
48
+ // .npy format: 128-byte header + raw float16 data
49
+ let data = try! Data(contentsOf: url)
50
+ // Find header end (newline after header)
51
+ var headerEnd = 0
52
+ for i in 0..<data.count {
53
+ if data[i] == 0x0A {
54
+ // Check if this could be the end of a npy header
55
+ // npy header ends with \n, and the header size is padded to multiple of 64
56
+ if i > 5 {
57
+ headerEnd = i + 1
58
+ // Verify the remaining data makes sense
59
+ let remaining = data.count - headerEnd
60
+ if remaining % 2 == 0 { // float16 = 2 bytes
61
+ break
62
+ }
63
+ }
64
+ }
65
+ }
66
+ let rawData = data.subdata(in: headerEnd..<data.count)
67
+ // Convert float16 to float32
68
+ let count = rawData.count / 2
69
+ var result = [Float](repeating: 0, count: count)
70
+ rawData.withUnsafeBytes { ptr in
71
+ let f16 = ptr.bindMemory(to: Float16.self)
72
+ for i in 0..<count {
73
+ result[i] = Float(f16[i])
74
+ }
75
+ }
76
+ return result
77
+ }
78
+
79
+ // MARK: - Speaker embeddings
80
+
81
+ func loadSpeaker(_ name: String) -> [Float] {
82
+ let url = repoRoot.appendingPathComponent("speakers.json")
83
+ let data = try! Data(contentsOf: url)
84
+ let json = try! JSONSerialization.jsonObject(with: data) as! [String: [Double]]
85
+ guard let emb = json[name] else {
86
+ fatalError("Speaker '\(name)' not found. Available: \(json.keys.sorted())")
87
+ }
88
+ return emb.map { Float($0) }
89
+ }
90
+
91
+
92
+ // MARK: - Tokenizer
93
+
94
+ // MARK: - CoreML helpers
95
+
96
+ func compileModel(at url: URL) throws -> MLModel {
97
+ print(" Compiling \(url.lastPathComponent)...")
98
+ let compiled = try MLModel.compileModel(at: url)
99
+ let config = MLModelConfiguration()
100
+ config.computeUnits = .cpuOnly
101
+ return try MLModel(contentsOf: compiled, configuration: config)
102
+ }
103
+
104
+ func mlArray(_ values: [Float], shape: [Int]) -> MLMultiArray {
105
+ let arr = try! MLMultiArray(shape: shape.map { NSNumber(value: $0) }, dataType: .float16)
106
+ let count = values.count
107
+ arr.withUnsafeMutableBufferPointer(ofType: Float16.self) { ptr, _ in
108
+ for i in 0..<count {
109
+ ptr[i] = Float16(values[i])
110
+ }
111
+ }
112
+ return arr
113
+ }
114
+
115
+ func mlArrayFloat32(_ values: [Float], shape: [Int]) -> MLMultiArray {
116
+ let arr = try! MLMultiArray(shape: shape.map { NSNumber(value: $0) }, dataType: .float32)
117
+ arr.withUnsafeMutableBufferPointer(ofType: Float.self) { dst, _ in
118
+ for i in 0..<values.count {
119
+ dst[i] = values[i]
120
+ }
121
+ }
122
+ return arr
123
+ }
124
+
125
+ func mlArrayInt32(_ values: [Int32], shape: [Int]) -> MLMultiArray {
126
+ let arr = try! MLMultiArray(shape: shape.map { NSNumber(value: $0) }, dataType: .int32)
127
+ arr.withUnsafeMutableBufferPointer(ofType: Int32.self) { dst, _ in
128
+ for i in 0..<values.count {
129
+ dst[i] = values[i]
130
+ }
131
+ }
132
+ return arr
133
+ }
134
+
135
+ func readFloat16Array(_ arr: MLMultiArray) -> [Float] {
136
+ let count = arr.count
137
+ var result = [Float](repeating: 0, count: count)
138
+ arr.withUnsafeBufferPointer(ofType: Float16.self) { ptr in
139
+ for i in 0..<count {
140
+ result[i] = Float(ptr[i])
141
+ }
142
+ }
143
+ return result
144
+ }
145
+
146
+ func readFloat32Array(_ arr: MLMultiArray) -> [Float] {
147
+ let count = arr.count
148
+ var result = [Float](repeating: 0, count: count)
149
+ arr.withUnsafeBufferPointer(ofType: Float.self) { ptr in
150
+ for i in 0..<count {
151
+ result[i] = ptr[i]
152
+ }
153
+ }
154
+ return result
155
+ }
156
+
157
+ // MARK: - Source signal generation (replaces HiFT's m_source in Swift)
158
+
159
+ func generateSourceSTFT(f0: [Float], melLength: Int) -> [Float] {
160
+ // f0 shape: (melLength,) — one f0 value per mel frame
161
+
162
+ // 1. Upsample f0 by hiftUpsampleScale (nearest neighbor)
163
+ let audioLength = melLength * hiftUpsampleScale
164
+ var f0Up = [Float](repeating: 0, count: audioLength)
165
+ for i in 0..<audioLength {
166
+ f0Up[i] = f0[min(i / hiftUpsampleScale, melLength - 1)]
167
+ }
168
+
169
+ // 2. Generate harmonics: f0 * [1, 2, ..., harmonic_num+1]
170
+ let numHarmonics = hiftHarmonicNum + 1 // 9
171
+ var sineWaves = [[Float]](repeating: [Float](repeating: 0, count: audioLength), count: numHarmonics)
172
+
173
+ for h in 0..<numHarmonics {
174
+ let harmonicMul = Float(h + 1)
175
+ // Cumulative phase: phase[t] = sum(f0[0..t] * harmonic / sr) * 2pi
176
+ var phase: Float = 0
177
+ for t in 0..<audioLength {
178
+ let f = f0Up[t] * harmonicMul
179
+ phase += f / hiftSamplingRate
180
+ // Keep phase in [0, 1) to avoid precision loss
181
+ phase = phase - Float(Int(phase))
182
+ sineWaves[h][t] = sin(phase * 2 * .pi) * hiftSineAmp
183
+ }
184
+ }
185
+
186
+ // 3. UV detection: voiced (f0 > 0) vs unvoiced
187
+ var uv = [Float](repeating: 0, count: audioLength)
188
+ for t in 0..<audioLength {
189
+ uv[t] = f0Up[t] > 0 ? 1.0 : 0.0
190
+ }
191
+
192
+ // 4. Apply UV masking + noise
193
+ for h in 0..<numHarmonics {
194
+ for t in 0..<audioLength {
195
+ let noise = Float.random(in: -1...1) * (uv[t] * hiftNoiseStd + (1 - uv[t]) * hiftSineAmp / 3)
196
+ sineWaves[h][t] = sineWaves[h][t] * uv[t] + noise
197
+ }
198
+ }
199
+
200
+ // 5. Linear combination: 9 harmonics → 1 via l_linear + tanh
201
+ var source = [Float](repeating: 0, count: audioLength)
202
+ for t in 0..<audioLength {
203
+ var val: Float = sourceLinearBias
204
+ for h in 0..<numHarmonics {
205
+ val += sineWaves[h][t] * sourceLinearWeight[h]
206
+ }
207
+ source[t] = tanh(val)
208
+ }
209
+
210
+ // 6. STFT of source signal
211
+ // n_fft=16, hop=4, hann window
212
+ let nfftHalf = hiftNfft / 2 + 1 // 9
213
+ let numFrames = audioLength / hiftHopLen + 1
214
+ // Output: (18, numFrames) — 9 real + 9 imag channels
215
+ var stft = [Float](repeating: 0, count: 18 * numFrames)
216
+
217
+ for frame in 0..<numFrames {
218
+ let center = frame * hiftHopLen
219
+ // Windowed segment
220
+ var segment = [Float](repeating: 0, count: hiftNfft)
221
+ for k in 0..<hiftNfft {
222
+ let idx = center - hiftNfft / 2 + k
223
+ if idx >= 0 && idx < audioLength {
224
+ segment[k] = source[idx] * hiftWindow[k]
225
+ }
226
+ }
227
+ // DFT for each frequency bin
228
+ for f in 0..<nfftHalf {
229
+ var real: Float = 0
230
+ var imag: Float = 0
231
+ for k in 0..<hiftNfft {
232
+ let angle = -2.0 * Float.pi * Float(f) * Float(k) / Float(hiftNfft)
233
+ real += segment[k] * cos(angle)
234
+ imag += segment[k] * sin(angle)
235
+ }
236
+ stft[f * numFrames + frame] = real // real part
237
+ stft[(nfftHalf + f) * numFrames + frame] = imag // imag part
238
+ }
239
+ }
240
+
241
+ return stft
242
+ }
243
+
244
+ // MARK: - iSTFT (magnitude + phase → waveform)
245
+
246
+ func istft(magnitude: [Float], phase: [Float], numFrames: Int) -> [Float] {
247
+ // Matches torch.istft(spec, n_fft=16, hop_length=4, win_length=16, window=hann, center=True)
248
+ let nfftHalf = hiftNfft / 2 + 1 // 9
249
+ // center=True means the STFT was padded by n_fft//2 on each side
250
+ // Total overlap-add length includes this padding
251
+ let padded_length = (numFrames - 1) * hiftHopLen + hiftNfft
252
+ var output = [Float](repeating: 0, count: padded_length)
253
+ var windowSum = [Float](repeating: 0, count: padded_length)
254
+
255
+ for frame in 0..<numFrames {
256
+ // Build full complex spectrum from one-sided
257
+ var real = [Float](repeating: 0, count: hiftNfft)
258
+ var imag = [Float](repeating: 0, count: hiftNfft)
259
+
260
+ for f in 0..<nfftHalf {
261
+ let mag = magnitude[f * numFrames + frame]
262
+ let ph = phase[f * numFrames + frame]
263
+ real[f] = mag * cos(ph)
264
+ imag[f] = mag * sin(ph)
265
+ }
266
+ // Mirror for negative frequencies (Hermitian symmetry)
267
+ for f in 1..<(hiftNfft / 2) {
268
+ real[hiftNfft - f] = real[f]
269
+ imag[hiftNfft - f] = -imag[f]
270
+ }
271
+
272
+ // IDFT
273
+ var segment = [Float](repeating: 0, count: hiftNfft)
274
+ for k in 0..<hiftNfft {
275
+ var val: Float = 0
276
+ for fi in 0..<hiftNfft {
277
+ let angle = 2.0 * Float.pi * Float(fi) * Float(k) / Float(hiftNfft)
278
+ val += real[fi] * cos(angle) - imag[fi] * sin(angle)
279
+ }
280
+ segment[k] = val / Float(hiftNfft)
281
+ }
282
+
283
+ // Overlap-add with window
284
+ let start = frame * hiftHopLen
285
+ for k in 0..<hiftNfft {
286
+ let idx = start + k
287
+ if idx < padded_length {
288
+ output[idx] += segment[k] * hiftWindow[k]
289
+ windowSum[idx] += hiftWindow[k] * hiftWindow[k]
290
+ }
291
+ }
292
+ }
293
+
294
+ // Normalize by window sum
295
+ for i in 0..<padded_length {
296
+ if windowSum[i] > 1e-8 {
297
+ output[i] /= windowSum[i]
298
+ }
299
+ }
300
+
301
+ // Trim center padding: remove n_fft//2 from start, and from end to match expected length
302
+ let pad = hiftNfft / 2 // 8
303
+ let expectedLength = (numFrames - 1) * hiftHopLen // what torch.istft returns
304
+ let trimStart = pad
305
+ let trimEnd = min(trimStart + expectedLength, padded_length)
306
+ var trimmed = Array(output[trimStart..<trimEnd])
307
+
308
+ // Clamp
309
+ for i in 0..<trimmed.count {
310
+ trimmed[i] = max(-0.99, min(0.99, trimmed[i]))
311
+ }
312
+
313
+ return trimmed
314
+ }
315
+
316
+ // MARK: - WAV writer
317
+
318
+ func writeWAV(_ samples: [Float], to url: URL, sampleRate: Int = 24000) {
319
+ let numSamples = samples.count
320
+ let dataSize = numSamples * 2 // 16-bit PCM
321
+ var data = Data()
322
+
323
+ // RIFF header
324
+ data.append(contentsOf: "RIFF".utf8)
325
+ var chunkSize = UInt32(36 + dataSize).littleEndian
326
+ data.append(Data(bytes: &chunkSize, count: 4))
327
+ data.append(contentsOf: "WAVE".utf8)
328
+
329
+ // fmt chunk
330
+ data.append(contentsOf: "fmt ".utf8)
331
+ var fmtSize = UInt32(16).littleEndian; data.append(Data(bytes: &fmtSize, count: 4))
332
+ var audioFormat = UInt16(1).littleEndian; data.append(Data(bytes: &audioFormat, count: 2))
333
+ var channels = UInt16(1).littleEndian; data.append(Data(bytes: &channels, count: 2))
334
+ var sr = UInt32(sampleRate).littleEndian; data.append(Data(bytes: &sr, count: 4))
335
+ var byteRate = UInt32(sampleRate * 2).littleEndian; data.append(Data(bytes: &byteRate, count: 4))
336
+ var blockAlign = UInt16(2).littleEndian; data.append(Data(bytes: &blockAlign, count: 2))
337
+ var bitsPerSample = UInt16(16).littleEndian; data.append(Data(bytes: &bitsPerSample, count: 2))
338
+
339
+ // data chunk
340
+ data.append(contentsOf: "data".utf8)
341
+ var dataChunkSize = UInt32(dataSize).littleEndian; data.append(Data(bytes: &dataChunkSize, count: 4))
342
+ for s in samples {
343
+ let clamped = max(-1.0, min(1.0, s))
344
+ var pcm = Int16(clamped * 32767.0).littleEndian
345
+ data.append(Data(bytes: &pcm, count: 2))
346
+ }
347
+
348
+ try! data.write(to: url)
349
+ }
350
+
351
+ // MARK: - Sampling
352
+
353
+ func sampleToken(logits: [Float], temperature: Float = 0.8, topK: Int = 50, topP: Float = 0.95) -> Int32 {
354
+ if temperature <= 0 {
355
+ return Int32(logits.enumerated().max(by: { $0.element < $1.element })!.offset)
356
+ }
357
+
358
+ var scaled = logits.map { $0 / temperature }
359
+
360
+ // Top-k: keep only the top K candidates
361
+ let indexed = scaled.enumerated().sorted { $0.element > $1.element }
362
+ let threshold = indexed[min(topK - 1, indexed.count - 1)].element
363
+ for i in 0..<scaled.count {
364
+ if scaled[i] < threshold { scaled[i] = -.infinity }
365
+ }
366
+
367
+ // Softmax
368
+ let maxVal = scaled.max()!
369
+ var exps = scaled.map { exp($0 - maxVal) }
370
+ let sum = exps.reduce(0, +)
371
+ exps = exps.map { $0 / sum }
372
+
373
+ // Top-p (nucleus): sort by probability, keep smallest set summing to >= topP
374
+ let sortedProbs = exps.enumerated().sorted { $0.element > $1.element }
375
+ var cumProb: Float = 0
376
+ var allowed = Set<Int>()
377
+ for (idx, prob) in sortedProbs {
378
+ cumProb += prob
379
+ allowed.insert(idx)
380
+ if cumProb >= topP { break }
381
+ }
382
+ // Zero out tokens outside the nucleus
383
+ for i in 0..<exps.count {
384
+ if !allowed.contains(i) { exps[i] = 0 }
385
+ }
386
+ // Re-normalize
387
+ let newSum = exps.reduce(0, +)
388
+ if newSum > 0 { exps = exps.map { $0 / newSum } }
389
+
390
+ // Sample
391
+ let r = Float.random(in: 0..<1)
392
+ var cumsum: Float = 0
393
+ for (i, p) in exps.enumerated() {
394
+ cumsum += p
395
+ if cumsum >= r { return Int32(i) }
396
+ }
397
+ return Int32(exps.count - 1)
398
+ }
399
+
400
+ // MARK: - Timing
401
+
402
+ func formatTime(_ seconds: Double) -> String {
403
+ if seconds < 0.001 { return String(format: "%.2fµs", seconds * 1_000_000) }
404
+ if seconds < 1.0 { return String(format: "%.1fms", seconds * 1000) }
405
+ return String(format: "%.2fs", seconds)
406
+ }
407
+
408
+ func measure<T>(_ label: String, _ block: () throws -> T) rethrows -> T {
409
+ let start = CFAbsoluteTimeGetCurrent()
410
+ let result = try block()
411
+ let elapsed = CFAbsoluteTimeGetCurrent() - start
412
+ print(" ⏱ \(label): \(formatTime(elapsed))")
413
+ return result
414
+ }
415
+
416
+ func measureAsync<T>(_ label: String, _ block: () async throws -> T) async rethrows -> T {
417
+ let start = CFAbsoluteTimeGetCurrent()
418
+ let result = try await block()
419
+ let elapsed = CFAbsoluteTimeGetCurrent() - start
420
+ print(" ⏱ \(label): \(formatTime(elapsed))")
421
+ return result
422
+ }
423
+
424
+ // MARK: - Main pipeline
425
+
426
+ print("Plapre Pico CoreML TTS Pipeline")
427
+ print("================================\n")
428
+
429
+ let text = CommandLine.arguments.count > 1 ? CommandLine.arguments[1] : "Hej, mit navn er Daniel."
430
+ let speakerName = CommandLine.arguments.count > 2 ? CommandLine.arguments[2] : "tor"
431
+ let outputPath = CommandLine.arguments.count > 3 ? CommandLine.arguments[3] : "output.wav"
432
+
433
+ print("Text: \(text)")
434
+ print("Speaker: \(speakerName)")
435
+ print("Output: \(outputPath)\n")
436
+
437
+ let pipelineStart = CFAbsoluteTimeGetCurrent()
438
+
439
+ // Load speaker
440
+ let speakerEmb = loadSpeaker(speakerName)
441
+ print("Loaded speaker embedding (\(speakerEmb.count) dims)")
442
+
443
+ // Tokenize using HuggingFace BPE tokenizer
444
+ let tokenizer = try await measureAsync("Tokenizer load") { try await AutoTokenizer.from(modelFolder: repoRoot) }
445
+ let textTokens = tokenizer.encode(text: text, addSpecialTokens: false).map { Int32($0) }
446
+ print("Tokenized: \(textTokens.count) tokens: \(textTokens)")
447
+
448
+ // Build input sequence: [placeholder, <text>, tokens..., <audio>]
449
+ var inputSeq: [Int32] = [eosToken, textMarkerToken] + textTokens + [audioMarkerToken]
450
+ let inputLen = inputSeq.count
451
+ print("Input sequence: \(inputLen) tokens")
452
+
453
+ // Pad to prefillSeqLen
454
+ while inputSeq.count < prefillSeqLen {
455
+ inputSeq.append(eosToken)
456
+ }
457
+
458
+ // Load RoPE tables
459
+ print("\nLoading RoPE tables...")
460
+ let ropeCos = loadRopeTable("rope_cos.npy")
461
+ let ropeSin = loadRopeTable("rope_sin.npy")
462
+ print("RoPE cos: \(ropeCos.count) values, sin: \(ropeSin.count) values")
463
+
464
+ // Compile models
465
+ print("\nCompiling models...")
466
+ var generatedTokens: [Int32] = []
467
+
468
+ let kanadeModel = try measure("Compile KanadeDecoder") { try compileModel(at: modelURL("KanadeDecoder")) }
469
+ let vocoderModel = try measure("Compile Vocoder") { try compileModel(at: modelURL("Vocoder")) }
470
+
471
+ if !CommandLine.arguments.contains("--test-audio") {
472
+ let decodeModel = try measure("Compile PlaprePico") { try compileModel(at: modelURL("PlaprePico")) }
473
+
474
+ // === Step 1: Prefill via decode model (one token at a time) ===
475
+ print("\n--- Prefill (token-by-token, stateless KV cache) ---")
476
+
477
+ // Allocate KV cache buffers (managed by Swift, passed as model inputs/outputs)
478
+ let numLayers = 30
479
+ let cacheShape = [1, numKvHeads, maxContext, headDim]
480
+ let cacheSize = numKvHeads * maxContext * headDim
481
+ var kvCaches: [MLMultiArray] = []
482
+ for _ in 0..<(numLayers * 2) {
483
+ kvCaches.append(mlArray([Float](repeating: 0, count: cacheSize), shape: cacheShape))
484
+ }
485
+
486
+ // Helper to run one token through decode model
487
+ func runDecodeStep(token: Int32, pos: Int, isSpeaker: Bool = false) throws -> [Float] {
488
+ var maskValues = [Float](repeating: -65504.0, count: maxContext)
489
+ for j in 0...pos { maskValues[j] = 0.0 }
490
+
491
+ let ropeOffset = pos * headDim
492
+ let cosBuf = Array(ropeCos[ropeOffset..<(ropeOffset + headDim)])
493
+ let sinBuf = Array(ropeSin[ropeOffset..<(ropeOffset + headDim)])
494
+
495
+ var updateMask = [Float](repeating: 0, count: maxContext)
496
+ updateMask[pos] = 1.0
497
+
498
+ var input: [String: MLFeatureValue] = [
499
+ "input_ids": .init(multiArray: mlArrayInt32([token], shape: [1, 1])),
500
+ "causal_mask": .init(multiArray: mlArray(maskValues, shape: [1, 1, 1, maxContext])),
501
+ "cos": .init(multiArray: mlArray(cosBuf, shape: [1, 1, 1, headDim])),
502
+ "sin": .init(multiArray: mlArray(sinBuf, shape: [1, 1, 1, headDim])),
503
+ "update_mask": .init(multiArray: mlArray(updateMask, shape: [1, 1, maxContext, 1])),
504
+ "speaker_embedding": .init(multiArray: mlArray(speakerEmb, shape: [1, speakerDim])),
505
+ "is_speaker_step": .init(multiArray: mlArray([isSpeaker ? Float(1.0) : Float(0.0)], shape: [1])),
506
+ ]
507
+
508
+ // Add KV cache inputs
509
+ for i in 0..<numLayers {
510
+ input["k_cache_\(i)"] = .init(multiArray: kvCaches[2 * i])
511
+ input["v_cache_\(i)"] = .init(multiArray: kvCaches[2 * i + 1])
512
+ }
513
+
514
+ let provider = try MLDictionaryFeatureProvider(dictionary: input)
515
+ let output = try decodeModel.prediction(from: provider)
516
+
517
+ // Read updated KV caches from output
518
+ for i in 0..<numLayers {
519
+ kvCaches[2 * i] = output.featureValue(for: "k_cache_\(i)_out")!.multiArrayValue!
520
+ kvCaches[2 * i + 1] = output.featureValue(for: "v_cache_\(i)_out")!.multiArrayValue!
521
+ }
522
+
523
+ let logitsArr = output.featureValue(for: "logits")!.multiArrayValue!
524
+ let count = logitsArr.shape.last!.intValue
525
+ var result = [Float](repeating: 0, count: count)
526
+ for i in 0..<count {
527
+ result[i] = logitsArr[[0, 0, i] as [NSNumber]].floatValue
528
+ }
529
+ return result
530
+ }
531
+
532
+ // The input sequence is: [placeholder(speaker), <text>, tokens..., <audio>]
533
+ // For the speaker token at position 0, we need to handle it differently.
534
+ // The decode model uses embed_tokens, but position 0 should be the speaker projection.
535
+ // WORKAROUND: feed the placeholder token (EOS=2) at position 0. The speaker conditioning
536
+ // won't be perfect since we can't inject the speaker_proj output through the decode model.
537
+ // For proper speaker conditioning, we'd need a dedicated prefill model or a combined model.
538
+ // For now, feed all tokens including the placeholder to validate the pipeline.
539
+
540
+ let inputTokens: [Int32] = Array(inputSeq.prefix(inputLen))
541
+ print("Processing \(inputTokens.count) input tokens...")
542
+
543
+ let prefillStart = CFAbsoluteTimeGetCurrent()
544
+ var lastLogits: [Float] = []
545
+ for (i, token) in inputTokens.enumerated() {
546
+ lastLogits = try runDecodeStep(token: token, pos: i, isSpeaker: i == 0)
547
+ let argmax = lastLogits.enumerated().max(by: { $0.element < $1.element })!.offset
548
+ let maxVal = lastLogits.max()!
549
+ print(" Prefill pos=\(i) token=\(token): argmax=\(argmax) max=\(String(format: "%.4f", maxVal))")
550
+ }
551
+ let prefillElapsed = CFAbsoluteTimeGetCurrent() - prefillStart
552
+ let prefillTokPerSec = Double(inputTokens.count) / prefillElapsed
553
+ print(" ⏱ Prefill: \(formatTime(prefillElapsed)) (\(inputTokens.count) tokens, \(String(format: "%.1f", prefillTokPerSec)) tok/s)")
554
+
555
+ let firstToken = sampleToken(logits: lastLogits, temperature: 0.8, topK: 50, topP: 0.95)
556
+ print("First generated token: \(firstToken)")
557
+ // Debug
558
+ let dbgLogits = lastLogits
559
+ let sortedIndices = dbgLogits.enumerated().sorted { $0.element > $1.element }
560
+ print(" Top 5: \(sortedIndices.prefix(5).map { "\($0.offset):\($0.element)" })")
561
+ print(" Logits count: \(dbgLogits.count), nonzero: \(dbgLogits.filter { $0 != 0 }.count)")
562
+ print(" Speaker emb first 3: \(speakerEmb.prefix(3))")
563
+
564
+ // === Step 2: Autoregressive decode ===
565
+ print("\n--- Decode ---")
566
+ generatedTokens = [firstToken]
567
+ let maxTokens = 500
568
+
569
+ print("Generating up to \(maxTokens) tokens...")
570
+ let decodeStart = CFAbsoluteTimeGetCurrent()
571
+ var nextToken = firstToken
572
+ var consecutiveNonAudio = 0
573
+ let nonAudioStopThreshold = 10 // stop after this many consecutive non-audio tokens
574
+ for step in 1..<maxTokens {
575
+ let pos = inputLen + step - 1
576
+
577
+ let logits = try runDecodeStep(token: nextToken, pos: pos)
578
+ nextToken = sampleToken(logits: logits, temperature: 0.8, topK: 50, topP: 0.95)
579
+ generatedTokens.append(nextToken)
580
+
581
+ if nextToken == eosToken {
582
+ print(" EOS at step \(step)")
583
+ break
584
+ }
585
+
586
+ // Track consecutive non-audio tokens — model may be done speaking
587
+ if nextToken >= audioTokenOffset && nextToken <= 20801 {
588
+ consecutiveNonAudio = 0
589
+ } else {
590
+ consecutiveNonAudio += 1
591
+ if consecutiveNonAudio >= nonAudioStopThreshold {
592
+ print(" Stopping: \(nonAudioStopThreshold) consecutive non-audio tokens at step \(step)")
593
+ break
594
+ }
595
+ }
596
+
597
+ if step % 25 == 0 {
598
+ let elapsed = CFAbsoluteTimeGetCurrent() - decodeStart
599
+ let tokPerSec = Double(step) / elapsed
600
+ print(" Step \(step) (\(Float(step) / 25.0)s audio) — \(formatTime(elapsed)) elapsed, \(String(format: "%.1f", tokPerSec)) tok/s")
601
+ }
602
+ }
603
+ let decodeElapsed = CFAbsoluteTimeGetCurrent() - decodeStart
604
+ let decodeSteps = generatedTokens.count - 1 // first token came from prefill
605
+ let decodeTokPerSec = Double(decodeSteps) / decodeElapsed
606
+ let audioSeconds = Float(generatedTokens.filter { $0 >= audioTokenOffset && $0 <= 20801 }.count) / 25.0
607
+ let rtf = Float(decodeElapsed) / audioSeconds // real-time factor: wall time / audio duration
608
+ print(" ⏱ Decode: \(formatTime(decodeElapsed)) (\(decodeSteps) steps, \(String(format: "%.1f", decodeTokPerSec)) tok/s)")
609
+ print(" ⏱ Audio generated: \(String(format: "%.1f", audioSeconds))s — RTF \(String(format: "%.2f", rtf))x (1.0 = realtime)")
610
+
611
+ } // end if !test-audio
612
+
613
+ var audioTokens: [Int32]
614
+ let testAudioOnly = CommandLine.arguments.contains("--test-audio")
615
+
616
+ if testAudioOnly {
617
+ // Skip LLM, use known-good tokens from Python pipeline for audio testing
618
+ print("\n--- Using hardcoded test tokens (--test-audio) ---")
619
+ audioTokens = [11620, 17958, 13738, 15707, 12635, 12635, 12131, 12637, 20677, 12903,
620
+ 17769, 17841, 20016, 20080, 17520, 20080, 17528, 14832, 14774, 12200,
621
+ 12199, 12263, 11693, 11622, 12130, 12066, 12050, 12050, 12050, 12050,
622
+ 14578, 14642, 14610, 14082, 12058, 11482, 11474, 14538, 14610, 14642,
623
+ 14610, 14082, 14082, 11490, 11482, 11482, 11482, 11482, 11482, 11474,
624
+ 11410, 11394, 12066, 12058, 14610, 14610, 14098, 11490, 11482, 11490,
625
+ 11482, 11482, 11482, 11482, 11482, 11474, 11410, 11394, 11394, 11954,
626
+ 12010, 12002, 11426, 11418, 11026, 14618, 14082, 12061, 19682, 19933,
627
+ 20590, 19877, 17770, 17322, 14832, 14760, 12192, 12200, 12192, 12200,
628
+ 12199, 12263, 11693, 11686, 11677, 11686, 8914, 8978, 8914, 8978]
629
+ generatedTokens = audioTokens
630
+ } else {
631
+ audioTokens = generatedTokens.filter { $0 >= audioTokenOffset && $0 <= 20801 }
632
+ }
633
+
634
+ print("\nGenerated \(generatedTokens.count) tokens, \(audioTokens.count) audio (\(Float(audioTokens.count) / 25.0)s)")
635
+ print("All tokens: \(generatedTokens.map { String($0) }.joined(separator: ", "))")
636
+ print("Audio tokens: \(audioTokens.prefix(20).map { String($0) }.joined(separator: ", "))...")
637
+
638
+ if audioTokens.isEmpty {
639
+ print("No audio tokens generated!")
640
+ exit(1)
641
+ }
642
+
643
+ // === Step 3: Kanade + Vocoder in chunks ===
644
+ // Kanade expects exactly 100 tokens (4s at 25 tokens/sec).
645
+ // Process audio tokens in 100-token chunks, concatenate waveforms.
646
+ let kanadeChunkSize = 100
647
+ let numChunks = (audioTokens.count + kanadeChunkSize - 1) / kanadeChunkSize
648
+ print("\n--- Kanade + Vocoder (\(numChunks) chunk\(numChunks == 1 ? "" : "s") of \(kanadeChunkSize) tokens) ---")
649
+
650
+ var waveform: [Float] = []
651
+ let audioDecodeStart = CFAbsoluteTimeGetCurrent()
652
+
653
+ for chunkIdx in 0..<numChunks {
654
+ let chunkStart = CFAbsoluteTimeGetCurrent()
655
+ let start = chunkIdx * kanadeChunkSize
656
+ let end = min(start + kanadeChunkSize, audioTokens.count)
657
+ let chunkTokens = Array(audioTokens[start..<end])
658
+
659
+ // Convert to Kanade indices (subtract audio offset) and pad to chunk size
660
+ var kanadeIndices = chunkTokens.map { $0 - Int32(audioTokenOffset) }
661
+ let actualCount = kanadeIndices.count
662
+ while kanadeIndices.count < kanadeChunkSize {
663
+ kanadeIndices.append(kanadeIndices.last ?? 0) // repeat last token as padding
664
+ }
665
+
666
+ // Kanade: tokens → mel
667
+ let kanadeStart = CFAbsoluteTimeGetCurrent()
668
+ let kanadeInput: [String: MLFeatureValue] = [
669
+ "token_indices": .init(multiArray: mlArrayInt32(kanadeIndices, shape: [kanadeChunkSize])),
670
+ "speaker_embedding": .init(multiArray: mlArrayFloat32(speakerEmb, shape: [1, speakerDim])),
671
+ ]
672
+ let kanadeProvider = try MLDictionaryFeatureProvider(dictionary: kanadeInput)
673
+ let kanadeOutput = try kanadeModel.prediction(from: kanadeProvider)
674
+ let mel = kanadeOutput.featureValue(for: "mel")!.multiArrayValue!
675
+ let kanadeElapsed = CFAbsoluteTimeGetCurrent() - kanadeStart
676
+
677
+ // Vocoder: mel → waveform
678
+ let vocoderStart = CFAbsoluteTimeGetCurrent()
679
+ let vocoderInput: [String: MLFeatureValue] = [
680
+ "mel": .init(multiArray: mel),
681
+ ]
682
+ let vocoderProvider = try MLDictionaryFeatureProvider(dictionary: vocoderInput)
683
+ let vocoderOutput = try vocoderModel.prediction(from: vocoderProvider)
684
+ let chunkWaveform = readFloat32Array(vocoderOutput.featureValue(for: "waveform")!.multiArrayValue!)
685
+ let vocoderElapsed = CFAbsoluteTimeGetCurrent() - vocoderStart
686
+
687
+ // If this chunk was padded, trim the waveform proportionally
688
+ let samplesPerToken = chunkWaveform.count / kanadeChunkSize // 960 samples per token at 24kHz
689
+ let usableSamples = actualCount * samplesPerToken
690
+ waveform.append(contentsOf: chunkWaveform.prefix(usableSamples))
691
+
692
+ let chunkElapsed = CFAbsoluteTimeGetCurrent() - chunkStart
693
+ let chunkDuration = String(format: "%.1f", Float(usableSamples) / Float(sampleRate))
694
+ print(" Chunk \(chunkIdx + 1)/\(numChunks): \(actualCount) tokens → \(chunkDuration)s audio — Kanade \(formatTime(kanadeElapsed)), Vocoder \(formatTime(vocoderElapsed)), total \(formatTime(chunkElapsed))")
695
+ }
696
+ let audioDecodeElapsed = CFAbsoluteTimeGetCurrent() - audioDecodeStart
697
+ print(" ⏱ Audio decode total: \(formatTime(audioDecodeElapsed)) (\(numChunks) chunk\(numChunks == 1 ? "" : "s"))")
698
+
699
+ print("Total waveform: \(waveform.count) samples (\(String(format: "%.1f", Float(waveform.count) / Float(sampleRate)))s)")
700
+
701
+ // === Write WAV ===
702
+ let outputURL = URL(fileURLWithPath: outputPath)
703
+ writeWAV(waveform, to: outputURL)
704
+ print("\nSaved to \(outputPath)")
705
+
706
+ // === Timing Summary ===
707
+ let pipelineElapsed = CFAbsoluteTimeGetCurrent() - pipelineStart
708
+ let totalAudioDuration = Float(waveform.count) / Float(sampleRate)
709
+ print("\n========== Timing Summary ==========")
710
+ print(" Total pipeline: \(formatTime(pipelineElapsed))")
711
+ print(" Audio output: \(String(format: "%.1f", totalAudioDuration))s")
712
+ print(" Overall RTF: \(String(format: "%.2f", Float(pipelineElapsed) / totalAudioDuration))x")
713
+ print("====================================")
714
+ print("Done!")