File size: 6,273 Bytes
3694da1
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
"""SQLite + numpy vector store.

Zero external dependencies (SQLite is built into Python). For our scale
(< 10K chunks total per user), a brute-force cosine-sim scan in numpy is
~5-20ms β€” there's no need for an ANN index. If we ever need it, we can
swap in FAISS or Chroma without touching the API.
"""

from __future__ import annotations

import sqlite3
import uuid
from contextlib import contextmanager
from datetime import datetime
from pathlib import Path

import numpy as np

DEFAULT_DB = Path("data/knowledge.sqlite")


class KnowledgeStore:
    def __init__(self, db_path: str | Path = DEFAULT_DB) -> None:
        self.db_path = Path(db_path)
        self.db_path.parent.mkdir(parents=True, exist_ok=True)
        self._init_schema()

    @contextmanager
    def _conn(self):
        c = sqlite3.connect(self.db_path)
        c.row_factory = sqlite3.Row
        c.execute("PRAGMA foreign_keys = ON")
        try:
            yield c
            c.commit()
        finally:
            c.close()

    def _init_schema(self) -> None:
        with self._conn() as c:
            c.executescript(
                """
                CREATE TABLE IF NOT EXISTS documents (
                    id           TEXT PRIMARY KEY,
                    name         TEXT NOT NULL,
                    format       TEXT NOT NULL,
                    size_bytes   INTEGER DEFAULT 0,
                    chunk_count  INTEGER DEFAULT 0,
                    uploaded_at  TEXT NOT NULL
                );
                CREATE TABLE IF NOT EXISTS chunks (
                    id           INTEGER PRIMARY KEY AUTOINCREMENT,
                    document_id  TEXT NOT NULL,
                    chunk_index  INTEGER NOT NULL,
                    text         TEXT NOT NULL,
                    embedding    BLOB NOT NULL,
                    FOREIGN KEY (document_id) REFERENCES documents(id) ON DELETE CASCADE
                );
                CREATE INDEX IF NOT EXISTS idx_chunks_doc ON chunks(document_id);
                """
            )

    # ── writes ──────────────────────────────────────────────────────────
    def add_document(
        self,
        name: str,
        fmt: str,
        chunks: list[str],
        embeddings: np.ndarray,
        size_bytes: int = 0,
    ) -> str:
        if len(chunks) != len(embeddings):
            raise ValueError(
                f"chunks ({len(chunks)}) and embeddings ({len(embeddings)}) must match"
            )
        doc_id = uuid.uuid4().hex[:12]
        with self._conn() as c:
            c.execute(
                "INSERT INTO documents (id, name, format, size_bytes, chunk_count, uploaded_at) "
                "VALUES (?, ?, ?, ?, ?, ?)",
                (doc_id, name, fmt, size_bytes, len(chunks), datetime.utcnow().isoformat()),
            )
            for i, (chunk, emb) in enumerate(zip(chunks, embeddings)):
                c.execute(
                    "INSERT INTO chunks (document_id, chunk_index, text, embedding) "
                    "VALUES (?, ?, ?, ?)",
                    (doc_id, i, chunk, emb.astype(np.float32).tobytes()),
                )
        return doc_id

    def delete_document(self, document_id: str) -> bool:
        with self._conn() as c:
            c.execute("DELETE FROM chunks WHERE document_id = ?", (document_id,))
            cur = c.execute("DELETE FROM documents WHERE id = ?", (document_id,))
            return cur.rowcount > 0

    def clear_all(self) -> None:
        with self._conn() as c:
            c.execute("DELETE FROM chunks")
            c.execute("DELETE FROM documents")

    # ── reads ───────────────────────────────────────────────────────────
    def list_documents(self) -> list[dict]:
        with self._conn() as c:
            rows = c.execute(
                "SELECT id, name, format, size_bytes, chunk_count, uploaded_at "
                "FROM documents ORDER BY uploaded_at DESC"
            ).fetchall()
        return [dict(r) for r in rows]

    def stats(self) -> dict:
        with self._conn() as c:
            doc_count = c.execute("SELECT COUNT(*) FROM documents").fetchone()[0]
            chunk_count = c.execute("SELECT COUNT(*) FROM chunks").fetchone()[0]
            total_size = c.execute(
                "SELECT COALESCE(SUM(size_bytes), 0) FROM documents"
            ).fetchone()[0]
        return {
            "document_count": doc_count,
            "chunk_count": chunk_count,
            "total_bytes": total_size,
        }

    def search(self, query_embedding: np.ndarray, top_k: int = 3) -> list[dict]:
        """Cosine-similarity search across all chunks.

        Returns the top-k most relevant chunks with full provenance metadata.
        Embeddings are assumed L2-normalised so dot product == cosine sim.
        """
        with self._conn() as c:
            rows = c.execute(
                """
                SELECT chunks.id, chunks.document_id, chunks.chunk_index, chunks.text,
                       chunks.embedding, documents.name AS doc_name,
                       documents.format AS doc_format
                FROM chunks
                JOIN documents ON chunks.document_id = documents.id
                """
            ).fetchall()
        if not rows:
            return []

        emb_matrix = np.stack(
            [np.frombuffer(r["embedding"], dtype=np.float32) for r in rows]
        )
        q = query_embedding.astype(np.float32)
        q_norm = np.linalg.norm(q)
        if q_norm > 0:
            q = q / q_norm
        scores = emb_matrix @ q
        top = np.argsort(scores)[::-1][:top_k]

        return [
            {
                "chunk_id": int(rows[i]["id"]),
                "document_id": rows[i]["document_id"],
                "document_name": rows[i]["doc_name"],
                "document_format": rows[i]["doc_format"],
                "chunk_index": int(rows[i]["chunk_index"]),
                "text": rows[i]["text"],
                "score": float(scores[i]),
            }
            for i in top
        ]