Spaces:
Running
Running
| import json | |
| import re | |
| from datetime import datetime | |
| from pathlib import Path | |
| import numpy as np | |
| class KnowledgeBase: | |
| def __init__(self, data_dir, embedding_model_id, cache_dir): | |
| self.knowledge_dir = Path(data_dir) / "knowledge" | |
| self.knowledge_dir.mkdir(parents=True, exist_ok=True) | |
| self.index_path = self.knowledge_dir / "index.json" | |
| self.vectors_path = self.knowledge_dir / "vectors.npz" | |
| self.embedding_model_id = embedding_model_id | |
| self.cache_dir = cache_dir | |
| self._model = None | |
| self._index = None | |
| self._vectors = None | |
| def _load_index(self): | |
| if self._index is not None: | |
| return self._index | |
| if self.index_path.exists(): | |
| try: | |
| self._index = json.loads(self.index_path.read_text(encoding="utf-8")) | |
| except (json.JSONDecodeError, OSError): | |
| self._index = {"chunks": []} | |
| else: | |
| self._index = {"chunks": []} | |
| return self._index | |
| def _load_vectors(self): | |
| if self._vectors is not None: | |
| return self._vectors | |
| if self.vectors_path.exists(): | |
| try: | |
| data = np.load(str(self.vectors_path)) | |
| self._vectors = data["vectors"] | |
| except Exception: | |
| self._vectors = np.empty((0, 384), dtype=np.float32) | |
| else: | |
| self._vectors = np.empty((0, 384), dtype=np.float32) | |
| return self._vectors | |
| def _save(self): | |
| index = self._load_index() | |
| vectors = self._load_vectors() | |
| self.index_path.write_text( | |
| json.dumps(index, ensure_ascii=False, indent=2), | |
| encoding="utf-8", | |
| ) | |
| np.savez_compressed(str(self.vectors_path), vectors=vectors) | |
| def _load_embedding_model(self): | |
| if self._model is not None: | |
| return self._model | |
| from sentence_transformers import SentenceTransformer | |
| self._model = SentenceTransformer( | |
| self.embedding_model_id, | |
| cache_folder=str(self.cache_dir), | |
| ) | |
| return self._model | |
| def _embed(self, texts): | |
| model = self._load_embedding_model() | |
| embeddings = model.encode( | |
| texts, show_progress_bar=False, normalize_embeddings=True, | |
| ) | |
| return np.array(embeddings, dtype=np.float32) | |
| def chunk_text(self, text, max_chars=500, overlap=50): | |
| sentences = re.split(r"(?<=[.!?])\s+", text.strip()) | |
| chunks = [] | |
| current = "" | |
| for sentence in sentences: | |
| sentence = sentence.strip() | |
| if not sentence: | |
| continue | |
| if not current: | |
| current = sentence | |
| continue | |
| combined = current + " " + sentence | |
| if len(combined) <= max_chars: | |
| current = combined | |
| else: | |
| if current: | |
| chunks.append(current) | |
| # carry over a small overlap for context continuity | |
| if overlap > 0 and len(current) > overlap: | |
| overlap_text = current[-overlap:].lstrip() | |
| current = overlap_text + " " + sentence | |
| else: | |
| current = sentence | |
| if current: | |
| chunks.append(current) | |
| # fallback for text without sentence-ending punctuation | |
| if not chunks and text.strip(): | |
| raw = text.strip() | |
| while len(raw) > max_chars: | |
| split_at = raw[:max_chars].rfind(" ") | |
| if split_at < 100: | |
| split_at = max_chars | |
| chunks.append(raw[:split_at].strip()) | |
| raw = raw[split_at:].strip() | |
| if raw: | |
| chunks.append(raw) | |
| return [c for c in chunks if len(c.strip()) >= 20] | |
| def add_document(self, text, source_name): | |
| chunks = self.chunk_text(text) | |
| if not chunks: | |
| return 0 | |
| print(f" Embedding {len(chunks)} chunks... ", end="", flush=True) | |
| embeddings = self._embed(chunks) | |
| print("done.") | |
| index = self._load_index() | |
| vectors = self._load_vectors() | |
| now = datetime.now().isoformat(timespec="seconds") | |
| for chunk_text in chunks: | |
| index["chunks"].append({ | |
| "text": chunk_text, | |
| "source": source_name, | |
| "added_at": now, | |
| }) | |
| if vectors.size == 0: | |
| self._vectors = embeddings | |
| else: | |
| self._vectors = np.vstack([vectors, embeddings]) | |
| self._save() | |
| return len(chunks) | |
| def search(self, query, top_k=3, threshold=0.3): | |
| index = self._load_index() | |
| vectors = self._load_vectors() | |
| if not index["chunks"] or vectors.size == 0: | |
| return [] | |
| query_vec = self._embed([query])[0] | |
| # cosine similarity (vectors are already L2-normalized) | |
| similarities = vectors @ query_vec | |
| top_indices = np.argsort(similarities)[::-1][:top_k] | |
| results = [] | |
| for i in top_indices: | |
| score = float(similarities[i]) | |
| if score < threshold: | |
| continue | |
| chunk = index["chunks"][i] | |
| results.append({ | |
| "text": chunk["text"], | |
| "source": chunk["source"], | |
| "score": score, | |
| }) | |
| return results | |
| def list_sources(self): | |
| index = self._load_index() | |
| sources = {} | |
| for chunk in index["chunks"]: | |
| name = chunk["source"] | |
| if name not in sources: | |
| sources[name] = {"count": 0, "added_at": chunk.get("added_at", "")} | |
| sources[name]["count"] += 1 | |
| return sources | |
| def format_source_list(self): | |
| sources = self.list_sources() | |
| if not sources: | |
| return "No study materials in knowledge base yet. Use /learn to add some." | |
| lines = ["", " Knowledge Base", " " + "-" * 50] | |
| for idx, (name, info) in enumerate(sources.items(), start=1): | |
| added = info["added_at"][:16].replace("T", " ") if info["added_at"] else "" | |
| lines.append(f" {idx}. {name}") | |
| lines.append(f" {info['count']} chunks | Added: {added}") | |
| lines.append("") | |
| return "\n".join(lines) | |
| def remove_source(self, index_number): | |
| sources = self.list_sources() | |
| source_names = list(sources.keys()) | |
| if index_number < 1 or index_number > len(source_names): | |
| return None | |
| target = source_names[index_number - 1] | |
| idx = self._load_index() | |
| vectors = self._load_vectors() | |
| keep = [i for i, c in enumerate(idx["chunks"]) if c["source"] != target] | |
| idx["chunks"] = [idx["chunks"][i] for i in keep] | |
| if keep and vectors.size > 0: | |
| self._vectors = vectors[keep] | |
| else: | |
| self._vectors = np.empty((0, 384), dtype=np.float32) | |
| self._save() | |
| return target | |
| def is_empty(self): | |
| index = self._load_index() | |
| return len(index["chunks"]) == 0 | |