File size: 2,938 Bytes
06ce7ac
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
"""Knowledge vector store using sentence-transformers + FAISS."""
import os
import sqlite3
import logging
from typing import List
from api.deps import load_config, get_logger

logger = get_logger("kapo.memory.knowledge")


class KnowledgeVectorStore:
    def __init__(self):
        cfg = load_config()
        self.index_path = cfg.get("FAISS_INDEX_PATH") or "./faiss.index"
        self.meta_db = self.index_path + ".meta.db"
        self.embed_model = cfg.get("EMBED_MODEL") or "sentence-transformers/all-MiniLM-L6-v2"
        self._init_meta()

    def _init_meta(self):
        conn = sqlite3.connect(self.meta_db)
        cur = conn.cursor()
        cur.execute(
            """

            CREATE TABLE IF NOT EXISTS vectors (

                id INTEGER PRIMARY KEY AUTOINCREMENT,

                source TEXT,

                content TEXT

            )

            """
        )
        conn.commit()
        conn.close()

    def _load_embedder(self):
        from sentence_transformers import SentenceTransformer
        return SentenceTransformer(self.embed_model)

    def _load_index(self, dim: int):
        import faiss
        if os.path.exists(self.index_path):
            return faiss.read_index(self.index_path)
        return faiss.IndexFlatL2(dim)

    def add_texts(self, texts: List[str], source: str = "unknown"):
        try:
            embedder = self._load_embedder()
            embeddings = embedder.encode(texts, show_progress_bar=False)
            dim = len(embeddings[0])
            index = self._load_index(dim)
            index.add(embeddings)
            import faiss
            faiss.write_index(index, self.index_path)

            conn = sqlite3.connect(self.meta_db)
            cur = conn.cursor()
            for t in texts:
                cur.execute("INSERT INTO vectors(source, content) VALUES(?,?)", (source, t))
            conn.commit()
            conn.close()
        except Exception:
            logger.exception("Failed to add texts")

    def query(self, q: str, top_k: int = 3):
        try:
            embedder = self._load_embedder()
            qv = embedder.encode([q])
            import faiss
            if not os.path.exists(self.index_path):
                return []
            index = faiss.read_index(self.index_path)
            scores, ids = index.search(qv, top_k)
            conn = sqlite3.connect(self.meta_db)
            cur = conn.cursor()
            results = []
            for idx in ids[0]:
                cur.execute("SELECT source, content FROM vectors WHERE id=?", (int(idx) + 1,))
                row = cur.fetchone()
                if row:
                    results.append({"source": row[0], "content": row[1]})
            conn.close()
            return results
        except Exception:
            logger.exception("Query failed")
            return []