Spaces:
Sleeping
Sleeping
| """Dataset loading and preparation utilities""" | |
| from __future__ import annotations | |
| import os | |
| from pathlib import Path | |
| from typing import Any, Dict, List, Optional, Union | |
| from dataclasses import dataclass | |
| import json | |
| try: | |
| from datasets import load_dataset, Dataset, DatasetDict | |
| HF_AVAILABLE = True | |
| except ImportError: | |
| HF_AVAILABLE = False | |
| Dataset = None | |
| DatasetDict = None | |
| from .config import load_config, DatasetConfig | |
| class DataSample: | |
| """Single data sample""" | |
| id: str | |
| question: str | |
| answer: str | |
| metadata: Dict[str, Any] | |
| def to_dict(self) -> Dict[str, Any]: | |
| """Convert to dictionary""" | |
| return { | |
| "id": self.id, | |
| "question": self.question, | |
| "answer": self.answer, | |
| "metadata": self.metadata, | |
| } | |
| def from_dict(cls, data: Dict[str, Any]) -> DataSample: | |
| """Create from dictionary""" | |
| return cls( | |
| id=data.get("id", ""), | |
| question=data["question"], | |
| answer=data["answer"], | |
| metadata=data.get("metadata", {}), | |
| ) | |
| class DatasetLoader: | |
| """Dataset loader with support for multiple formats""" | |
| def __init__(self, dataset_name: str, config_path: Optional[str] = None): | |
| """Initialize dataset loader | |
| Args: | |
| dataset_name: Name of dataset (e.g., 'gsm8k') | |
| config_path: Optional path to configuration file | |
| """ | |
| self.dataset_name = dataset_name | |
| self.config = load_config(config_path) | |
| self.dataset_config = self.config.get_dataset_config(dataset_name) | |
| self._dataset = None | |
| def load(self, split: str = "train", force_reload: bool = False) -> List[DataSample]: | |
| """Load dataset | |
| Args: | |
| split: Dataset split to load ('train', 'test', etc.) | |
| force_reload: Force reload from source | |
| Returns: | |
| List of DataSample objects | |
| """ | |
| if not force_reload and self._dataset is not None: | |
| return self._dataset | |
| # Check if dataset is already downloaded locally | |
| dataset_path = Path(self.dataset_config.path) | |
| local_file = dataset_path / f"{split}.jsonl" | |
| if local_file.exists() and not force_reload: | |
| print(f"Loading {self.dataset_name} from local file: {local_file}") | |
| samples = self._load_from_jsonl(local_file) | |
| elif HF_AVAILABLE: | |
| print(f"Loading {self.dataset_name} from HuggingFace: {self.dataset_config.huggingface_id}") | |
| samples = self._load_from_huggingface(split) | |
| # Cache locally | |
| self._save_to_jsonl(samples, local_file) | |
| else: | |
| raise RuntimeError( | |
| "Dataset not found locally and HuggingFace datasets library is not available. " | |
| "Install with: pip install datasets" | |
| ) | |
| self._dataset = samples | |
| return samples | |
| def _load_from_huggingface(self, split: str) -> List[DataSample]: | |
| """Load dataset from HuggingFace""" | |
| if not HF_AVAILABLE: | |
| raise ImportError("datasets library not available") | |
| # Get the actual split name from config | |
| actual_split = self.dataset_config.split.get(split, split) | |
| # Parse dataset ID and config name | |
| # Format can be: "username/dataset" or "username/dataset:config_name" | |
| dataset_id = self.dataset_config.huggingface_id | |
| config_name = None | |
| if ":" in dataset_id: | |
| dataset_id, config_name = dataset_id.split(":", 1) | |
| # For GSM8K specifically, default to 'main' config if not specified | |
| if "gsm8k" in dataset_id.lower() and config_name is None: | |
| config_name = "main" | |
| # Load dataset with optional cache_dir | |
| load_args = [dataset_id] | |
| if config_name is not None: | |
| load_args.append(config_name) | |
| load_kwargs = { | |
| "split": actual_split, | |
| } | |
| if self.dataset_config.cache_dir is not None: | |
| cache_dir = Path(self.dataset_config.cache_dir) | |
| cache_dir.mkdir(parents=True, exist_ok=True) | |
| load_kwargs["cache_dir"] = str(cache_dir) | |
| dataset = load_dataset( | |
| *load_args, | |
| **load_kwargs, | |
| ) | |
| # Convert to our format | |
| samples = [] | |
| for idx, item in enumerate(dataset): | |
| samples.append(self._convert_to_sample(item, idx)) | |
| return samples | |
| def _convert_to_sample(self, item: Dict[str, Any], idx: int) -> DataSample: | |
| """Convert dataset item to DataSample | |
| This method should be customized for each dataset format | |
| """ | |
| # GSM8K format | |
| if self.dataset_name == "gsm8k": | |
| return DataSample( | |
| id=f"gsm8k_{idx}", | |
| question=item.get("question", ""), | |
| answer=item.get("answer", ""), | |
| metadata={"original": item}, | |
| ) | |
| # MATH-500 format | |
| if self.dataset_name == "math-500": | |
| return DataSample( | |
| id=item.get("unique_id", f"math500_{idx}").replace("/", "_").replace(".json", ""), | |
| question=item.get("problem", ""), | |
| answer=item.get("answer", ""), | |
| metadata={ | |
| "solution": item.get("solution", ""), | |
| "subject": item.get("subject", ""), | |
| "level": item.get("level", ""), | |
| "original": item, | |
| }, | |
| ) | |
| # Generic format | |
| return DataSample( | |
| id=item.get("id", f"{self.dataset_name}_{idx}"), | |
| question=item.get("question", item.get("input", "")), | |
| answer=item.get("answer", item.get("output", "")), | |
| metadata={"original": item}, | |
| ) | |
| def _load_from_jsonl(self, file_path: Path) -> List[DataSample]: | |
| """Load dataset from JSONL file""" | |
| samples = [] | |
| with open(file_path, "r", encoding="utf-8") as f: | |
| for idx, line in enumerate(f): | |
| if line.strip(): | |
| data = json.loads(line) | |
| # Check if it's already in DataSample format (has 'question' and 'answer' keys) | |
| if "question" in data and "answer" in data: | |
| samples.append(DataSample.from_dict(data)) | |
| else: | |
| # Convert from raw format (e.g., MATH-500 with 'problem' key) | |
| samples.append(self._convert_to_sample(data, idx)) | |
| return samples | |
| def _save_to_jsonl(self, samples: List[DataSample], file_path: Path): | |
| """Save dataset to JSONL file""" | |
| file_path.parent.mkdir(parents=True, exist_ok=True) | |
| with open(file_path, "w", encoding="utf-8") as f: | |
| for sample in samples: | |
| f.write(json.dumps(sample.to_dict()) + "\n") | |
| print(f"Cached dataset to {file_path}") | |
| def get_sample(self, idx: int, split: str = "train") -> DataSample: | |
| """Get a single sample by index | |
| Args: | |
| idx: Index of sample | |
| split: Dataset split | |
| Returns: | |
| DataSample object | |
| """ | |
| samples = self.load(split) | |
| return samples[idx] | |
| def __len__(self) -> int: | |
| """Get dataset size""" | |
| if self._dataset is None: | |
| return 0 | |
| return len(self._dataset) | |
| def __getitem__(self, idx: int) -> DataSample: | |
| """Get sample by index""" | |
| if self._dataset is None: | |
| raise RuntimeError("Dataset not loaded. Call load() first.") | |
| return self._dataset[idx] | |
| def prepare_dataset( | |
| dataset_name: str, | |
| split: str = "train", | |
| config_path: Optional[str] = None, | |
| max_samples: Optional[int] = None, | |
| ) -> List[DataSample]: | |
| """Prepare dataset for use | |
| Args: | |
| dataset_name: Name of dataset | |
| split: Dataset split to load | |
| config_path: Optional configuration file path | |
| max_samples: Maximum number of samples to load (for debugging) | |
| Returns: | |
| List of DataSample objects | |
| """ | |
| loader = DatasetLoader(dataset_name, config_path) | |
| samples = loader.load(split) | |
| if max_samples is not None: | |
| samples = samples[:max_samples] | |
| print(f"Loaded {len(samples)} samples from {dataset_name} ({split} split)") | |
| return samples | |
| def create_dummy_gsm8k_dataset(output_dir: Optional[Path] = None): | |
| """Create a dummy GSM8K dataset for testing | |
| Args: | |
| output_dir: Output directory. If None, uses config path. | |
| """ | |
| if output_dir is None: | |
| config = load_config() | |
| dataset_config = config.get_dataset_config("gsm8k") | |
| output_dir = Path(dataset_config.path) | |
| output_dir.mkdir(parents=True, exist_ok=True) | |
| # Create dummy samples | |
| dummy_samples = [ | |
| DataSample( | |
| id="gsm8k_0", | |
| question="Janet's ducks lay 16 eggs per day. She eats three for breakfast every morning and bakes muffins for her friends every day with four. She sells the remainder at the farmers' market daily for $2 per fresh duck egg. How much in dollars does she make every day at the farmers' market?", | |
| answer="Janet sells 16 - 3 - 4 = <<16-3-4=9>>9 duck eggs every day.\nShe makes 9 * 2 = $<<9*2=18>>18 every day at the farmers' market.\n#### 18", | |
| metadata={"source": "dummy", "difficulty": "easy"}, | |
| ), | |
| DataSample( | |
| id="gsm8k_1", | |
| question="A robe takes 2 bolts of blue fiber and half that much white fiber. How many bolts in total does it take?", | |
| answer="It takes 2/2=<<2/2=1>>1 bolt of white fiber\nSo the total amount of fabric is 2+1=<<2+1=3>>3 bolts of fiber\n#### 3", | |
| metadata={"source": "dummy", "difficulty": "easy"}, | |
| ), | |
| DataSample( | |
| id="gsm8k_2", | |
| question="Josh decides to try flipping a house. He buys a house for $80,000 and then puts in $50,000 in repairs. This increased the value of the house by 150%. How much profit did he make?", | |
| answer="The cost of the house and repairs came out to 80,000+50,000=$<<80000+50000=130000>>130,000\nHe increased the value of the house by 80,000*1.5=<<80000*1.5=120000>>120,000\nSo the new value of the house is 120,000+80,000=$<<120000+80000=200000>>200,000\nSo he made a profit of 200,000-130,000=$<<200000-130000=70000>>70,000\n#### 70000", | |
| metadata={"source": "dummy", "difficulty": "medium"}, | |
| ), | |
| ] | |
| # Save train and test splits | |
| for split in ["train", "test"]: | |
| output_file = output_dir / f"{split}.jsonl" | |
| with open(output_file, "w", encoding="utf-8") as f: | |
| for sample in dummy_samples: | |
| f.write(json.dumps(sample.to_dict()) + "\n") | |
| print(f"Created dummy {split} dataset at {output_file}") | |
| if __name__ == "__main__": | |
| # Create dummy dataset if run directly | |
| print("Creating dummy GSM8K dataset...") | |
| create_dummy_gsm8k_dataset() | |