RGBMetrics / quick_test.py
RGB Evaluation
feat: Add separate grid layout for 4 RAG abilities in Streamlit UI
af25c62
"""
Quick Test Script for RGB RAG Evaluation
Tests a small sample to verify everything works before full evaluation.
"""
import os
import sys
# Add parent directory to path
sys.path.insert(0, os.path.dirname(os.path.abspath(__file__)))
from src.llm_client import GroqLLMClient
from src.data_loader import RGBDataLoader
from src.evaluator import RGBEvaluator
from src.prompts import format_prompt, get_prompt_template
def quick_test(num_samples: int = 3):
"""
Run a quick test with a few samples.
Args:
num_samples: Number of samples to test per task.
"""
print("="*60)
print("RGB RAG Evaluation - Quick Test")
print("="*60)
# Check for API key
api_key = os.getenv("GROQ_API_KEY")
if not api_key:
print("\n❌ ERROR: GROQ_API_KEY environment variable not set!")
print("\nTo set it:")
print(" 1. Get your free API key from: https://console.groq.com/")
print(" 2. Create a .env file with: GROQ_API_KEY=your_key_here")
print(" 3. Or set environment variable: $env:GROQ_API_KEY='your_key_here'")
return False
print(f"\nβœ“ API Key found: {api_key[:8]}...")
# Initialize components
print("\nInitializing components...")
try:
client = GroqLLMClient(model="llama-3.1-8b-instant") # Use fast model for testing
loader = RGBDataLoader()
evaluator = RGBEvaluator()
print(f"βœ“ Using model: {client.model}")
except Exception as e:
print(f"❌ Error initializing: {e}")
return False
# Test 1: Simple generation
print("\n" + "-"*60)
print("Test 1: Simple Generation")
print("-"*60)
try:
response = client.generate("What is 2 + 2? Answer with just the number.")
print(f" Prompt: 'What is 2 + 2?'")
print(f" Response: '{response}'")
print(" βœ“ LLM connection working!")
except Exception as e:
print(f" ❌ Error: {e}")
return False
# Test 2: Noise Robustness sample
print("\n" + "-"*60)
print("Test 2: Noise Robustness Sample")
print("-"*60)
try:
samples = loader.load_noise_robustness(max_samples=num_samples)
sample = samples[0]
prompt = format_prompt(
question=sample.question,
documents=sample.documents,
template=get_prompt_template("default")
)
print(f" Question: {sample.question[:60]}...")
print(f" Expected: {sample.answer[:60]}...")
response = client.generate(prompt)
print(f" Response: {response[:100]}...")
is_correct = evaluator.is_correct(response, sample.answer)
print(f" Correct: {is_correct}")
print(" βœ“ Noise robustness test complete!")
except Exception as e:
print(f" ❌ Error: {e}")
# Test 3: Negative Rejection sample
print("\n" + "-"*60)
print("Test 3: Negative Rejection Sample")
print("-"*60)
try:
samples = loader.load_negative_rejection(max_samples=num_samples)
sample = samples[0]
prompt = format_prompt(
question=sample.question,
documents=sample.documents,
template=get_prompt_template("negative")
)
print(f" Question: {sample.question[:60]}...")
response = client.generate(prompt)
print(f" Response: {response[:100]}...")
is_rejection = evaluator.is_rejection(response)
print(f" Rejected: {is_rejection} (should reject since docs don't have answer)")
print(" βœ“ Negative rejection test complete!")
except Exception as e:
print(f" ❌ Error: {e}")
# Test 4: Information Integration sample
print("\n" + "-"*60)
print("Test 4: Information Integration Sample")
print("-"*60)
try:
samples = loader.load_information_integration(max_samples=num_samples)
sample = samples[0]
prompt = format_prompt(
question=sample.question,
documents=sample.documents,
template=get_prompt_template("default")
)
print(f" Question: {sample.question[:60]}...")
print(f" Expected: {sample.answer[:60]}...")
response = client.generate(prompt)
print(f" Response: {response[:100]}...")
is_correct = evaluator.is_correct(response, sample.answer)
print(f" Correct: {is_correct}")
print(" βœ“ Information integration test complete!")
except Exception as e:
print(f" ❌ Error: {e}")
# Test 5: Counterfactual Robustness sample
print("\n" + "-"*60)
print("Test 5: Counterfactual Robustness Sample")
print("-"*60)
try:
samples = loader.load_counterfactual_robustness(max_samples=num_samples)
sample = samples[0]
prompt = format_prompt(
question=sample.question,
documents=sample.documents,
template=get_prompt_template("counterfactual")
)
print(f" Question: {sample.question[:60]}...")
print(f" Correct Answer: {sample.answer}")
print(f" Fake Answer (in docs): {sample.counterfactual_answer}")
response = client.generate(prompt)
print(f" Response: {response[:100]}...")
detects_error = evaluator.detects_error(response, sample.counterfactual_answer)
corrects_error = evaluator.corrects_error(response, sample.answer, sample.counterfactual_answer)
print(f" Detects Error: {detects_error}")
print(f" Corrects Error: {corrects_error}")
print(" βœ“ Counterfactual robustness test complete!")
except Exception as e:
print(f" ❌ Error: {e}")
print("\n" + "="*60)
print("Quick Test Complete!")
print("="*60)
print("\nTo run full evaluation:")
print(" python run_evaluation.py --max-samples 50")
print("\nOr for a complete run:")
print(" python run_evaluation.py")
return True
if __name__ == "__main__":
import argparse
parser = argparse.ArgumentParser(description="Quick test for RGB evaluation")
parser.add_argument("-n", "--num-samples", type=int, default=3,
help="Number of samples to test per task")
args = parser.parse_args()
quick_test(args.num_samples)