1na37 commited on
Commit
99ed8ef
Β·
verified Β·
1 Parent(s): 648e946

Upload rag.py

Browse files
Files changed (1) hide show
  1. rag.py +145 -0
rag.py ADDED
@@ -0,0 +1,145 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # ============================================================
2
+ # rag.py β€” Lightweight RAG using TF-IDF (scikit-learn only)
3
+ # NO extra dependencies needed β€” scikit-learn already in requirements.txt
4
+ # Drop this file next to app.py on HuggingFace Space
5
+ # ============================================================
6
+
7
+ import os
8
+ import glob
9
+ import pickle
10
+ import numpy as np
11
+ from sklearn.feature_extraction.text import TfidfVectorizer
12
+ from sklearn.metrics.pairwise import cosine_similarity
13
+
14
+
15
+ # ─────────────────────────────────────────────
16
+ # Build / Load Knowledge Base Index
17
+ # ─────────────────────────────────────────────
18
+
19
+ def _load_chunks(kb_path: str = "knowledge_base", chunk_size: int = 300) -> list[dict]:
20
+ """Read all .txt files in knowledge_base/ and split into overlapping chunks."""
21
+ chunks = []
22
+ txt_files = glob.glob(os.path.join(kb_path, "**/*.txt"), recursive=True)
23
+ txt_files += glob.glob(os.path.join(kb_path, "*.txt"))
24
+ txt_files = list(set(txt_files))
25
+
26
+ for fpath in txt_files:
27
+ try:
28
+ with open(fpath, "r", encoding="utf-8") as f:
29
+ text = f.read()
30
+ fname = os.path.basename(fpath)
31
+
32
+ # Split into sentences then group into chunks
33
+ lines = [l.strip() for l in text.split("\n") if l.strip()]
34
+ current, current_len = [], 0
35
+ for line in lines:
36
+ current.append(line)
37
+ current_len += len(line)
38
+ if current_len >= chunk_size:
39
+ chunks.append({"text": " ".join(current), "source": fname})
40
+ # overlap: keep last 2 lines
41
+ current = current[-2:]
42
+ current_len = sum(len(l) for l in current)
43
+ if current:
44
+ chunks.append({"text": " ".join(current), "source": fname})
45
+ except Exception:
46
+ pass
47
+
48
+ return chunks
49
+
50
+
51
+ def build_index(kb_path: str = "knowledge_base", index_path: str = "kb_index.pkl") -> dict:
52
+ """Build TF-IDF index from knowledge base and save to disk."""
53
+ chunks = _load_chunks(kb_path)
54
+ if not chunks:
55
+ return {}
56
+
57
+ texts = [c["text"] for c in chunks]
58
+ vectorizer = TfidfVectorizer(
59
+ ngram_range=(1, 2),
60
+ max_features=8000,
61
+ sublinear_tf=True,
62
+ strip_accents="unicode",
63
+ )
64
+ matrix = vectorizer.fit_transform(texts)
65
+
66
+ index = {
67
+ "chunks": chunks,
68
+ "texts": texts,
69
+ "vectorizer": vectorizer,
70
+ "matrix": matrix,
71
+ }
72
+
73
+ try:
74
+ with open(index_path, "wb") as f:
75
+ pickle.dump(index, f)
76
+ except Exception:
77
+ pass
78
+
79
+ return index
80
+
81
+
82
+ def load_index(index_path: str = "kb_index.pkl", kb_path: str = "knowledge_base") -> dict:
83
+ """Load existing index or build a fresh one if not found."""
84
+ if os.path.exists(index_path):
85
+ try:
86
+ with open(index_path, "rb") as f:
87
+ return pickle.load(f)
88
+ except Exception:
89
+ pass
90
+
91
+ # Auto-build if index missing
92
+ if os.path.isdir(kb_path):
93
+ return build_index(kb_path, index_path)
94
+
95
+ return {}
96
+
97
+
98
+ # ─────────────────────────────────────────────
99
+ # Retrieval
100
+ # ─────────────────────────────────────────────
101
+
102
+ def retrieve(query: str, index: dict, k: int = 3, min_score: float = 0.05) -> str:
103
+ """
104
+ Return the top-k most relevant knowledge base chunks for a query.
105
+ Returns a formatted string ready to inject into an LLM prompt.
106
+ """
107
+ if not index or not query.strip():
108
+ return ""
109
+
110
+ try:
111
+ vectorizer = index["vectorizer"]
112
+ matrix = index["matrix"]
113
+ chunks = index["chunks"]
114
+
115
+ q_vec = vectorizer.transform([query])
116
+ scores = cosine_similarity(q_vec, matrix).flatten()
117
+ top_idx = np.argsort(scores)[::-1][:k]
118
+
119
+ results = []
120
+ seen = set()
121
+ for i in top_idx:
122
+ if scores[i] < min_score:
123
+ continue
124
+ text = chunks[i]["text"]
125
+ src = chunks[i]["source"].replace(".txt", "")
126
+ if text not in seen:
127
+ results.append(f"[{src}] {text}")
128
+ seen.add(text)
129
+
130
+ return "\n\n".join(results) if results else ""
131
+ except Exception:
132
+ return ""
133
+
134
+
135
+ # ─────────────────────────────────────────────
136
+ # Helpers
137
+ # ─────────────────────────────────────────────
138
+
139
+ def kb_status(index: dict) -> str:
140
+ """Return a short human-readable status string."""
141
+ if not index:
142
+ return "❌ Knowledge base not loaded"
143
+ n_chunks = len(index.get("chunks", []))
144
+ sources = {c["source"] for c in index.get("chunks", [])}
145
+ return f"βœ… KB: {len(sources)} files Β· {n_chunks} chunks"