testt / rag.py
1na37's picture
Update rag.py
767744c verified
# ============================================================
# rag.py β€” Lightweight RAG using TF-IDF (scikit-learn only)
# NO extra dependencies needed β€” scikit-learn already in requirements.txt
# Drop this file next to app.py on HuggingFace Space
# ============================================================
import os
import glob
import pickle
import numpy as np
from sklearn.feature_extraction.text import TfidfVectorizer
from sklearn.metrics.pairwise import cosine_similarity
# HuggingFace Spaces: repo root is read-only, /tmp is writable
_DEFAULT_INDEX_PATH = "/tmp/kb_index.pkl"
# ─────────────────────────────────────────────
# Build / Load Knowledge Base Index
# ─────────────────────────────────────────────
def _load_chunks(kb_path: str = "knowledge_base", chunk_size: int = 300) -> list[dict]:
"""Read all .txt files in knowledge_base/ and split into overlapping chunks."""
chunks = []
txt_files = glob.glob(os.path.join(kb_path, "**/*.txt"), recursive=True)
txt_files += glob.glob(os.path.join(kb_path, "*.txt"))
txt_files = list(set(txt_files))
for fpath in txt_files:
try:
with open(fpath, "r", encoding="utf-8") as f:
text = f.read()
fname = os.path.basename(fpath)
# Split into sentences then group into chunks
lines = [l.strip() for l in text.split("\n") if l.strip()]
current, current_len = [], 0
for line in lines:
current.append(line)
current_len += len(line)
if current_len >= chunk_size:
chunks.append({"text": " ".join(current), "source": fname})
# overlap: keep last 2 lines
current = current[-2:]
current_len = sum(len(l) for l in current)
if current:
chunks.append({"text": " ".join(current), "source": fname})
except Exception:
pass
return chunks
def build_index(kb_path: str = "knowledge_base", index_path: str = _DEFAULT_INDEX_PATH) -> dict:
"""Build TF-IDF index from knowledge base and save to disk."""
chunks = _load_chunks(kb_path)
if not chunks:
return {}
texts = [c["text"] for c in chunks]
vectorizer = TfidfVectorizer(
ngram_range=(1, 2),
max_features=8000,
sublinear_tf=True,
strip_accents="unicode",
)
matrix = vectorizer.fit_transform(texts)
index = {
"chunks": chunks,
"texts": texts,
"vectorizer": vectorizer,
"matrix": matrix,
}
# Try to save β€” /tmp is writable on HuggingFace Spaces
try:
with open(index_path, "wb") as f:
pickle.dump(index, f)
except Exception as e:
# Fallback: try the local directory (works in local dev)
fallback = "kb_index.pkl"
try:
with open(fallback, "wb") as f:
pickle.dump(index, f)
except Exception:
pass # In-memory only β€” still fully functional
return index
def load_index(index_path: str = _DEFAULT_INDEX_PATH, kb_path: str = "knowledge_base") -> dict:
"""
Load existing index or build a fresh one if not found.
Checks /tmp first (HF Spaces), then local dir (local dev).
"""
# Try /tmp path first
for path in [index_path, "kb_index.pkl"]:
if os.path.exists(path):
try:
with open(path, "rb") as f:
return pickle.load(f)
except Exception:
pass # Corrupted β€” rebuild below
# Auto-build if index missing or corrupted
if os.path.isdir(kb_path):
return build_index(kb_path, index_path)
return {}
# ─────────────────────────────────────────────
# Retrieval
# ─────────────────────────────────────────────
def retrieve(query: str, index: dict, k: int = 3, min_score: float = 0.05) -> str:
"""
Return the top-k most relevant knowledge base chunks for a query.
Returns a formatted string ready to inject into an LLM prompt.
"""
if not index or not query.strip():
return ""
try:
vectorizer = index["vectorizer"]
matrix = index["matrix"]
chunks = index["chunks"]
q_vec = vectorizer.transform([query])
scores = cosine_similarity(q_vec, matrix).flatten()
top_idx = np.argsort(scores)[::-1][:k]
results = []
seen = set()
for i in top_idx:
if scores[i] < min_score:
continue
text = chunks[i]["text"]
src = chunks[i]["source"].replace(".txt", "")
if text not in seen:
results.append(f"[{src}] {text}")
seen.add(text)
return "\n\n".join(results) if results else ""
except Exception:
return ""
# ─────────────────────────────────────────────
# Helpers
# ─────────────────────────────────────────────
def kb_status(index: dict) -> str:
"""Return a short human-readable status string."""
if not index:
return "❌ Knowledge base not loaded"
n_chunks = len(index.get("chunks", []))
sources = {c["source"] for c in index.get("chunks", [])}
return f"βœ… KB: {len(sources)} files Β· {n_chunks} chunks"