static-retrieval-mrl-en-v1-sqlite-vec / EmbeddingsBenchmark /Tests /EmbeddingsBenchmarkTests /EmbeddingsBenchmarkTests.swift
| import CoreML | |
| import SQLiteVec | |
| import Testing | |
| @testable import EmbeddingsBenchmarkLib | |
| func createDatabase(_ data: [[Float]]) async throws -> Database { | |
| try SQLiteVec.initialize() | |
| let db = try Database(.inMemory) | |
| try await db.execute("CREATE VIRTUAL TABLE embeddings USING vec0(embedding float[3])") | |
| for (index, row) in data.enumerated() { | |
| try await db.execute( | |
| """ | |
| INSERT INTO embeddings(rowid, embedding) | |
| VALUES (?, ?) | |
| """, | |
| params: [index, row] | |
| ) | |
| } | |
| return db | |
| } | |
| func testEmbeddingMethods() async throws { | |
| let data: [[Float]] = [ | |
| [1.0, 2.0, 3.0], | |
| [4.0, 5.0, 6.0], | |
| [7.0, 8.0, 9.0] | |
| ] | |
| let embeddings = MLTensor(shape: [3, 3], scalars: data.flatMap { $0 }) | |
| let coreMLResult = await queryEmbeddings(embeddings: embeddings, tokenIds: [0, 2]) | |
| let db = try await createDatabase(data) | |
| let sqliteResult = try await queryEmbeddings( | |
| db: db, | |
| query: "(?, ?)", | |
| tokenIds: [0, 2], | |
| vectorSize: 3) | |
| #expect(coreMLResult == sqliteResult) | |
| } | |