hh786's picture
Deployment of Hierarchical RAG system
c54dcef
"""Evaluation metrics for RAG systems."""
import time
from typing import List, Dict, Any, Tuple, Optional
import numpy as np
from sentence_transformers import SentenceTransformer, util
class RAGEvaluator:
"""Evaluate RAG system performance."""
def __init__(self, embedding_model_name: str = "sentence-transformers/all-MiniLM-L6-v2"):
"""
Initialize evaluator.
Args:
embedding_model_name: Model for semantic similarity
"""
self.embedding_model = SentenceTransformer(embedding_model_name)
def hit_at_k(
self,
retrieved_ids: List[str],
relevant_ids: List[str],
k: int = 5
) -> float:
"""
Calculate Hit@k metric.
Args:
retrieved_ids: List of retrieved document IDs
relevant_ids: List of relevant document IDs
k: Number of top results to consider
Returns:
Hit@k score (1 if any relevant doc in top-k, else 0)
"""
top_k = retrieved_ids[:k]
return 1.0 if any(rid in relevant_ids for rid in top_k) else 0.0
def precision_at_k(
self,
retrieved_ids: List[str],
relevant_ids: List[str],
k: int = 5
) -> float:
"""
Calculate Precision@k.
Args:
retrieved_ids: List of retrieved document IDs
relevant_ids: List of relevant document IDs
k: Number of top results to consider
Returns:
Precision@k score
"""
top_k = retrieved_ids[:k]
if not top_k:
return 0.0
relevant_in_top_k = sum(1 for rid in top_k if rid in relevant_ids)
return relevant_in_top_k / len(top_k)
def recall_at_k(
self,
retrieved_ids: List[str],
relevant_ids: List[str],
k: int = 5
) -> float:
"""
Calculate Recall@k.
Args:
retrieved_ids: List of retrieved document IDs
relevant_ids: List of relevant document IDs
k: Number of top results to consider
Returns:
Recall@k score
"""
if not relevant_ids:
return 0.0
top_k = retrieved_ids[:k]
relevant_in_top_k = sum(1 for rid in top_k if rid in relevant_ids)
return relevant_in_top_k / len(relevant_ids)
def mrr(
self,
retrieved_ids: List[str],
relevant_ids: List[str]
) -> float:
"""
Calculate Mean Reciprocal Rank.
Args:
retrieved_ids: List of retrieved document IDs
relevant_ids: List of relevant document IDs
Returns:
MRR score
"""
for i, rid in enumerate(retrieved_ids, 1):
if rid in relevant_ids:
return 1.0 / i
return 0.0
def semantic_similarity(
self,
answer: str,
reference: str
) -> float:
"""
Calculate semantic similarity between answer and reference.
Args:
answer: Generated answer
reference: Reference answer
Returns:
Cosine similarity score
"""
embeddings = self.embedding_model.encode([answer, reference])
similarity = util.cos_sim(embeddings[0], embeddings[1])
return float(similarity[0][0])
def evaluate_retrieval(
self,
retrieved_results: List[Dict[str, Any]],
relevant_ids: List[str],
k_values: List[int] = [1, 3, 5, 10]
) -> Dict[str, Any]:
"""
Comprehensive retrieval evaluation.
Args:
retrieved_results: List of retrieval results
relevant_ids: List of relevant document IDs
k_values: List of k values for Hit@k, Precision@k, Recall@k
Returns:
Dictionary with all metrics
"""
retrieved_ids = [r["id"] for r in retrieved_results]
metrics = {
"mrr": self.mrr(retrieved_ids, relevant_ids)
}
for k in k_values:
metrics[f"hit@{k}"] = self.hit_at_k(retrieved_ids, relevant_ids, k)
metrics[f"precision@{k}"] = self.precision_at_k(retrieved_ids, relevant_ids, k)
metrics[f"recall@{k}"] = self.recall_at_k(retrieved_ids, relevant_ids, k)
return metrics
def evaluate_generation(
self,
generated_answer: str,
reference_answer: str
) -> Dict[str, float]:
"""
Evaluate generated answer quality.
Args:
generated_answer: Generated answer
reference_answer: Reference answer
Returns:
Dictionary with generation metrics
"""
return {
"semantic_similarity": self.semantic_similarity(generated_answer, reference_answer)
}
def evaluate_rag_pipeline(
self,
rag_result: Dict[str, Any],
relevant_ids: List[str],
reference_answer: Optional[str] = None,
k_values: List[int] = [1, 3, 5]
) -> Dict[str, Any]:
"""
Evaluate complete RAG pipeline.
Args:
rag_result: Result from RAG query
relevant_ids: List of relevant document IDs
reference_answer: Optional reference answer
k_values: List of k values for metrics
Returns:
Dictionary with all evaluation metrics
"""
metrics = {
"pipeline": rag_result.get("pipeline", "Unknown"),
"retrieval_time": rag_result.get("retrieval_time", 0),
"generation_time": rag_result.get("generation_time", 0),
"total_time": rag_result.get("total_time", 0)
}
# Retrieval metrics
retrieval_metrics = self.evaluate_retrieval(
rag_result["contexts"],
relevant_ids,
k_values
)
metrics.update(retrieval_metrics)
# Generation metrics (if reference provided)
if reference_answer:
generation_metrics = self.evaluate_generation(
rag_result["answer"],
reference_answer
)
metrics.update(generation_metrics)
return metrics
def compare_pipelines(
self,
base_result: Dict[str, Any],
hier_result: Dict[str, Any],
relevant_ids: List[str],
reference_answer: Optional[str] = None,
k_values: List[int] = [1, 3, 5]
) -> Dict[str, Any]:
"""
Compare Base-RAG and Hier-RAG results.
Args:
base_result: Result from Base-RAG
hier_result: Result from Hier-RAG
relevant_ids: List of relevant document IDs
reference_answer: Optional reference answer
k_values: List of k values for metrics
Returns:
Dictionary with comparison metrics
"""
base_metrics = self.evaluate_rag_pipeline(
base_result,
relevant_ids,
reference_answer,
k_values
)
hier_metrics = self.evaluate_rag_pipeline(
hier_result,
relevant_ids,
reference_answer,
k_values
)
# Calculate improvements
comparison = {
"base_rag": base_metrics,
"hier_rag": hier_metrics,
"improvements": {}
}
# Speed improvements
if base_metrics["total_time"] > 0:
comparison["improvements"]["speedup"] = base_metrics["total_time"] / hier_metrics["total_time"]
# Accuracy improvements
for k in k_values:
hit_key = f"hit@{k}"
if hit_key in base_metrics and hit_key in hier_metrics:
comparison["improvements"][f"{hit_key}_delta"] = hier_metrics[hit_key] - base_metrics[hit_key]
if "mrr" in base_metrics and "mrr" in hier_metrics:
comparison["improvements"]["mrr_delta"] = hier_metrics["mrr"] - base_metrics["mrr"]
if "semantic_similarity" in base_metrics and "semantic_similarity" in hier_metrics:
comparison["improvements"]["similarity_delta"] = (
hier_metrics["semantic_similarity"] - base_metrics["semantic_similarity"]
)
return comparison
class BenchmarkDataset:
"""Generate or load benchmark datasets for evaluation."""
def __init__(self):
"""Initialize benchmark dataset."""
self.queries = []
self.ground_truth = {}
def add_query(
self,
query: str,
relevant_ids: List[str],
reference_answer: Optional[str] = None
) -> None:
"""
Add a query to the benchmark.
Args:
query: Query text
relevant_ids: List of relevant document IDs
reference_answer: Optional reference answer
"""
self.queries.append(query)
self.ground_truth[query] = {
"relevant_ids": relevant_ids,
"reference_answer": reference_answer
}
def get_sample_hospital_queries(self) -> List[Dict[str, Any]]:
"""
Get sample queries for hospital domain.
Returns:
List of query dictionaries
"""
return [
{
"query": "What are the patient admission procedures?",
"domain": "Clinical Care",
"expected_doc_type": "protocol"
},
{
"query": "What are the infection control policies?",
"domain": "Quality & Safety",
"expected_doc_type": "policy"
},
{
"query": "How should medication errors be reported?",
"domain": "Quality & Safety",
"expected_doc_type": "policy"
},
{
"query": "What training is required for new nurses?",
"domain": "Education & Training",
"expected_doc_type": "manual"
},
{
"query": "What are the emergency response procedures?",
"domain": "Clinical Care",
"expected_doc_type": "protocol"
}
]
def get_sample_bank_queries(self) -> List[Dict[str, Any]]:
"""
Get sample queries for banking domain.
Returns:
List of query dictionaries
"""
return [
{
"query": "What are the KYC requirements for new accounts?",
"domain": "Compliance & Legal",
"expected_doc_type": "policy"
},
{
"query": "How do I process a personal loan application?",
"domain": "Retail Banking",
"expected_doc_type": "manual"
},
{
"query": "What is the credit risk assessment procedure?",
"domain": "Risk Management",
"expected_doc_type": "guideline"
},
{
"query": "What are the fraud prevention measures?",
"domain": "Risk Management",
"expected_doc_type": "policy"
},
{
"query": "How should suspicious transactions be reported?",
"domain": "Compliance & Legal",
"expected_doc_type": "policy"
}
]
def get_sample_fluid_simulation_queries(self) -> List[Dict[str, Any]]:
"""
Get sample queries for fluid simulation domain.
Returns:
List of query dictionaries
"""
return [
{
"query": "How does the SIMPLE algorithm work?",
"domain": "Numerical Methods",
"expected_doc_type": "paper"
},
{
"query": "What turbulence models are available?",
"domain": "Physical Models",
"expected_doc_type": "manual"
},
{
"query": "How do I set up a cavity flow benchmark?",
"domain": "Validation & Verification",
"expected_doc_type": "tutorial"
},
{
"query": "What mesh generation techniques are recommended?",
"domain": "Numerical Methods",
"expected_doc_type": "manual"
},
{
"query": "How do I enable parallel computing for simulations?",
"domain": "Software & Tools",
"expected_doc_type": "manual"
}
]
def load_from_file(self, filepath: str) -> None:
"""
Load benchmark dataset from JSON file.
Args:
filepath: Path to JSON file
"""
import json
with open(filepath, 'r', encoding='utf-8') as f:
data = json.load(f)
self.queries = data.get("queries", [])
self.ground_truth = data.get("ground_truth", {})
def save_to_file(self, filepath: str) -> None:
"""
Save benchmark dataset to JSON file.
Args:
filepath: Path to output JSON file
"""
import json
from pathlib import Path
Path(filepath).parent.mkdir(parents=True, exist_ok=True)
data = {
"queries": self.queries,
"ground_truth": self.ground_truth
}
with open(filepath, 'w', encoding='utf-8') as f:
json.dump(data, f, indent=2, ensure_ascii=False)