StarMist0012's picture
Add files using upload-large-folder tool
3270dae verified
"""Base class for local JSONL-based datasets (async-only)."""
import json
from typing import Optional, Dict, Any
import torch
from torch.utils.data import Dataset
from taoTrain.config import TrainingConfig
from taoTrain.data.chunk_manager import ChunkManager
from taoTrain.data.tokenizer import SentencePieceTokenizerWrapper
class BaseJSONLDataset(Dataset):
"""
Base class for local JSONL-based datasets with async-only streaming.
Designed for use with AsyncBatchIterator and TokenizationQueue.
All data loading and preprocessing happens asynchronously in background threads.
"""
def __init__(self, config: TrainingConfig, split: str = "train"):
"""
Initialize JSONL dataset with chunked loading.
Args:
config: Training configuration
split: Dataset split (train, validation, test) - not used for JSONL but kept for compatibility
Note:
Requires AsyncBatchIterator and TokenizationQueue for data loading.
See taoTrain/data/async_loader.py for usage.
"""
self.config = config
self.split = split
self.tokenizer = None
# Initialize chunk manager for streaming
dataset_config = self.config.dataset
jsonl_path = dataset_config.jsonl_path
if not jsonl_path:
raise ValueError("jsonl_path must be provided for local JSONL datasets")
# Create chunk manager
enable_streaming = dataset_config.enable_streaming
chunk_size_gb = dataset_config.chunk_size_gb
samples_per_chunk = dataset_config.samples_per_chunk
enable_metadata_cache = dataset_config.enable_chunk_metadata_cache
chunk_cache_dir = dataset_config.chunk_cache_dir
max_samples = dataset_config.max_samples
if enable_streaming:
self.chunk_manager = ChunkManager(
jsonl_path,
chunk_size_gb=chunk_size_gb,
samples_per_chunk=samples_per_chunk,
enable_metadata_cache=enable_metadata_cache,
chunk_cache_dir=chunk_cache_dir,
max_samples=max_samples
)
print(f"✓ {self.chunk_manager}")
else:
self.chunk_manager = None
# Current chunk data
self._current_chunk_num = None
self._current_chunk_data = None # {"text": [...]} or preprocessed data
self._text_field = dataset_config.text_field
# Load tokenizer
print("✓ Loading tokenizer...")
self._load_tokenizer()
print("✓ Dataset initialization complete (async mode - chunks loaded on-demand).")
def _load_tokenizer(self):
"""Load tokenizer (from local SentencePiece or HuggingFace)."""
dataset_config = self.config.dataset
# Check if tokenizer_path is specified
if dataset_config.tokenizer_path:
tokenizer_type = dataset_config.tokenizer_type
# Auto-detect tokenizer type based on file extension
if tokenizer_type is None:
if dataset_config.tokenizer_path.endswith('.model'):
tokenizer_type = 'sentencepiece'
else:
tokenizer_type = 'huggingface'
if tokenizer_type == 'sentencepiece':
# Load SentencePiece tokenizer
try:
import sentencepiece as spm
sp = spm.SentencePieceProcessor()
sp.Load(dataset_config.tokenizer_path)
# Wrap SentencePiece in a compatible interface
self.tokenizer = SentencePieceTokenizerWrapper(sp)
except ImportError:
raise ImportError("SentencePiece not installed. Install with: pip install sentencepiece")
except Exception as e:
raise ValueError(f"Failed to load SentencePiece tokenizer from {dataset_config.tokenizer_path}: {e}")
else:
# Load HuggingFace tokenizer from path
try:
from transformers import AutoTokenizer
self.tokenizer = AutoTokenizer.from_pretrained(dataset_config.tokenizer_path)
except ImportError as e:
raise ImportError("HuggingFace tokenizers require the optional 'transformers' dependency") from e
except Exception as e:
raise ValueError(f"Failed to load HuggingFace tokenizer from {dataset_config.tokenizer_path}: {e}")
else:
# Default to GPT-2 tokenizer
try:
from transformers import AutoTokenizer
except ImportError as e:
raise ImportError("Default GPT-2 tokenizer requires the optional 'transformers' dependency") from e
tokenizer_name = getattr(self.config, 'tokenizer_name', 'gpt2')
self.tokenizer = AutoTokenizer.from_pretrained(tokenizer_name)
# Set pad token if not set (for HuggingFace tokenizers)
if hasattr(self.tokenizer, 'pad_token') and self.tokenizer.pad_token is None:
if hasattr(self.tokenizer, 'eos_token'):
self.tokenizer.pad_token = self.tokenizer.eos_token
def _load_chunk(self, chunk_num: int):
"""
Load a specific chunk from JSONL file.
Args:
chunk_num: Chunk number to load (0-indexed)
"""
if not self.chunk_manager:
return
if chunk_num == self._current_chunk_num and self._current_chunk_data is not None:
# Already loaded
return
# Read chunk
chunk_examples = self.chunk_manager.read_chunk(chunk_num)
# Convert to text data
texts = []
for obj in chunk_examples:
if self._text_field in obj:
texts.append(obj[self._text_field])
self._current_chunk_data = {"text": texts}
self._current_chunk_num = chunk_num
# Preprocess chunk (tokenization happens in background via AsyncBatchIterator)
self._preprocess_chunk()
def _get_chunk_for_idx(self, idx: int) -> int:
"""
Determine which chunk contains the given global index.
Args:
idx: Global index
Returns:
Chunk number (0-indexed)
"""
if not self.chunk_manager:
return 0
current_line = 0
for chunk_num, (start_line, end_line) in enumerate(self.chunk_manager.chunk_line_ranges):
if idx < (end_line - start_line):
return chunk_num
idx -= (end_line - start_line)
# Shouldn't reach here
return 0
def _get_local_idx_in_chunk(self, global_idx: int) -> int:
"""
Convert global index to local index within the chunk.
Args:
global_idx: Global index
Returns:
Local index within the chunk
"""
if not self.chunk_manager:
return global_idx
current_line = 0
for chunk_num, (start_line, end_line) in enumerate(self.chunk_manager.chunk_line_ranges):
chunk_size = end_line - start_line
if global_idx < chunk_size:
return global_idx
global_idx -= chunk_size
return 0
def _preprocess(self):
"""Preprocess dataset (to be implemented by subclasses)."""
pass
def _preprocess_chunk(self):
"""
Preprocess current chunk (to be implemented by subclasses).
This is called after a chunk is loaded by AsyncBatchIterator.
"""
pass
def __len__(self) -> int:
"""Return dataset length."""
if self.chunk_manager:
return self.chunk_manager.effective_lines
elif self._current_chunk_data and "text" in self._current_chunk_data:
return len(self._current_chunk_data.get("text", []))
return 0
def __getitem__(self, idx: int) -> Dict[str, torch.Tensor]:
"""Get item (to be implemented by subclasses)."""
pass