# ============================================================ # rag.py — Lightweight RAG using TF-IDF (scikit-learn only) # NO extra dependencies needed — scikit-learn already in requirements.txt # Drop this file next to app.py on HuggingFace Space # ============================================================ import os import glob import pickle import numpy as np from sklearn.feature_extraction.text import TfidfVectorizer from sklearn.metrics.pairwise import cosine_similarity # HuggingFace Spaces: repo root is read-only, /tmp is writable _DEFAULT_INDEX_PATH = "/tmp/kb_index.pkl" # ───────────────────────────────────────────── # Build / Load Knowledge Base Index # ───────────────────────────────────────────── def _load_chunks(kb_path: str = "knowledge_base", chunk_size: int = 300) -> list[dict]: """Read all .txt files in knowledge_base/ and split into overlapping chunks.""" chunks = [] txt_files = glob.glob(os.path.join(kb_path, "**/*.txt"), recursive=True) txt_files += glob.glob(os.path.join(kb_path, "*.txt")) txt_files = list(set(txt_files)) for fpath in txt_files: try: with open(fpath, "r", encoding="utf-8") as f: text = f.read() fname = os.path.basename(fpath) # Split into sentences then group into chunks lines = [l.strip() for l in text.split("\n") if l.strip()] current, current_len = [], 0 for line in lines: current.append(line) current_len += len(line) if current_len >= chunk_size: chunks.append({"text": " ".join(current), "source": fname}) # overlap: keep last 2 lines current = current[-2:] current_len = sum(len(l) for l in current) if current: chunks.append({"text": " ".join(current), "source": fname}) except Exception: pass return chunks def build_index(kb_path: str = "knowledge_base", index_path: str = _DEFAULT_INDEX_PATH) -> dict: """Build TF-IDF index from knowledge base and save to disk.""" chunks = _load_chunks(kb_path) if not chunks: return {} texts = [c["text"] for c in chunks] vectorizer = TfidfVectorizer( ngram_range=(1, 2), max_features=8000, sublinear_tf=True, strip_accents="unicode", ) matrix = vectorizer.fit_transform(texts) index = { "chunks": chunks, "texts": texts, "vectorizer": vectorizer, "matrix": matrix, } # Try to save — /tmp is writable on HuggingFace Spaces try: with open(index_path, "wb") as f: pickle.dump(index, f) except Exception as e: # Fallback: try the local directory (works in local dev) fallback = "kb_index.pkl" try: with open(fallback, "wb") as f: pickle.dump(index, f) except Exception: pass # In-memory only — still fully functional return index def load_index(index_path: str = _DEFAULT_INDEX_PATH, kb_path: str = "knowledge_base") -> dict: """ Load existing index or build a fresh one if not found. Checks /tmp first (HF Spaces), then local dir (local dev). """ # Try /tmp path first for path in [index_path, "kb_index.pkl"]: if os.path.exists(path): try: with open(path, "rb") as f: return pickle.load(f) except Exception: pass # Corrupted — rebuild below # Auto-build if index missing or corrupted if os.path.isdir(kb_path): return build_index(kb_path, index_path) return {} # ───────────────────────────────────────────── # Retrieval # ───────────────────────────────────────────── def retrieve(query: str, index: dict, k: int = 3, min_score: float = 0.05) -> str: """ Return the top-k most relevant knowledge base chunks for a query. Returns a formatted string ready to inject into an LLM prompt. """ if not index or not query.strip(): return "" try: vectorizer = index["vectorizer"] matrix = index["matrix"] chunks = index["chunks"] q_vec = vectorizer.transform([query]) scores = cosine_similarity(q_vec, matrix).flatten() top_idx = np.argsort(scores)[::-1][:k] results = [] seen = set() for i in top_idx: if scores[i] < min_score: continue text = chunks[i]["text"] src = chunks[i]["source"].replace(".txt", "") if text not in seen: results.append(f"[{src}] {text}") seen.add(text) return "\n\n".join(results) if results else "" except Exception: return "" # ───────────────────────────────────────────── # Helpers # ───────────────────────────────────────────── def kb_status(index: dict) -> str: """Return a short human-readable status string.""" if not index: return "❌ Knowledge base not loaded" n_chunks = len(index.get("chunks", [])) sources = {c["source"] for c in index.get("chunks", [])} return f"✅ KB: {len(sources)} files · {n_chunks} chunks"