smart-summarizer / evaluation /dataset_loader.py
Rajak13's picture
Add comprehensive CNN/DailyMail evaluation system - dataset loading, model evaluation, topic analysis, and comparison
cf5d247 verified
raw
history blame
5.23 kB
"""
Dataset Loader for CNN/DailyMail Dataset
Handles loading, splitting, and preprocessing of evaluation data
"""
from datasets import load_dataset
import pandas as pd
import json
import os
from typing import Dict, List, Tuple
import logging
logger = logging.getLogger(__name__)
class CNNDailyMailLoader:
"""Load and manage CNN/DailyMail dataset for summarization evaluation"""
def __init__(self, cache_dir: str = "data/cache"):
self.cache_dir = cache_dir
self.dataset = None
os.makedirs(cache_dir, exist_ok=True)
def load_dataset(self, version: str = "3.0.0") -> Dict:
"""Load CNN/DailyMail dataset from HuggingFace"""
logger.info(f"Loading CNN/DailyMail dataset version {version}")
try:
self.dataset = load_dataset("abisee/cnn_dailymail", version)
logger.info("Dataset loaded successfully")
return self.dataset
except Exception as e:
logger.error(f"Failed to load dataset: {e}")
raise
def get_splits(self) -> Tuple[List[Dict], List[Dict], List[Dict]]:
"""Get train, validation, and test splits"""
if not self.dataset:
self.load_dataset()
train_data = list(self.dataset['train'])
val_data = list(self.dataset['validation'])
test_data = list(self.dataset['test'])
logger.info(f"Train: {len(train_data)}, Val: {len(val_data)}, Test: {len(test_data)}")
return train_data, val_data, test_data
def create_evaluation_subset(self, split: str = "test", size: int = 100) -> List[Dict]:
"""Create a smaller subset for evaluation"""
if not self.dataset:
self.load_dataset()
data = list(self.dataset[split])
subset = data[:size]
# Clean and format data
evaluation_data = []
for item in subset:
evaluation_data.append({
'id': item.get('id', ''),
'article': item['article'],
'highlights': item['highlights'],
'url': item.get('url', '')
})
return evaluation_data
def save_evaluation_data(self, data: List[Dict], filename: str):
"""Save evaluation data to JSON file"""
filepath = os.path.join(self.cache_dir, filename)
with open(filepath, 'w', encoding='utf-8') as f:
json.dump(data, f, indent=2, ensure_ascii=False)
logger.info(f"Saved {len(data)} items to {filepath}")
def load_evaluation_data(self, filename: str) -> List[Dict]:
"""Load evaluation data from JSON file"""
filepath = os.path.join(self.cache_dir, filename)
if os.path.exists(filepath):
with open(filepath, 'r', encoding='utf-8') as f:
data = json.load(f)
logger.info(f"Loaded {len(data)} items from {filepath}")
return data
else:
logger.warning(f"File not found: {filepath}")
return []
def get_topic_categories(self) -> Dict[str, List[str]]:
"""Define topic categories for evaluation"""
return {
'politics': ['election', 'government', 'president', 'congress', 'senate', 'political'],
'business': ['company', 'market', 'stock', 'economy', 'financial', 'business'],
'technology': ['tech', 'computer', 'software', 'internet', 'digital', 'AI'],
'sports': ['game', 'team', 'player', 'sport', 'match', 'championship'],
'health': ['medical', 'health', 'doctor', 'hospital', 'disease', 'treatment'],
'entertainment': ['movie', 'actor', 'celebrity', 'film', 'music', 'entertainment']
}
def categorize_by_topic(self, data: List[Dict]) -> Dict[str, List[Dict]]:
"""Categorize articles by topic"""
categories = self.get_topic_categories()
categorized = {topic: [] for topic in categories.keys()}
categorized['other'] = []
for item in data:
article_text = item['article'].lower()
assigned = False
for topic, keywords in categories.items():
if any(keyword in article_text for keyword in keywords):
categorized[topic].append(item)
assigned = True
break
if not assigned:
categorized['other'].append(item)
# Log distribution
for topic, items in categorized.items():
logger.info(f"{topic}: {len(items)} articles")
return categorized
if __name__ == "__main__":
# Example usage
loader = CNNDailyMailLoader()
# Load dataset
dataset = loader.load_dataset()
# Create evaluation subset
eval_data = loader.create_evaluation_subset(size=200)
# Categorize by topic
categorized = loader.categorize_by_topic(eval_data)
# Save data
loader.save_evaluation_data(eval_data, "cnn_dailymail_eval_200.json")
for topic, items in categorized.items():
if items:
loader.save_evaluation_data(items, f"cnn_dailymail_{topic}.json")