File size: 9,524 Bytes
5374a2d |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 |
import os
import json
from typing import List, Dict, Any, Optional
from pathlib import Path
from datasets import load_dataset
from .benchmark import Benchmark
from .measures import exact_match_score, f1_score, acc_score
from ..core.logging import logger
def download_real_mm_rag_data(save_dir: str = "./data/real_mm_rag") -> str:
"""Download the REAL-MM-RAG FinReport dataset.
Args:
save_dir: Directory to save the dataset files
Returns:
str: Path to the saved dataset directory
"""
try:
os.makedirs(save_dir, exist_ok=True)
# Check if dataset already exists
dataset_path = os.path.join(save_dir, "real_mm_rag_finreport.json")
images_dir = os.path.join(save_dir, "images")
if os.path.exists(dataset_path) and os.path.exists(images_dir):
# Quick check if images directory has content
image_files = [f for f in os.listdir(images_dir) if f.endswith(('.png', '.jpg', '.jpeg'))]
if len(image_files) > 0:
logger.info(f"Dataset already exists at {save_dir} with {len(image_files)} images")
return save_dir
logger.info("Downloading REAL-MM-RAG FinReport dataset...")
dataset = load_dataset("ibm-research/REAL-MM-RAG_FinReport", split="test")
# Create images directory
images_dir = os.path.join(save_dir, "images")
os.makedirs(images_dir, exist_ok=True)
# Process dataset: save images and create metadata
metadata_list = []
for i, example in enumerate(dataset):
# Create metadata entry (without the image object)
metadata = {
'id': example['id'],
'query': example['query'],
'answer': example['answer'],
'image_filename': example['image_filename']
}
# Add rephrase levels if they exist
for level in ['rephrase_level_1', 'rephrase_level_2', 'rephrase_level_3']:
if level in example and example[level]:
metadata[level] = example[level]
metadata_list.append(metadata)
# Save PIL Image if it exists
if example['image'] is not None:
image_filename = example['image_filename']
image_path = os.path.join(images_dir, image_filename)
# Save PIL Image
example['image'].save(image_path)
if i % 100 == 0:
logger.info(f"Saved {i+1}/{len(dataset)} images...")
# Save metadata as JSON (without image objects)
dataset_path = os.path.join(save_dir, "real_mm_rag_finreport.json")
with open(dataset_path, 'w') as f:
json.dump(metadata_list, f, indent=2)
logger.info(f"Dataset downloaded to {save_dir}")
logger.info(f"Total samples: {len(dataset)}")
logger.info(f"Images saved to: {images_dir}")
return save_dir
except Exception as e:
logger.error(f"Failed to download REAL-MM-RAG dataset: {str(e)}")
raise
class RealMMRAG(Benchmark):
"""REAL-MM-RAG FinReport benchmark for multimodal retrieval evaluation.
This benchmark contains financial report pages with associated queries,
designed to test multimodal retrieval capabilities on real-world documents.
"""
def __init__(self, path: str = None, mode: str = "test", **kwargs):
path = os.path.expanduser(path or "~/.evoagentx/data/real_mm_rag")
# Set up file paths before calling super().__init__ which calls _load_data
self.dataset_file = Path(path) / "real_mm_rag_finreport.json"
self.images_dir = Path(path) / "images"
super().__init__(name=type(self).__name__, path=path, mode=mode, **kwargs)
def _load_data(self):
"""Load the dataset from JSON file."""
if not self.dataset_file.exists():
download_real_mm_rag_data(save_dir=self.path)
try:
with open(self.dataset_file, 'r') as f:
self._test_data = json.load(f)
logger.info(f"Loaded {len(self._test_data)} samples from REAL-MM-RAG dataset")
except Exception as e:
logger.error(f"Failed to load dataset: {str(e)}")
raise
def _get_label(self, example: Any) -> Any:
return example["answer"]
def _get_id(self, example: Any) -> Any:
return example["id"]
def evaluate(self, prediction: Any, label: Any) -> dict:
# For multimodal, we can use simple string matching
em = exact_match_score(prediction=prediction, ground_truth=label)
f1 = f1_score(prediction=prediction, ground_truth=label)
acc = acc_score(prediction=prediction, ground_truths=[label])
return {"f1": f1, "em": em, "acc": acc}
@property
def data(self) -> List[Dict[str, Any]]:
"""Get the raw dataset."""
return self._test_data
def get_sample(self, index: int) -> Dict[str, Any]:
"""Get a single sample by index.
Args:
index: Sample index
Returns:
Dict containing query, image_filename, answer, and rephrases
"""
if index >= len(self._test_data):
raise IndexError(f"Index {index} out of range for dataset size {len(self._test_data)}")
sample = self._test_data[index]
# Add full image path
sample['image_path'] = str(self.images_dir / sample['image_filename'])
return sample
def get_samples(self, start: int = 0, end: Optional[int] = None) -> List[Dict[str, Any]]:
"""Get a range of samples.
Args:
start: Start index (inclusive)
end: End index (exclusive). If None, goes to end of dataset
Returns:
List of samples
"""
end = end or len(self._test_data)
samples = []
for i in range(start, min(end, len(self._test_data))):
samples.append(self.get_sample(i))
return samples
def get_random_samples(self, n: int, seed: int = 42) -> List[Dict[str, Any]]:
"""Get n random samples from the dataset.
Args:
n: Number of samples to return
seed: Random seed for reproducibility
Returns:
List of random samples
"""
import random
random.seed(seed)
indices = random.sample(range(len(self._test_data)), min(n, len(self._test_data)))
return [self.get_sample(i) for i in indices]
def get_query_variations(self, sample: Dict[str, Any]) -> List[str]:
"""Get all query variations for a sample.
Args:
sample: A sample from the dataset
Returns:
List of query variations (original + 3 rephrase levels)
"""
queries = [sample['query']]
# Add rephrase levels if they exist
for level in ['rephrase_level_1', 'rephrase_level_2', 'rephrase_level_3']:
if level in sample and sample[level]:
queries.append(sample[level])
return queries
def get_stats(self) -> Dict[str, Any]:
"""Get dataset statistics.
Returns:
Dictionary with dataset statistics
"""
total_samples = len(self._test_data)
# Count samples with different rephrase levels
has_rephrase_1 = sum(1 for s in self._test_data if s.get('rephrase_level_1'))
has_rephrase_2 = sum(1 for s in self._test_data if s.get('rephrase_level_2'))
has_rephrase_3 = sum(1 for s in self._test_data if s.get('rephrase_level_3'))
# Get unique image files
unique_images = set(s['image_filename'] for s in self._test_data)
return {
"total_samples": total_samples,
"unique_images": len(unique_images),
"samples_with_rephrase_1": has_rephrase_1,
"samples_with_rephrase_2": has_rephrase_2,
"samples_with_rephrase_3": has_rephrase_3,
"avg_queries_per_image": total_samples / len(unique_images)
}
if __name__ == "__main__":
# Download and test the dataset
data_dir = "./debug/data/real_mm_rag"
# Download dataset
download_real_mm_rag_data(data_dir)
# Initialize benchmark
benchmark = RealMMRAG(data_dir)
# Print stats
stats = benchmark.get_stats()
print("REAL-MM-RAG Dataset Statistics:")
for key, value in stats.items():
print(f" {key}: {value}")
# Show sample data
print("\nSample queries:")
samples = benchmark.get_random_samples(3)
for i, sample in enumerate(samples, 1):
print(f"\nSample {i}:")
print(f" Image: {sample['image_filename']}")
print(f" Query: {sample['query']}")
print(f" Answer: {sample['answer']}")
variations = benchmark.get_query_variations(sample)
if len(variations) > 1:
print(f" Query variations: {len(variations)}")
for j, var in enumerate(variations[1:], 1):
print(f" Level {j}: {var[:100]}...")
|