import json import re from datetime import datetime from pathlib import Path import numpy as np class KnowledgeBase: def __init__(self, data_dir, embedding_model_id, cache_dir): self.knowledge_dir = Path(data_dir) / "knowledge" self.knowledge_dir.mkdir(parents=True, exist_ok=True) self.index_path = self.knowledge_dir / "index.json" self.vectors_path = self.knowledge_dir / "vectors.npz" self.embedding_model_id = embedding_model_id self.cache_dir = cache_dir self._model = None self._index = None self._vectors = None def _load_index(self): if self._index is not None: return self._index if self.index_path.exists(): try: self._index = json.loads(self.index_path.read_text(encoding="utf-8")) except (json.JSONDecodeError, OSError): self._index = {"chunks": []} else: self._index = {"chunks": []} return self._index def _load_vectors(self): if self._vectors is not None: return self._vectors if self.vectors_path.exists(): try: data = np.load(str(self.vectors_path)) self._vectors = data["vectors"] except Exception: self._vectors = np.empty((0, 384), dtype=np.float32) else: self._vectors = np.empty((0, 384), dtype=np.float32) return self._vectors def _save(self): index = self._load_index() vectors = self._load_vectors() self.index_path.write_text( json.dumps(index, ensure_ascii=False, indent=2), encoding="utf-8", ) np.savez_compressed(str(self.vectors_path), vectors=vectors) def _load_embedding_model(self): if self._model is not None: return self._model from sentence_transformers import SentenceTransformer self._model = SentenceTransformer( self.embedding_model_id, cache_folder=str(self.cache_dir), ) return self._model def _embed(self, texts): model = self._load_embedding_model() embeddings = model.encode( texts, show_progress_bar=False, normalize_embeddings=True, ) return np.array(embeddings, dtype=np.float32) def chunk_text(self, text, max_chars=500, overlap=50): sentences = re.split(r"(?<=[.!?])\s+", text.strip()) chunks = [] current = "" for sentence in sentences: sentence = sentence.strip() if not sentence: continue if not current: current = sentence continue combined = current + " " + sentence if len(combined) <= max_chars: current = combined else: if current: chunks.append(current) # carry over a small overlap for context continuity if overlap > 0 and len(current) > overlap: overlap_text = current[-overlap:].lstrip() current = overlap_text + " " + sentence else: current = sentence if current: chunks.append(current) # fallback for text without sentence-ending punctuation if not chunks and text.strip(): raw = text.strip() while len(raw) > max_chars: split_at = raw[:max_chars].rfind(" ") if split_at < 100: split_at = max_chars chunks.append(raw[:split_at].strip()) raw = raw[split_at:].strip() if raw: chunks.append(raw) return [c for c in chunks if len(c.strip()) >= 20] def add_document(self, text, source_name): chunks = self.chunk_text(text) if not chunks: return 0 print(f" Embedding {len(chunks)} chunks... ", end="", flush=True) embeddings = self._embed(chunks) print("done.") index = self._load_index() vectors = self._load_vectors() now = datetime.now().isoformat(timespec="seconds") for chunk_text in chunks: index["chunks"].append({ "text": chunk_text, "source": source_name, "added_at": now, }) if vectors.size == 0: self._vectors = embeddings else: self._vectors = np.vstack([vectors, embeddings]) self._save() return len(chunks) def search(self, query, top_k=3, threshold=0.3): index = self._load_index() vectors = self._load_vectors() if not index["chunks"] or vectors.size == 0: return [] query_vec = self._embed([query])[0] # cosine similarity (vectors are already L2-normalized) similarities = vectors @ query_vec top_indices = np.argsort(similarities)[::-1][:top_k] results = [] for i in top_indices: score = float(similarities[i]) if score < threshold: continue chunk = index["chunks"][i] results.append({ "text": chunk["text"], "source": chunk["source"], "score": score, }) return results def list_sources(self): index = self._load_index() sources = {} for chunk in index["chunks"]: name = chunk["source"] if name not in sources: sources[name] = {"count": 0, "added_at": chunk.get("added_at", "")} sources[name]["count"] += 1 return sources def format_source_list(self): sources = self.list_sources() if not sources: return "No study materials in knowledge base yet. Use /learn to add some." lines = ["", " Knowledge Base", " " + "-" * 50] for idx, (name, info) in enumerate(sources.items(), start=1): added = info["added_at"][:16].replace("T", " ") if info["added_at"] else "" lines.append(f" {idx}. {name}") lines.append(f" {info['count']} chunks | Added: {added}") lines.append("") return "\n".join(lines) def remove_source(self, index_number): sources = self.list_sources() source_names = list(sources.keys()) if index_number < 1 or index_number > len(source_names): return None target = source_names[index_number - 1] idx = self._load_index() vectors = self._load_vectors() keep = [i for i, c in enumerate(idx["chunks"]) if c["source"] != target] idx["chunks"] = [idx["chunks"][i] for i in keep] if keep and vectors.size > 0: self._vectors = vectors[keep] else: self._vectors = np.empty((0, 384), dtype=np.float32) self._save() return target def is_empty(self): index = self._load_index() return len(index["chunks"]) == 0