"""Base class for HuggingFace-based datasets.""" from typing import Optional, Dict import torch from torch.utils.data import Dataset from datasets import load_dataset from transformers import AutoTokenizer from taoTrain.config import TrainingConfig class BaseHFDataset(Dataset): """Base class for HuggingFace-based datasets.""" def __init__(self, config: TrainingConfig, split: str = "train"): """ Initialize dataset. Args: config: Training configuration split: Dataset split (train, validation, test) """ self.config = config self.split = split self.data = None self.tokenizer = None # Load tokenizer self._load_tokenizer() # Load and preprocess dataset self._load_dataset() self._preprocess() def _load_tokenizer(self): """Load tokenizer from HuggingFace.""" # Default to GPT-2 tokenizer if not specified tokenizer_name = getattr(self.config, 'tokenizer_name', 'gpt2') self.tokenizer = AutoTokenizer.from_pretrained(tokenizer_name) # Set pad token if not set if self.tokenizer.pad_token is None: self.tokenizer.pad_token = self.tokenizer.eos_token def _load_dataset(self): """Load dataset from HuggingFace.""" dataset_config = self.config.dataset try: # Load dataset if dataset_config.config: self.data = load_dataset( dataset_config.dataset_name, dataset_config.config, split=self.split, cache_dir=dataset_config.cache_dir, trust_remote_code=True, ) else: self.data = load_dataset( dataset_config.dataset_name, split=self.split, cache_dir=dataset_config.cache_dir, trust_remote_code=True, ) except Exception as e: raise ValueError(f"Failed to load dataset {dataset_config.dataset_name}: {e}") # Limit samples if specified if dataset_config.max_samples: self.data = self.data.select(range(min(dataset_config.max_samples, len(self.data)))) def _preprocess(self): """Preprocess dataset (to be implemented by subclasses).""" pass def __len__(self) -> int: """Return dataset length.""" return len(self.data) def __getitem__(self, idx: int) -> Dict[str, torch.Tensor]: """Get item (to be implemented by subclasses).""" pass