import os import json import asyncio import random import pandas as pd import nest_asyncio from llama_index.core import ( VectorStoreIndex, Settings, Document, ) from llama_index.core.node_parser import SentenceSplitter from llama_index.core.prompts import PromptTemplate from llama_index.llms.ollama import Ollama from llama_index.embeddings.huggingface import HuggingFaceEmbedding nest_asyncio.apply() GROUND_TRUTH_PATH = "retrieval_ground_truth_pairs_30.json" async def generate_query_for_node(llm, node_text): """ Generate one realistic user query from a counseling document. """ prompt = PromptTemplate( """ You are creating a retrieval evaluation dataset for a mental well-being RAG system. Given the counseling interaction below, write ONE realistic user query that someone might ask if they needed this kind of counseling support. Rules: - Write only the user query. - Do not answer the query. - Keep it natural and concise. - Do not mention that this is based on a document. Counseling interaction: {node_text} User query: """ ) response = await llm.apredict( prompt, node_text=node_text, ) return response.strip() async def main(): # ========================================== # 1. MODEL CONFIGURATION # ========================================== print("Initializing models...") llm = Ollama( model="llama3:latest", request_timeout=600.0, ) embed_model = HuggingFaceEmbedding( model_name="BAAI/bge-small-en-v1.5" ) Settings.llm = llm Settings.embed_model = embed_model # ========================================== # 2. LOAD DATASET FROM JSONL FILE # ========================================== json_path = "data/combined_dataset.json" if not os.path.exists(json_path): print(f"Error: {json_path} not found.") return print(f"Loading dataset from {json_path}...") raw_data = [] with open(json_path, "r", encoding="utf-8") as f: for line in f: if line.strip(): raw_data.append(json.loads(line)) print(f"Loaded {len(raw_data)} total records.") # ========================================== # 3. RANDOM SAMPLING # ========================================== sample_size = min(30, len(raw_data)) random.seed(42) sample_data = random.sample(raw_data, sample_size) print(f"Randomly sampled {sample_size} records.") # ========================================== # 4. CREATE DOCUMENTS # ========================================== documents = [] for i, entry in enumerate(sample_data): context = entry.get("Context", "") response = entry.get("Response", "") text_content = ( f"User: {context}\n\n" f"Therapist: {response}" ) if text_content.strip(): documents.append( Document( text=text_content, metadata={ "sample_id": i } ) ) print(f"Prepared {len(documents)} documents.") if len(documents) == 0: print("Error: No valid documents were created. Check dataset keys.") return # ========================================== # 5. CREATE NODES # ========================================== print("Creating nodes...") parser = SentenceSplitter( chunk_size=768, chunk_overlap=100, ) nodes = parser.get_nodes_from_documents(documents) print(f"Generated {len(nodes)} nodes.") if len(nodes) == 0: print("Error: No nodes were created.") return # ========================================== # 6. BUILD VECTOR INDEX # ========================================== print("Building vector index...") index = VectorStoreIndex(nodes) retriever = index.as_retriever( similarity_top_k=5 ) # ========================================== # 7. GENERATE OR LOAD SYNTHETIC GROUND TRUTH # ========================================== if os.path.exists(GROUND_TRUTH_PATH): print(f"Loading existing ground truth from {GROUND_TRUTH_PATH}...") with open(GROUND_TRUTH_PATH, "r", encoding="utf-8") as f: qa_pairs = json.load(f) else: print("Generating synthetic retrieval queries...") qa_pairs = [] for idx, node in enumerate(nodes): print(f"Generating query {idx + 1}/{len(nodes)}...") node_text = node.get_content() query = await generate_query_for_node( llm=llm, node_text=node_text, ) qa_pairs.append( { "query_id": idx, "query": query, "expected_node_id": node.node_id, "source_text": node_text, } ) with open(GROUND_TRUTH_PATH, "w", encoding="utf-8") as f: json.dump( qa_pairs, f, indent=2, ensure_ascii=False, ) print(f"Saved {GROUND_TRUTH_PATH}") # ========================================== # 8. MANUAL RETRIEVAL EVALUATION # ========================================== print("Running retrieval evaluation...") results = [] for pair in qa_pairs: query = pair["query"] expected_node_id = pair["expected_node_id"] retrieved_nodes = await retriever.aretrieve(query) retrieved_ids = [ item.node.node_id for item in retrieved_nodes ] hit = 0 reciprocal_rank = 0.0 rank = None if expected_node_id in retrieved_ids: hit = 1 rank = retrieved_ids.index(expected_node_id) + 1 reciprocal_rank = 1.0 / rank results.append( { "query_id": pair["query_id"], "query": query, "expected_node_id": expected_node_id, "retrieved_node_ids": retrieved_ids, "hit_rate@5": hit, "mrr@5": reciprocal_rank, "rank": rank, } ) # ========================================== # 9. COMPUTE METRICS # ========================================== df = pd.DataFrame(results) hit_rate = df["hit_rate@5"].mean() mrr = df["mrr@5"].mean() df.to_csv( "retrieval_eval_results.csv", index=False, ) # ========================================== # 10. FINAL RESULTS # ========================================== print("\n" + "=" * 50) print(" RAG RETRIEVAL PERFORMANCE") print("=" * 50) print(f"Dataset Source: {json_path}") print("Embedding Model: BAAI/bge-small-en-v1.5") print(f"Documents Used: {len(documents)}") print(f"Nodes Used: {len(nodes)}") print(f"Total Queries: {len(qa_pairs)}") print("-" * 50) print(f"Hit Rate @ 5: {hit_rate:.4f}") print(f"MRR @ 5: {mrr:.4f}") print("=" * 50) print("Evaluation complete!") print("Detailed results saved to retrieval_eval_results.csv") if __name__ == "__main__": asyncio.run(main())