Jatin Mehra commited on
Commit
8d0a63e
·
1 Parent(s): bcaf7fa

Add dataset processing and evaluation script for RAG system with memory management

Browse files
Files changed (1) hide show
  1. gen_dataset.py +158 -0
gen_dataset.py ADDED
@@ -0,0 +1,158 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from datasets import load_dataset
2
+
3
+ ds = load_dataset("neural-bridge/rag-dataset-12000")
4
+
5
+ # Test the RAG system with DS dataset
6
+ from sentence_transformers import SentenceTransformer
7
+ from preprocessing import model_selection, create_embeddings, build_faiss_index, retrieve_similar_chunks, agentic_rag
8
+ import dotenv
9
+ from langchain_community.tools.tavily_search import TavilySearchResults
10
+ import json
11
+ import gc
12
+ import torch # For clearing CUDA cache if available
13
+ import os
14
+ from langchain.memory import ConversationBufferMemory
15
+ import json
16
+ import csv
17
+ from sentence_transformers import SentenceTransformer, util
18
+ from rouge_score import rouge_scorer
19
+
20
+ # Configuration parameters
21
+ SAMPLE_SIZE = 80 # Number of documents to test
22
+ BATCH_SIZE = 1 # Save results after every X iterations
23
+ OUTPUT_FILE = 'rag_test_output.json'
24
+
25
+ tools = [TavilySearchResults(max_results=5)]
26
+ dotenv.load_dotenv()
27
+
28
+ # create a simple chunking function for text based
29
+ def chunk_text(text, max_length=250):
30
+ # Split the text into chunks of max_length with metadata
31
+ chunks = []
32
+ for i in range(0, len(text), max_length):
33
+ chunk = text[i:i + max_length]
34
+ chunks.append({"text": chunk, "metadata": {"chunk_id": i // max_length}})
35
+ return chunks
36
+
37
+ # Function to clear memory
38
+ def clear_memory():
39
+ gc.collect() # Run garbage collector
40
+ if torch.cuda.is_available(): # If using GPU
41
+ torch.cuda.empty_cache() # Clear CUDA cache
42
+
43
+ # Initialize or load output data
44
+ if os.path.exists(OUTPUT_FILE):
45
+ with open(OUTPUT_FILE, 'r') as f:
46
+ try:
47
+ output_data = json.load(f)
48
+ start_idx = len(output_data) # Resume from where we left off
49
+ print(f"Resuming from index {start_idx}")
50
+ except json.JSONDecodeError:
51
+ output_data = [] # Start fresh if file is corrupted
52
+ start_idx = 0
53
+ else:
54
+ output_data = [] # Start fresh if file doesn't exist
55
+ start_idx = 0
56
+
57
+ # Process documents in range
58
+ try:
59
+ for i in range(start_idx, min(start_idx + SAMPLE_SIZE, len(ds['train']))):
60
+ print(f"Processing document {i}/{min(start_idx + SAMPLE_SIZE, len(ds['train']))}")
61
+
62
+ # Get current document data
63
+ llm = model_selection("meta-llama/llama-4-scout-17b-16e-instruct")
64
+ current_context_text = ds['train'][i]['context']
65
+ model = SentenceTransformer('BAAI/bge-large-en-v1.5')
66
+
67
+ # Process text and create embeddings
68
+ chunks = chunk_text(current_context_text, max_length=100)
69
+ embeddings, chunks = create_embeddings(chunks, model)
70
+ index = build_faiss_index(embeddings)
71
+ query = ds['train'][i]['question']
72
+
73
+ # Retrieve similar chunks
74
+ similar_chunks = retrieve_similar_chunks(query, index, chunks, model, k=5)
75
+ agent_memory = ConversationBufferMemory(memory_key="chat_history", return_messages=True)
76
+ # Run RAG system
77
+ print(f"Query: {query}")
78
+ response = agentic_rag(llm, tools, query=query, context_chunks=similar_chunks, memory=agent_memory, Use_Tavily=False)
79
+
80
+ print("Assistant:", response["output"])
81
+ print("Ground Truth:", ds['train'][i]['answer'])
82
+ print("==="*50)
83
+
84
+ # Store the results
85
+ output_data.append({
86
+ "query": query,
87
+ "assistant_response": response["output"],
88
+ "ground_truth": ds['train'][i]['answer'],
89
+ "context": current_context_text
90
+ })
91
+
92
+ # Save results periodically to preserve memory
93
+ if (i + 1) % BATCH_SIZE == 0 or i == min(start_idx + SAMPLE_SIZE, len(ds['train'])) - 1:
94
+ with open(OUTPUT_FILE, 'w') as f:
95
+ json.dump(output_data, f, indent=4)
96
+ print(f"\nSaved results for {len(output_data)} documents to {OUTPUT_FILE}")
97
+
98
+ # Clear memory
99
+ del llm, current_context_text, model, chunks, embeddings, index, similar_chunks, response
100
+ clear_memory()
101
+
102
+ except Exception as e:
103
+ print(f"Error occurred at document index {i}: {str(e)}")
104
+ # Save whatever results we have so far
105
+ with open(OUTPUT_FILE, 'w') as f:
106
+ json.dump(output_data, f, indent=4)
107
+ print(f"\nSaved partial results for {len(output_data)} documents to {OUTPUT_FILE}")
108
+
109
+ print(f"\nCompleted processing {len(output_data)} documents. Results saved to {OUTPUT_FILE}")
110
+
111
+
112
+ # Load model
113
+ model = SentenceTransformer('BAAI/bge-large-en-v1.5')
114
+ rouge = rouge_scorer.RougeScorer(['rougeL'], use_stemmer=True)
115
+
116
+ # File paths
117
+ input_file = 'rag_test_output.json'
118
+ output_file = 'rag_scores.csv'
119
+ semantic_threshold = 0.75
120
+
121
+ # Read JSON array
122
+ with open(input_file, 'r', encoding='utf-8') as f:
123
+ data = json.load(f)
124
+
125
+ results = []
126
+
127
+ # Score each item
128
+ for item in data:
129
+ query = item.get("query", "")
130
+ assistant_response = item.get("assistant_response", "")
131
+ ground_truth = item.get("ground_truth", "")
132
+ context = item.get("context", "")
133
+
134
+ # Compute semantic similarity
135
+ emb_response = model.encode(assistant_response, convert_to_tensor=True)
136
+ emb_truth = model.encode(ground_truth, convert_to_tensor=True)
137
+ similarity = util.pytorch_cos_sim(emb_response, emb_truth).item()
138
+
139
+ # Compute ROUGE-L F1
140
+ rouge_score = rouge.score(assistant_response, ground_truth)['rougeL'].fmeasure
141
+
142
+ # Final status
143
+ status = "PASS" if similarity >= semantic_threshold else "FAIL"
144
+
145
+ results.append({
146
+ "query": query,
147
+ "semantic_similarity": round(similarity, 4),
148
+ "rougeL_f1": round(rouge_score, 4),
149
+ "status": status
150
+ })
151
+
152
+ # Write results to CSV
153
+ with open(output_file, 'w', newline='', encoding='utf-8') as csvfile:
154
+ writer = csv.DictWriter(csvfile, fieldnames=["query", "semantic_similarity", "rougeL_f1", "status"])
155
+ writer.writeheader()
156
+ writer.writerows(results)
157
+
158
+ print(f"Scores saved to '{output_file}'")