""" Data Loader for RGB Dataset Handles loading and preprocessing of RGB benchmark datasets: - en_refine.json: For noise robustness and negative rejection - en_int.json: For information integration - en_fact.json: For counterfactual robustness Dataset structure (from https://github.com/chen700564/RGB): - en_refine.json: {id, query, answer, positive, negative} - en_int.json: {id, query, answer, answer1, answer2, positive, negative} - en_fact.json: {id, query, answer, fakeanswer, positive_wrong, positive, negative} """ import json import os import random from typing import List, Dict, Any, Optional, Tuple from dataclasses import dataclass from enum import Enum class TaskType(Enum): """Types of RAG evaluation tasks.""" NOISE_ROBUSTNESS = "noise_robustness" NEGATIVE_REJECTION = "negative_rejection" INFORMATION_INTEGRATION = "information_integration" COUNTERFACTUAL_ROBUSTNESS = "counterfactual_robustness" @dataclass class RGBSample: """A single sample from the RGB dataset.""" id: int question: str answer: str # Ground truth answer (can be string or list) documents: List[str] # Retrieved documents/passages task_type: TaskType noise_level: Optional[int] = None # Number of noise documents has_answer: Optional[bool] = None # Whether docs contain the answer num_docs_needed: Optional[int] = None # Docs needed for answer has_counterfactual: Optional[bool] = None # Whether docs contain counterfactual counterfactual_answer: Optional[str] = None # The counterfactual (wrong) answer raw_data: Optional[Dict] = None # Original raw data class RGBDataLoader: """ Loader for RGB benchmark datasets. Implements data loading as per the RGB paper and repository. """ def __init__(self, data_dir: str = "data", passage_num: int = 5): """ Initialize the data loader. Args: data_dir: Directory containing the RGB dataset files. passage_num: Number of passages to include per sample (default 5). """ self.data_dir = data_dir self.passage_num = passage_num self._validate_data_dir() def _validate_data_dir(self) -> None: """Check if data directory exists.""" if not os.path.exists(self.data_dir): os.makedirs(self.data_dir) print(f"Created data directory: {self.data_dir}") print("Please run: python download_datasets.py") def _get_file_path(self, filename: str) -> str: """Get full path to a data file.""" return os.path.join(self.data_dir, filename) def _load_jsonl(self, filepath: str) -> List[Dict]: """Load a JSONL file (one JSON object per line).""" data = [] with open(filepath, 'r', encoding='utf-8') as f: for line in f: line = line.strip() if line: data.append(json.loads(line)) return data def _format_answer(self, answer: Any) -> str: """ Format answer to string for comparison. For nested lists (information integration), flatten to list of alternatives. For simple lists (noise robustness), take first or join. """ if isinstance(answer, list): # Check if it's a nested list (from en_int.json with answer variants) if answer and isinstance(answer[0], list): # Flatten nested list: [['variant1', 'variant2'], 'other_answer'] → all variants variants = [] for item in answer: if isinstance(item, list): variants.extend(item) else: variants.append(str(item)) # Return as pipe-separated alternatives for matching return "|".join(variants) else: # Simple list: join with pipe as alternatives return "|".join(str(a) for a in answer) return str(answer) def load_noise_robustness( self, max_samples: Optional[int] = None, noise_rate: float = 0.4 ) -> List[RGBSample]: """ Load data for Noise Robustness evaluation. Uses en_refine.json - tests LLM's ability to handle noisy documents. Args: max_samples: Maximum number of samples to load (None for all). noise_rate: Rate of noise documents (0.0 to 0.8). Returns: List of RGBSample objects for noise robustness evaluation. """ filepath = self._get_file_path("en_refine.json") if not os.path.exists(filepath): raise FileNotFoundError( f"Dataset file not found: {filepath}\n" "Please run: python download_datasets.py" ) data = self._load_jsonl(filepath) samples = [] for idx, item in enumerate(data): if max_samples and idx >= max_samples: break # Calculate number of positive and negative documents neg_num = int(self.passage_num * noise_rate) pos_num = self.passage_num - neg_num # Get positive and negative documents positive_docs = item.get('positive', [])[:pos_num] negative_docs = item.get('negative', [])[:neg_num] # Combine and shuffle documents documents = positive_docs + negative_docs random.shuffle(documents) if not documents: continue sample = RGBSample( id=item.get('id', idx), question=item.get('query', ''), answer=self._format_answer(item.get('answer', '')), documents=documents, task_type=TaskType.NOISE_ROBUSTNESS, noise_level=neg_num, has_answer=True, raw_data=item ) samples.append(sample) print(f"Loaded {len(samples)} samples for Noise Robustness (noise_rate={noise_rate})") return samples def load_negative_rejection( self, max_samples: Optional[int] = None ) -> List[RGBSample]: """ Load data for Negative Rejection evaluation. Uses en_refine.json with noise_rate=1.0 (all negative documents). Tests LLM's ability to reject when documents don't contain the answer. Args: max_samples: Maximum number of samples to load (None for all). Returns: List of RGBSample objects for negative rejection evaluation. """ filepath = self._get_file_path("en_refine.json") if not os.path.exists(filepath): raise FileNotFoundError( f"Dataset file not found: {filepath}\n" "Please run: python download_datasets.py" ) data = self._load_jsonl(filepath) samples = [] for idx, item in enumerate(data): if max_samples and idx >= max_samples: break # For negative rejection, use only negative documents negative_docs = item.get('negative', [])[:self.passage_num] if not negative_docs: continue sample = RGBSample( id=item.get('id', idx), question=item.get('query', ''), answer=self._format_answer(item.get('answer', '')), documents=negative_docs, task_type=TaskType.NEGATIVE_REJECTION, has_answer=False, # Documents don't contain the answer raw_data=item ) samples.append(sample) print(f"Loaded {len(samples)} samples for Negative Rejection") return samples def load_information_integration( self, max_samples: Optional[int] = None ) -> List[RGBSample]: """ Load data for Information Integration evaluation. Uses en_int.json - tests LLM's ability to integrate info from multiple docs. Args: max_samples: Maximum number of samples to load (None for all). Returns: List of RGBSample objects for information integration evaluation. """ filepath = self._get_file_path("en_int.json") if not os.path.exists(filepath): raise FileNotFoundError( f"Dataset file not found: {filepath}\n" "Please run: python download_datasets.py" ) data = self._load_jsonl(filepath) samples = [] for idx, item in enumerate(data): if max_samples and idx >= max_samples: break # For information integration, we need documents from different sources # The 'positive' field contains lists of documents for each answer component positive_docs = item.get('positive', []) # Flatten and get one document from each source documents = [] if isinstance(positive_docs, list): for doc_group in positive_docs: if isinstance(doc_group, list) and doc_group: documents.append(doc_group[0]) # Take first from each group elif isinstance(doc_group, str): documents.append(doc_group) # Add some negative docs if needed neg_num = max(0, self.passage_num - len(documents)) negative_docs = item.get('negative', [])[:neg_num] documents.extend(negative_docs) if not documents: continue random.shuffle(documents) sample = RGBSample( id=item.get('id', idx), question=item.get('query', ''), answer=self._format_answer(item.get('answer', '')), documents=documents[:self.passage_num], task_type=TaskType.INFORMATION_INTEGRATION, num_docs_needed=len(positive_docs) if isinstance(positive_docs, list) else 1, raw_data=item ) samples.append(sample) print(f"Loaded {len(samples)} samples for Information Integration") return samples def load_counterfactual_robustness( self, max_samples: Optional[int] = None ) -> List[RGBSample]: """ Load data for Counterfactual Robustness evaluation. Uses en_fact.json - tests LLM's ability to detect/correct factual errors. Args: max_samples: Maximum number of samples to load (None for all). Returns: List of RGBSample objects for counterfactual robustness evaluation. """ filepath = self._get_file_path("en_fact.json") if not os.path.exists(filepath): raise FileNotFoundError( f"Dataset file not found: {filepath}\n" "Please run: python download_datasets.py" ) data = self._load_jsonl(filepath) samples = [] for idx, item in enumerate(data): if max_samples and idx >= max_samples: break # For counterfactual, we use positive_wrong documents (contain fake answer) # and can mix with some correct documents wrong_docs = item.get('positive_wrong', []) correct_docs = item.get('positive', []) negative_docs = item.get('negative', []) # Use mainly wrong docs with some negative documents = wrong_docs[:3] + negative_docs[:2] if not documents: # Fallback to any available docs documents = wrong_docs or correct_docs or negative_docs if not documents: continue random.shuffle(documents) sample = RGBSample( id=item.get('id', idx), question=item.get('query', ''), answer=self._format_answer(item.get('answer', '')), documents=documents[:self.passage_num], task_type=TaskType.COUNTERFACTUAL_ROBUSTNESS, has_counterfactual=True, counterfactual_answer=self._format_answer(item.get('fakeanswer', '')), raw_data=item ) samples.append(sample) print(f"Loaded {len(samples)} samples for Counterfactual Robustness") return samples def load_all_for_task( self, task_type: TaskType, max_samples: Optional[int] = None, **kwargs ) -> List[RGBSample]: """ Load data for a specific task type. Args: task_type: The type of evaluation task. max_samples: Maximum samples to load. **kwargs: Additional arguments for specific loaders. Returns: List of RGBSample objects. """ loaders = { TaskType.NOISE_ROBUSTNESS: self.load_noise_robustness, TaskType.NEGATIVE_REJECTION: self.load_negative_rejection, TaskType.INFORMATION_INTEGRATION: self.load_information_integration, TaskType.COUNTERFACTUAL_ROBUSTNESS: self.load_counterfactual_robustness, } return loaders[task_type](max_samples, **kwargs) def get_dataset_stats(self) -> Dict[str, Any]: """Get statistics about the loaded datasets.""" stats = {} files = { "en_refine.json": "Noise Robustness & Negative Rejection", "en_int.json": "Information Integration", "en_fact.json": "Counterfactual Robustness" } for filename, description in files.items(): filepath = self._get_file_path(filename) if os.path.exists(filepath): data = self._load_jsonl(filepath) stats[filename] = { "description": description, "num_samples": len(data), "file_size_bytes": os.path.getsize(filepath) } else: stats[filename] = {"error": "File not found"} return stats def test_loader(): """Test the data loader with actual data.""" loader = RGBDataLoader() print("="*60) print("RGB Dataset Loader Test") print("="*60) # Get stats stats = loader.get_dataset_stats() print("\nDataset Statistics:") for filename, info in stats.items(): print(f" {filename}: {info}") # Test loading a few samples from each task print("\n" + "-"*60) try: samples = loader.load_noise_robustness(max_samples=2) if samples: print(f"\nNoise Robustness Sample:") print(f" Question: {samples[0].question[:80]}...") print(f" Answer: {samples[0].answer}") print(f" Num Docs: {len(samples[0].documents)}") except FileNotFoundError as e: print(f" Skipping: {e}") try: samples = loader.load_negative_rejection(max_samples=2) if samples: print(f"\nNegative Rejection Sample:") print(f" Question: {samples[0].question[:80]}...") print(f" Num Docs: {len(samples[0].documents)}") except FileNotFoundError as e: print(f" Skipping: {e}") try: samples = loader.load_information_integration(max_samples=2) if samples: print(f"\nInformation Integration Sample:") print(f" Question: {samples[0].question[:80]}...") print(f" Answer: {samples[0].answer}") except FileNotFoundError as e: print(f" Skipping: {e}") try: samples = loader.load_counterfactual_robustness(max_samples=2) if samples: print(f"\nCounterfactual Robustness Sample:") print(f" Question: {samples[0].question[:80]}...") print(f" Correct Answer: {samples[0].answer}") print(f" Fake Answer: {samples[0].counterfactual_answer}") except FileNotFoundError as e: print(f" Skipping: {e}") print("\n" + "="*60) if __name__ == "__main__": test_loader()