ishmeet-yo commited on
Commit
f6c9e8d
·
verified ·
1 Parent(s): 4e3e1a0

Update app/rag.py

Browse files
Files changed (1) hide show
  1. app/rag.py +191 -158
app/rag.py CHANGED
@@ -1,158 +1,191 @@
1
- import os
2
- import pickle
3
- import hashlib
4
- import numpy as np
5
- from sentence_transformers import SentenceTransformer
6
- from sklearn.feature_extraction.text import TfidfVectorizer
7
- from sklearn.preprocessing import normalize
8
-
9
- CACHE_DIR = "app/cache"
10
- DATA_DIR = "app/data"
11
- MODEL_NAME = "sentence-transformers/all-MiniLM-L6-v2"
12
-
13
-
14
- def compute_hash(files):
15
- h = hashlib.md5()
16
- for f in files:
17
- with open(f, "rb") as fp:
18
- h.update(fp.read())
19
- return h.hexdigest()
20
-
21
-
22
- def load_documents():
23
- files = [
24
- os.path.join(DATA_DIR, f)
25
- for f in os.listdir(DATA_DIR)
26
- if f.endswith(".txt")
27
- ]
28
-
29
- texts = []
30
- for f in files:
31
- with open(f, encoding="utf-8", errors="ignore") as fp:
32
- texts.append(fp.read())
33
-
34
- return texts, files
35
-
36
-
37
- def chunk_text(text, size=500, overlap=100):
38
- words = text.split()
39
- chunks = []
40
- i = 0
41
-
42
- while i < len(words):
43
- chunk = words[i:i+size]
44
- chunks.append(" ".join(chunk))
45
- i += size - overlap
46
-
47
- return chunks
48
-
49
- def chunk_documents(texts):
50
- chunks = []
51
- for t in texts:
52
- chunks.extend(chunk_text(t))
53
- return chunks
54
-
55
-
56
- def build_embeddings(chunks):
57
- model = SentenceTransformer(MODEL_NAME)
58
-
59
- semantic = normalize(model.encode(chunks))
60
- narrative = normalize(model.encode(
61
- ["Story context: " + c for c in chunks]
62
- ))
63
- entity = normalize(model.encode(chunks))
64
-
65
- tfidf = TfidfVectorizer()
66
- tfidf_matrix = tfidf.fit_transform(chunks)
67
-
68
- return {
69
- "semantic": semantic,
70
- "narrative": narrative,
71
- "entity": entity,
72
- "tfidf": tfidf,
73
- "tfidf_matrix": tfidf_matrix,
74
- "model": model
75
- }
76
-
77
-
78
- def save_cache(chunks, heads, dataset_hash):
79
- os.makedirs(CACHE_DIR, exist_ok=True)
80
-
81
- np.save(f"{CACHE_DIR}/semantic.npy", heads["semantic"])
82
- np.save(f"{CACHE_DIR}/narrative.npy", heads["narrative"])
83
- np.save(f"{CACHE_DIR}/entity.npy", heads["entity"])
84
-
85
- with open(f"{CACHE_DIR}/chunks.pkl", "wb") as f:
86
- pickle.dump(chunks, f)
87
-
88
- with open(f"{CACHE_DIR}/tfidf.pkl", "wb") as f:
89
- pickle.dump(heads["tfidf"], f)
90
-
91
- with open(f"{CACHE_DIR}/tfidf_matrix.pkl", "wb") as f:
92
- pickle.dump(heads["tfidf_matrix"], f)
93
-
94
- with open(f"{CACHE_DIR}/hash.txt", "w") as f:
95
- f.write(dataset_hash)
96
-
97
- def load_cache():
98
- with open(f"{CACHE_DIR}/chunks.pkl", "rb") as f:
99
- chunks = pickle.load(f)
100
-
101
- heads = {
102
- "semantic": np.load(f"{CACHE_DIR}/semantic.npy"),
103
- "narrative": np.load(f"{CACHE_DIR}/narrative.npy"),
104
- "entity": np.load(f"{CACHE_DIR}/entity.npy")
105
- }
106
-
107
- with open(f"{CACHE_DIR}/tfidf.pkl", "rb") as f:
108
- heads["tfidf"] = pickle.load(f)
109
-
110
- with open(f"{CACHE_DIR}/tfidf_matrix.pkl", "rb") as f:
111
- heads["tfidf_matrix"] = pickle.load(f)
112
-
113
- heads["model"] = SentenceTransformer(MODEL_NAME)
114
- return chunks, heads
115
-
116
- def load_data():
117
- texts, files = load_documents()
118
- chunks = chunk_documents(texts)
119
- dataset_hash = compute_hash(files)
120
-
121
- hash_path = f"{CACHE_DIR}/hash.txt"
122
-
123
- if os.path.exists(hash_path):
124
- with open(hash_path) as f:
125
- cached_hash = f.read().strip()
126
- else:
127
- cached_hash = None
128
-
129
- if cached_hash == dataset_hash:
130
- print("Loading embeddings from cache")
131
- return load_cache()
132
-
133
- print("Building embeddings")
134
- heads = build_embeddings(chunks)
135
- save_cache(chunks, heads, dataset_hash)
136
- return chunks, heads
137
-
138
-
139
- def retrieve_chunks(query, chunks, heads, k=5):
140
- model = heads["model"]
141
-
142
- q_sem = normalize(model.encode([query]))
143
- q_nav = normalize(model.encode(["Story question: " + query]))
144
-
145
- sem_score = heads["semantic"] @ q_sem.T
146
- nav_score = heads["narrative"] @ q_nav.T
147
-
148
- q_tfidf = heads["tfidf"].transform([query])
149
- key_score = heads["tfidf_matrix"] @ q_tfidf.T
150
-
151
- final = (
152
- 0.45 * sem_score +
153
- 0.35 * nav_score +
154
- 0.20 * key_score.toarray()
155
- )
156
-
157
- idx = np.argsort(final.flatten())[::-1][:k]
158
- return [chunks[i] for i in idx]
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import pickle
3
+ import hashlib
4
+ import numpy as np
5
+
6
+ from sentence_transformers import SentenceTransformer
7
+ from sklearn.feature_extraction.text import TfidfVectorizer
8
+ from sklearn.preprocessing import normalize
9
+
10
+ CACHE_DIR = "app/cache"
11
+ DATA_DIR = "app/data"
12
+ MODEL_NAME = "sentence-transformers/all-MiniLM-L6-v2"
13
+
14
+ CHUNK_SIZE = 500
15
+ CHUNK_OVERLAP = 100
16
+
17
+
18
+ def compute_hash(files):
19
+ h = hashlib.md5()
20
+ for f in sorted(files):
21
+ with open(f, "rb") as fp:
22
+ h.update(fp.read())
23
+ return h.hexdigest()
24
+
25
+
26
+ def load_documents():
27
+ files = [
28
+ os.path.join(DATA_DIR, f)
29
+ for f in os.listdir(DATA_DIR)
30
+ if f.endswith(".txt")
31
+ ]
32
+
33
+ if not files:
34
+ raise RuntimeError("No .txt files found in app/data")
35
+
36
+ texts = []
37
+ for f in files:
38
+ with open(f, encoding="utf-8", errors="ignore") as fp:
39
+ texts.append(fp.read())
40
+
41
+ return texts, files
42
+
43
+
44
+ def chunk_text(text, size=CHUNK_SIZE, overlap=CHUNK_OVERLAP):
45
+ words = text.split()
46
+ chunks = []
47
+ i = 0
48
+
49
+ while i < len(words):
50
+ chunk = words[i:i + size]
51
+ chunks.append(" ".join(chunk))
52
+ i += size - overlap
53
+
54
+ return chunks
55
+
56
+
57
+ def chunk_documents(texts):
58
+ chunks = []
59
+ for t in texts:
60
+ chunks.extend(chunk_text(t))
61
+ return chunks
62
+
63
+
64
+ def build_embeddings(chunks):
65
+ model = SentenceTransformer(MODEL_NAME)
66
+
67
+ semantic = normalize(
68
+ model.encode(chunks, batch_size=32, show_progress_bar=True)
69
+ )
70
+
71
+ narrative = normalize(
72
+ model.encode(
73
+ ["Story context: " + c for c in chunks],
74
+ batch_size=32,
75
+ show_progress_bar=True
76
+ )
77
+ )
78
+
79
+ entity = normalize(
80
+ model.encode(
81
+ ["Entities mentioned: " + c for c in chunks],
82
+ batch_size=32,
83
+ show_progress_bar=True
84
+ )
85
+ )
86
+
87
+ tfidf = TfidfVectorizer(
88
+ ngram_range=(1, 2),
89
+ stop_words="english"
90
+ )
91
+ tfidf_matrix = tfidf.fit_transform(chunks)
92
+
93
+ return {
94
+ "semantic": semantic,
95
+ "narrative": narrative,
96
+ "entity": entity,
97
+ "tfidf": tfidf,
98
+ "tfidf_matrix": tfidf_matrix,
99
+ "model": model
100
+ }
101
+
102
+
103
+ def save_cache(chunks, heads, dataset_hash):
104
+ os.makedirs(CACHE_DIR, exist_ok=True)
105
+
106
+ np.save(f"{CACHE_DIR}/semantic.npy", heads["semantic"])
107
+ np.save(f"{CACHE_DIR}/narrative.npy", heads["narrative"])
108
+ np.save(f"{CACHE_DIR}/entity.npy", heads["entity"])
109
+
110
+ with open(f"{CACHE_DIR}/chunks.pkl", "wb") as f:
111
+ pickle.dump(chunks, f)
112
+
113
+ with open(f"{CACHE_DIR}/tfidf.pkl", "wb") as f:
114
+ pickle.dump(heads["tfidf"], f)
115
+
116
+ with open(f"{CACHE_DIR}/tfidf_matrix.pkl", "wb") as f:
117
+ pickle.dump(heads["tfidf_matrix"], f)
118
+
119
+ with open(f"{CACHE_DIR}/hash.txt", "w") as f:
120
+ f.write(dataset_hash)
121
+
122
+
123
+ def load_cache():
124
+ with open(f"{CACHE_DIR}/chunks.pkl", "rb") as f:
125
+ chunks = pickle.load(f)
126
+
127
+ heads = {
128
+ "semantic": np.load(f"{CACHE_DIR}/semantic.npy"),
129
+ "narrative": np.load(f"{CACHE_DIR}/narrative.npy"),
130
+ "entity": np.load(f"{CACHE_DIR}/entity.npy"),
131
+ }
132
+
133
+ with open(f"{CACHE_DIR}/tfidf.pkl", "rb") as f:
134
+ heads["tfidf"] = pickle.load(f)
135
+
136
+ with open(f"{CACHE_DIR}/tfidf_matrix.pkl", "rb") as f:
137
+ heads["tfidf_matrix"] = pickle.load(f)
138
+
139
+ # model is loaded once here
140
+ heads["model"] = SentenceTransformer(MODEL_NAME)
141
+
142
+ return chunks, heads
143
+
144
+
145
+ def load_data():
146
+ texts, files = load_documents()
147
+ chunks = chunk_documents(texts)
148
+
149
+ dataset_hash = compute_hash(files)
150
+ hash_path = f"{CACHE_DIR}/hash.txt"
151
+
152
+ cached_hash = None
153
+ if os.path.exists(hash_path):
154
+ with open(hash_path) as f:
155
+ cached_hash = f.read().strip()
156
+
157
+ if cached_hash == dataset_hash:
158
+ print("Loading embeddings from cache")
159
+ return load_cache()
160
+
161
+ print("Building embeddings")
162
+ heads = build_embeddings(chunks)
163
+ save_cache(chunks, heads, dataset_hash)
164
+
165
+ return chunks, heads
166
+
167
+
168
+ def retrieve_chunks(query, chunks, heads, k=5):
169
+ model = heads["model"]
170
+
171
+ q_sem = normalize(model.encode([query]))
172
+ q_nav = normalize(model.encode(["Story question: " + query]))
173
+ q_ent = normalize(model.encode(["Entities in question: " + query]))
174
+
175
+ sem_score = heads["semantic"] @ q_sem.T
176
+ nav_score = heads["narrative"] @ q_nav.T
177
+ ent_score = heads["entity"] @ q_ent.T
178
+
179
+ q_tfidf = heads["tfidf"].transform([query])
180
+ key_score = heads["tfidf_matrix"] @ q_tfidf.T
181
+
182
+ final_score = (
183
+ 0.40 * sem_score +
184
+ 0.30 * nav_score +
185
+ 0.15 * ent_score +
186
+ 0.15 * key_score.toarray()
187
+ )
188
+
189
+ top_idx = np.argsort(final_score.flatten())[::-1][:k]
190
+
191
+ return [chunks[i] for i in top_idx]