Spaces:
Build error
Build error
| import sqlite3 | |
| import sqlite_vec | |
| import struct | |
| from typing import Optional | |
| import os | |
| # Database path - same as main SQLite database | |
| DB_PATH = os.path.join(os.path.dirname(__file__), "..", "project_memory.db") | |
| def _serialize_vector(vec: list[float]) -> bytes: | |
| """Convert list of floats to bytes for sqlite-vec.""" | |
| return struct.pack(f'{len(vec)}f', *vec) | |
| def _get_connection(): | |
| """Get SQLite connection with sqlite-vec loaded.""" | |
| conn = sqlite3.connect(DB_PATH) | |
| conn.enable_load_extension(True) | |
| sqlite_vec.load(conn) | |
| conn.enable_load_extension(False) | |
| return conn | |
| def init_vectorstore(): | |
| """Initialize the vector table. Call once at startup.""" | |
| conn = _get_connection() | |
| # Metadata table for embeddings | |
| conn.execute(""" | |
| CREATE TABLE IF NOT EXISTS embeddings ( | |
| id TEXT PRIMARY KEY, | |
| project_id TEXT NOT NULL, | |
| user_id TEXT, | |
| task_id TEXT, | |
| text TEXT, | |
| created_at TEXT | |
| ) | |
| """) | |
| # Create index for faster project filtering | |
| conn.execute(""" | |
| CREATE INDEX IF NOT EXISTS idx_embeddings_project | |
| ON embeddings(project_id) | |
| """) | |
| # Create virtual table for vector search (768 dims for Gemini) | |
| conn.execute(""" | |
| CREATE VIRTUAL TABLE IF NOT EXISTS vec_embeddings USING vec0( | |
| id TEXT PRIMARY KEY, | |
| embedding FLOAT[768] | |
| ) | |
| """) | |
| conn.commit() | |
| conn.close() | |
| def add_embedding( | |
| log_entry_id: str, | |
| text: str, | |
| embedding: list[float], | |
| metadata: dict | |
| ) -> None: | |
| """Store embedding with metadata.""" | |
| conn = _get_connection() | |
| # Store metadata | |
| conn.execute(""" | |
| INSERT OR REPLACE INTO embeddings (id, project_id, user_id, task_id, text, created_at) | |
| VALUES (?, ?, ?, ?, ?, ?) | |
| """, ( | |
| log_entry_id, | |
| metadata.get("project_id"), | |
| metadata.get("user_id"), | |
| metadata.get("task_id"), | |
| text[:2000], # Truncate long text | |
| metadata.get("created_at") | |
| )) | |
| # Store vector | |
| conn.execute(""" | |
| INSERT OR REPLACE INTO vec_embeddings (id, embedding) | |
| VALUES (?, ?) | |
| """, (log_entry_id, _serialize_vector(embedding))) | |
| conn.commit() | |
| conn.close() | |
| def search( | |
| query_embedding: list[float], | |
| project_id: str, | |
| n_results: int = 10, | |
| filters: Optional[dict] = None | |
| ) -> list[dict]: | |
| """Search for similar documents within a project.""" | |
| conn = _get_connection() | |
| # Vector similarity search with metadata filter | |
| # sqlite-vec uses k parameter in the MATCH clause | |
| query = """ | |
| SELECT | |
| e.id, | |
| e.project_id, | |
| e.user_id, | |
| e.task_id, | |
| e.text, | |
| e.created_at, | |
| v.distance | |
| FROM vec_embeddings v | |
| JOIN embeddings e ON v.id = e.id | |
| WHERE v.embedding MATCH ? | |
| AND k = ? | |
| AND e.project_id = ? | |
| """ | |
| params = [_serialize_vector(query_embedding), n_results * 2, project_id] | |
| if filters: | |
| if filters.get("user_id"): | |
| query += " AND e.user_id = ?" | |
| params.append(filters["user_id"]) | |
| # Date filters for time-based queries | |
| if filters.get("date_from"): | |
| date_from = filters["date_from"] | |
| if hasattr(date_from, 'isoformat'): | |
| date_from = date_from.isoformat() | |
| query += " AND e.created_at >= ?" | |
| params.append(date_from) | |
| if filters.get("date_to"): | |
| date_to = filters["date_to"] | |
| if hasattr(date_to, 'isoformat'): | |
| date_to = date_to.isoformat() | |
| query += " AND e.created_at < ?" | |
| params.append(date_to) | |
| query += " ORDER BY v.distance LIMIT ?" | |
| params.append(n_results) | |
| results = conn.execute(query, params).fetchall() | |
| conn.close() | |
| return [ | |
| { | |
| "id": row[0], | |
| "metadata": { | |
| "project_id": row[1], | |
| "user_id": row[2], | |
| "task_id": row[3], | |
| "text": row[4], | |
| "created_at": row[5] | |
| }, | |
| "distance": row[6] | |
| } | |
| for row in results | |
| ] | |
| def delete_by_project(project_id: str) -> None: | |
| """Delete all vectors for a project.""" | |
| conn = _get_connection() | |
| # Get IDs to delete | |
| ids = conn.execute( | |
| "SELECT id FROM embeddings WHERE project_id = ?", | |
| (project_id,) | |
| ).fetchall() | |
| for (id_,) in ids: | |
| conn.execute("DELETE FROM vec_embeddings WHERE id = ?", (id_,)) | |
| conn.execute("DELETE FROM embeddings WHERE project_id = ?", (project_id,)) | |
| conn.commit() | |
| conn.close() | |
| def count_embeddings(project_id: Optional[str] = None) -> int: | |
| """Count embeddings, optionally filtered by project.""" | |
| conn = _get_connection() | |
| if project_id: | |
| result = conn.execute( | |
| "SELECT COUNT(*) FROM embeddings WHERE project_id = ?", | |
| (project_id,) | |
| ).fetchone() | |
| else: | |
| result = conn.execute("SELECT COUNT(*) FROM embeddings").fetchone() | |
| conn.close() | |
| return result[0] | |