| | import Foundation |
| | import CoreML |
| |
|
| | struct InputPayload: Codable { |
| | let prompt: String |
| | let prompt_ids: [Int32] |
| | let seq_len: Int |
| | let max_new_tokens: Int |
| | let steps: Int |
| | let mask_token_id: Int32 |
| | let eos_token_id: Int32 |
| | let pad_token_id: Int32 |
| | let compute_units: String? |
| | } |
| |
|
| | struct StepStat: Codable { |
| | let step: Int |
| | let masked_before: Int |
| | let fixed_this_step: Int |
| | let masked_after: Int |
| | let avg_fixed_score: Float |
| | } |
| |
|
| | struct OutputPayload: Codable { |
| | let prompt: String |
| | let prompt_ids: [Int32] |
| | let final_ids: [Int32] |
| | let generated_ids: [Int32] |
| | let generated_ids_untrimmed: [Int32] |
| | let prompt_len: Int |
| | let total_len: Int |
| | let step_stats: [StepStat] |
| | let load_seconds: Double |
| | let total_predict_seconds: Double |
| | let loop_seconds: Double |
| | } |
| |
|
| | func argValue(_ name: String) -> String? { |
| | let args = CommandLine.arguments |
| | guard let idx = args.firstIndex(of: name), idx + 1 < args.count else { |
| | return nil |
| | } |
| | return args[idx + 1] |
| | } |
| |
|
| | func usageAndExit() -> Never { |
| | fputs("Usage: swift scripts/llada_diffuse.swift --model <path.mlmodelc|mlpackage> --input <input.json> --output <output.json>\n", stderr) |
| | exit(2) |
| | } |
| |
|
| | func computeUnits(from value: String?) -> MLComputeUnits { |
| | guard let raw = value?.lowercased() else { return .all } |
| | switch raw { |
| | case "cpu", "cpuonly": |
| | return .cpuOnly |
| | case "cpugpu", "cpuandgpu": |
| | return .cpuAndGPU |
| | case "cpune", "cpuandneuralengine": |
| | return .cpuAndNeuralEngine |
| | default: |
| | return .all |
| | } |
| | } |
| |
|
| | func int32Array(_ values: [Int32], shape: [NSNumber]) throws -> MLMultiArray { |
| | let arr = try MLMultiArray(shape: shape, dataType: .int32) |
| | for i in 0..<values.count { |
| | arr[i] = NSNumber(value: values[i]) |
| | } |
| | return arr |
| | } |
| |
|
| | do { |
| | guard let modelPath = argValue("--model"), |
| | let inputPath = argValue("--input"), |
| | let outputPath = argValue("--output") else { |
| | usageAndExit() |
| | } |
| |
|
| | let inputData = try Data(contentsOf: URL(fileURLWithPath: inputPath)) |
| | let input = try JSONDecoder().decode(InputPayload.self, from: inputData) |
| |
|
| | let seqLen = max(1, input.seq_len) |
| | let promptLen = min(input.prompt_ids.count, seqLen) |
| | let maxNew = max(0, min(input.max_new_tokens, seqLen - promptLen)) |
| | let totalLen = promptLen + maxNew |
| | let stepCount = max(1, input.steps) |
| |
|
| | var tokenBuffer = Array(repeating: input.pad_token_id, count: seqLen) |
| | for i in 0..<promptLen { |
| | tokenBuffer[i] = input.prompt_ids[i] |
| | } |
| |
|
| | var attentionMask = Array(repeating: Int32(0), count: seqLen) |
| | if totalLen > 0 { |
| | for i in 0..<totalLen { |
| | attentionMask[i] = 1 |
| | } |
| | } |
| |
|
| | var fixed = Array(repeating: false, count: seqLen) |
| | if promptLen > 0 { |
| | for i in 0..<promptLen { |
| | fixed[i] = true |
| | } |
| | } |
| | if totalLen > promptLen { |
| | for i in promptLen..<totalLen { |
| | tokenBuffer[i] = input.mask_token_id |
| | } |
| | } |
| |
|
| | let cfg = MLModelConfiguration() |
| | cfg.computeUnits = computeUnits(from: input.compute_units) |
| |
|
| | let modelURL = URL(fileURLWithPath: modelPath) |
| | let modelLoadStart = Date() |
| | let resolvedModelURL: URL |
| | if modelURL.pathExtension.lowercased() == "mlpackage" { |
| | resolvedModelURL = try MLModel.compileModel(at: modelURL) |
| | } else { |
| | resolvedModelURL = modelURL |
| | } |
| | let model = try MLModel(contentsOf: resolvedModelURL, configuration: cfg) |
| | let modelLoadSeconds = Date().timeIntervalSince(modelLoadStart) |
| |
|
| | func predict(ids: [Int32], mask: [Int32]) throws -> ([Int32], [Float], Double) { |
| | let idsMA = try int32Array(ids, shape: [1, NSNumber(value: seqLen)]) |
| | let maskMA = try int32Array(mask, shape: [1, NSNumber(value: seqLen)]) |
| | let provider = try MLDictionaryFeatureProvider(dictionary: [ |
| | "input_ids": MLFeatureValue(multiArray: idsMA), |
| | "attention_mask": MLFeatureValue(multiArray: maskMA) |
| | ]) |
| | let t0 = Date() |
| | let out = try model.prediction(from: provider) |
| | let dt = Date().timeIntervalSince(t0) |
| |
|
| | guard let predMA = out.featureValue(for: "var_4801")?.multiArrayValue, |
| | let scoreMA = out.featureValue(for: "var_4806")?.multiArrayValue else { |
| | throw NSError(domain: "llada_diffuse", code: 1, userInfo: [NSLocalizedDescriptionKey: "Model outputs var_4801/var_4806 not found"]) |
| | } |
| |
|
| | var pred = Array(repeating: Int32(0), count: seqLen) |
| | var score = Array(repeating: Float(0), count: seqLen) |
| | for i in 0..<seqLen { |
| | pred[i] = predMA[i].int32Value |
| | score[i] = scoreMA[i].floatValue |
| | } |
| | return (pred, score, dt) |
| | } |
| |
|
| | var stepStats: [StepStat] = [] |
| | var totalPredictSeconds = 0.0 |
| | let loopStart = Date() |
| |
|
| | for step in 1...stepCount { |
| | var maskedPositions: [Int] = [] |
| | if totalLen > promptLen { |
| | for i in promptLen..<totalLen where !fixed[i] { |
| | maskedPositions.append(i) |
| | } |
| | } |
| | if maskedPositions.isEmpty { |
| | break |
| | } |
| |
|
| | let (pred, score, predictSeconds) = try predict(ids: tokenBuffer, mask: attentionMask) |
| | totalPredictSeconds += predictSeconds |
| |
|
| | let remainingAfter = Int(Double(maxNew) * Double(max(0, stepCount - step)) / Double(stepCount)) |
| | let leaveMasked = min(remainingAfter, maskedPositions.count) |
| | let fixCount = max(0, maskedPositions.count - leaveMasked) |
| |
|
| | let ranked = maskedPositions.sorted { score[$0] > score[$1] } |
| |
|
| | var scoreSum: Float = 0 |
| | if fixCount > 0 { |
| | for j in 0..<fixCount { |
| | let pos = ranked[j] |
| | tokenBuffer[pos] = pred[pos] |
| | fixed[pos] = true |
| | scoreSum += score[pos] |
| | } |
| | } |
| | if fixCount < ranked.count { |
| | for j in fixCount..<ranked.count { |
| | tokenBuffer[ranked[j]] = input.mask_token_id |
| | } |
| | } |
| |
|
| | var maskedAfter = 0 |
| | if totalLen > promptLen { |
| | for i in promptLen..<totalLen where !fixed[i] { |
| | maskedAfter += 1 |
| | } |
| | } |
| | let avgScore = fixCount > 0 ? scoreSum / Float(fixCount) : 0 |
| | stepStats.append(StepStat( |
| | step: step, |
| | masked_before: maskedPositions.count, |
| | fixed_this_step: fixCount, |
| | masked_after: maskedAfter, |
| | avg_fixed_score: avgScore |
| | )) |
| | } |
| |
|
| | |
| | var remainingPositions: [Int] = [] |
| | if totalLen > promptLen { |
| | for i in promptLen..<totalLen where !fixed[i] { |
| | remainingPositions.append(i) |
| | } |
| | } |
| | if !remainingPositions.isEmpty { |
| | let (pred, _, predictSeconds) = try predict(ids: tokenBuffer, mask: attentionMask) |
| | totalPredictSeconds += predictSeconds |
| | for pos in remainingPositions { |
| | tokenBuffer[pos] = pred[pos] |
| | fixed[pos] = true |
| | } |
| | stepStats.append(StepStat( |
| | step: stepCount + 1, |
| | masked_before: remainingPositions.count, |
| | fixed_this_step: remainingPositions.count, |
| | masked_after: 0, |
| | avg_fixed_score: 0 |
| | )) |
| | } |
| |
|
| | let loopSeconds = Date().timeIntervalSince(loopStart) |
| | let untrimmedGenerated = totalLen > promptLen ? Array(tokenBuffer[promptLen..<totalLen]) : [] |
| | var generated: [Int32] = [] |
| | generated.reserveCapacity(untrimmedGenerated.count) |
| | for token in untrimmedGenerated { |
| | if token == input.eos_token_id { |
| | break |
| | } |
| | generated.append(token) |
| | } |
| |
|
| | let output = OutputPayload( |
| | prompt: input.prompt, |
| | prompt_ids: Array(tokenBuffer.prefix(promptLen)), |
| | final_ids: tokenBuffer, |
| | generated_ids: generated, |
| | generated_ids_untrimmed: untrimmedGenerated, |
| | prompt_len: promptLen, |
| | total_len: totalLen, |
| | step_stats: stepStats, |
| | load_seconds: modelLoadSeconds, |
| | total_predict_seconds: totalPredictSeconds, |
| | loop_seconds: loopSeconds |
| | ) |
| |
|
| | let encoded = try JSONEncoder().encode(output) |
| | try encoded.write(to: URL(fileURLWithPath: outputPath)) |
| | print("Wrote \(outputPath)") |
| | print(String(format: "load=%.2fs predict_total=%.2fs loop=%.2fs", output.load_seconds, output.total_predict_seconds, output.loop_seconds)) |
| | } catch { |
| | fputs("ERROR: \(error)\n", stderr) |
| | exit(1) |
| | } |
| |
|