eduai / core /knowledge_base.py
Shreesha-2011's picture
Deploy EduAI backend - LLM + TTS + RAG + 15 maths chapters
1d9dc1f
Raw
History Blame Contribute Delete
7.13 kB
import json
import re
from datetime import datetime
from pathlib import Path
import numpy as np
class KnowledgeBase:
def __init__(self, data_dir, embedding_model_id, cache_dir):
self.knowledge_dir = Path(data_dir) / "knowledge"
self.knowledge_dir.mkdir(parents=True, exist_ok=True)
self.index_path = self.knowledge_dir / "index.json"
self.vectors_path = self.knowledge_dir / "vectors.npz"
self.embedding_model_id = embedding_model_id
self.cache_dir = cache_dir
self._model = None
self._index = None
self._vectors = None
def _load_index(self):
if self._index is not None:
return self._index
if self.index_path.exists():
try:
self._index = json.loads(self.index_path.read_text(encoding="utf-8"))
except (json.JSONDecodeError, OSError):
self._index = {"chunks": []}
else:
self._index = {"chunks": []}
return self._index
def _load_vectors(self):
if self._vectors is not None:
return self._vectors
if self.vectors_path.exists():
try:
data = np.load(str(self.vectors_path))
self._vectors = data["vectors"]
except Exception:
self._vectors = np.empty((0, 384), dtype=np.float32)
else:
self._vectors = np.empty((0, 384), dtype=np.float32)
return self._vectors
def _save(self):
index = self._load_index()
vectors = self._load_vectors()
self.index_path.write_text(
json.dumps(index, ensure_ascii=False, indent=2),
encoding="utf-8",
)
np.savez_compressed(str(self.vectors_path), vectors=vectors)
def _load_embedding_model(self):
if self._model is not None:
return self._model
from sentence_transformers import SentenceTransformer
self._model = SentenceTransformer(
self.embedding_model_id,
cache_folder=str(self.cache_dir),
)
return self._model
def _embed(self, texts):
model = self._load_embedding_model()
embeddings = model.encode(
texts, show_progress_bar=False, normalize_embeddings=True,
)
return np.array(embeddings, dtype=np.float32)
def chunk_text(self, text, max_chars=500, overlap=50):
sentences = re.split(r"(?<=[.!?])\s+", text.strip())
chunks = []
current = ""
for sentence in sentences:
sentence = sentence.strip()
if not sentence:
continue
if not current:
current = sentence
continue
combined = current + " " + sentence
if len(combined) <= max_chars:
current = combined
else:
if current:
chunks.append(current)
# carry over a small overlap for context continuity
if overlap > 0 and len(current) > overlap:
overlap_text = current[-overlap:].lstrip()
current = overlap_text + " " + sentence
else:
current = sentence
if current:
chunks.append(current)
# fallback for text without sentence-ending punctuation
if not chunks and text.strip():
raw = text.strip()
while len(raw) > max_chars:
split_at = raw[:max_chars].rfind(" ")
if split_at < 100:
split_at = max_chars
chunks.append(raw[:split_at].strip())
raw = raw[split_at:].strip()
if raw:
chunks.append(raw)
return [c for c in chunks if len(c.strip()) >= 20]
def add_document(self, text, source_name):
chunks = self.chunk_text(text)
if not chunks:
return 0
print(f" Embedding {len(chunks)} chunks... ", end="", flush=True)
embeddings = self._embed(chunks)
print("done.")
index = self._load_index()
vectors = self._load_vectors()
now = datetime.now().isoformat(timespec="seconds")
for chunk_text in chunks:
index["chunks"].append({
"text": chunk_text,
"source": source_name,
"added_at": now,
})
if vectors.size == 0:
self._vectors = embeddings
else:
self._vectors = np.vstack([vectors, embeddings])
self._save()
return len(chunks)
def search(self, query, top_k=3, threshold=0.3):
index = self._load_index()
vectors = self._load_vectors()
if not index["chunks"] or vectors.size == 0:
return []
query_vec = self._embed([query])[0]
# cosine similarity (vectors are already L2-normalized)
similarities = vectors @ query_vec
top_indices = np.argsort(similarities)[::-1][:top_k]
results = []
for i in top_indices:
score = float(similarities[i])
if score < threshold:
continue
chunk = index["chunks"][i]
results.append({
"text": chunk["text"],
"source": chunk["source"],
"score": score,
})
return results
def list_sources(self):
index = self._load_index()
sources = {}
for chunk in index["chunks"]:
name = chunk["source"]
if name not in sources:
sources[name] = {"count": 0, "added_at": chunk.get("added_at", "")}
sources[name]["count"] += 1
return sources
def format_source_list(self):
sources = self.list_sources()
if not sources:
return "No study materials in knowledge base yet. Use /learn to add some."
lines = ["", " Knowledge Base", " " + "-" * 50]
for idx, (name, info) in enumerate(sources.items(), start=1):
added = info["added_at"][:16].replace("T", " ") if info["added_at"] else ""
lines.append(f" {idx}. {name}")
lines.append(f" {info['count']} chunks | Added: {added}")
lines.append("")
return "\n".join(lines)
def remove_source(self, index_number):
sources = self.list_sources()
source_names = list(sources.keys())
if index_number < 1 or index_number > len(source_names):
return None
target = source_names[index_number - 1]
idx = self._load_index()
vectors = self._load_vectors()
keep = [i for i, c in enumerate(idx["chunks"]) if c["source"] != target]
idx["chunks"] = [idx["chunks"][i] for i in keep]
if keep and vectors.size > 0:
self._vectors = vectors[keep]
else:
self._vectors = np.empty((0, 384), dtype=np.float32)
self._save()
return target
def is_empty(self):
index = self._load_index()
return len(index["chunks"]) == 0