Anonymous commited on
Commit
478aa65
·
1 Parent(s): 5e1004e

Added retrieval evaluator

Browse files
Files changed (1) hide show
  1. evaluate_retriever.py +289 -0
evaluate_retriever.py ADDED
@@ -0,0 +1,289 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import json
3
+ import asyncio
4
+ import random
5
+ import pandas as pd
6
+ import nest_asyncio
7
+
8
+ from llama_index.core import (
9
+ VectorStoreIndex,
10
+ Settings,
11
+ Document,
12
+ )
13
+
14
+ from llama_index.core.node_parser import SentenceSplitter
15
+ from llama_index.core.prompts import PromptTemplate
16
+
17
+ from llama_index.llms.ollama import Ollama
18
+ from llama_index.embeddings.huggingface import HuggingFaceEmbedding
19
+
20
+
21
+ nest_asyncio.apply()
22
+
23
+
24
+ GROUND_TRUTH_PATH = "retrieval_ground_truth_pairs_30.json"
25
+
26
+
27
+ async def generate_query_for_node(llm, node_text):
28
+ """
29
+ Generate one realistic user query from a counseling document.
30
+ """
31
+
32
+ prompt = PromptTemplate(
33
+ """
34
+ You are creating a retrieval evaluation dataset for a mental well-being RAG system.
35
+
36
+ Given the counseling interaction below, write ONE realistic user query that someone might ask
37
+ if they needed this kind of counseling support.
38
+
39
+ Rules:
40
+ - Write only the user query.
41
+ - Do not answer the query.
42
+ - Keep it natural and concise.
43
+ - Do not mention that this is based on a document.
44
+
45
+ Counseling interaction:
46
+ {node_text}
47
+
48
+ User query:
49
+ """
50
+ )
51
+
52
+ response = await llm.apredict(
53
+ prompt,
54
+ node_text=node_text,
55
+ )
56
+
57
+ return response.strip()
58
+
59
+
60
+ async def main():
61
+
62
+ # ==========================================
63
+ # 1. MODEL CONFIGURATION
64
+ # ==========================================
65
+ print("Initializing models...")
66
+
67
+ llm = Ollama(
68
+ model="llama3:latest",
69
+ request_timeout=600.0,
70
+ )
71
+
72
+ embed_model = HuggingFaceEmbedding(
73
+ model_name="BAAI/bge-small-en-v1.5"
74
+ )
75
+
76
+ Settings.llm = llm
77
+ Settings.embed_model = embed_model
78
+
79
+ # ==========================================
80
+ # 2. LOAD DATASET FROM JSONL FILE
81
+ # ==========================================
82
+ json_path = "data/combined_dataset.json"
83
+
84
+ if not os.path.exists(json_path):
85
+ print(f"Error: {json_path} not found.")
86
+ return
87
+
88
+ print(f"Loading dataset from {json_path}...")
89
+
90
+ raw_data = []
91
+
92
+ with open(json_path, "r", encoding="utf-8") as f:
93
+ for line in f:
94
+ if line.strip():
95
+ raw_data.append(json.loads(line))
96
+
97
+ print(f"Loaded {len(raw_data)} total records.")
98
+
99
+ # ==========================================
100
+ # 3. RANDOM SAMPLING
101
+ # ==========================================
102
+ sample_size = min(30, len(raw_data))
103
+
104
+ random.seed(42)
105
+
106
+ sample_data = random.sample(raw_data, sample_size)
107
+
108
+ print(f"Randomly sampled {sample_size} records.")
109
+
110
+ # ==========================================
111
+ # 4. CREATE DOCUMENTS
112
+ # ==========================================
113
+ documents = []
114
+
115
+ for i, entry in enumerate(sample_data):
116
+
117
+ context = entry.get("Context", "")
118
+ response = entry.get("Response", "")
119
+
120
+ text_content = (
121
+ f"User: {context}\n\n"
122
+ f"Therapist: {response}"
123
+ )
124
+
125
+ if text_content.strip():
126
+ documents.append(
127
+ Document(
128
+ text=text_content,
129
+ metadata={
130
+ "sample_id": i
131
+ }
132
+ )
133
+ )
134
+
135
+ print(f"Prepared {len(documents)} documents.")
136
+
137
+ if len(documents) == 0:
138
+ print("Error: No valid documents were created. Check dataset keys.")
139
+ return
140
+
141
+ # ==========================================
142
+ # 5. CREATE NODES
143
+ # ==========================================
144
+ print("Creating nodes...")
145
+
146
+ parser = SentenceSplitter(
147
+ chunk_size=768,
148
+ chunk_overlap=100,
149
+ )
150
+
151
+ nodes = parser.get_nodes_from_documents(documents)
152
+
153
+ print(f"Generated {len(nodes)} nodes.")
154
+
155
+ if len(nodes) == 0:
156
+ print("Error: No nodes were created.")
157
+ return
158
+
159
+ # ==========================================
160
+ # 6. BUILD VECTOR INDEX
161
+ # ==========================================
162
+ print("Building vector index...")
163
+
164
+ index = VectorStoreIndex(nodes)
165
+
166
+ retriever = index.as_retriever(
167
+ similarity_top_k=5
168
+ )
169
+
170
+ # ==========================================
171
+ # 7. GENERATE OR LOAD SYNTHETIC GROUND TRUTH
172
+ # ==========================================
173
+ if os.path.exists(GROUND_TRUTH_PATH):
174
+ print(f"Loading existing ground truth from {GROUND_TRUTH_PATH}...")
175
+
176
+ with open(GROUND_TRUTH_PATH, "r", encoding="utf-8") as f:
177
+ qa_pairs = json.load(f)
178
+
179
+ else:
180
+ print("Generating synthetic retrieval queries...")
181
+
182
+ qa_pairs = []
183
+
184
+ for idx, node in enumerate(nodes):
185
+ print(f"Generating query {idx + 1}/{len(nodes)}...")
186
+
187
+ node_text = node.get_content()
188
+
189
+ query = await generate_query_for_node(
190
+ llm=llm,
191
+ node_text=node_text,
192
+ )
193
+
194
+ qa_pairs.append(
195
+ {
196
+ "query_id": idx,
197
+ "query": query,
198
+ "expected_node_id": node.node_id,
199
+ "source_text": node_text,
200
+ }
201
+ )
202
+
203
+ with open(GROUND_TRUTH_PATH, "w", encoding="utf-8") as f:
204
+ json.dump(
205
+ qa_pairs,
206
+ f,
207
+ indent=2,
208
+ ensure_ascii=False,
209
+ )
210
+
211
+ print(f"Saved {GROUND_TRUTH_PATH}")
212
+
213
+ # ==========================================
214
+ # 8. MANUAL RETRIEVAL EVALUATION
215
+ # ==========================================
216
+ print("Running retrieval evaluation...")
217
+
218
+ results = []
219
+
220
+ for pair in qa_pairs:
221
+ query = pair["query"]
222
+ expected_node_id = pair["expected_node_id"]
223
+
224
+ retrieved_nodes = await retriever.aretrieve(query)
225
+
226
+ retrieved_ids = [
227
+ item.node.node_id
228
+ for item in retrieved_nodes
229
+ ]
230
+
231
+ hit = 0
232
+ reciprocal_rank = 0.0
233
+ rank = None
234
+
235
+ if expected_node_id in retrieved_ids:
236
+ hit = 1
237
+ rank = retrieved_ids.index(expected_node_id) + 1
238
+ reciprocal_rank = 1.0 / rank
239
+
240
+ results.append(
241
+ {
242
+ "query_id": pair["query_id"],
243
+ "query": query,
244
+ "expected_node_id": expected_node_id,
245
+ "retrieved_node_ids": retrieved_ids,
246
+ "hit_rate@5": hit,
247
+ "mrr@5": reciprocal_rank,
248
+ "rank": rank,
249
+ }
250
+ )
251
+
252
+ # ==========================================
253
+ # 9. COMPUTE METRICS
254
+ # ==========================================
255
+ df = pd.DataFrame(results)
256
+
257
+ hit_rate = df["hit_rate@5"].mean()
258
+ mrr = df["mrr@5"].mean()
259
+
260
+ df.to_csv(
261
+ "retrieval_eval_results.csv",
262
+ index=False,
263
+ )
264
+
265
+ # ==========================================
266
+ # 10. FINAL RESULTS
267
+ # ==========================================
268
+ print("\n" + "=" * 50)
269
+ print(" RAG RETRIEVAL PERFORMANCE")
270
+ print("=" * 50)
271
+
272
+ print(f"Dataset Source: {json_path}")
273
+ print("Embedding Model: BAAI/bge-small-en-v1.5")
274
+ print(f"Documents Used: {len(documents)}")
275
+ print(f"Nodes Used: {len(nodes)}")
276
+ print(f"Total Queries: {len(qa_pairs)}")
277
+
278
+ print("-" * 50)
279
+
280
+ print(f"Hit Rate @ 5: {hit_rate:.4f}")
281
+ print(f"MRR @ 5: {mrr:.4f}")
282
+
283
+ print("=" * 50)
284
+ print("Evaluation complete!")
285
+ print("Detailed results saved to retrieval_eval_results.csv")
286
+
287
+
288
+ if __name__ == "__main__":
289
+ asyncio.run(main())