Tell_Me / evaluate_retriever.py
Anonymous
Added retrieval evaluator
478aa65
import os
import json
import asyncio
import random
import pandas as pd
import nest_asyncio
from llama_index.core import (
VectorStoreIndex,
Settings,
Document,
)
from llama_index.core.node_parser import SentenceSplitter
from llama_index.core.prompts import PromptTemplate
from llama_index.llms.ollama import Ollama
from llama_index.embeddings.huggingface import HuggingFaceEmbedding
nest_asyncio.apply()
GROUND_TRUTH_PATH = "retrieval_ground_truth_pairs_30.json"
async def generate_query_for_node(llm, node_text):
"""
Generate one realistic user query from a counseling document.
"""
prompt = PromptTemplate(
"""
You are creating a retrieval evaluation dataset for a mental well-being RAG system.
Given the counseling interaction below, write ONE realistic user query that someone might ask
if they needed this kind of counseling support.
Rules:
- Write only the user query.
- Do not answer the query.
- Keep it natural and concise.
- Do not mention that this is based on a document.
Counseling interaction:
{node_text}
User query:
"""
)
response = await llm.apredict(
prompt,
node_text=node_text,
)
return response.strip()
async def main():
# ==========================================
# 1. MODEL CONFIGURATION
# ==========================================
print("Initializing models...")
llm = Ollama(
model="llama3:latest",
request_timeout=600.0,
)
embed_model = HuggingFaceEmbedding(
model_name="BAAI/bge-small-en-v1.5"
)
Settings.llm = llm
Settings.embed_model = embed_model
# ==========================================
# 2. LOAD DATASET FROM JSONL FILE
# ==========================================
json_path = "data/combined_dataset.json"
if not os.path.exists(json_path):
print(f"Error: {json_path} not found.")
return
print(f"Loading dataset from {json_path}...")
raw_data = []
with open(json_path, "r", encoding="utf-8") as f:
for line in f:
if line.strip():
raw_data.append(json.loads(line))
print(f"Loaded {len(raw_data)} total records.")
# ==========================================
# 3. RANDOM SAMPLING
# ==========================================
sample_size = min(30, len(raw_data))
random.seed(42)
sample_data = random.sample(raw_data, sample_size)
print(f"Randomly sampled {sample_size} records.")
# ==========================================
# 4. CREATE DOCUMENTS
# ==========================================
documents = []
for i, entry in enumerate(sample_data):
context = entry.get("Context", "")
response = entry.get("Response", "")
text_content = (
f"User: {context}\n\n"
f"Therapist: {response}"
)
if text_content.strip():
documents.append(
Document(
text=text_content,
metadata={
"sample_id": i
}
)
)
print(f"Prepared {len(documents)} documents.")
if len(documents) == 0:
print("Error: No valid documents were created. Check dataset keys.")
return
# ==========================================
# 5. CREATE NODES
# ==========================================
print("Creating nodes...")
parser = SentenceSplitter(
chunk_size=768,
chunk_overlap=100,
)
nodes = parser.get_nodes_from_documents(documents)
print(f"Generated {len(nodes)} nodes.")
if len(nodes) == 0:
print("Error: No nodes were created.")
return
# ==========================================
# 6. BUILD VECTOR INDEX
# ==========================================
print("Building vector index...")
index = VectorStoreIndex(nodes)
retriever = index.as_retriever(
similarity_top_k=5
)
# ==========================================
# 7. GENERATE OR LOAD SYNTHETIC GROUND TRUTH
# ==========================================
if os.path.exists(GROUND_TRUTH_PATH):
print(f"Loading existing ground truth from {GROUND_TRUTH_PATH}...")
with open(GROUND_TRUTH_PATH, "r", encoding="utf-8") as f:
qa_pairs = json.load(f)
else:
print("Generating synthetic retrieval queries...")
qa_pairs = []
for idx, node in enumerate(nodes):
print(f"Generating query {idx + 1}/{len(nodes)}...")
node_text = node.get_content()
query = await generate_query_for_node(
llm=llm,
node_text=node_text,
)
qa_pairs.append(
{
"query_id": idx,
"query": query,
"expected_node_id": node.node_id,
"source_text": node_text,
}
)
with open(GROUND_TRUTH_PATH, "w", encoding="utf-8") as f:
json.dump(
qa_pairs,
f,
indent=2,
ensure_ascii=False,
)
print(f"Saved {GROUND_TRUTH_PATH}")
# ==========================================
# 8. MANUAL RETRIEVAL EVALUATION
# ==========================================
print("Running retrieval evaluation...")
results = []
for pair in qa_pairs:
query = pair["query"]
expected_node_id = pair["expected_node_id"]
retrieved_nodes = await retriever.aretrieve(query)
retrieved_ids = [
item.node.node_id
for item in retrieved_nodes
]
hit = 0
reciprocal_rank = 0.0
rank = None
if expected_node_id in retrieved_ids:
hit = 1
rank = retrieved_ids.index(expected_node_id) + 1
reciprocal_rank = 1.0 / rank
results.append(
{
"query_id": pair["query_id"],
"query": query,
"expected_node_id": expected_node_id,
"retrieved_node_ids": retrieved_ids,
"hit_rate@5": hit,
"mrr@5": reciprocal_rank,
"rank": rank,
}
)
# ==========================================
# 9. COMPUTE METRICS
# ==========================================
df = pd.DataFrame(results)
hit_rate = df["hit_rate@5"].mean()
mrr = df["mrr@5"].mean()
df.to_csv(
"retrieval_eval_results.csv",
index=False,
)
# ==========================================
# 10. FINAL RESULTS
# ==========================================
print("\n" + "=" * 50)
print(" RAG RETRIEVAL PERFORMANCE")
print("=" * 50)
print(f"Dataset Source: {json_path}")
print("Embedding Model: BAAI/bge-small-en-v1.5")
print(f"Documents Used: {len(documents)}")
print(f"Nodes Used: {len(nodes)}")
print(f"Total Queries: {len(qa_pairs)}")
print("-" * 50)
print(f"Hit Rate @ 5: {hit_rate:.4f}")
print(f"MRR @ 5: {mrr:.4f}")
print("=" * 50)
print("Evaluation complete!")
print("Detailed results saved to retrieval_eval_results.csv")
if __name__ == "__main__":
asyncio.run(main())