File size: 5,145 Bytes
35765b5 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 |
import sqlite3
import sqlite_vec
import struct
from typing import Optional
import os
# Database path - same as main SQLite database
DB_PATH = os.path.join(os.path.dirname(__file__), "..", "project_memory.db")
def _serialize_vector(vec: list[float]) -> bytes:
"""Convert list of floats to bytes for sqlite-vec."""
return struct.pack(f'{len(vec)}f', *vec)
def _get_connection():
"""Get SQLite connection with sqlite-vec loaded."""
conn = sqlite3.connect(DB_PATH)
conn.enable_load_extension(True)
sqlite_vec.load(conn)
conn.enable_load_extension(False)
return conn
def init_vectorstore():
"""Initialize the vector table. Call once at startup."""
conn = _get_connection()
# Metadata table for embeddings
conn.execute("""
CREATE TABLE IF NOT EXISTS embeddings (
id TEXT PRIMARY KEY,
project_id TEXT NOT NULL,
user_id TEXT,
task_id TEXT,
text TEXT,
created_at TEXT
)
""")
# Create index for faster project filtering
conn.execute("""
CREATE INDEX IF NOT EXISTS idx_embeddings_project
ON embeddings(project_id)
""")
# Create virtual table for vector search (768 dims for Gemini)
conn.execute("""
CREATE VIRTUAL TABLE IF NOT EXISTS vec_embeddings USING vec0(
id TEXT PRIMARY KEY,
embedding FLOAT[768]
)
""")
conn.commit()
conn.close()
def add_embedding(
log_entry_id: str,
text: str,
embedding: list[float],
metadata: dict
) -> None:
"""Store embedding with metadata."""
conn = _get_connection()
# Store metadata
conn.execute("""
INSERT OR REPLACE INTO embeddings (id, project_id, user_id, task_id, text, created_at)
VALUES (?, ?, ?, ?, ?, ?)
""", (
log_entry_id,
metadata.get("project_id"),
metadata.get("user_id"),
metadata.get("task_id"),
text[:2000], # Truncate long text
metadata.get("created_at")
))
# Store vector
conn.execute("""
INSERT OR REPLACE INTO vec_embeddings (id, embedding)
VALUES (?, ?)
""", (log_entry_id, _serialize_vector(embedding)))
conn.commit()
conn.close()
def search(
query_embedding: list[float],
project_id: str,
n_results: int = 10,
filters: Optional[dict] = None
) -> list[dict]:
"""Search for similar documents within a project."""
conn = _get_connection()
# Vector similarity search with metadata filter
# sqlite-vec uses k parameter in the MATCH clause
query = """
SELECT
e.id,
e.project_id,
e.user_id,
e.task_id,
e.text,
e.created_at,
v.distance
FROM vec_embeddings v
JOIN embeddings e ON v.id = e.id
WHERE v.embedding MATCH ?
AND k = ?
AND e.project_id = ?
"""
params = [_serialize_vector(query_embedding), n_results * 2, project_id]
if filters:
if filters.get("user_id"):
query += " AND e.user_id = ?"
params.append(filters["user_id"])
# Date filters for time-based queries
if filters.get("date_from"):
date_from = filters["date_from"]
if hasattr(date_from, 'isoformat'):
date_from = date_from.isoformat()
query += " AND e.created_at >= ?"
params.append(date_from)
if filters.get("date_to"):
date_to = filters["date_to"]
if hasattr(date_to, 'isoformat'):
date_to = date_to.isoformat()
query += " AND e.created_at < ?"
params.append(date_to)
query += " ORDER BY v.distance LIMIT ?"
params.append(n_results)
results = conn.execute(query, params).fetchall()
conn.close()
return [
{
"id": row[0],
"metadata": {
"project_id": row[1],
"user_id": row[2],
"task_id": row[3],
"text": row[4],
"created_at": row[5]
},
"distance": row[6]
}
for row in results
]
def delete_by_project(project_id: str) -> None:
"""Delete all vectors for a project."""
conn = _get_connection()
# Get IDs to delete
ids = conn.execute(
"SELECT id FROM embeddings WHERE project_id = ?",
(project_id,)
).fetchall()
for (id_,) in ids:
conn.execute("DELETE FROM vec_embeddings WHERE id = ?", (id_,))
conn.execute("DELETE FROM embeddings WHERE project_id = ?", (project_id,))
conn.commit()
conn.close()
def count_embeddings(project_id: Optional[str] = None) -> int:
"""Count embeddings, optionally filtered by project."""
conn = _get_connection()
if project_id:
result = conn.execute(
"SELECT COUNT(*) FROM embeddings WHERE project_id = ?",
(project_id,)
).fetchone()
else:
result = conn.execute("SELECT COUNT(*) FROM embeddings").fetchone()
conn.close()
return result[0]
|