agllm2-dev / retrieval_evaluation.py
arbabarshad's picture
starting oct 5
b9629f4
"""
Retrieval Evaluation Script for AgLLM
Generates questions from chunks and evaluates retrieval performance with precision@k and nDCG@k metrics
"""
import os
import json
import random
import numpy as np
from typing import List, Dict, Tuple, Optional
from dataclasses import dataclass
import pandas as pd
from tqdm import tqdm
from langchain.vectorstores import Chroma
from langchain.embeddings import OpenAIEmbeddings
from langchain.schema import Document
import openai
from dotenv import load_dotenv
import time
load_dotenv()
@dataclass
class EvaluationSample:
"""Holds a chunk, its generated question, and metadata"""
chunk_id: str
chunk_content: str
metadata: Dict
question: str
ground_truth_chunk_id: str # The chunk that contains the answer
class QuestionGenerator:
"""Generates questions from chunks using GPT-4"""
def __init__(self, api_key: Optional[str] = None):
self.api_key = api_key or os.getenv("OPENAI_API_KEY")
if not self.api_key:
raise ValueError("OpenAI API key not found")
def generate_question(self, chunk_content: str, metadata: Dict) -> str:
"""Generate a question where the chunk contains the answer"""
# Build context from metadata
context_parts = []
species_name = metadata.get('matched_specie_0')
if species_name:
context_parts.append(f"Species: {species_name}")
if 'common_name' in metadata:
context_parts.append(f"Common Name: {metadata['common_name']}")
if 'region' in metadata:
context_parts.append(f"Region: {metadata['region']}")
context = " | ".join(context_parts) if context_parts else ""
prompt = f"""Given the following agricultural information chunk, generate ONE specific question that this chunk directly answers.
The question should be natural and the kind a farmer or agricultural expert might ask.
The answer to your question MUST be found in the provided chunk.
Context: {context}
Chunk Content:
{chunk_content} # Limit chunk size for prompt
Generate a single, clear question (no explanations, just the question):"""
try:
from openai import OpenAI
client = OpenAI(api_key=self.api_key)
response = client.chat.completions.create(
model="gpt-4o",
messages=[
{"role": "system", "content": "You are an agricultural expert who creates precise questions from agricultural information."},
{"role": "user", "content": prompt}
],
max_tokens=100,
temperature=0.7
)
question = response.choices[0].message.content.strip()
return question
except Exception as e:
print(f"Error generating question: {e}")
# Fallback question
species = metadata.get('matched_specie_0', 'this species')
return f"What IPM information is available for {species}?"
class RetrievalEvaluator:
"""Evaluates retrieval performance"""
def __init__(self, persist_directory: str, embedding_model = None):
self.persist_directory = persist_directory
self.embedding = embedding_model or OpenAIEmbeddings()
self.vectordb = Chroma(
persist_directory=persist_directory,
embedding_function=self.embedding
)
def retrieve_chunks(self, query: str, k: int = 5, filter_dict: Optional[Dict] = None) -> List[Tuple[Document, float]]:
"""Retrieve top-k chunks for a query with optional metadata filtering"""
if filter_dict:
results = self.vectordb.similarity_search_with_score(
query,
k=k,
filter=filter_dict
)
else:
results = self.vectordb.similarity_search_with_score(query, k=k)
return results
def calculate_precision_at_k(self, retrieved_ids: List[str], ground_truth_id: str, k: int) -> float:
"""Calculate precision@k - binary: 1 if ground truth in top-k, 0 otherwise"""
retrieved_at_k = retrieved_ids[:k]
return 1.0 if ground_truth_id in retrieved_at_k else 0.0
def calculate_ndcg_at_k(self, retrieved_ids: List[str], ground_truth_id: str, k: int) -> float:
"""Calculate nDCG@k - gives credit for ranking ground truth higher"""
dcg = 0.0
for i, chunk_id in enumerate(retrieved_ids[:k]):
if chunk_id == ground_truth_id:
# Relevance is 1 for ground truth, 0 for others
dcg += 1.0 / np.log2(i + 2) # i+2 because positions start at 1
break
# Ideal DCG is 1.0 at position 1
idcg = 1.0
return dcg / idcg if idcg > 0 else 0.0
def evaluate_retrieval_pipelines(self, samples: List[EvaluationSample], k_values: List[int] = [1, 3, 5]) -> Dict:
"""Evaluate different retrieval pipelines"""
results = {
'no_filter': {f'precision@{k}': [] for k in k_values} | {f'ndcg@{k}': [] for k in k_values},
'species_only': {f'precision@{k}': [] for k in k_values} | {f'ndcg@{k}': [] for k in k_values},
'region_only': {f'precision@{k}': [] for k in k_values} | {f'ndcg@{k}': [] for k in k_values},
'species_and_region': {f'precision@{k}': [] for k in k_values} | {f'ndcg@{k}': [] for k in k_values}
}
for sample in tqdm(samples, desc="Evaluating samples"):
question = sample.question
ground_truth_id = sample.ground_truth_chunk_id
metadata = sample.metadata
# Define filter strategies (using ChromaDB filter format)
species_value = metadata.get('matched_specie_0')
region_value = metadata.get('region')
filters = {
'no_filter': None,
'species_only': {
'$or': [
{'matched_specie_0': {'$eq': species_value}},
{'matched_specie_1': {'$eq': species_value}},
{'matched_specie_2': {'$eq': species_value}}
]
} if species_value else None,
'region_only': {'region': {'$eq': region_value}} if region_value else None,
'species_and_region': {
'$and': [
{
'$or': [
{'matched_specie_0': {'$eq': species_value}},
{'matched_specie_1': {'$eq': species_value}},
{'matched_specie_2': {'$eq': species_value}}
]
},
{'region': {'$eq': region_value}}
]
} if species_value and region_value else None
}
for filter_name, filter_dict in filters.items():
# Skip if required metadata is missing
if filter_name != 'no_filter' and filter_dict is None:
continue
# Retrieve chunks
max_k = max(k_values)
retrieved_results = self.retrieve_chunks(question, k=max_k, filter_dict=filter_dict)
# Extract chunk IDs from results
retrieved_ids = []
for doc, score in retrieved_results:
# Extract chunk ID from source metadata
source = doc.metadata.get('source', '')
retrieved_ids.append(source)
# Calculate metrics for each k
for k in k_values:
precision = self.calculate_precision_at_k(retrieved_ids, ground_truth_id, k)
ndcg = self.calculate_ndcg_at_k(retrieved_ids, ground_truth_id, k)
results[filter_name][f'precision@{k}'].append(precision)
results[filter_name][f'ndcg@{k}'].append(ndcg)
# Calculate averages
averaged_results = {}
for pipeline, metrics in results.items():
averaged_results[pipeline] = {}
for metric_name, values in metrics.items():
if values: # Only calculate if we have values
averaged_results[pipeline][metric_name] = {
'mean': np.mean(values),
'std': np.std(values),
'count': len(values)
}
return averaged_results
def load_chunks_from_vectordb(persist_directory: str, sample_size: Optional[int] = None) -> List[Dict]:
"""Load chunks from Chroma vectorDB"""
embeddings = OpenAIEmbeddings()
vectordb = Chroma(
persist_directory=persist_directory,
embedding_function=embeddings
)
# Get all documents
# Note: Chroma doesn't have a direct way to get all docs, so we use a large search
results = vectordb.similarity_search("", k=10000) # Get many results
chunks = []
for doc in results:
chunk_data = {
'id': doc.metadata.get('source', ''),
'content': doc.page_content,
'metadata': doc.metadata
}
chunks.append(chunk_data)
if sample_size and len(chunks) > sample_size:
chunks = random.sample(chunks, sample_size)
return chunks
def main():
"""Main evaluation pipeline"""
# Configuration
VECTOR_DB_PATH = 'vector-databases-deployed/db5-agllm-data-isu-field-insects-all-species'
SAMPLE_SIZE = 100 # Start with smaller sample for testing
K_VALUES = [1, 3, 5]
OUTPUT_FILE = 'retrieval_evaluation_results.json'
print("Starting Retrieval Evaluation Pipeline")
print("=" * 50)
# Step 1: Load chunks from vector database
print("\n1. Loading chunks from vector database...")
chunks = load_chunks_from_vectordb(VECTOR_DB_PATH, sample_size=SAMPLE_SIZE)
print(f" Loaded {len(chunks)} chunks")
# Step 1.5: Analyze metadata availability
print("\n Metadata Analysis:")
matched_specie_count = sum(1 for chunk in chunks if chunk['metadata'].get('matched_specie_0'))
region_count = sum(1 for chunk in chunks if chunk['metadata'].get('region'))
both_species_region = sum(1 for chunk in chunks if chunk['metadata'].get('matched_specie_0') and chunk['metadata'].get('region'))
print(f" - Chunks with 'matched_specie_0' field: {matched_specie_count}")
print(f" - Chunks with 'region' field: {region_count}")
print(f" - Chunks with species and region: {both_species_region}")
# Show sample metadata
if chunks:
sample_metadata = chunks[0]['metadata']
print(f" - Sample metadata keys: {list(sample_metadata.keys())}")
species_field = sample_metadata.get('matched_specie_0')
region_field = sample_metadata.get('region')
print(f" - Sample species: {species_field}")
print(f" - Sample region: {region_field}")
##### DEBUGGING: Print all species and region values #####
print("\n##### ALL CHUNK METADATA DEBUG #####")
for i, chunk in enumerate(chunks):
metadata = chunk['metadata']
species = metadata.get('matched_specie_0', 'MISSING')
region = metadata.get('region', 'MISSING')
source = metadata.get('source', 'unknown') + "..." # Truncate for readability
print(f"Chunk {i+1:2d}: Species='{species}' | Region='{region}' | Source={source}")
print("##### END DEBUG #####\n")
# Step 2: Generate questions for chunks
print("\n2. Generating questions from chunks...")
question_generator = QuestionGenerator()
samples = []
for i, chunk in enumerate(tqdm(chunks, desc="Generating questions")):
try:
question = question_generator.generate_question(
chunk['content'],
chunk['metadata']
)
sample = EvaluationSample(
chunk_id=chunk['id'],
chunk_content=chunk['content'],
metadata=chunk['metadata'],
question=question,
ground_truth_chunk_id=chunk['id']
)
samples.append(sample)
# Rate limiting for API
if (i + 1) % 10 == 0:
time.sleep(1)
except Exception as e:
print(f" Error processing chunk {i}: {e}")
continue
print(f" Generated {len(samples)} question-chunk pairs")
# Step 3: Evaluate retrieval pipelines
print("\n3. Evaluating retrieval pipelines...")
evaluator = RetrievalEvaluator(VECTOR_DB_PATH)
results = evaluator.evaluate_retrieval_pipelines(samples, k_values=K_VALUES)
# Step 4: Display and save results
print("\n4. Evaluation Results:")
print("=" * 50)
# Create results table
pipelines = ['no_filter', 'species_only', 'region_only', 'species_and_region']
pipeline_names = ['No Filter', 'Species Only', 'Region Only', 'Species + Region']
# Precision table
print("\nPRECISION RESULTS:")
print("-" * 70)
print(f"{'Pipeline':<15} {'P@1':<8} {'P@3':<8} {'P@5':<8}")
print("-" * 70)
for i, pipeline in enumerate(pipelines):
if pipeline in results:
p1 = results[pipeline].get('precision@1', {}).get('mean', 0.0)
p3 = results[pipeline].get('precision@3', {}).get('mean', 0.0)
p5 = results[pipeline].get('precision@5', {}).get('mean', 0.0)
print(f"{pipeline_names[i]:<15} {p1:<8.3f} {p3:<8.3f} {p5:<8.3f}")
# nDCG table
print("\nnDCG RESULTS:")
print("-" * 70)
print(f"{'Pipeline':<15} {'nDCG@1':<8} {'nDCG@3':<8} {'nDCG@5':<8}")
print("-" * 70)
for i, pipeline in enumerate(pipelines):
if pipeline in results:
n1 = results[pipeline].get('ndcg@1', {}).get('mean', 0.0)
n3 = results[pipeline].get('ndcg@3', {}).get('mean', 0.0)
n5 = results[pipeline].get('ndcg@5', {}).get('mean', 0.0)
print(f"{pipeline_names[i]:<15} {n1:<8.3f} {n3:<8.3f} {n5:<8.3f}")
# Save detailed results
with open(OUTPUT_FILE, 'w') as f:
json.dump(results, f, indent=2)
print(f"\nDetailed results saved to {OUTPUT_FILE}")
if __name__ == "__main__":
main()