File size: 8,593 Bytes
3270dae
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
"""Base class for local JSONL-based datasets (async-only)."""

import json
from typing import Optional, Dict, Any
import torch
from torch.utils.data import Dataset
from taoTrain.config import TrainingConfig
from taoTrain.data.chunk_manager import ChunkManager
from taoTrain.data.tokenizer import SentencePieceTokenizerWrapper


class BaseJSONLDataset(Dataset):
    """

    Base class for local JSONL-based datasets with async-only streaming.

    

    Designed for use with AsyncBatchIterator and TokenizationQueue.

    All data loading and preprocessing happens asynchronously in background threads.

    """
    
    def __init__(self, config: TrainingConfig, split: str = "train"):
        """

        Initialize JSONL dataset with chunked loading.

        

        Args:

            config: Training configuration

            split: Dataset split (train, validation, test) - not used for JSONL but kept for compatibility

        

        Note:

            Requires AsyncBatchIterator and TokenizationQueue for data loading.

            See taoTrain/data/async_loader.py for usage.

        """
        self.config = config
        self.split = split
        self.tokenizer = None
        
        # Initialize chunk manager for streaming
        dataset_config = self.config.dataset
        jsonl_path = dataset_config.jsonl_path
        
        if not jsonl_path:
            raise ValueError("jsonl_path must be provided for local JSONL datasets")
        
        # Create chunk manager
        enable_streaming = dataset_config.enable_streaming
        chunk_size_gb = dataset_config.chunk_size_gb
        samples_per_chunk = dataset_config.samples_per_chunk
        enable_metadata_cache = dataset_config.enable_chunk_metadata_cache
        chunk_cache_dir = dataset_config.chunk_cache_dir
        max_samples = dataset_config.max_samples
        
        if enable_streaming:
            self.chunk_manager = ChunkManager(
                jsonl_path, 
                chunk_size_gb=chunk_size_gb,
                samples_per_chunk=samples_per_chunk,
                enable_metadata_cache=enable_metadata_cache,
                chunk_cache_dir=chunk_cache_dir,
                max_samples=max_samples
            )
            print(f"✓ {self.chunk_manager}")
        else:
            self.chunk_manager = None
        
        # Current chunk data
        self._current_chunk_num = None
        self._current_chunk_data = None  # {"text": [...]} or preprocessed data
        self._text_field = dataset_config.text_field
        
        # Load tokenizer
        print("✓ Loading tokenizer...")
        self._load_tokenizer()
        
        print("✓ Dataset initialization complete (async mode - chunks loaded on-demand).")
    
    def _load_tokenizer(self):
        """Load tokenizer (from local SentencePiece or HuggingFace)."""
        dataset_config = self.config.dataset
        
        # Check if tokenizer_path is specified
        if dataset_config.tokenizer_path:
            tokenizer_type = dataset_config.tokenizer_type
            
            # Auto-detect tokenizer type based on file extension
            if tokenizer_type is None:
                if dataset_config.tokenizer_path.endswith('.model'):
                    tokenizer_type = 'sentencepiece'
                else:
                    tokenizer_type = 'huggingface'
            
            if tokenizer_type == 'sentencepiece':
                # Load SentencePiece tokenizer
                try:
                    import sentencepiece as spm
                    sp = spm.SentencePieceProcessor()
                    sp.Load(dataset_config.tokenizer_path)
                    # Wrap SentencePiece in a compatible interface
                    self.tokenizer = SentencePieceTokenizerWrapper(sp)
                except ImportError:
                    raise ImportError("SentencePiece not installed. Install with: pip install sentencepiece")
                except Exception as e:
                    raise ValueError(f"Failed to load SentencePiece tokenizer from {dataset_config.tokenizer_path}: {e}")
            else:
                # Load HuggingFace tokenizer from path
                try:
                    from transformers import AutoTokenizer
                    self.tokenizer = AutoTokenizer.from_pretrained(dataset_config.tokenizer_path)
                except ImportError as e:
                    raise ImportError("HuggingFace tokenizers require the optional 'transformers' dependency") from e
                except Exception as e:
                    raise ValueError(f"Failed to load HuggingFace tokenizer from {dataset_config.tokenizer_path}: {e}")
        else:
            # Default to GPT-2 tokenizer
            try:
                from transformers import AutoTokenizer
            except ImportError as e:
                raise ImportError("Default GPT-2 tokenizer requires the optional 'transformers' dependency") from e
            tokenizer_name = getattr(self.config, 'tokenizer_name', 'gpt2')
            self.tokenizer = AutoTokenizer.from_pretrained(tokenizer_name)
        
        # Set pad token if not set (for HuggingFace tokenizers)
        if hasattr(self.tokenizer, 'pad_token') and self.tokenizer.pad_token is None:
            if hasattr(self.tokenizer, 'eos_token'):
                self.tokenizer.pad_token = self.tokenizer.eos_token
    
    def _load_chunk(self, chunk_num: int):
        """

        Load a specific chunk from JSONL file.

        

        Args:

            chunk_num: Chunk number to load (0-indexed)

        """
        if not self.chunk_manager:
            return
        
        if chunk_num == self._current_chunk_num and self._current_chunk_data is not None:
            # Already loaded
            return
        
        # Read chunk
        chunk_examples = self.chunk_manager.read_chunk(chunk_num)
        
        # Convert to text data
        texts = []
        for obj in chunk_examples:
            if self._text_field in obj:
                texts.append(obj[self._text_field])
        
        self._current_chunk_data = {"text": texts}
        self._current_chunk_num = chunk_num
        
        # Preprocess chunk (tokenization happens in background via AsyncBatchIterator)
        self._preprocess_chunk()
    
    def _get_chunk_for_idx(self, idx: int) -> int:
        """

        Determine which chunk contains the given global index.

        

        Args:

            idx: Global index

        

        Returns:

            Chunk number (0-indexed)

        """
        if not self.chunk_manager:
            return 0
        
        current_line = 0
        for chunk_num, (start_line, end_line) in enumerate(self.chunk_manager.chunk_line_ranges):
            if idx < (end_line - start_line):
                return chunk_num
            idx -= (end_line - start_line)
        
        # Shouldn't reach here
        return 0
    
    def _get_local_idx_in_chunk(self, global_idx: int) -> int:
        """

        Convert global index to local index within the chunk.

        

        Args:

            global_idx: Global index

        

        Returns:

            Local index within the chunk

        """
        if not self.chunk_manager:
            return global_idx
        
        current_line = 0
        for chunk_num, (start_line, end_line) in enumerate(self.chunk_manager.chunk_line_ranges):
            chunk_size = end_line - start_line
            if global_idx < chunk_size:
                return global_idx
            global_idx -= chunk_size
        
        return 0
    
    def _preprocess(self):
        """Preprocess dataset (to be implemented by subclasses)."""
        pass
    
    def _preprocess_chunk(self):
        """

        Preprocess current chunk (to be implemented by subclasses).

        

        This is called after a chunk is loaded by AsyncBatchIterator.

        """
        pass
    
    def __len__(self) -> int:
        """Return dataset length."""
        if self.chunk_manager:
            return self.chunk_manager.effective_lines
        elif self._current_chunk_data and "text" in self._current_chunk_data:
            return len(self._current_chunk_data.get("text", []))
        return 0
    
    def __getitem__(self, idx: int) -> Dict[str, torch.Tensor]:
        """Get item (to be implemented by subclasses)."""
        pass