|
|
import os |
|
|
import json |
|
|
from typing import List, Dict, Any, Optional |
|
|
from pathlib import Path |
|
|
|
|
|
from datasets import load_dataset |
|
|
from .benchmark import Benchmark |
|
|
from .measures import exact_match_score, f1_score, acc_score |
|
|
from ..core.logging import logger |
|
|
|
|
|
|
|
|
def download_real_mm_rag_data(save_dir: str = "./data/real_mm_rag") -> str: |
|
|
"""Download the REAL-MM-RAG FinReport dataset. |
|
|
|
|
|
Args: |
|
|
save_dir: Directory to save the dataset files |
|
|
|
|
|
Returns: |
|
|
str: Path to the saved dataset directory |
|
|
""" |
|
|
try: |
|
|
os.makedirs(save_dir, exist_ok=True) |
|
|
|
|
|
|
|
|
dataset_path = os.path.join(save_dir, "real_mm_rag_finreport.json") |
|
|
images_dir = os.path.join(save_dir, "images") |
|
|
|
|
|
if os.path.exists(dataset_path) and os.path.exists(images_dir): |
|
|
|
|
|
image_files = [f for f in os.listdir(images_dir) if f.endswith(('.png', '.jpg', '.jpeg'))] |
|
|
if len(image_files) > 0: |
|
|
logger.info(f"Dataset already exists at {save_dir} with {len(image_files)} images") |
|
|
return save_dir |
|
|
|
|
|
logger.info("Downloading REAL-MM-RAG FinReport dataset...") |
|
|
dataset = load_dataset("ibm-research/REAL-MM-RAG_FinReport", split="test") |
|
|
|
|
|
|
|
|
images_dir = os.path.join(save_dir, "images") |
|
|
os.makedirs(images_dir, exist_ok=True) |
|
|
|
|
|
|
|
|
metadata_list = [] |
|
|
for i, example in enumerate(dataset): |
|
|
|
|
|
metadata = { |
|
|
'id': example['id'], |
|
|
'query': example['query'], |
|
|
'answer': example['answer'], |
|
|
'image_filename': example['image_filename'] |
|
|
} |
|
|
|
|
|
|
|
|
for level in ['rephrase_level_1', 'rephrase_level_2', 'rephrase_level_3']: |
|
|
if level in example and example[level]: |
|
|
metadata[level] = example[level] |
|
|
|
|
|
metadata_list.append(metadata) |
|
|
|
|
|
|
|
|
if example['image'] is not None: |
|
|
image_filename = example['image_filename'] |
|
|
image_path = os.path.join(images_dir, image_filename) |
|
|
|
|
|
|
|
|
example['image'].save(image_path) |
|
|
|
|
|
if i % 100 == 0: |
|
|
logger.info(f"Saved {i+1}/{len(dataset)} images...") |
|
|
|
|
|
|
|
|
dataset_path = os.path.join(save_dir, "real_mm_rag_finreport.json") |
|
|
with open(dataset_path, 'w') as f: |
|
|
json.dump(metadata_list, f, indent=2) |
|
|
|
|
|
logger.info(f"Dataset downloaded to {save_dir}") |
|
|
logger.info(f"Total samples: {len(dataset)}") |
|
|
logger.info(f"Images saved to: {images_dir}") |
|
|
|
|
|
return save_dir |
|
|
|
|
|
except Exception as e: |
|
|
logger.error(f"Failed to download REAL-MM-RAG dataset: {str(e)}") |
|
|
raise |
|
|
|
|
|
|
|
|
class RealMMRAG(Benchmark): |
|
|
"""REAL-MM-RAG FinReport benchmark for multimodal retrieval evaluation. |
|
|
|
|
|
This benchmark contains financial report pages with associated queries, |
|
|
designed to test multimodal retrieval capabilities on real-world documents. |
|
|
""" |
|
|
|
|
|
def __init__(self, path: str = None, mode: str = "test", **kwargs): |
|
|
path = os.path.expanduser(path or "~/.evoagentx/data/real_mm_rag") |
|
|
|
|
|
|
|
|
self.dataset_file = Path(path) / "real_mm_rag_finreport.json" |
|
|
self.images_dir = Path(path) / "images" |
|
|
|
|
|
super().__init__(name=type(self).__name__, path=path, mode=mode, **kwargs) |
|
|
|
|
|
def _load_data(self): |
|
|
"""Load the dataset from JSON file.""" |
|
|
if not self.dataset_file.exists(): |
|
|
download_real_mm_rag_data(save_dir=self.path) |
|
|
|
|
|
try: |
|
|
with open(self.dataset_file, 'r') as f: |
|
|
self._test_data = json.load(f) |
|
|
|
|
|
logger.info(f"Loaded {len(self._test_data)} samples from REAL-MM-RAG dataset") |
|
|
|
|
|
except Exception as e: |
|
|
logger.error(f"Failed to load dataset: {str(e)}") |
|
|
raise |
|
|
|
|
|
def _get_label(self, example: Any) -> Any: |
|
|
return example["answer"] |
|
|
|
|
|
def _get_id(self, example: Any) -> Any: |
|
|
return example["id"] |
|
|
|
|
|
def evaluate(self, prediction: Any, label: Any) -> dict: |
|
|
|
|
|
em = exact_match_score(prediction=prediction, ground_truth=label) |
|
|
f1 = f1_score(prediction=prediction, ground_truth=label) |
|
|
acc = acc_score(prediction=prediction, ground_truths=[label]) |
|
|
return {"f1": f1, "em": em, "acc": acc} |
|
|
|
|
|
@property |
|
|
def data(self) -> List[Dict[str, Any]]: |
|
|
"""Get the raw dataset.""" |
|
|
return self._test_data |
|
|
|
|
|
def get_sample(self, index: int) -> Dict[str, Any]: |
|
|
"""Get a single sample by index. |
|
|
|
|
|
Args: |
|
|
index: Sample index |
|
|
|
|
|
Returns: |
|
|
Dict containing query, image_filename, answer, and rephrases |
|
|
""" |
|
|
if index >= len(self._test_data): |
|
|
raise IndexError(f"Index {index} out of range for dataset size {len(self._test_data)}") |
|
|
|
|
|
sample = self._test_data[index] |
|
|
|
|
|
|
|
|
sample['image_path'] = str(self.images_dir / sample['image_filename']) |
|
|
|
|
|
return sample |
|
|
|
|
|
def get_samples(self, start: int = 0, end: Optional[int] = None) -> List[Dict[str, Any]]: |
|
|
"""Get a range of samples. |
|
|
|
|
|
Args: |
|
|
start: Start index (inclusive) |
|
|
end: End index (exclusive). If None, goes to end of dataset |
|
|
|
|
|
Returns: |
|
|
List of samples |
|
|
""" |
|
|
end = end or len(self._test_data) |
|
|
samples = [] |
|
|
|
|
|
for i in range(start, min(end, len(self._test_data))): |
|
|
samples.append(self.get_sample(i)) |
|
|
|
|
|
return samples |
|
|
|
|
|
def get_random_samples(self, n: int, seed: int = 42) -> List[Dict[str, Any]]: |
|
|
"""Get n random samples from the dataset. |
|
|
|
|
|
Args: |
|
|
n: Number of samples to return |
|
|
seed: Random seed for reproducibility |
|
|
|
|
|
Returns: |
|
|
List of random samples |
|
|
""" |
|
|
import random |
|
|
random.seed(seed) |
|
|
|
|
|
indices = random.sample(range(len(self._test_data)), min(n, len(self._test_data))) |
|
|
return [self.get_sample(i) for i in indices] |
|
|
|
|
|
def get_query_variations(self, sample: Dict[str, Any]) -> List[str]: |
|
|
"""Get all query variations for a sample. |
|
|
|
|
|
Args: |
|
|
sample: A sample from the dataset |
|
|
|
|
|
Returns: |
|
|
List of query variations (original + 3 rephrase levels) |
|
|
""" |
|
|
queries = [sample['query']] |
|
|
|
|
|
|
|
|
for level in ['rephrase_level_1', 'rephrase_level_2', 'rephrase_level_3']: |
|
|
if level in sample and sample[level]: |
|
|
queries.append(sample[level]) |
|
|
|
|
|
return queries |
|
|
|
|
|
def get_stats(self) -> Dict[str, Any]: |
|
|
"""Get dataset statistics. |
|
|
|
|
|
Returns: |
|
|
Dictionary with dataset statistics |
|
|
""" |
|
|
total_samples = len(self._test_data) |
|
|
|
|
|
|
|
|
has_rephrase_1 = sum(1 for s in self._test_data if s.get('rephrase_level_1')) |
|
|
has_rephrase_2 = sum(1 for s in self._test_data if s.get('rephrase_level_2')) |
|
|
has_rephrase_3 = sum(1 for s in self._test_data if s.get('rephrase_level_3')) |
|
|
|
|
|
|
|
|
unique_images = set(s['image_filename'] for s in self._test_data) |
|
|
|
|
|
return { |
|
|
"total_samples": total_samples, |
|
|
"unique_images": len(unique_images), |
|
|
"samples_with_rephrase_1": has_rephrase_1, |
|
|
"samples_with_rephrase_2": has_rephrase_2, |
|
|
"samples_with_rephrase_3": has_rephrase_3, |
|
|
"avg_queries_per_image": total_samples / len(unique_images) |
|
|
} |
|
|
|
|
|
|
|
|
if __name__ == "__main__": |
|
|
|
|
|
data_dir = "./debug/data/real_mm_rag" |
|
|
|
|
|
|
|
|
download_real_mm_rag_data(data_dir) |
|
|
|
|
|
|
|
|
benchmark = RealMMRAG(data_dir) |
|
|
|
|
|
|
|
|
stats = benchmark.get_stats() |
|
|
print("REAL-MM-RAG Dataset Statistics:") |
|
|
for key, value in stats.items(): |
|
|
print(f" {key}: {value}") |
|
|
|
|
|
|
|
|
print("\nSample queries:") |
|
|
samples = benchmark.get_random_samples(3) |
|
|
for i, sample in enumerate(samples, 1): |
|
|
print(f"\nSample {i}:") |
|
|
print(f" Image: {sample['image_filename']}") |
|
|
print(f" Query: {sample['query']}") |
|
|
print(f" Answer: {sample['answer']}") |
|
|
|
|
|
variations = benchmark.get_query_variations(sample) |
|
|
if len(variations) > 1: |
|
|
print(f" Query variations: {len(variations)}") |
|
|
for j, var in enumerate(variations[1:], 1): |
|
|
print(f" Level {j}: {var[:100]}...") |
|
|
|