AiCoder / brain_server /memory /knowledge_vector.py
MrA7A3's picture
Upload 35 files
4f96544 verified
"""Knowledge vector store using sentence-transformers + FAISS."""
import os
import sqlite3
import logging
from typing import List
from api.deps import load_config, get_logger
logger = get_logger("kapo.memory.knowledge")
class KnowledgeVectorStore:
def __init__(self):
cfg = load_config()
self.index_path = cfg.get("FAISS_INDEX_PATH") or "./faiss.index"
self.meta_db = self.index_path + ".meta.db"
self.embed_model = cfg.get("EMBED_MODEL") or "sentence-transformers/all-MiniLM-L6-v2"
self._init_meta()
def _init_meta(self):
conn = sqlite3.connect(self.meta_db)
cur = conn.cursor()
cur.execute(
"""
CREATE TABLE IF NOT EXISTS vectors (
id INTEGER PRIMARY KEY AUTOINCREMENT,
source TEXT,
content TEXT
)
"""
)
conn.commit()
conn.close()
def _load_embedder(self):
from sentence_transformers import SentenceTransformer
return SentenceTransformer(self.embed_model)
def _load_index(self, dim: int):
import faiss
if os.path.exists(self.index_path):
return faiss.read_index(self.index_path)
return faiss.IndexFlatL2(dim)
def add_texts(self, texts: List[str], source: str = "unknown"):
try:
embedder = self._load_embedder()
embeddings = embedder.encode(texts, show_progress_bar=False)
dim = len(embeddings[0])
index = self._load_index(dim)
index.add(embeddings)
import faiss
faiss.write_index(index, self.index_path)
conn = sqlite3.connect(self.meta_db)
cur = conn.cursor()
for t in texts:
cur.execute("INSERT INTO vectors(source, content) VALUES(?,?)", (source, t))
conn.commit()
conn.close()
except Exception:
logger.exception("Failed to add texts")
def query(self, q: str, top_k: int = 3):
try:
embedder = self._load_embedder()
qv = embedder.encode([q])
import faiss
if not os.path.exists(self.index_path):
return []
index = faiss.read_index(self.index_path)
scores, ids = index.search(qv, top_k)
conn = sqlite3.connect(self.meta_db)
cur = conn.cursor()
results = []
for idx in ids[0]:
cur.execute("SELECT source, content FROM vectors WHERE id=?", (int(idx) + 1,))
row = cur.fetchone()
if row:
results.append({"source": row[0], "content": row[1]})
conn.close()
return results
except Exception:
logger.exception("Query failed")
return []