"""Base class for local JSONL-based datasets (async-only).""" import json from typing import Optional, Dict, Any import torch from torch.utils.data import Dataset from taoTrain.config import TrainingConfig from taoTrain.data.chunk_manager import ChunkManager from taoTrain.data.tokenizer import SentencePieceTokenizerWrapper class BaseJSONLDataset(Dataset): """ Base class for local JSONL-based datasets with async-only streaming. Designed for use with AsyncBatchIterator and TokenizationQueue. All data loading and preprocessing happens asynchronously in background threads. """ def __init__(self, config: TrainingConfig, split: str = "train"): """ Initialize JSONL dataset with chunked loading. Args: config: Training configuration split: Dataset split (train, validation, test) - not used for JSONL but kept for compatibility Note: Requires AsyncBatchIterator and TokenizationQueue for data loading. See taoTrain/data/async_loader.py for usage. """ self.config = config self.split = split self.tokenizer = None # Initialize chunk manager for streaming dataset_config = self.config.dataset jsonl_path = dataset_config.jsonl_path if not jsonl_path: raise ValueError("jsonl_path must be provided for local JSONL datasets") # Create chunk manager enable_streaming = dataset_config.enable_streaming chunk_size_gb = dataset_config.chunk_size_gb samples_per_chunk = dataset_config.samples_per_chunk enable_metadata_cache = dataset_config.enable_chunk_metadata_cache chunk_cache_dir = dataset_config.chunk_cache_dir max_samples = dataset_config.max_samples if enable_streaming: self.chunk_manager = ChunkManager( jsonl_path, chunk_size_gb=chunk_size_gb, samples_per_chunk=samples_per_chunk, enable_metadata_cache=enable_metadata_cache, chunk_cache_dir=chunk_cache_dir, max_samples=max_samples ) print(f"✓ {self.chunk_manager}") else: self.chunk_manager = None # Current chunk data self._current_chunk_num = None self._current_chunk_data = None # {"text": [...]} or preprocessed data self._text_field = dataset_config.text_field # Load tokenizer print("✓ Loading tokenizer...") self._load_tokenizer() print("✓ Dataset initialization complete (async mode - chunks loaded on-demand).") def _load_tokenizer(self): """Load tokenizer (from local SentencePiece or HuggingFace).""" dataset_config = self.config.dataset # Check if tokenizer_path is specified if dataset_config.tokenizer_path: tokenizer_type = dataset_config.tokenizer_type # Auto-detect tokenizer type based on file extension if tokenizer_type is None: if dataset_config.tokenizer_path.endswith('.model'): tokenizer_type = 'sentencepiece' else: tokenizer_type = 'huggingface' if tokenizer_type == 'sentencepiece': # Load SentencePiece tokenizer try: import sentencepiece as spm sp = spm.SentencePieceProcessor() sp.Load(dataset_config.tokenizer_path) # Wrap SentencePiece in a compatible interface self.tokenizer = SentencePieceTokenizerWrapper(sp) except ImportError: raise ImportError("SentencePiece not installed. Install with: pip install sentencepiece") except Exception as e: raise ValueError(f"Failed to load SentencePiece tokenizer from {dataset_config.tokenizer_path}: {e}") else: # Load HuggingFace tokenizer from path try: from transformers import AutoTokenizer self.tokenizer = AutoTokenizer.from_pretrained(dataset_config.tokenizer_path) except ImportError as e: raise ImportError("HuggingFace tokenizers require the optional 'transformers' dependency") from e except Exception as e: raise ValueError(f"Failed to load HuggingFace tokenizer from {dataset_config.tokenizer_path}: {e}") else: # Default to GPT-2 tokenizer try: from transformers import AutoTokenizer except ImportError as e: raise ImportError("Default GPT-2 tokenizer requires the optional 'transformers' dependency") from e tokenizer_name = getattr(self.config, 'tokenizer_name', 'gpt2') self.tokenizer = AutoTokenizer.from_pretrained(tokenizer_name) # Set pad token if not set (for HuggingFace tokenizers) if hasattr(self.tokenizer, 'pad_token') and self.tokenizer.pad_token is None: if hasattr(self.tokenizer, 'eos_token'): self.tokenizer.pad_token = self.tokenizer.eos_token def _load_chunk(self, chunk_num: int): """ Load a specific chunk from JSONL file. Args: chunk_num: Chunk number to load (0-indexed) """ if not self.chunk_manager: return if chunk_num == self._current_chunk_num and self._current_chunk_data is not None: # Already loaded return # Read chunk chunk_examples = self.chunk_manager.read_chunk(chunk_num) # Convert to text data texts = [] for obj in chunk_examples: if self._text_field in obj: texts.append(obj[self._text_field]) self._current_chunk_data = {"text": texts} self._current_chunk_num = chunk_num # Preprocess chunk (tokenization happens in background via AsyncBatchIterator) self._preprocess_chunk() def _get_chunk_for_idx(self, idx: int) -> int: """ Determine which chunk contains the given global index. Args: idx: Global index Returns: Chunk number (0-indexed) """ if not self.chunk_manager: return 0 current_line = 0 for chunk_num, (start_line, end_line) in enumerate(self.chunk_manager.chunk_line_ranges): if idx < (end_line - start_line): return chunk_num idx -= (end_line - start_line) # Shouldn't reach here return 0 def _get_local_idx_in_chunk(self, global_idx: int) -> int: """ Convert global index to local index within the chunk. Args: global_idx: Global index Returns: Local index within the chunk """ if not self.chunk_manager: return global_idx current_line = 0 for chunk_num, (start_line, end_line) in enumerate(self.chunk_manager.chunk_line_ranges): chunk_size = end_line - start_line if global_idx < chunk_size: return global_idx global_idx -= chunk_size return 0 def _preprocess(self): """Preprocess dataset (to be implemented by subclasses).""" pass def _preprocess_chunk(self): """ Preprocess current chunk (to be implemented by subclasses). This is called after a chunk is loaded by AsyncBatchIterator. """ pass def __len__(self) -> int: """Return dataset length.""" if self.chunk_manager: return self.chunk_manager.effective_lines elif self._current_chunk_data and "text" in self._current_chunk_data: return len(self._current_chunk_data.get("text", [])) return 0 def __getitem__(self, idx: int) -> Dict[str, torch.Tensor]: """Get item (to be implemented by subclasses).""" pass