oraculumai's picture
examples: add runnable CoreML diffusion loop scripts
592f7ae verified
import Foundation
import CoreML
struct InputPayload: Codable {
let prompt: String
let prompt_ids: [Int32]
let seq_len: Int
let max_new_tokens: Int
let steps: Int
let mask_token_id: Int32
let eos_token_id: Int32
let pad_token_id: Int32
let compute_units: String?
}
struct StepStat: Codable {
let step: Int
let masked_before: Int
let fixed_this_step: Int
let masked_after: Int
let avg_fixed_score: Float
}
struct OutputPayload: Codable {
let prompt: String
let prompt_ids: [Int32]
let final_ids: [Int32]
let generated_ids: [Int32]
let generated_ids_untrimmed: [Int32]
let prompt_len: Int
let total_len: Int
let step_stats: [StepStat]
let load_seconds: Double
let total_predict_seconds: Double
let loop_seconds: Double
}
func argValue(_ name: String) -> String? {
let args = CommandLine.arguments
guard let idx = args.firstIndex(of: name), idx + 1 < args.count else {
return nil
}
return args[idx + 1]
}
func usageAndExit() -> Never {
fputs("Usage: swift scripts/llada_diffuse.swift --model <path.mlmodelc|mlpackage> --input <input.json> --output <output.json>\n", stderr)
exit(2)
}
func computeUnits(from value: String?) -> MLComputeUnits {
guard let raw = value?.lowercased() else { return .all }
switch raw {
case "cpu", "cpuonly":
return .cpuOnly
case "cpugpu", "cpuandgpu":
return .cpuAndGPU
case "cpune", "cpuandneuralengine":
return .cpuAndNeuralEngine
default:
return .all
}
}
func int32Array(_ values: [Int32], shape: [NSNumber]) throws -> MLMultiArray {
let arr = try MLMultiArray(shape: shape, dataType: .int32)
for i in 0..<values.count {
arr[i] = NSNumber(value: values[i])
}
return arr
}
do {
guard let modelPath = argValue("--model"),
let inputPath = argValue("--input"),
let outputPath = argValue("--output") else {
usageAndExit()
}
let inputData = try Data(contentsOf: URL(fileURLWithPath: inputPath))
let input = try JSONDecoder().decode(InputPayload.self, from: inputData)
let seqLen = max(1, input.seq_len)
let promptLen = min(input.prompt_ids.count, seqLen)
let maxNew = max(0, min(input.max_new_tokens, seqLen - promptLen))
let totalLen = promptLen + maxNew
let stepCount = max(1, input.steps)
var tokenBuffer = Array(repeating: input.pad_token_id, count: seqLen)
for i in 0..<promptLen {
tokenBuffer[i] = input.prompt_ids[i]
}
var attentionMask = Array(repeating: Int32(0), count: seqLen)
if totalLen > 0 {
for i in 0..<totalLen {
attentionMask[i] = 1
}
}
var fixed = Array(repeating: false, count: seqLen)
if promptLen > 0 {
for i in 0..<promptLen {
fixed[i] = true
}
}
if totalLen > promptLen {
for i in promptLen..<totalLen {
tokenBuffer[i] = input.mask_token_id
}
}
let cfg = MLModelConfiguration()
cfg.computeUnits = computeUnits(from: input.compute_units)
let modelURL = URL(fileURLWithPath: modelPath)
let modelLoadStart = Date()
let resolvedModelURL: URL
if modelURL.pathExtension.lowercased() == "mlpackage" {
resolvedModelURL = try MLModel.compileModel(at: modelURL)
} else {
resolvedModelURL = modelURL
}
let model = try MLModel(contentsOf: resolvedModelURL, configuration: cfg)
let modelLoadSeconds = Date().timeIntervalSince(modelLoadStart)
func predict(ids: [Int32], mask: [Int32]) throws -> ([Int32], [Float], Double) {
let idsMA = try int32Array(ids, shape: [1, NSNumber(value: seqLen)])
let maskMA = try int32Array(mask, shape: [1, NSNumber(value: seqLen)])
let provider = try MLDictionaryFeatureProvider(dictionary: [
"input_ids": MLFeatureValue(multiArray: idsMA),
"attention_mask": MLFeatureValue(multiArray: maskMA)
])
let t0 = Date()
let out = try model.prediction(from: provider)
let dt = Date().timeIntervalSince(t0)
guard let predMA = out.featureValue(for: "var_4801")?.multiArrayValue,
let scoreMA = out.featureValue(for: "var_4806")?.multiArrayValue else {
throw NSError(domain: "llada_diffuse", code: 1, userInfo: [NSLocalizedDescriptionKey: "Model outputs var_4801/var_4806 not found"])
}
var pred = Array(repeating: Int32(0), count: seqLen)
var score = Array(repeating: Float(0), count: seqLen)
for i in 0..<seqLen {
pred[i] = predMA[i].int32Value
score[i] = scoreMA[i].floatValue
}
return (pred, score, dt)
}
var stepStats: [StepStat] = []
var totalPredictSeconds = 0.0
let loopStart = Date()
for step in 1...stepCount {
var maskedPositions: [Int] = []
if totalLen > promptLen {
for i in promptLen..<totalLen where !fixed[i] {
maskedPositions.append(i)
}
}
if maskedPositions.isEmpty {
break
}
let (pred, score, predictSeconds) = try predict(ids: tokenBuffer, mask: attentionMask)
totalPredictSeconds += predictSeconds
let remainingAfter = Int(Double(maxNew) * Double(max(0, stepCount - step)) / Double(stepCount))
let leaveMasked = min(remainingAfter, maskedPositions.count)
let fixCount = max(0, maskedPositions.count - leaveMasked)
let ranked = maskedPositions.sorted { score[$0] > score[$1] }
var scoreSum: Float = 0
if fixCount > 0 {
for j in 0..<fixCount {
let pos = ranked[j]
tokenBuffer[pos] = pred[pos]
fixed[pos] = true
scoreSum += score[pos]
}
}
if fixCount < ranked.count {
for j in fixCount..<ranked.count {
tokenBuffer[ranked[j]] = input.mask_token_id
}
}
var maskedAfter = 0
if totalLen > promptLen {
for i in promptLen..<totalLen where !fixed[i] {
maskedAfter += 1
}
}
let avgScore = fixCount > 0 ? scoreSum / Float(fixCount) : 0
stepStats.append(StepStat(
step: step,
masked_before: maskedPositions.count,
fixed_this_step: fixCount,
masked_after: maskedAfter,
avg_fixed_score: avgScore
))
}
// Safety: if any generation positions are still masked, fill them from one final pass.
var remainingPositions: [Int] = []
if totalLen > promptLen {
for i in promptLen..<totalLen where !fixed[i] {
remainingPositions.append(i)
}
}
if !remainingPositions.isEmpty {
let (pred, _, predictSeconds) = try predict(ids: tokenBuffer, mask: attentionMask)
totalPredictSeconds += predictSeconds
for pos in remainingPositions {
tokenBuffer[pos] = pred[pos]
fixed[pos] = true
}
stepStats.append(StepStat(
step: stepCount + 1,
masked_before: remainingPositions.count,
fixed_this_step: remainingPositions.count,
masked_after: 0,
avg_fixed_score: 0
))
}
let loopSeconds = Date().timeIntervalSince(loopStart)
let untrimmedGenerated = totalLen > promptLen ? Array(tokenBuffer[promptLen..<totalLen]) : []
var generated: [Int32] = []
generated.reserveCapacity(untrimmedGenerated.count)
for token in untrimmedGenerated {
if token == input.eos_token_id {
break
}
generated.append(token)
}
let output = OutputPayload(
prompt: input.prompt,
prompt_ids: Array(tokenBuffer.prefix(promptLen)),
final_ids: tokenBuffer,
generated_ids: generated,
generated_ids_untrimmed: untrimmedGenerated,
prompt_len: promptLen,
total_len: totalLen,
step_stats: stepStats,
load_seconds: modelLoadSeconds,
total_predict_seconds: totalPredictSeconds,
loop_seconds: loopSeconds
)
let encoded = try JSONEncoder().encode(output)
try encoded.write(to: URL(fileURLWithPath: outputPath))
print("Wrote \(outputPath)")
print(String(format: "load=%.2fs predict_total=%.2fs loop=%.2fs", output.load_seconds, output.total_predict_seconds, output.loop_seconds))
} catch {
fputs("ERROR: \(error)\n", stderr)
exit(1)
}