| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
|
|
| import Foundation |
| import CoreML |
| import Tokenizers |
|
|
| |
| @MainActor |
| public final class Qwen3CoreML { |
|
|
| |
|
|
| public struct Config { |
| public static let maxContextLength = 1024 |
| public static let maxTokens = 512 |
| public static let temperature: Float = 0.7 |
| public static let topK = 40 |
| public static let topP: Float = 0.9 |
|
|
| |
| public static let prefillModelName = "Qwen3-0.6B-Prefill-Int4" |
| public static let decodeModelName = "Qwen3-0.6B-Decode-Int4" |
| public static let tokenizerModelId = "Qwen/Qwen3-0.6B" |
| } |
|
|
| |
|
|
| private var prefillModel: MLModel? |
| private var decodeModel: MLModel? |
| private var tokenizer: Tokenizer? |
| private var decodeState: MLState? |
|
|
| private(set) var isModelsLoaded = false |
| private(set) var isGenerating = false |
|
|
| |
| private let eosTokenIds: Set<Int> = [151643, 151645] |
| private let bosTokenId = 151643 |
| private let chatTemplate = "<|im_start|>system\nYou are a helpful assistant.<|im_end|>\n<|im_start|>user\n%@<|im_end|>\n<|im_start|>assistant\n" |
|
|
| |
| private(set) var tokensPerSecond: Double = 0 |
| private(set) var currentPosition = 0 |
|
|
| |
|
|
| public init() { |
| print("🤖 Qwen3CoreML initialized") |
| } |
|
|
| |
|
|
| |
| public func loadModels() async throws { |
| guard !isModelsLoaded else { |
| print("🤖 Qwen3: Models already loaded") |
| return |
| } |
|
|
| print("🤖 Qwen3: Loading CoreML models and tokenizer...") |
|
|
| do { |
| |
| try await loadModel(named: Config.prefillModelName, into: &prefillModel) |
| print("✅ Prefill model loaded") |
|
|
| |
| try await loadModel(named: Config.decodeModelName, into: &decodeModel, withState: true) |
| print("✅ Decode model loaded") |
|
|
| |
| tokenizer = try await AutoTokenizer.from(pretrained: Config.tokenizerModelId) |
| print("✅ Tokenizer loaded") |
|
|
| isModelsLoaded = true |
| print("🎉 Qwen3 models loaded successfully") |
|
|
| } catch { |
| print("❌ Failed to load Qwen3 models: \(error.localizedDescription)") |
| throw Qwen3Error.modelLoadingFailed(error.localizedDescription) |
| } |
| } |
|
|
| |
| private func loadModel(named modelName: String, into model: inout MLModel?, withState: Bool = false) async throws { |
| let config = MLModelConfiguration() |
| config.computeUnits = .cpuAndNeuralEngine |
|
|
| |
| var modelURL: URL? |
|
|
| |
| if let url = Bundle.main.url(forResource: modelName, withExtension: "mlpackage") { |
| modelURL = url |
| } |
| |
| else if let appSupport = FileManager.default.urls(for: .applicationSupportDirectory, in: .userDomainMask).first { |
| let appDir = appSupport.appendingPathComponent("Qwen3CoreML") |
| let modelsDir = appDir.appendingPathComponent("Models") |
| let modelPath = modelsDir.appendingPathComponent("\(modelName).mlpackage") |
| if FileManager.default.fileExists(atPath: modelPath.path) { |
| modelURL = modelPath |
| } |
| } |
|
|
| guard let modelURL = modelURL else { |
| throw Qwen3Error.modelNotFound(modelName) |
| } |
|
|
| |
| let compiledURL = try await MLModel.compileModel(at: modelURL) |
| model = try MLModel(contentsOf: compiledURL, configuration: config) |
|
|
| |
| if withState { |
| decodeState = model?.makeState() |
| } |
| } |
|
|
| |
|
|
| |
| public func generate( |
| userMessage: String, |
| systemPrompt: String = "You are a helpful assistant.", |
| maxTokens: Int = Config.maxTokens, |
| temperature: Float = Config.temperature, |
| enableThinking: Bool = false |
| ) -> AsyncStream<String> { |
| AsyncStream { continuation in |
| Task { |
| await generateInternal( |
| userMessage: userMessage, |
| systemPrompt: systemPrompt, |
| maxTokens: maxTokens, |
| temperature: temperature, |
| enableThinking: enableThinking, |
| continuation: continuation |
| ) |
| } |
| } |
| } |
|
|
| |
| public func generateSync( |
| userMessage: String, |
| systemPrompt: String = "You are a helpful assistant.", |
| maxTokens: Int = Config.maxTokens, |
| temperature: Float = Config.temperature, |
| enableThinking: Bool = false |
| ) async throws -> String { |
| guard isModelsLoaded, let tokenizer = tokenizer else { |
| throw Qwen3Error.modelNotLoaded |
| } |
|
|
| var result = "" |
|
|
| for await chunk in generate( |
| userMessage: userMessage, |
| systemPrompt: systemPrompt, |
| maxTokens: maxTokens, |
| temperature: temperature, |
| enableThinking: enableThinking |
| ) { |
| result += chunk |
| } |
|
|
| return result |
| } |
|
|
| |
| public func resetConversation() { |
| decodeState = decodeModel?.makeState() |
| currentPosition = 0 |
| print("🔄 Qwen3 conversation reset") |
| } |
|
|
| |
|
|
| private func generateInternal( |
| userMessage: String, |
| systemPrompt: String, |
| maxTokens: Int, |
| temperature: Float, |
| enableThinking: Bool, |
| continuation: AsyncStream<String>.Continuation |
| ) async { |
| guard isModelsLoaded, |
| let prefillModel = prefillModel, |
| let decodeModel = decodeModel, |
| let tokenizer = tokenizer, |
| var decodeState = decodeState else { |
| continuation.finish() |
| return |
| } |
|
|
| isGenerating = true |
| let startTime = Date() |
|
|
| defer { |
| isGenerating = false |
| continuation.finish() |
| } |
|
|
| do { |
| |
| let chatPrompt = formatChatPrompt( |
| userMessage: userMessage, |
| systemPrompt: systemPrompt, |
| enableThinking: enableThinking |
| ) |
|
|
| |
| let inputTokens = tokenizer.encode(text: chatPrompt) |
|
|
| |
| guard inputTokens.count + maxTokens <= Config.maxContextLength else { |
| print("⚠️ Prompt too long, truncating...") |
| |
| let truncatedTokens = Array(inputTokens.suffix(Config.maxContextLength - maxTokens)) |
| |
| let tokensToProcess = truncatedTokens.first == bosTokenId ? truncatedTokens : [bosTokenId] + truncatedTokens |
| try await processTokens(tokensToProcess, model: prefillModel) |
| } |
|
|
| |
| try await processTokens(inputTokens, model: prefillModel) |
|
|
| |
| var generatedTokens: [Int] = [] |
| var isInThinkingBlock = false |
|
|
| for _ in 0..<maxTokens { |
| let nextToken = try await generateNextToken( |
| temperature: temperature, |
| decodeModel: decodeModel, |
| decodeState: &decodeState |
| ) |
|
|
| |
| if eosTokenIds.contains(nextToken) { |
| break |
| } |
|
|
| generatedTokens.append(nextToken) |
|
|
| |
| if nextToken == 151667 { |
| isInThinkingBlock = true |
| } else if nextToken == 151668 { |
| isInThinkingBlock = false |
| if !enableThinking { |
| continue |
| } |
| } |
|
|
| |
| let tokenText = tokenizer.decode(tokens: [nextToken]) |
|
|
| |
| if !isInThinkingBlock || enableThinking { |
| continuation.yield(tokenText) |
| } |
| } |
|
|
| |
| let elapsed = Date().timeIntervalSince(startTime) |
| tokensPerSecond = Double(generatedTokens.count) / elapsed |
| print("📊 Generation: \(generatedTokens.count) tokens in \(String(format: "%.2f", elapsed))s (\(String(format: "%.1f", tokensPerSecond)) tok/s)") |
|
|
| } catch { |
| print("❌ Generation failed: \(error.localizedDescription)") |
| |
| } |
| } |
|
|
| |
| private func processTokens(_ tokens: [Int], model: MLModel) async throws { |
| let seqLen = tokens.count |
|
|
| |
| let causalMask = createCausalMask(seqLen: seqLen, totalLen: seqLen) |
| let inputIdsTensor = MLTensor( |
| shape: [1, seqLen], |
| scalars: tokens.map { Int32($0) }, |
| scalarType: Int32.self |
| ) |
|
|
| let inputs = try MLDictionaryFeatureProvider(dictionary: [ |
| "inputIds": MLFeatureValue(tensor: inputIdsTensor), |
| "causalMask": MLFeatureValue(tensor: causalMask) |
| ]) |
|
|
| |
| _ = try await model.prediction(from: inputs) |
| currentPosition = seqLen |
| } |
|
|
| |
| private func generateNextToken( |
| temperature: Float, |
| decodeModel: MLModel, |
| decodeState: inout MLState |
| ) async throws -> Int { |
| |
| let positionIds = [Int32(currentPosition)] |
|
|
| let positionTensor = MLTensor( |
| shape: [1, 1], |
| scalars: positionIds, |
| scalarType: Int32.self |
| ) |
|
|
| |
| let dummyInputTensor = MLTensor( |
| shape: [1, 1], |
| scalars: [Int32(0)], |
| scalarType: Int32.self |
| ) |
|
|
| let inputs = try MLDictionaryFeatureProvider(dictionary: [ |
| "inputIds": MLFeatureValue(tensor: dummyInputTensor), |
| "positionIds": MLFeatureValue(tensor: positionTensor), |
| ]) |
|
|
| let output = try await decodeModel.prediction(from: inputs, using: decodeState) |
|
|
| guard let logitsTensor = output.featureValue(for: "logits")?.tensorValue(of: Float16.self) else { |
| throw Qwen3Error.inferenceError("No logits in model output") |
| } |
|
|
| |
| let nextToken = sampleToken(from: logitsTensor, temperature: temperature) |
|
|
| |
| currentPosition += 1 |
|
|
| return nextToken |
| } |
|
|
| |
| private func sampleToken(from logitsTensor: MLTensor, temperature: Float) -> Int { |
| |
| let vocabSize = logitsTensor.shape[2] |
|
|
| var logitsArray = [Float](repeating: 0, count: vocabSize) |
| logitsTensor.withUnsafeBufferPointer(of: Float16.self) { buffer in |
| for i in 0..<vocabSize { |
| logitsArray[i] = Float(buffer[vocabSize + i]) |
| } |
| } |
|
|
| if temperature <= 0 { |
| |
| return logitsArray.enumerated().max(by: { $0.element < $1.element })?.offset ?? 0 |
| } |
|
|
| |
| let scaledLogits = logitsArray.map { $0 / temperature } |
| let maxLogit = scaledLogits.max() ?? 0 |
| let expLogits = scaledLogits.map { exp($0 - maxLogit) } |
| let sumExp = expLogits.reduce(0, +) |
| let probs = expLogits.map { $0 / sumExp } |
|
|
| |
| let random = Float.random(in: 0..<1) |
| var cumulative: Float = 0 |
|
|
| for (index, prob) in probs.enumerated() { |
| cumulative += prob |
| if random < cumulative { |
| return index |
| } |
| } |
|
|
| return vocabSize - 1 |
| } |
|
|
| |
| private func createCausalMask(seqLen: Int, totalLen: Int) -> MLTensor { |
| var maskData = [Float16](repeating: Float16(-Float.infinity), count: seqLen * totalLen) |
|
|
| for i in 0..<seqLen { |
| for j in 0..<(totalLen - seqLen + i + 1) { |
| maskData[i * totalLen + j] = Float16(0) |
| } |
| } |
|
|
| return MLTensor( |
| shape: [1, 1, seqLen, totalLen], |
| scalars: maskData, |
| scalarType: Float16.self |
| ) |
| } |
|
|
| |
| private func formatChatPrompt(userMessage: String, systemPrompt: String, enableThinking: Bool) -> String { |
| let chatTemplate = "<|im_start|>system\n\(systemPrompt)<|im_end|>\n<|im_start|>user\n\(userMessage)<|im_end|>\n<|im_start|>assistant\n" |
|
|
| if enableThinking { |
| return chatTemplate |
| } else { |
| return chatTemplate + "/no_think\n" |
| } |
| } |
| } |
|
|
| |
|
|
| public enum Qwen3Error: LocalizedError { |
| case modelNotFound(String) |
| case modelNotLoaded |
| case modelLoadingFailed(String) |
| case inferenceError(String) |
| case tokenizationError |
|
|
| public var errorDescription: String? { |
| switch self { |
| case .modelNotFound(let modelName): |
| return "Model '\(modelName)' not found. Place it in app bundle or ~/Library/Application Support/Qwen3CoreML/Models/" |
| case .modelNotLoaded: |
| return "Models are not loaded. Call loadModels() first." |
| case .inferenceError(let message): |
| return "Inference error: \(message)" |
| case .tokenizationError: |
| return "Tokenization error" |
| } |
| } |
| } |
|
|
| |
|
|
| |
| extension Qwen3CoreML { |
|
|
| |
| public func correct(text: String) async throws -> String { |
| return try await generateSync( |
| userMessage: """ |
| Please correct the following text by fixing punctuation, capitalization, and grammatical errors. |
| Keep the original language. Only output the corrected text, nothing else. |
| |
| Text: \(text) |
| |
| Corrected: |
| """, |
| systemPrompt: "You are a professional proofreader and text editor.", |
| maxTokens: 256, |
| temperature: 0.1 |
| ).trimmingCharacters(in: .whitespacesAndNewlines) |
| } |
| } |
|
|