""" Dataset Loader for CNN/DailyMail Dataset Handles loading, splitting, and preprocessing of evaluation data """ from datasets import load_dataset import pandas as pd import json import os from typing import Dict, List, Tuple import logging logger = logging.getLogger(__name__) class CNNDailyMailLoader: """Load and manage CNN/DailyMail dataset for summarization evaluation""" def __init__(self, cache_dir: str = "data/cache"): self.cache_dir = cache_dir self.dataset = None os.makedirs(cache_dir, exist_ok=True) def load_dataset(self, version: str = "3.0.0") -> Dict: """Load CNN/DailyMail dataset from HuggingFace""" logger.info(f"Loading CNN/DailyMail dataset version {version}") try: self.dataset = load_dataset("abisee/cnn_dailymail", version) logger.info("Dataset loaded successfully") return self.dataset except Exception as e: logger.error(f"Failed to load dataset: {e}") raise def get_splits(self) -> Tuple[List[Dict], List[Dict], List[Dict]]: """Get train, validation, and test splits""" if not self.dataset: self.load_dataset() train_data = list(self.dataset['train']) val_data = list(self.dataset['validation']) test_data = list(self.dataset['test']) logger.info(f"Train: {len(train_data)}, Val: {len(val_data)}, Test: {len(test_data)}") return train_data, val_data, test_data def create_evaluation_subset(self, split: str = "test", size: int = 100) -> List[Dict]: """Create a smaller subset for evaluation""" if not self.dataset: self.load_dataset() data = list(self.dataset[split]) subset = data[:size] # Clean and format data evaluation_data = [] for item in subset: evaluation_data.append({ 'id': item.get('id', ''), 'article': item['article'], 'highlights': item['highlights'], 'url': item.get('url', '') }) return evaluation_data def save_evaluation_data(self, data: List[Dict], filename: str): """Save evaluation data to JSON file""" filepath = os.path.join(self.cache_dir, filename) with open(filepath, 'w', encoding='utf-8') as f: json.dump(data, f, indent=2, ensure_ascii=False) logger.info(f"Saved {len(data)} items to {filepath}") def load_evaluation_data(self, filename: str) -> List[Dict]: """Load evaluation data from JSON file""" filepath = os.path.join(self.cache_dir, filename) if os.path.exists(filepath): with open(filepath, 'r', encoding='utf-8') as f: data = json.load(f) logger.info(f"Loaded {len(data)} items from {filepath}") return data else: logger.warning(f"File not found: {filepath}") return [] def get_topic_categories(self) -> Dict[str, List[str]]: """Define topic categories for evaluation""" return { 'politics': ['election', 'government', 'president', 'congress', 'senate', 'political'], 'business': ['company', 'market', 'stock', 'economy', 'financial', 'business'], 'technology': ['tech', 'computer', 'software', 'internet', 'digital', 'AI'], 'sports': ['game', 'team', 'player', 'sport', 'match', 'championship'], 'health': ['medical', 'health', 'doctor', 'hospital', 'disease', 'treatment'], 'entertainment': ['movie', 'actor', 'celebrity', 'film', 'music', 'entertainment'] } def categorize_by_topic(self, data: List[Dict]) -> Dict[str, List[Dict]]: """Categorize articles by topic""" categories = self.get_topic_categories() categorized = {topic: [] for topic in categories.keys()} categorized['other'] = [] for item in data: article_text = item['article'].lower() assigned = False for topic, keywords in categories.items(): if any(keyword in article_text for keyword in keywords): categorized[topic].append(item) assigned = True break if not assigned: categorized['other'].append(item) # Log distribution for topic, items in categorized.items(): logger.info(f"{topic}: {len(items)} articles") return categorized if __name__ == "__main__": # Example usage loader = CNNDailyMailLoader() # Load dataset dataset = loader.load_dataset() # Create evaluation subset eval_data = loader.create_evaluation_subset(size=200) # Categorize by topic categorized = loader.categorize_by_topic(eval_data) # Save data loader.save_evaluation_data(eval_data, "cnn_dailymail_eval_200.json") for topic, items in categorized.items(): if items: loader.save_evaluation_data(items, f"cnn_dailymail_{topic}.json")