Spaces:
Sleeping
Sleeping
| """ | |
| Quick Test Script for RGB RAG Evaluation | |
| Tests a small sample to verify everything works before full evaluation. | |
| """ | |
| import os | |
| import sys | |
| # Add parent directory to path | |
| sys.path.insert(0, os.path.dirname(os.path.abspath(__file__))) | |
| from src.llm_client import GroqLLMClient | |
| from src.data_loader import RGBDataLoader | |
| from src.evaluator import RGBEvaluator | |
| from src.prompts import format_prompt, get_prompt_template | |
| def quick_test(num_samples: int = 3): | |
| """ | |
| Run a quick test with a few samples. | |
| Args: | |
| num_samples: Number of samples to test per task. | |
| """ | |
| print("="*60) | |
| print("RGB RAG Evaluation - Quick Test") | |
| print("="*60) | |
| # Check for API key | |
| api_key = os.getenv("GROQ_API_KEY") | |
| if not api_key: | |
| print("\nβ ERROR: GROQ_API_KEY environment variable not set!") | |
| print("\nTo set it:") | |
| print(" 1. Get your free API key from: https://console.groq.com/") | |
| print(" 2. Create a .env file with: GROQ_API_KEY=your_key_here") | |
| print(" 3. Or set environment variable: $env:GROQ_API_KEY='your_key_here'") | |
| return False | |
| print(f"\nβ API Key found: {api_key[:8]}...") | |
| # Initialize components | |
| print("\nInitializing components...") | |
| try: | |
| client = GroqLLMClient(model="llama-3.1-8b-instant") # Use fast model for testing | |
| loader = RGBDataLoader() | |
| evaluator = RGBEvaluator() | |
| print(f"β Using model: {client.model}") | |
| except Exception as e: | |
| print(f"β Error initializing: {e}") | |
| return False | |
| # Test 1: Simple generation | |
| print("\n" + "-"*60) | |
| print("Test 1: Simple Generation") | |
| print("-"*60) | |
| try: | |
| response = client.generate("What is 2 + 2? Answer with just the number.") | |
| print(f" Prompt: 'What is 2 + 2?'") | |
| print(f" Response: '{response}'") | |
| print(" β LLM connection working!") | |
| except Exception as e: | |
| print(f" β Error: {e}") | |
| return False | |
| # Test 2: Noise Robustness sample | |
| print("\n" + "-"*60) | |
| print("Test 2: Noise Robustness Sample") | |
| print("-"*60) | |
| try: | |
| samples = loader.load_noise_robustness(max_samples=num_samples) | |
| sample = samples[0] | |
| prompt = format_prompt( | |
| question=sample.question, | |
| documents=sample.documents, | |
| template=get_prompt_template("default") | |
| ) | |
| print(f" Question: {sample.question[:60]}...") | |
| print(f" Expected: {sample.answer[:60]}...") | |
| response = client.generate(prompt) | |
| print(f" Response: {response[:100]}...") | |
| is_correct = evaluator.is_correct(response, sample.answer) | |
| print(f" Correct: {is_correct}") | |
| print(" β Noise robustness test complete!") | |
| except Exception as e: | |
| print(f" β Error: {e}") | |
| # Test 3: Negative Rejection sample | |
| print("\n" + "-"*60) | |
| print("Test 3: Negative Rejection Sample") | |
| print("-"*60) | |
| try: | |
| samples = loader.load_negative_rejection(max_samples=num_samples) | |
| sample = samples[0] | |
| prompt = format_prompt( | |
| question=sample.question, | |
| documents=sample.documents, | |
| template=get_prompt_template("negative") | |
| ) | |
| print(f" Question: {sample.question[:60]}...") | |
| response = client.generate(prompt) | |
| print(f" Response: {response[:100]}...") | |
| is_rejection = evaluator.is_rejection(response) | |
| print(f" Rejected: {is_rejection} (should reject since docs don't have answer)") | |
| print(" β Negative rejection test complete!") | |
| except Exception as e: | |
| print(f" β Error: {e}") | |
| # Test 4: Information Integration sample | |
| print("\n" + "-"*60) | |
| print("Test 4: Information Integration Sample") | |
| print("-"*60) | |
| try: | |
| samples = loader.load_information_integration(max_samples=num_samples) | |
| sample = samples[0] | |
| prompt = format_prompt( | |
| question=sample.question, | |
| documents=sample.documents, | |
| template=get_prompt_template("default") | |
| ) | |
| print(f" Question: {sample.question[:60]}...") | |
| print(f" Expected: {sample.answer[:60]}...") | |
| response = client.generate(prompt) | |
| print(f" Response: {response[:100]}...") | |
| is_correct = evaluator.is_correct(response, sample.answer) | |
| print(f" Correct: {is_correct}") | |
| print(" β Information integration test complete!") | |
| except Exception as e: | |
| print(f" β Error: {e}") | |
| # Test 5: Counterfactual Robustness sample | |
| print("\n" + "-"*60) | |
| print("Test 5: Counterfactual Robustness Sample") | |
| print("-"*60) | |
| try: | |
| samples = loader.load_counterfactual_robustness(max_samples=num_samples) | |
| sample = samples[0] | |
| prompt = format_prompt( | |
| question=sample.question, | |
| documents=sample.documents, | |
| template=get_prompt_template("counterfactual") | |
| ) | |
| print(f" Question: {sample.question[:60]}...") | |
| print(f" Correct Answer: {sample.answer}") | |
| print(f" Fake Answer (in docs): {sample.counterfactual_answer}") | |
| response = client.generate(prompt) | |
| print(f" Response: {response[:100]}...") | |
| detects_error = evaluator.detects_error(response, sample.counterfactual_answer) | |
| corrects_error = evaluator.corrects_error(response, sample.answer, sample.counterfactual_answer) | |
| print(f" Detects Error: {detects_error}") | |
| print(f" Corrects Error: {corrects_error}") | |
| print(" β Counterfactual robustness test complete!") | |
| except Exception as e: | |
| print(f" β Error: {e}") | |
| print("\n" + "="*60) | |
| print("Quick Test Complete!") | |
| print("="*60) | |
| print("\nTo run full evaluation:") | |
| print(" python run_evaluation.py --max-samples 50") | |
| print("\nOr for a complete run:") | |
| print(" python run_evaluation.py") | |
| return True | |
| if __name__ == "__main__": | |
| import argparse | |
| parser = argparse.ArgumentParser(description="Quick test for RGB evaluation") | |
| parser.add_argument("-n", "--num-samples", type=int, default=3, | |
| help="Number of samples to test per task") | |
| args = parser.parse_args() | |
| quick_test(args.num_samples) | |