Spaces:
Paused
Paused
| from typing import List, Dict, Any | |
| from datasets import load_dataset, Dataset | |
| import logging | |
| class DatasetManagementService: | |
| def __init__(self, dataset_name: str): | |
| self.dataset_name = dataset_name | |
| def update_dataset(self, new_metadata: List[Dict[str, Any]]) -> str: | |
| try: | |
| # Try to load the existing dataset | |
| try: | |
| dataset = load_dataset(self.dataset_name, split="train") | |
| current_data = dataset.to_dict() | |
| except Exception: | |
| # If loading fails, start with an empty dictionary | |
| current_data = {} | |
| # If the dataset is empty, initialize it with the structure from new_metadata | |
| if not current_data: | |
| current_data = {key: [] for key in new_metadata[0].keys()} | |
| updated = False | |
| for paper in new_metadata: | |
| entry_id = paper['entry_id'] | |
| if 'entry_id' not in current_data: | |
| current_data['entry_id'] = [] | |
| if entry_id not in current_data['entry_id']: | |
| # Add new paper | |
| for key, value in paper.items(): | |
| current_data.setdefault(key, []).append(value) | |
| updated = True | |
| else: | |
| # Update existing paper | |
| index = current_data['entry_id'].index(entry_id) | |
| for key, value in paper.items(): | |
| if current_data[key][index] != value: | |
| current_data[key][index] = value | |
| updated = True | |
| if updated: | |
| updated_dataset = Dataset.from_dict(current_data) | |
| updated_dataset.push_to_hub(self.dataset_name, split="train") | |
| return f"Successfully updated dataset with {len(new_metadata)} papers" | |
| else: | |
| return "No new data to update." | |
| except Exception as e: | |
| return f"Failed to update dataset: {str(e)}" | |
| def get_dataset_size(self) -> int: | |
| try: | |
| dataset = load_dataset(self.dataset_name, split="train") | |
| size = len(dataset) | |
| logging.info(f"Dataset size: {size}") | |
| return size | |
| except Exception as e: | |
| logging.error(f"Error getting dataset size: {str(e)}") | |
| return 0 | |
| def get_dataset_records(self, page: int, page_size: int) -> List[Dict[str, Any]]: | |
| try: | |
| dataset = load_dataset(self.dataset_name, split="train") | |
| start_idx = (page - 1) * page_size | |
| end_idx = start_idx + page_size | |
| records = dataset[start_idx:end_idx] | |
| # Convert to list of dictionaries | |
| records_list = [dict(zip(records.keys(), values)) for values in zip(*records.values())] | |
| logging.info(f"Records type: {type(records_list)}") | |
| logging.info(f"Number of records: {len(records_list)}") | |
| return records_list | |
| except Exception as e: | |
| logging.error(f"Error loading dataset records: {str(e)}") | |
| return [{"error": f"Error loading dataset: {str(e)}"}] | |
| # Usage: | |
| # dataset_service = DatasetManagementService("your_dataset_name") | |
| # result = dataset_service.update_dataset(new_metadata) | |
| # records = dataset_service.get_dataset_records() |