Spaces:
Sleeping
Sleeping
| """Dataset loader for RAG Bench datasets.""" | |
| import os | |
| from typing import List, Dict, Optional | |
| from datasets import load_dataset | |
| import pandas as pd | |
| from tqdm import tqdm | |
| class RAGBenchLoader: | |
| """Load and manage RAG Bench datasets.""" | |
| SUPPORTED_DATASETS = [ | |
| 'covidqa', | |
| 'cuad', | |
| 'delucionqa', | |
| 'emanual', | |
| 'expertqa', | |
| 'finqa', | |
| 'hagrid', | |
| 'hotpotqa', | |
| 'msmarco', | |
| 'pubmedqa', | |
| 'tatqa', | |
| 'techqa' | |
| ] | |
| def __init__(self, cache_dir: str = "./data_cache"): | |
| """Initialize the dataset loader. | |
| Args: | |
| cache_dir: Directory to cache downloaded datasets | |
| """ | |
| self.cache_dir = cache_dir | |
| os.makedirs(cache_dir, exist_ok=True) | |
| def load_dataset(self, dataset_name: str, split: str = "test", | |
| max_samples: Optional[int] = None) -> List[Dict]: | |
| """Load a RAG Bench dataset from rungalileo/ragbench. | |
| Args: | |
| dataset_name: Name of the dataset to load | |
| split: Dataset split (train/validation/test) | |
| max_samples: Maximum number of samples to load | |
| Returns: | |
| List of dictionaries containing dataset samples | |
| """ | |
| if dataset_name not in self.SUPPORTED_DATASETS: | |
| raise ValueError(f"Unsupported dataset: {dataset_name}. " | |
| f"Supported: {self.SUPPORTED_DATASETS}") | |
| print(f"Loading {dataset_name} dataset ({split} split) from rungalileo/ragbench...") | |
| try: | |
| # Load from rungalileo/ragbench | |
| dataset = load_dataset("rungalileo/ragbench", dataset_name, split=split, | |
| cache_dir=self.cache_dir) | |
| processed_data = [] | |
| samples = dataset if max_samples is None else dataset.select(range(min(max_samples, len(dataset)))) | |
| # Process the dataset | |
| for item in tqdm(samples, desc=f"Processing {dataset_name}"): | |
| processed_data.append(self._process_ragbench_item(item, dataset_name)) | |
| print(f"Loaded {len(processed_data)} samples from {dataset_name}") | |
| return processed_data | |
| except Exception as e: | |
| print(f"Error loading {dataset_name}: {str(e)}") | |
| print("Falling back to sample data for testing...") | |
| return self._create_sample_data(dataset_name, max_samples or 10) | |
| def _process_ragbench_item(self, item: Dict, dataset_name: str) -> Dict: | |
| """Process a single RAGBench dataset item into standardized format. | |
| Args: | |
| item: Raw dataset item | |
| dataset_name: Name of the dataset | |
| Returns: | |
| Processed item dictionary | |
| """ | |
| # RAGBench datasets typically have: question, documents, answer, and retrieved_contexts | |
| processed = { | |
| "question": item.get("question", ""), | |
| "answer": item.get("answer", ""), | |
| "context": "", # For embedding and retrieval | |
| "documents": [], # Store original documents list | |
| "dataset": dataset_name, | |
| "ground_truth_scores": {} # NEW: Extract ground truth evaluation scores | |
| } | |
| # Extract documents - RAGBench uses 'documents' as primary source for embeddings | |
| # Priority: documents > retrieved_contexts > context | |
| if "documents" in item: | |
| if isinstance(item["documents"], list): | |
| processed["documents"] = [str(doc) for doc in item["documents"]] | |
| processed["context"] = " ".join(processed["documents"]) | |
| else: | |
| processed["documents"] = [str(item["documents"])] | |
| processed["context"] = str(item["documents"]) | |
| elif "retrieved_contexts" in item: | |
| if isinstance(item["retrieved_contexts"], list): | |
| processed["documents"] = [str(ctx) for ctx in item["retrieved_contexts"]] | |
| processed["context"] = " ".join(processed["documents"]) | |
| else: | |
| processed["documents"] = [str(item["retrieved_contexts"])] | |
| processed["context"] = str(item["retrieved_contexts"]) | |
| elif "context" in item: | |
| if isinstance(item["context"], list): | |
| processed["documents"] = [str(ctx) for ctx in item["context"]] | |
| processed["context"] = " ".join(processed["documents"]) | |
| else: | |
| processed["documents"] = [str(item["context"])] | |
| processed["context"] = str(item["context"]) | |
| # Extract ground truth evaluation scores from RAGBench dataset | |
| # These are pre-computed metrics from the RAGBench paper | |
| ground_truth_scores = {} | |
| # Extract metric scores - try multiple possible field names | |
| # RAGBench paper uses these metric names (with various possible field formats) | |
| score_fields = [ | |
| # (possible_field_names, canonical_metric_name) | |
| (["relevance_score", "context_relevance", "relevance", "R"], "context_relevance"), | |
| (["utilization_score", "context_utilization", "utilization", "T"], "context_utilization"), | |
| (["completeness_score", "completeness", "C"], "completeness"), | |
| (["adherence_score", "adherence", "A", "overall_supported"], "adherence"), | |
| ] | |
| for field_names, metric_name in score_fields: | |
| for field_name in field_names: | |
| if field_name in item: | |
| try: | |
| # Handle string/numeric conversion | |
| score_value = item[field_name] | |
| if isinstance(score_value, bool): | |
| # Boolean adherence: True=1.0, False=0.0 | |
| score_value = 1.0 if score_value else 0.0 | |
| elif isinstance(score_value, str): | |
| # Try to convert string to float | |
| if score_value.lower() in ['true', 'yes']: | |
| score_value = 1.0 | |
| elif score_value.lower() in ['false', 'no']: | |
| score_value = 0.0 | |
| else: | |
| score_value = float(score_value) | |
| ground_truth_scores[metric_name] = float(score_value) | |
| break # Found this metric, move to next | |
| except (ValueError, TypeError): | |
| continue # Try next field name | |
| # Store ground truth scores if any were found | |
| if ground_truth_scores: | |
| processed["ground_truth_scores"] = ground_truth_scores | |
| # Store additional metadata if available | |
| if "metadata" in item: | |
| processed["metadata"] = item["metadata"] | |
| return processed | |
| def load_all_datasets(self, split: str = "test", max_samples: Optional[int] = None) -> Dict[str, List[Dict]]: | |
| """Load all RAGBench datasets. | |
| Args: | |
| split: Dataset split to load | |
| max_samples: Maximum samples per dataset | |
| Returns: | |
| Dictionary mapping dataset names to their data | |
| """ | |
| all_data = {} | |
| for dataset_name in self.SUPPORTED_DATASETS: | |
| print(f"\n{'='*50}") | |
| print(f"Loading {dataset_name}...") | |
| print(f"{'='*50}") | |
| try: | |
| all_data[dataset_name] = self.load_dataset(dataset_name, split, max_samples) | |
| except Exception as e: | |
| print(f"Failed to load {dataset_name}: {str(e)}") | |
| all_data[dataset_name] = [] | |
| return all_data | |
| def _create_sample_data(self, dataset_name: str, num_samples: int) -> List[Dict]: | |
| """Create sample data for testing when actual dataset is unavailable.""" | |
| sample_data = [] | |
| for i in range(num_samples): | |
| # Create multiple sample documents per question | |
| sample_docs = [ | |
| f"Document 1: This is the first sample document {i+1} for {dataset_name} dataset. " | |
| f"It contains relevant information to answer the question.", | |
| f"Document 2: This is the second sample document {i+1} providing additional context. " | |
| f"It includes more details about the topic.", | |
| f"Document 3: This is the third sample document {i+1} with supplementary information." | |
| ] | |
| sample_data.append({ | |
| "question": f"Sample question {i+1} for {dataset_name}?", | |
| "answer": f"Sample answer {i+1}", | |
| "documents": sample_docs, | |
| "context": " ".join(sample_docs), # Combined for backward compatibility | |
| "dataset": dataset_name | |
| }) | |
| return sample_data | |
| def get_test_data(self, dataset_name: str, num_samples: int = 100) -> List[Dict]: | |
| """Get test data for TRACE evaluation. | |
| Args: | |
| dataset_name: Name of the dataset | |
| num_samples: Number of test samples | |
| Returns: | |
| List of test samples | |
| """ | |
| return self.load_dataset(dataset_name, split="test", max_samples=num_samples) | |
| def get_test_data_size(self, dataset_name: str) -> int: | |
| """Get the total number of test samples available in a dataset. | |
| Args: | |
| dataset_name: Name of the dataset | |
| Returns: | |
| Total number of test samples available | |
| """ | |
| try: | |
| from datasets import load_dataset_builder | |
| # Load dataset builder to get dataset info | |
| builder = load_dataset_builder("rungalileo/ragbench", dataset_name) | |
| # Try to get test split size | |
| if hasattr(builder.info, 'splits') and builder.info.splits: | |
| if 'test' in builder.info.splits: | |
| return builder.info.splits['test'].num_examples | |
| elif 'validation' in builder.info.splits: | |
| return builder.info.splits['validation'].num_examples | |
| else: | |
| # Get first available split | |
| first_split = list(builder.info.splits.keys())[0] | |
| return builder.info.splits[first_split].num_examples | |
| # Fallback: load full test dataset to count | |
| ds = load_dataset("rungalileo/ragbench", dataset_name, split="test", trust_remote_code=True) | |
| return len(ds) | |
| except Exception as e: | |
| print(f"Error getting test data size: {e}") | |
| # Return a reasonable default | |
| return 100 |