Daniel Rothmann
Remove int4 variant which had too poor quality
5ed6654
import ArgumentParser
import CoreML
import Foundation
import Tokenizers
// MARK: - CLI Arguments
@main
struct PlapreCLI: AsyncParsableCommand {
static var configuration = CommandConfiguration(
commandName: "plapre-cli",
abstract: "Plapre Pico CoreML TTS Pipeline",
discussion: "Danish text-to-speech using CoreML models."
)
@Argument(help: "Text to synthesize")
var text: String = "Hej, mit navn er Daniel."
@Option(name: .shortAndLong, help: "Speaker voice (available: tor, ida, liv, ask, kaj)")
var speaker: String = "tor"
@Option(name: .shortAndLong, help: "Output WAV file path")
var output: String = "output.wav"
@Flag(name: .long, help: "Use int8 quantized model (smaller)")
var int8 = false
// MARK: - Run Pipeline
mutating func run() async throws {
print("Plapre Pico CoreML TTS Pipeline")
print("================================\n")
print("Text: \(text)")
print("Speaker: \(speaker)")
print("Output: \(output)\n")
let pipelineStart = CFAbsoluteTimeGetCurrent()
// Load speaker and tokenize
let speakerEmb = try loadSpeaker(speaker)
print("Loaded speaker embedding (\(speakerEmb.count) dims)")
let tokenizer = try await measureAsync("Tokenizer load") {
try await AutoTokenizer.from(modelFolder: PlapreConfig.repoRoot)
}
let textTokens = tokenizer.encode(text: text, addSpecialTokens: false).map { Int32($0) }
print("Tokenized: \(textTokens.count) tokens: \(textTokens)")
// Build input sequence: [EOS, text_marker, tokens..., audio_marker]
var inputSeq: [Int32] =
[PlapreConfig.eosToken, PlapreConfig.textMarkerToken] + textTokens + [
PlapreConfig.audioMarkerToken
]
let inputLen = inputSeq.count
print("Input sequence: \(inputLen) tokens")
while inputSeq.count < PlapreConfig.prefillSequenceLength {
inputSeq.append(PlapreConfig.eosToken)
}
// Load RoPE tables
print("\nLoading RoPE tables...")
let ropeCosF32 = try loadRopeTable("rope_cos.npy")
let ropeSinF32 = try loadRopeTable("rope_sin.npy")
let ropeCos16: [Float16] = ropeCosF32.map { Float16($0) }
let ropeSin16: [Float16] = ropeSinF32.map { Float16($0) }
print("RoPE cos: \(ropeCos16.count) values, sin: \(ropeSin16.count) values")
// Compile models
print("\nCompiling models...")
let decodeModel = try measure("Compile PlaprePico") {
try compileModel(
at: PlapreConfig.modelURL(for: "PlaprePico", useInt8: int8))
}
let kanadeModel = try measure("Compile KanadeDecoder") {
try compileModel(
at: PlapreConfig.modelURL(for: "KanadeDecoder", useInt8: false))
}
let vocoderModel = try measure("Compile Vocoder") {
try compileModel(
at: PlapreConfig.modelURL(for: "Vocoder", useInt8: false))
}
// Pre-allocate MLMultiArrays (performance-critical: single allocation, reused for all steps)
let pInputIds = try! MLMultiArray(shape: [1, 1], dataType: .int32)
let pCausalMask = try! MLMultiArray(
shape: [1, 1, 1, NSNumber(value: PlapreConfig.maxContextLength)], dataType: .float16)
let pCos = try! MLMultiArray(
shape: [1, 1, 1, NSNumber(value: PlapreConfig.headDimension)], dataType: .float16)
let pSin = try! MLMultiArray(
shape: [1, 1, 1, NSNumber(value: PlapreConfig.headDimension)], dataType: .float16)
let pUpdateMask = try! MLMultiArray(
shape: [1, 1, NSNumber(value: PlapreConfig.maxContextLength), 1], dataType: .float16)
let pSpeakerEmb = makeFloat16Array(
speakerEmb, shape: [1, PlapreConfig.speakerEmbeddingDimension])
let pIsSpeaker = try! MLMultiArray(shape: [1], dataType: .float16)
// Initialize causal mask to all -inf
pCausalMask.withUnsafeMutableBufferPointer(ofType: Float16.self) { ptr, _ in
for i in 0..<PlapreConfig.maxContextLength { ptr[i] = Float16(-65504.0) }
}
let inputProvider = DecodeInputProvider([
"input_ids": pInputIds,
"causal_mask": pCausalMask,
"cos": pCos,
"sin": pSin,
"update_mask": pUpdateMask,
"speaker_embedding": pSpeakerEmb,
"is_speaker_step": pIsSpeaker,
])
// State for stateful model
var state = decodeModel.makeState()
var lastLogits = [Float16](repeating: 0, count: PlapreConfig.vocabSize)
func runDecodeStep(token: Int32, pos: Int, isSpeaker: Bool = false) throws {
pInputIds.withUnsafeMutableBufferPointer(ofType: Int32.self) { ptr, _ in ptr[0] = token
}
pCausalMask.withUnsafeMutableBufferPointer(ofType: Float16.self) { ptr, _ in
ptr[pos] = Float16(0.0)
}
let ropeOffset = pos * PlapreConfig.headDimension
pCos.withUnsafeMutableBufferPointer(ofType: Float16.self) { ptr, _ in
for i in 0..<PlapreConfig.headDimension { ptr[i] = ropeCos16[ropeOffset + i] }
}
pSin.withUnsafeMutableBufferPointer(ofType: Float16.self) { ptr, _ in
for i in 0..<PlapreConfig.headDimension { ptr[i] = ropeSin16[ropeOffset + i] }
}
pUpdateMask.withUnsafeMutableBufferPointer(ofType: Float16.self) { ptr, _ in
if pos > 0 { ptr[pos - 1] = Float16(0.0) }
ptr[pos] = Float16(1.0)
}
pIsSpeaker.withUnsafeMutableBufferPointer(ofType: Float16.self) { ptr, _ in
ptr[0] = isSpeaker ? Float16(1.0) : Float16(0.0)
}
let output = try decodeModel.prediction(from: inputProvider, using: state)
let arr = output.featureValue(for: "logits")!.multiArrayValue!
arr.withUnsafeBufferPointer(ofType: Float16.self) { ptr in
for i in 0..<PlapreConfig.vocabSize { lastLogits[i] = ptr[i] }
}
}
// === Prefill ===
print("\n--- Prefill (token-by-token) ---")
let inputTokens = Array(inputSeq.prefix(inputLen))
print("Processing \(inputTokens.count) input tokens...")
let prefillStart = CFAbsoluteTimeGetCurrent()
state = decodeModel.makeState()
pCausalMask.withUnsafeMutableBufferPointer(ofType: Float16.self) { ptr, _ in
for i in 0..<PlapreConfig.maxContextLength { ptr[i] = Float16(-65504.0) }
}
pUpdateMask.withUnsafeMutableBufferPointer(ofType: Float16.self) { ptr, _ in
for i in 0..<PlapreConfig.maxContextLength { ptr[i] = Float16(0.0) }
}
try runDecodeStep(token: inputTokens[0], pos: 0, isSpeaker: true)
for i in 1..<inputTokens.count {
try runDecodeStep(token: inputTokens[i], pos: i, isSpeaker: false)
}
let prefillElapsed = CFAbsoluteTimeGetCurrent() - prefillStart
let prefillTokPerSec = Double(inputTokens.count) / prefillElapsed
print(
" ⏱ Prefill: \(formatTime(prefillElapsed)) (\(inputTokens.count) tokens, \(String(format: "%.1f", prefillTokPerSec)) tok/s)"
)
let firstToken = sampleFromLogits(
lastLogits, temperature: PlapreConfig.defaultTemperature, topK: PlapreConfig.defaultTopK
)
print("First generated token: \(firstToken)")
// === Autoregressive decode ===
print("\n--- Decode ---")
var generatedTokens: [Int32] = [firstToken]
let maxTokens = min(
PlapreConfig.maxGenerationTokens, PlapreConfig.maxContextLength - inputLen - 1)
print("Generating up to \(maxTokens) tokens...")
let decodeStart = CFAbsoluteTimeGetCurrent()
var nextToken = firstToken
var consecutiveNonAudio = 0
for step in 1..<maxTokens {
let pos = inputLen + step - 1
try runDecodeStep(token: nextToken, pos: pos)
nextToken = sampleFromLogits(
lastLogits, temperature: PlapreConfig.defaultTemperature,
topK: PlapreConfig.defaultTopK)
generatedTokens.append(nextToken)
if nextToken == PlapreConfig.eosToken {
print(" EOS at step \(step)")
break
}
if nextToken >= PlapreConfig.audioTokenOffset && nextToken <= PlapreConfig.audioTokenMax
{
consecutiveNonAudio = 0
} else {
consecutiveNonAudio += 1
if consecutiveNonAudio >= PlapreConfig.nonAudioStopThreshold {
print(
" Stopping: \(PlapreConfig.nonAudioStopThreshold) consecutive non-audio tokens at step \(step)"
)
break
}
}
if step % 25 == 0 {
let elapsed = CFAbsoluteTimeGetCurrent() - decodeStart
let tokPerSec = Double(step) / elapsed
print(
" Step \(step) (\(Float(step) / Float(PlapreConfig.audioTokensPerSecond))s audio) — \(formatTime(elapsed)) elapsed, \(String(format: "%.1f", tokPerSec)) tok/s"
)
}
}
let decodeElapsed = CFAbsoluteTimeGetCurrent() - decodeStart
let decodeSteps = generatedTokens.count - 1
let decodeTokPerSec = Double(decodeSteps) / decodeElapsed
let audioSeconds =
Float(
generatedTokens.filter {
$0 >= PlapreConfig.audioTokenOffset && $0 <= PlapreConfig.audioTokenMax
}.count) / Float(PlapreConfig.audioTokensPerSecond)
let rtf = Float(decodeElapsed) / audioSeconds
print(
" ⏱ Decode: \(formatTime(decodeElapsed)) (\(decodeSteps) steps, \(String(format: "%.1f", decodeTokPerSec)) tok/s)"
)
print(
" ⏱ Audio generated: \(String(format: "%.1f", audioSeconds))s — RTF \(String(format: "%.2f", rtf))x (1.0 = realtime)"
)
// === Audio synthesis ===
let audioTokens = generatedTokens.filter {
$0 >= PlapreConfig.audioTokenOffset && $0 <= PlapreConfig.audioTokenMax
}
print(
"\nGenerated \(generatedTokens.count) tokens, \(audioTokens.count) audio (\(Float(audioTokens.count) / Float(PlapreConfig.audioTokensPerSecond))s)"
)
guard !audioTokens.isEmpty else {
throw PipelineError.noAudioTokensGenerated
}
// === Kanade + Vocoder in chunks ===
let numChunks =
(audioTokens.count + PlapreConfig.kanadeChunkSize - 1) / PlapreConfig.kanadeChunkSize
print("\n--- Kanade + Vocoder (\(numChunks) chunk\(numChunks == 1 ? "" : "s")) ---")
var waveform: [Float] = []
let audioDecodeStart = CFAbsoluteTimeGetCurrent()
for chunkIdx in 0..<numChunks {
let chunkStart = CFAbsoluteTimeGetCurrent()
let start = chunkIdx * PlapreConfig.kanadeChunkSize
let end = min(start + PlapreConfig.kanadeChunkSize, audioTokens.count)
let chunkTokens = Array(audioTokens[start..<end])
var kanadeIndices = chunkTokens.map { $0 - Int32(PlapreConfig.audioTokenOffset) }
let actualCount = kanadeIndices.count
while kanadeIndices.count < PlapreConfig.kanadeChunkSize {
kanadeIndices.append(kanadeIndices.last ?? 0)
}
// Kanade: tokens → mel
let kanadeStart = CFAbsoluteTimeGetCurrent()
let kanadeInput: [String: MLFeatureValue] = [
"token_indices": .init(
multiArray: makeInt32Array(kanadeIndices, shape: [PlapreConfig.kanadeChunkSize])
),
"speaker_embedding": .init(
multiArray: makeFloat32Array(
speakerEmb, shape: [1, PlapreConfig.speakerEmbeddingDimension])),
]
let kanadeProvider = try MLDictionaryFeatureProvider(dictionary: kanadeInput)
let kanadeOutput = try await kanadeModel.prediction(from: kanadeProvider)
let mel = kanadeOutput.featureValue(for: "mel")!.multiArrayValue!
let kanadeElapsed = CFAbsoluteTimeGetCurrent() - kanadeStart
// Vocoder: mel → waveform
let vocoderStart = CFAbsoluteTimeGetCurrent()
let vocoderInput: [String: MLFeatureValue] = [
"mel": .init(multiArray: mel)
]
let vocoderProvider = try MLDictionaryFeatureProvider(dictionary: vocoderInput)
let vocoderOutput = try await vocoderModel.prediction(from: vocoderProvider)
let chunkWaveform = readFloat32Array(
vocoderOutput.featureValue(for: "waveform")!.multiArrayValue!)
let vocoderElapsed = CFAbsoluteTimeGetCurrent() - vocoderStart
let samplesPerToken = chunkWaveform.count / PlapreConfig.kanadeChunkSize
let usableSamples = actualCount * samplesPerToken
waveform.append(contentsOf: chunkWaveform.prefix(usableSamples))
let chunkElapsed = CFAbsoluteTimeGetCurrent() - chunkStart
let chunkDuration = String(
format: "%.1f", Float(usableSamples) / Float(PlapreConfig.sampleRate))
print(
" Chunk \(chunkIdx + 1)/\(numChunks): \(actualCount) tokens → \(chunkDuration)s audio — Kanade \(formatTime(kanadeElapsed)), Vocoder \(formatTime(vocoderElapsed)), total \(formatTime(chunkElapsed))"
)
}
let audioDecodeElapsed = CFAbsoluteTimeGetCurrent() - audioDecodeStart
print(
" ⏱ Audio decode total: \(formatTime(audioDecodeElapsed)) (\(numChunks) chunk\(numChunks == 1 ? "" : "s"))"
)
print(
"Total waveform: \(waveform.count) samples (\(String(format: "%.1f", Float(waveform.count) / Float(PlapreConfig.sampleRate)))s)"
)
// === Write WAV ===
let outputURL = URL(fileURLWithPath: output)
try writeWAV(waveform, to: outputURL)
print("\nSaved to \(output)")
// === Timing Summary ===
let pipelineElapsed = CFAbsoluteTimeGetCurrent() - pipelineStart
let totalAudioDuration = Float(waveform.count) / Float(PlapreConfig.sampleRate)
print("\n========== Timing Summary ==========")
print(" Total pipeline: \(formatTime(pipelineElapsed))")
print(" Audio output: \(String(format: "%.1f", totalAudioDuration))s")
print(
" Overall RTF: \(String(format: "%.2f", Float(pipelineElapsed) / totalAudioDuration))x"
)
print("====================================")
print("Done!")
}
}
enum PipelineError: LocalizedError {
case noAudioTokensGenerated
var errorDescription: String? {
switch self {
case .noAudioTokensGenerated:
return "No audio tokens were generated. The model may have failed to produce speech."
}
}
}