File size: 2,938 Bytes
4f96544 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 | """Knowledge vector store using sentence-transformers + FAISS."""
import os
import sqlite3
import logging
from typing import List
from api.deps import load_config, get_logger
logger = get_logger("kapo.memory.knowledge")
class KnowledgeVectorStore:
def __init__(self):
cfg = load_config()
self.index_path = cfg.get("FAISS_INDEX_PATH") or "./faiss.index"
self.meta_db = self.index_path + ".meta.db"
self.embed_model = cfg.get("EMBED_MODEL") or "sentence-transformers/all-MiniLM-L6-v2"
self._init_meta()
def _init_meta(self):
conn = sqlite3.connect(self.meta_db)
cur = conn.cursor()
cur.execute(
"""
CREATE TABLE IF NOT EXISTS vectors (
id INTEGER PRIMARY KEY AUTOINCREMENT,
source TEXT,
content TEXT
)
"""
)
conn.commit()
conn.close()
def _load_embedder(self):
from sentence_transformers import SentenceTransformer
return SentenceTransformer(self.embed_model)
def _load_index(self, dim: int):
import faiss
if os.path.exists(self.index_path):
return faiss.read_index(self.index_path)
return faiss.IndexFlatL2(dim)
def add_texts(self, texts: List[str], source: str = "unknown"):
try:
embedder = self._load_embedder()
embeddings = embedder.encode(texts, show_progress_bar=False)
dim = len(embeddings[0])
index = self._load_index(dim)
index.add(embeddings)
import faiss
faiss.write_index(index, self.index_path)
conn = sqlite3.connect(self.meta_db)
cur = conn.cursor()
for t in texts:
cur.execute("INSERT INTO vectors(source, content) VALUES(?,?)", (source, t))
conn.commit()
conn.close()
except Exception:
logger.exception("Failed to add texts")
def query(self, q: str, top_k: int = 3):
try:
embedder = self._load_embedder()
qv = embedder.encode([q])
import faiss
if not os.path.exists(self.index_path):
return []
index = faiss.read_index(self.index_path)
scores, ids = index.search(qv, top_k)
conn = sqlite3.connect(self.meta_db)
cur = conn.cursor()
results = []
for idx in ids[0]:
cur.execute("SELECT source, content FROM vectors WHERE id=?", (int(idx) + 1,))
row = cur.fetchone()
if row:
results.append({"source": row[0], "content": row[1]})
conn.close()
return results
except Exception:
logger.exception("Query failed")
return []
|