Spaces:
Sleeping
Sleeping
| """ | |
| Dataset Loader for CNN/DailyMail Dataset | |
| Handles loading, splitting, and preprocessing of evaluation data | |
| """ | |
| from datasets import load_dataset | |
| import pandas as pd | |
| import json | |
| import os | |
| from typing import Dict, List, Tuple | |
| import logging | |
| logger = logging.getLogger(__name__) | |
| class CNNDailyMailLoader: | |
| """Load and manage CNN/DailyMail dataset for summarization evaluation""" | |
| def __init__(self, cache_dir: str = "data/cache"): | |
| self.cache_dir = cache_dir | |
| self.dataset = None | |
| os.makedirs(cache_dir, exist_ok=True) | |
| def load_dataset(self, version: str = "3.0.0") -> Dict: | |
| """Load CNN/DailyMail dataset from HuggingFace""" | |
| logger.info(f"Loading CNN/DailyMail dataset version {version}") | |
| try: | |
| self.dataset = load_dataset("abisee/cnn_dailymail", version) | |
| logger.info("Dataset loaded successfully") | |
| return self.dataset | |
| except Exception as e: | |
| logger.error(f"Failed to load dataset: {e}") | |
| raise | |
| def get_splits(self) -> Tuple[List[Dict], List[Dict], List[Dict]]: | |
| """Get train, validation, and test splits""" | |
| if not self.dataset: | |
| self.load_dataset() | |
| train_data = list(self.dataset['train']) | |
| val_data = list(self.dataset['validation']) | |
| test_data = list(self.dataset['test']) | |
| logger.info(f"Train: {len(train_data)}, Val: {len(val_data)}, Test: {len(test_data)}") | |
| return train_data, val_data, test_data | |
| def create_evaluation_subset(self, split: str = "test", size: int = 100) -> List[Dict]: | |
| """Create a smaller subset for evaluation""" | |
| if not self.dataset: | |
| self.load_dataset() | |
| data = list(self.dataset[split]) | |
| subset = data[:size] | |
| # Clean and format data | |
| evaluation_data = [] | |
| for item in subset: | |
| evaluation_data.append({ | |
| 'id': item.get('id', ''), | |
| 'article': item['article'], | |
| 'highlights': item['highlights'], | |
| 'url': item.get('url', '') | |
| }) | |
| return evaluation_data | |
| def save_evaluation_data(self, data: List[Dict], filename: str): | |
| """Save evaluation data to JSON file""" | |
| filepath = os.path.join(self.cache_dir, filename) | |
| with open(filepath, 'w', encoding='utf-8') as f: | |
| json.dump(data, f, indent=2, ensure_ascii=False) | |
| logger.info(f"Saved {len(data)} items to {filepath}") | |
| def load_evaluation_data(self, filename: str) -> List[Dict]: | |
| """Load evaluation data from JSON file""" | |
| filepath = os.path.join(self.cache_dir, filename) | |
| if os.path.exists(filepath): | |
| with open(filepath, 'r', encoding='utf-8') as f: | |
| data = json.load(f) | |
| logger.info(f"Loaded {len(data)} items from {filepath}") | |
| return data | |
| else: | |
| logger.warning(f"File not found: {filepath}") | |
| return [] | |
| def get_topic_categories(self) -> Dict[str, List[str]]: | |
| """Define topic categories for evaluation""" | |
| return { | |
| 'politics': ['election', 'government', 'president', 'congress', 'senate', 'political'], | |
| 'business': ['company', 'market', 'stock', 'economy', 'financial', 'business'], | |
| 'technology': ['tech', 'computer', 'software', 'internet', 'digital', 'AI'], | |
| 'sports': ['game', 'team', 'player', 'sport', 'match', 'championship'], | |
| 'health': ['medical', 'health', 'doctor', 'hospital', 'disease', 'treatment'], | |
| 'entertainment': ['movie', 'actor', 'celebrity', 'film', 'music', 'entertainment'] | |
| } | |
| def categorize_by_topic(self, data: List[Dict]) -> Dict[str, List[Dict]]: | |
| """Categorize articles by topic""" | |
| categories = self.get_topic_categories() | |
| categorized = {topic: [] for topic in categories.keys()} | |
| categorized['other'] = [] | |
| for item in data: | |
| article_text = item['article'].lower() | |
| assigned = False | |
| for topic, keywords in categories.items(): | |
| if any(keyword in article_text for keyword in keywords): | |
| categorized[topic].append(item) | |
| assigned = True | |
| break | |
| if not assigned: | |
| categorized['other'].append(item) | |
| # Log distribution | |
| for topic, items in categorized.items(): | |
| logger.info(f"{topic}: {len(items)} articles") | |
| return categorized | |
| if __name__ == "__main__": | |
| # Example usage | |
| loader = CNNDailyMailLoader() | |
| # Load dataset | |
| dataset = loader.load_dataset() | |
| # Create evaluation subset | |
| eval_data = loader.create_evaluation_subset(size=200) | |
| # Categorize by topic | |
| categorized = loader.categorize_by_topic(eval_data) | |
| # Save data | |
| loader.save_evaluation_data(eval_data, "cnn_dailymail_eval_200.json") | |
| for topic, items in categorized.items(): | |
| if items: | |
| loader.save_evaluation_data(items, f"cnn_dailymail_{topic}.json") |