File size: 5,824 Bytes
99ed8ef 767744c 99ed8ef 767744c 99ed8ef 767744c 99ed8ef 767744c 99ed8ef 767744c 99ed8ef 767744c | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 | # ============================================================
# rag.py β Lightweight RAG using TF-IDF (scikit-learn only)
# NO extra dependencies needed β scikit-learn already in requirements.txt
# Drop this file next to app.py on HuggingFace Space
# ============================================================
import os
import glob
import pickle
import numpy as np
from sklearn.feature_extraction.text import TfidfVectorizer
from sklearn.metrics.pairwise import cosine_similarity
# HuggingFace Spaces: repo root is read-only, /tmp is writable
_DEFAULT_INDEX_PATH = "/tmp/kb_index.pkl"
# βββββββββββββββββββββββββββββββββββββββββββββ
# Build / Load Knowledge Base Index
# βββββββββββββββββββββββββββββββββββββββββββββ
def _load_chunks(kb_path: str = "knowledge_base", chunk_size: int = 300) -> list[dict]:
"""Read all .txt files in knowledge_base/ and split into overlapping chunks."""
chunks = []
txt_files = glob.glob(os.path.join(kb_path, "**/*.txt"), recursive=True)
txt_files += glob.glob(os.path.join(kb_path, "*.txt"))
txt_files = list(set(txt_files))
for fpath in txt_files:
try:
with open(fpath, "r", encoding="utf-8") as f:
text = f.read()
fname = os.path.basename(fpath)
# Split into sentences then group into chunks
lines = [l.strip() for l in text.split("\n") if l.strip()]
current, current_len = [], 0
for line in lines:
current.append(line)
current_len += len(line)
if current_len >= chunk_size:
chunks.append({"text": " ".join(current), "source": fname})
# overlap: keep last 2 lines
current = current[-2:]
current_len = sum(len(l) for l in current)
if current:
chunks.append({"text": " ".join(current), "source": fname})
except Exception:
pass
return chunks
def build_index(kb_path: str = "knowledge_base", index_path: str = _DEFAULT_INDEX_PATH) -> dict:
"""Build TF-IDF index from knowledge base and save to disk."""
chunks = _load_chunks(kb_path)
if not chunks:
return {}
texts = [c["text"] for c in chunks]
vectorizer = TfidfVectorizer(
ngram_range=(1, 2),
max_features=8000,
sublinear_tf=True,
strip_accents="unicode",
)
matrix = vectorizer.fit_transform(texts)
index = {
"chunks": chunks,
"texts": texts,
"vectorizer": vectorizer,
"matrix": matrix,
}
# Try to save β /tmp is writable on HuggingFace Spaces
try:
with open(index_path, "wb") as f:
pickle.dump(index, f)
except Exception as e:
# Fallback: try the local directory (works in local dev)
fallback = "kb_index.pkl"
try:
with open(fallback, "wb") as f:
pickle.dump(index, f)
except Exception:
pass # In-memory only β still fully functional
return index
def load_index(index_path: str = _DEFAULT_INDEX_PATH, kb_path: str = "knowledge_base") -> dict:
"""
Load existing index or build a fresh one if not found.
Checks /tmp first (HF Spaces), then local dir (local dev).
"""
# Try /tmp path first
for path in [index_path, "kb_index.pkl"]:
if os.path.exists(path):
try:
with open(path, "rb") as f:
return pickle.load(f)
except Exception:
pass # Corrupted β rebuild below
# Auto-build if index missing or corrupted
if os.path.isdir(kb_path):
return build_index(kb_path, index_path)
return {}
# βββββββββββββββββββββββββββββββββββββββββββββ
# Retrieval
# βββββββββββββββββββββββββββββββββββββββββββββ
def retrieve(query: str, index: dict, k: int = 3, min_score: float = 0.05) -> str:
"""
Return the top-k most relevant knowledge base chunks for a query.
Returns a formatted string ready to inject into an LLM prompt.
"""
if not index or not query.strip():
return ""
try:
vectorizer = index["vectorizer"]
matrix = index["matrix"]
chunks = index["chunks"]
q_vec = vectorizer.transform([query])
scores = cosine_similarity(q_vec, matrix).flatten()
top_idx = np.argsort(scores)[::-1][:k]
results = []
seen = set()
for i in top_idx:
if scores[i] < min_score:
continue
text = chunks[i]["text"]
src = chunks[i]["source"].replace(".txt", "")
if text not in seen:
results.append(f"[{src}] {text}")
seen.add(text)
return "\n\n".join(results) if results else ""
except Exception:
return ""
# βββββββββββββββββββββββββββββββββββββββββββββ
# Helpers
# βββββββββββββββββββββββββββββββββββββββββββββ
def kb_status(index: dict) -> str:
"""Return a short human-readable status string."""
if not index:
return "β Knowledge base not loaded"
n_chunks = len(index.get("chunks", []))
sources = {c["source"] for c in index.get("chunks", [])}
return f"β
KB: {len(sources)} files Β· {n_chunks} chunks" |