|
|
import sqlite3 |
|
|
import sqlite_vec |
|
|
import struct |
|
|
from typing import Optional |
|
|
import os |
|
|
|
|
|
|
|
|
DB_PATH = os.path.join(os.path.dirname(__file__), "..", "project_memory.db") |
|
|
|
|
|
|
|
|
def _serialize_vector(vec: list[float]) -> bytes: |
|
|
"""Convert list of floats to bytes for sqlite-vec.""" |
|
|
return struct.pack(f'{len(vec)}f', *vec) |
|
|
|
|
|
|
|
|
def _get_connection(): |
|
|
"""Get SQLite connection with sqlite-vec loaded.""" |
|
|
conn = sqlite3.connect(DB_PATH) |
|
|
conn.enable_load_extension(True) |
|
|
sqlite_vec.load(conn) |
|
|
conn.enable_load_extension(False) |
|
|
return conn |
|
|
|
|
|
|
|
|
def init_vectorstore(): |
|
|
"""Initialize the vector table. Call once at startup.""" |
|
|
conn = _get_connection() |
|
|
|
|
|
|
|
|
conn.execute(""" |
|
|
CREATE TABLE IF NOT EXISTS embeddings ( |
|
|
id TEXT PRIMARY KEY, |
|
|
project_id TEXT NOT NULL, |
|
|
user_id TEXT, |
|
|
task_id TEXT, |
|
|
text TEXT, |
|
|
created_at TEXT |
|
|
) |
|
|
""") |
|
|
|
|
|
|
|
|
conn.execute(""" |
|
|
CREATE INDEX IF NOT EXISTS idx_embeddings_project |
|
|
ON embeddings(project_id) |
|
|
""") |
|
|
|
|
|
|
|
|
conn.execute(""" |
|
|
CREATE VIRTUAL TABLE IF NOT EXISTS vec_embeddings USING vec0( |
|
|
id TEXT PRIMARY KEY, |
|
|
embedding FLOAT[768] |
|
|
) |
|
|
""") |
|
|
|
|
|
conn.commit() |
|
|
conn.close() |
|
|
|
|
|
|
|
|
def add_embedding( |
|
|
log_entry_id: str, |
|
|
text: str, |
|
|
embedding: list[float], |
|
|
metadata: dict |
|
|
) -> None: |
|
|
"""Store embedding with metadata.""" |
|
|
conn = _get_connection() |
|
|
|
|
|
|
|
|
conn.execute(""" |
|
|
INSERT OR REPLACE INTO embeddings (id, project_id, user_id, task_id, text, created_at) |
|
|
VALUES (?, ?, ?, ?, ?, ?) |
|
|
""", ( |
|
|
log_entry_id, |
|
|
metadata.get("project_id"), |
|
|
metadata.get("user_id"), |
|
|
metadata.get("task_id"), |
|
|
text[:2000], |
|
|
metadata.get("created_at") |
|
|
)) |
|
|
|
|
|
|
|
|
conn.execute(""" |
|
|
INSERT OR REPLACE INTO vec_embeddings (id, embedding) |
|
|
VALUES (?, ?) |
|
|
""", (log_entry_id, _serialize_vector(embedding))) |
|
|
|
|
|
conn.commit() |
|
|
conn.close() |
|
|
|
|
|
|
|
|
def search( |
|
|
query_embedding: list[float], |
|
|
project_id: str, |
|
|
n_results: int = 10, |
|
|
filters: Optional[dict] = None |
|
|
) -> list[dict]: |
|
|
"""Search for similar documents within a project.""" |
|
|
conn = _get_connection() |
|
|
|
|
|
|
|
|
|
|
|
query = """ |
|
|
SELECT |
|
|
e.id, |
|
|
e.project_id, |
|
|
e.user_id, |
|
|
e.task_id, |
|
|
e.text, |
|
|
e.created_at, |
|
|
v.distance |
|
|
FROM vec_embeddings v |
|
|
JOIN embeddings e ON v.id = e.id |
|
|
WHERE v.embedding MATCH ? |
|
|
AND k = ? |
|
|
AND e.project_id = ? |
|
|
""" |
|
|
params = [_serialize_vector(query_embedding), n_results * 2, project_id] |
|
|
|
|
|
if filters: |
|
|
if filters.get("user_id"): |
|
|
query += " AND e.user_id = ?" |
|
|
params.append(filters["user_id"]) |
|
|
|
|
|
|
|
|
if filters.get("date_from"): |
|
|
date_from = filters["date_from"] |
|
|
if hasattr(date_from, 'isoformat'): |
|
|
date_from = date_from.isoformat() |
|
|
query += " AND e.created_at >= ?" |
|
|
params.append(date_from) |
|
|
|
|
|
if filters.get("date_to"): |
|
|
date_to = filters["date_to"] |
|
|
if hasattr(date_to, 'isoformat'): |
|
|
date_to = date_to.isoformat() |
|
|
query += " AND e.created_at < ?" |
|
|
params.append(date_to) |
|
|
|
|
|
query += " ORDER BY v.distance LIMIT ?" |
|
|
params.append(n_results) |
|
|
|
|
|
results = conn.execute(query, params).fetchall() |
|
|
conn.close() |
|
|
|
|
|
return [ |
|
|
{ |
|
|
"id": row[0], |
|
|
"metadata": { |
|
|
"project_id": row[1], |
|
|
"user_id": row[2], |
|
|
"task_id": row[3], |
|
|
"text": row[4], |
|
|
"created_at": row[5] |
|
|
}, |
|
|
"distance": row[6] |
|
|
} |
|
|
for row in results |
|
|
] |
|
|
|
|
|
|
|
|
def delete_by_project(project_id: str) -> None: |
|
|
"""Delete all vectors for a project.""" |
|
|
conn = _get_connection() |
|
|
|
|
|
|
|
|
ids = conn.execute( |
|
|
"SELECT id FROM embeddings WHERE project_id = ?", |
|
|
(project_id,) |
|
|
).fetchall() |
|
|
|
|
|
for (id_,) in ids: |
|
|
conn.execute("DELETE FROM vec_embeddings WHERE id = ?", (id_,)) |
|
|
|
|
|
conn.execute("DELETE FROM embeddings WHERE project_id = ?", (project_id,)) |
|
|
conn.commit() |
|
|
conn.close() |
|
|
|
|
|
|
|
|
def count_embeddings(project_id: Optional[str] = None) -> int: |
|
|
"""Count embeddings, optionally filtered by project.""" |
|
|
conn = _get_connection() |
|
|
|
|
|
if project_id: |
|
|
result = conn.execute( |
|
|
"SELECT COUNT(*) FROM embeddings WHERE project_id = ?", |
|
|
(project_id,) |
|
|
).fetchone() |
|
|
else: |
|
|
result = conn.execute("SELECT COUNT(*) FROM embeddings").fetchone() |
|
|
|
|
|
conn.close() |
|
|
return result[0] |
|
|
|