| import ArgumentParser |
| import CoreML |
| import Foundation |
| import Tokenizers |
|
|
| |
|
|
| @main |
| struct PlapreCLI: AsyncParsableCommand { |
| static var configuration = CommandConfiguration( |
| commandName: "plapre-cli", |
| abstract: "Plapre Pico CoreML TTS Pipeline", |
| discussion: "Danish text-to-speech using CoreML models." |
| ) |
|
|
| @Argument(help: "Text to synthesize") |
| var text: String = "Hej, mit navn er Daniel." |
|
|
| @Option(name: .shortAndLong, help: "Speaker voice (available: tor, ida, liv, ask, kaj)") |
| var speaker: String = "tor" |
|
|
| @Option(name: .shortAndLong, help: "Output WAV file path") |
| var output: String = "output.wav" |
|
|
| @Flag(name: .long, help: "Use int8 quantized model (smaller)") |
| var int8 = false |
|
|
| |
|
|
| mutating func run() async throws { |
| print("Plapre Pico CoreML TTS Pipeline") |
| print("================================\n") |
|
|
| print("Text: \(text)") |
| print("Speaker: \(speaker)") |
| print("Output: \(output)\n") |
|
|
| let pipelineStart = CFAbsoluteTimeGetCurrent() |
|
|
| |
| let speakerEmb = try loadSpeaker(speaker) |
| print("Loaded speaker embedding (\(speakerEmb.count) dims)") |
|
|
| let tokenizer = try await measureAsync("Tokenizer load") { |
| try await AutoTokenizer.from(modelFolder: PlapreConfig.repoRoot) |
| } |
| let textTokens = tokenizer.encode(text: text, addSpecialTokens: false).map { Int32($0) } |
| print("Tokenized: \(textTokens.count) tokens: \(textTokens)") |
|
|
| |
| var inputSeq: [Int32] = |
| [PlapreConfig.eosToken, PlapreConfig.textMarkerToken] + textTokens + [ |
| PlapreConfig.audioMarkerToken |
| ] |
| let inputLen = inputSeq.count |
| print("Input sequence: \(inputLen) tokens") |
|
|
| while inputSeq.count < PlapreConfig.prefillSequenceLength { |
| inputSeq.append(PlapreConfig.eosToken) |
| } |
|
|
| |
| print("\nLoading RoPE tables...") |
| let ropeCosF32 = try loadRopeTable("rope_cos.npy") |
| let ropeSinF32 = try loadRopeTable("rope_sin.npy") |
| let ropeCos16: [Float16] = ropeCosF32.map { Float16($0) } |
| let ropeSin16: [Float16] = ropeSinF32.map { Float16($0) } |
| print("RoPE cos: \(ropeCos16.count) values, sin: \(ropeSin16.count) values") |
|
|
| |
| print("\nCompiling models...") |
| let decodeModel = try measure("Compile PlaprePico") { |
| try compileModel( |
| at: PlapreConfig.modelURL(for: "PlaprePico", useInt8: int8)) |
| } |
| let kanadeModel = try measure("Compile KanadeDecoder") { |
| try compileModel( |
| at: PlapreConfig.modelURL(for: "KanadeDecoder", useInt8: false)) |
| } |
| let vocoderModel = try measure("Compile Vocoder") { |
| try compileModel( |
| at: PlapreConfig.modelURL(for: "Vocoder", useInt8: false)) |
| } |
|
|
| |
| let pInputIds = try! MLMultiArray(shape: [1, 1], dataType: .int32) |
| let pCausalMask = try! MLMultiArray( |
| shape: [1, 1, 1, NSNumber(value: PlapreConfig.maxContextLength)], dataType: .float16) |
| let pCos = try! MLMultiArray( |
| shape: [1, 1, 1, NSNumber(value: PlapreConfig.headDimension)], dataType: .float16) |
| let pSin = try! MLMultiArray( |
| shape: [1, 1, 1, NSNumber(value: PlapreConfig.headDimension)], dataType: .float16) |
| let pUpdateMask = try! MLMultiArray( |
| shape: [1, 1, NSNumber(value: PlapreConfig.maxContextLength), 1], dataType: .float16) |
| let pSpeakerEmb = makeFloat16Array( |
| speakerEmb, shape: [1, PlapreConfig.speakerEmbeddingDimension]) |
| let pIsSpeaker = try! MLMultiArray(shape: [1], dataType: .float16) |
|
|
| |
| pCausalMask.withUnsafeMutableBufferPointer(ofType: Float16.self) { ptr, _ in |
| for i in 0..<PlapreConfig.maxContextLength { ptr[i] = Float16(-65504.0) } |
| } |
|
|
| let inputProvider = DecodeInputProvider([ |
| "input_ids": pInputIds, |
| "causal_mask": pCausalMask, |
| "cos": pCos, |
| "sin": pSin, |
| "update_mask": pUpdateMask, |
| "speaker_embedding": pSpeakerEmb, |
| "is_speaker_step": pIsSpeaker, |
| ]) |
|
|
| |
| var state = decodeModel.makeState() |
| var lastLogits = [Float16](repeating: 0, count: PlapreConfig.vocabSize) |
|
|
| func runDecodeStep(token: Int32, pos: Int, isSpeaker: Bool = false) throws { |
| pInputIds.withUnsafeMutableBufferPointer(ofType: Int32.self) { ptr, _ in ptr[0] = token |
| } |
| pCausalMask.withUnsafeMutableBufferPointer(ofType: Float16.self) { ptr, _ in |
| ptr[pos] = Float16(0.0) |
| } |
|
|
| let ropeOffset = pos * PlapreConfig.headDimension |
| pCos.withUnsafeMutableBufferPointer(ofType: Float16.self) { ptr, _ in |
| for i in 0..<PlapreConfig.headDimension { ptr[i] = ropeCos16[ropeOffset + i] } |
| } |
| pSin.withUnsafeMutableBufferPointer(ofType: Float16.self) { ptr, _ in |
| for i in 0..<PlapreConfig.headDimension { ptr[i] = ropeSin16[ropeOffset + i] } |
| } |
| pUpdateMask.withUnsafeMutableBufferPointer(ofType: Float16.self) { ptr, _ in |
| if pos > 0 { ptr[pos - 1] = Float16(0.0) } |
| ptr[pos] = Float16(1.0) |
| } |
| pIsSpeaker.withUnsafeMutableBufferPointer(ofType: Float16.self) { ptr, _ in |
| ptr[0] = isSpeaker ? Float16(1.0) : Float16(0.0) |
| } |
|
|
| let output = try decodeModel.prediction(from: inputProvider, using: state) |
| let arr = output.featureValue(for: "logits")!.multiArrayValue! |
| arr.withUnsafeBufferPointer(ofType: Float16.self) { ptr in |
| for i in 0..<PlapreConfig.vocabSize { lastLogits[i] = ptr[i] } |
| } |
| } |
|
|
| |
| print("\n--- Prefill (token-by-token) ---") |
| let inputTokens = Array(inputSeq.prefix(inputLen)) |
| print("Processing \(inputTokens.count) input tokens...") |
|
|
| let prefillStart = CFAbsoluteTimeGetCurrent() |
| state = decodeModel.makeState() |
| pCausalMask.withUnsafeMutableBufferPointer(ofType: Float16.self) { ptr, _ in |
| for i in 0..<PlapreConfig.maxContextLength { ptr[i] = Float16(-65504.0) } |
| } |
| pUpdateMask.withUnsafeMutableBufferPointer(ofType: Float16.self) { ptr, _ in |
| for i in 0..<PlapreConfig.maxContextLength { ptr[i] = Float16(0.0) } |
| } |
| try runDecodeStep(token: inputTokens[0], pos: 0, isSpeaker: true) |
|
|
| for i in 1..<inputTokens.count { |
| try runDecodeStep(token: inputTokens[i], pos: i, isSpeaker: false) |
| } |
|
|
| let prefillElapsed = CFAbsoluteTimeGetCurrent() - prefillStart |
| let prefillTokPerSec = Double(inputTokens.count) / prefillElapsed |
| print( |
| " ⏱ Prefill: \(formatTime(prefillElapsed)) (\(inputTokens.count) tokens, \(String(format: "%.1f", prefillTokPerSec)) tok/s)" |
| ) |
|
|
| let firstToken = sampleFromLogits( |
| lastLogits, temperature: PlapreConfig.defaultTemperature, topK: PlapreConfig.defaultTopK |
| ) |
| print("First generated token: \(firstToken)") |
|
|
| |
| print("\n--- Decode ---") |
| var generatedTokens: [Int32] = [firstToken] |
| let maxTokens = min( |
| PlapreConfig.maxGenerationTokens, PlapreConfig.maxContextLength - inputLen - 1) |
| print("Generating up to \(maxTokens) tokens...") |
|
|
| let decodeStart = CFAbsoluteTimeGetCurrent() |
| var nextToken = firstToken |
| var consecutiveNonAudio = 0 |
|
|
| for step in 1..<maxTokens { |
| let pos = inputLen + step - 1 |
| try runDecodeStep(token: nextToken, pos: pos) |
| nextToken = sampleFromLogits( |
| lastLogits, temperature: PlapreConfig.defaultTemperature, |
| topK: PlapreConfig.defaultTopK) |
| generatedTokens.append(nextToken) |
|
|
| if nextToken == PlapreConfig.eosToken { |
| print(" EOS at step \(step)") |
| break |
| } |
|
|
| if nextToken >= PlapreConfig.audioTokenOffset && nextToken <= PlapreConfig.audioTokenMax |
| { |
| consecutiveNonAudio = 0 |
| } else { |
| consecutiveNonAudio += 1 |
| if consecutiveNonAudio >= PlapreConfig.nonAudioStopThreshold { |
| print( |
| " Stopping: \(PlapreConfig.nonAudioStopThreshold) consecutive non-audio tokens at step \(step)" |
| ) |
| break |
| } |
| } |
|
|
| if step % 25 == 0 { |
| let elapsed = CFAbsoluteTimeGetCurrent() - decodeStart |
| let tokPerSec = Double(step) / elapsed |
| print( |
| " Step \(step) (\(Float(step) / Float(PlapreConfig.audioTokensPerSecond))s audio) — \(formatTime(elapsed)) elapsed, \(String(format: "%.1f", tokPerSec)) tok/s" |
| ) |
| } |
| } |
|
|
| let decodeElapsed = CFAbsoluteTimeGetCurrent() - decodeStart |
| let decodeSteps = generatedTokens.count - 1 |
| let decodeTokPerSec = Double(decodeSteps) / decodeElapsed |
| let audioSeconds = |
| Float( |
| generatedTokens.filter { |
| $0 >= PlapreConfig.audioTokenOffset && $0 <= PlapreConfig.audioTokenMax |
| }.count) / Float(PlapreConfig.audioTokensPerSecond) |
| let rtf = Float(decodeElapsed) / audioSeconds |
|
|
| print( |
| " ⏱ Decode: \(formatTime(decodeElapsed)) (\(decodeSteps) steps, \(String(format: "%.1f", decodeTokPerSec)) tok/s)" |
| ) |
| print( |
| " ⏱ Audio generated: \(String(format: "%.1f", audioSeconds))s — RTF \(String(format: "%.2f", rtf))x (1.0 = realtime)" |
| ) |
|
|
| |
| let audioTokens = generatedTokens.filter { |
| $0 >= PlapreConfig.audioTokenOffset && $0 <= PlapreConfig.audioTokenMax |
| } |
| print( |
| "\nGenerated \(generatedTokens.count) tokens, \(audioTokens.count) audio (\(Float(audioTokens.count) / Float(PlapreConfig.audioTokensPerSecond))s)" |
| ) |
|
|
| guard !audioTokens.isEmpty else { |
| throw PipelineError.noAudioTokensGenerated |
| } |
|
|
| |
| let numChunks = |
| (audioTokens.count + PlapreConfig.kanadeChunkSize - 1) / PlapreConfig.kanadeChunkSize |
| print("\n--- Kanade + Vocoder (\(numChunks) chunk\(numChunks == 1 ? "" : "s")) ---") |
|
|
| var waveform: [Float] = [] |
| let audioDecodeStart = CFAbsoluteTimeGetCurrent() |
|
|
| for chunkIdx in 0..<numChunks { |
| let chunkStart = CFAbsoluteTimeGetCurrent() |
| let start = chunkIdx * PlapreConfig.kanadeChunkSize |
| let end = min(start + PlapreConfig.kanadeChunkSize, audioTokens.count) |
| let chunkTokens = Array(audioTokens[start..<end]) |
|
|
| var kanadeIndices = chunkTokens.map { $0 - Int32(PlapreConfig.audioTokenOffset) } |
| let actualCount = kanadeIndices.count |
| while kanadeIndices.count < PlapreConfig.kanadeChunkSize { |
| kanadeIndices.append(kanadeIndices.last ?? 0) |
| } |
|
|
| |
| let kanadeStart = CFAbsoluteTimeGetCurrent() |
| let kanadeInput: [String: MLFeatureValue] = [ |
| "token_indices": .init( |
| multiArray: makeInt32Array(kanadeIndices, shape: [PlapreConfig.kanadeChunkSize]) |
| ), |
| "speaker_embedding": .init( |
| multiArray: makeFloat32Array( |
| speakerEmb, shape: [1, PlapreConfig.speakerEmbeddingDimension])), |
| ] |
| let kanadeProvider = try MLDictionaryFeatureProvider(dictionary: kanadeInput) |
| let kanadeOutput = try await kanadeModel.prediction(from: kanadeProvider) |
| let mel = kanadeOutput.featureValue(for: "mel")!.multiArrayValue! |
| let kanadeElapsed = CFAbsoluteTimeGetCurrent() - kanadeStart |
|
|
| |
| let vocoderStart = CFAbsoluteTimeGetCurrent() |
| let vocoderInput: [String: MLFeatureValue] = [ |
| "mel": .init(multiArray: mel) |
| ] |
| let vocoderProvider = try MLDictionaryFeatureProvider(dictionary: vocoderInput) |
| let vocoderOutput = try await vocoderModel.prediction(from: vocoderProvider) |
| let chunkWaveform = readFloat32Array( |
| vocoderOutput.featureValue(for: "waveform")!.multiArrayValue!) |
| let vocoderElapsed = CFAbsoluteTimeGetCurrent() - vocoderStart |
|
|
| let samplesPerToken = chunkWaveform.count / PlapreConfig.kanadeChunkSize |
| let usableSamples = actualCount * samplesPerToken |
| waveform.append(contentsOf: chunkWaveform.prefix(usableSamples)) |
|
|
| let chunkElapsed = CFAbsoluteTimeGetCurrent() - chunkStart |
| let chunkDuration = String( |
| format: "%.1f", Float(usableSamples) / Float(PlapreConfig.sampleRate)) |
| print( |
| " Chunk \(chunkIdx + 1)/\(numChunks): \(actualCount) tokens → \(chunkDuration)s audio — Kanade \(formatTime(kanadeElapsed)), Vocoder \(formatTime(vocoderElapsed)), total \(formatTime(chunkElapsed))" |
| ) |
| } |
|
|
| let audioDecodeElapsed = CFAbsoluteTimeGetCurrent() - audioDecodeStart |
| print( |
| " ⏱ Audio decode total: \(formatTime(audioDecodeElapsed)) (\(numChunks) chunk\(numChunks == 1 ? "" : "s"))" |
| ) |
|
|
| print( |
| "Total waveform: \(waveform.count) samples (\(String(format: "%.1f", Float(waveform.count) / Float(PlapreConfig.sampleRate)))s)" |
| ) |
|
|
| |
| let outputURL = URL(fileURLWithPath: output) |
| try writeWAV(waveform, to: outputURL) |
| print("\nSaved to \(output)") |
|
|
| |
| let pipelineElapsed = CFAbsoluteTimeGetCurrent() - pipelineStart |
| let totalAudioDuration = Float(waveform.count) / Float(PlapreConfig.sampleRate) |
| print("\n========== Timing Summary ==========") |
| print(" Total pipeline: \(formatTime(pipelineElapsed))") |
| print(" Audio output: \(String(format: "%.1f", totalAudioDuration))s") |
| print( |
| " Overall RTF: \(String(format: "%.2f", Float(pipelineElapsed) / totalAudioDuration))x" |
| ) |
| print("====================================") |
| print("Done!") |
| } |
| } |
|
|
| enum PipelineError: LocalizedError { |
| case noAudioTokensGenerated |
|
|
| var errorDescription: String? { |
| switch self { |
| case .noAudioTokensGenerated: |
| return "No audio tokens were generated. The model may have failed to produce speech." |
| } |
| } |
| } |
|
|