Spaces:
Runtime error
Runtime error
Update app.py
Browse files
app.py
CHANGED
|
@@ -47,10 +47,12 @@ class RAGEvaluator:
|
|
| 47 |
self.current_dataset = None
|
| 48 |
self.test_samples = []
|
| 49 |
|
| 50 |
-
def load_dataset(self, dataset_name: str, num_samples: int =
|
|
|
|
| 51 |
if dataset_name == "squad":
|
| 52 |
dataset = load_dataset("squad_v2", split="validation")
|
| 53 |
-
|
|
|
|
| 54 |
self.test_samples = [
|
| 55 |
{
|
| 56 |
"question": sample["question"],
|
|
@@ -62,7 +64,7 @@ class RAGEvaluator:
|
|
| 62 |
]
|
| 63 |
elif dataset_name == "msmarco":
|
| 64 |
dataset = load_dataset("ms_marco", "v2.1", split="train")
|
| 65 |
-
samples = dataset.select(range(
|
| 66 |
self.test_samples = [
|
| 67 |
{
|
| 68 |
"question": sample["query"],
|
|
@@ -76,40 +78,60 @@ class RAGEvaluator:
|
|
| 76 |
return self.test_samples
|
| 77 |
|
| 78 |
def evaluate_configuration(self, vector_db, qa_chain, splitting_strategy: str, chunk_size: str) -> Dict:
|
|
|
|
| 79 |
if not self.test_samples:
|
| 80 |
return {"error": "No dataset loaded"}
|
| 81 |
|
| 82 |
results = []
|
| 83 |
-
|
| 84 |
-
|
| 85 |
-
|
| 86 |
-
|
| 87 |
-
})
|
| 88 |
|
| 89 |
-
|
| 90 |
-
|
| 91 |
-
|
| 92 |
-
|
| 93 |
-
|
| 94 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 95 |
|
|
|
|
| 96 |
eval_dataset = Dataset.from_list(results)
|
| 97 |
metrics = [ContextRecall(), AnswerRelevancy(), Faithfulness(), ContextPrecision()]
|
| 98 |
-
scores = evaluate(eval_dataset, metrics=metrics)
|
| 99 |
|
| 100 |
-
|
| 101 |
-
|
| 102 |
-
|
| 103 |
-
|
| 104 |
-
|
| 105 |
-
|
| 106 |
-
|
| 107 |
-
scores['
|
| 108 |
-
scores['
|
| 109 |
-
scores['
|
| 110 |
-
|
| 111 |
-
|
| 112 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 113 |
|
| 114 |
# Text splitting and database functions
|
| 115 |
def get_text_splitter(strategy: str, chunk_size: int = 1024, chunk_overlap: int = 64):
|
|
|
|
| 47 |
self.current_dataset = None
|
| 48 |
self.test_samples = []
|
| 49 |
|
| 50 |
+
def load_dataset(self, dataset_name: str, num_samples: int = 5):
|
| 51 |
+
"""Load a smaller subset of questions"""
|
| 52 |
if dataset_name == "squad":
|
| 53 |
dataset = load_dataset("squad_v2", split="validation")
|
| 54 |
+
# Select diverse questions based on length and type
|
| 55 |
+
samples = dataset.select(range(0, 1000, 100))[:num_samples] # Take 10 spaced-out samples
|
| 56 |
self.test_samples = [
|
| 57 |
{
|
| 58 |
"question": sample["question"],
|
|
|
|
| 64 |
]
|
| 65 |
elif dataset_name == "msmarco":
|
| 66 |
dataset = load_dataset("ms_marco", "v2.1", split="train")
|
| 67 |
+
samples = dataset.select(range(0, 1000, 100))[:num_samples]
|
| 68 |
self.test_samples = [
|
| 69 |
{
|
| 70 |
"question": sample["query"],
|
|
|
|
| 78 |
return self.test_samples
|
| 79 |
|
| 80 |
def evaluate_configuration(self, vector_db, qa_chain, splitting_strategy: str, chunk_size: str) -> Dict:
|
| 81 |
+
"""Evaluate with progress tracking"""
|
| 82 |
if not self.test_samples:
|
| 83 |
return {"error": "No dataset loaded"}
|
| 84 |
|
| 85 |
results = []
|
| 86 |
+
total_questions = len(self.test_samples)
|
| 87 |
+
|
| 88 |
+
# Add progress tracking
|
| 89 |
+
for i, sample in enumerate(self.test_samples):
|
| 90 |
+
print(f"Evaluating question {i+1}/{total_questions}")
|
| 91 |
|
| 92 |
+
try:
|
| 93 |
+
response = qa_chain.invoke({
|
| 94 |
+
"question": sample["question"],
|
| 95 |
+
"chat_history": []
|
| 96 |
+
})
|
| 97 |
+
|
| 98 |
+
results.append({
|
| 99 |
+
"question": sample["question"],
|
| 100 |
+
"answer": response["answer"],
|
| 101 |
+
"contexts": [doc.page_content for doc in response["source_documents"]],
|
| 102 |
+
"ground_truths": [sample["ground_truth"]]
|
| 103 |
+
})
|
| 104 |
+
except Exception as e:
|
| 105 |
+
print(f"Error processing question {i+1}: {str(e)}")
|
| 106 |
+
continue
|
| 107 |
|
| 108 |
+
# Calculate RAGAS metrics
|
| 109 |
eval_dataset = Dataset.from_list(results)
|
| 110 |
metrics = [ContextRecall(), AnswerRelevancy(), Faithfulness(), ContextPrecision()]
|
|
|
|
| 111 |
|
| 112 |
+
try:
|
| 113 |
+
scores = evaluate(eval_dataset, metrics=metrics)
|
| 114 |
+
|
| 115 |
+
return {
|
| 116 |
+
"configuration": f"{splitting_strategy}_{chunk_size}",
|
| 117 |
+
"questions_evaluated": len(results),
|
| 118 |
+
"context_recall": float(scores['context_recall']),
|
| 119 |
+
"answer_relevancy": float(scores['answer_relevancy']),
|
| 120 |
+
"faithfulness": float(scores['faithfulness']),
|
| 121 |
+
"context_precision": float(scores['context_precision']),
|
| 122 |
+
"average_score": float(np.mean([
|
| 123 |
+
scores['context_recall'],
|
| 124 |
+
scores['answer_relevancy'],
|
| 125 |
+
scores['faithfulness'],
|
| 126 |
+
scores['context_precision']
|
| 127 |
+
]))
|
| 128 |
+
}
|
| 129 |
+
except Exception as e:
|
| 130 |
+
return {
|
| 131 |
+
"configuration": f"{splitting_strategy}_{chunk_size}",
|
| 132 |
+
"error": str(e),
|
| 133 |
+
"questions_evaluated": len(results)
|
| 134 |
+
}
|
| 135 |
|
| 136 |
# Text splitting and database functions
|
| 137 |
def get_text_splitter(strategy: str, chunk_size: int = 1024, chunk_overlap: int = 64):
|