File size: 5,145 Bytes
35765b5
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
import sqlite3
import sqlite_vec
import struct
from typing import Optional
import os

# Database path - same as main SQLite database
DB_PATH = os.path.join(os.path.dirname(__file__), "..", "project_memory.db")


def _serialize_vector(vec: list[float]) -> bytes:
    """Convert list of floats to bytes for sqlite-vec."""
    return struct.pack(f'{len(vec)}f', *vec)


def _get_connection():
    """Get SQLite connection with sqlite-vec loaded."""
    conn = sqlite3.connect(DB_PATH)
    conn.enable_load_extension(True)
    sqlite_vec.load(conn)
    conn.enable_load_extension(False)
    return conn


def init_vectorstore():
    """Initialize the vector table. Call once at startup."""
    conn = _get_connection()

    # Metadata table for embeddings
    conn.execute("""
        CREATE TABLE IF NOT EXISTS embeddings (
            id TEXT PRIMARY KEY,
            project_id TEXT NOT NULL,
            user_id TEXT,
            task_id TEXT,
            text TEXT,
            created_at TEXT
        )
    """)

    # Create index for faster project filtering
    conn.execute("""
        CREATE INDEX IF NOT EXISTS idx_embeddings_project
        ON embeddings(project_id)
    """)

    # Create virtual table for vector search (768 dims for Gemini)
    conn.execute("""
        CREATE VIRTUAL TABLE IF NOT EXISTS vec_embeddings USING vec0(
            id TEXT PRIMARY KEY,
            embedding FLOAT[768]
        )
    """)

    conn.commit()
    conn.close()


def add_embedding(
    log_entry_id: str,
    text: str,
    embedding: list[float],
    metadata: dict
) -> None:
    """Store embedding with metadata."""
    conn = _get_connection()

    # Store metadata
    conn.execute("""
        INSERT OR REPLACE INTO embeddings (id, project_id, user_id, task_id, text, created_at)
        VALUES (?, ?, ?, ?, ?, ?)
    """, (
        log_entry_id,
        metadata.get("project_id"),
        metadata.get("user_id"),
        metadata.get("task_id"),
        text[:2000],  # Truncate long text
        metadata.get("created_at")
    ))

    # Store vector
    conn.execute("""
        INSERT OR REPLACE INTO vec_embeddings (id, embedding)
        VALUES (?, ?)
    """, (log_entry_id, _serialize_vector(embedding)))

    conn.commit()
    conn.close()


def search(
    query_embedding: list[float],
    project_id: str,
    n_results: int = 10,
    filters: Optional[dict] = None
) -> list[dict]:
    """Search for similar documents within a project."""
    conn = _get_connection()

    # Vector similarity search with metadata filter
    # sqlite-vec uses k parameter in the MATCH clause
    query = """
        SELECT
            e.id,
            e.project_id,
            e.user_id,
            e.task_id,
            e.text,
            e.created_at,
            v.distance
        FROM vec_embeddings v
        JOIN embeddings e ON v.id = e.id
        WHERE v.embedding MATCH ?
            AND k = ?
            AND e.project_id = ?
    """
    params = [_serialize_vector(query_embedding), n_results * 2, project_id]

    if filters:
        if filters.get("user_id"):
            query += " AND e.user_id = ?"
            params.append(filters["user_id"])

        # Date filters for time-based queries
        if filters.get("date_from"):
            date_from = filters["date_from"]
            if hasattr(date_from, 'isoformat'):
                date_from = date_from.isoformat()
            query += " AND e.created_at >= ?"
            params.append(date_from)

        if filters.get("date_to"):
            date_to = filters["date_to"]
            if hasattr(date_to, 'isoformat'):
                date_to = date_to.isoformat()
            query += " AND e.created_at < ?"
            params.append(date_to)

    query += " ORDER BY v.distance LIMIT ?"
    params.append(n_results)

    results = conn.execute(query, params).fetchall()
    conn.close()

    return [
        {
            "id": row[0],
            "metadata": {
                "project_id": row[1],
                "user_id": row[2],
                "task_id": row[3],
                "text": row[4],
                "created_at": row[5]
            },
            "distance": row[6]
        }
        for row in results
    ]


def delete_by_project(project_id: str) -> None:
    """Delete all vectors for a project."""
    conn = _get_connection()

    # Get IDs to delete
    ids = conn.execute(
        "SELECT id FROM embeddings WHERE project_id = ?",
        (project_id,)
    ).fetchall()

    for (id_,) in ids:
        conn.execute("DELETE FROM vec_embeddings WHERE id = ?", (id_,))

    conn.execute("DELETE FROM embeddings WHERE project_id = ?", (project_id,))
    conn.commit()
    conn.close()


def count_embeddings(project_id: Optional[str] = None) -> int:
    """Count embeddings, optionally filtered by project."""
    conn = _get_connection()

    if project_id:
        result = conn.execute(
            "SELECT COUNT(*) FROM embeddings WHERE project_id = ?",
            (project_id,)
        ).fetchone()
    else:
        result = conn.execute("SELECT COUNT(*) FROM embeddings").fetchone()

    conn.close()
    return result[0]