File size: 19,041 Bytes
3270dae
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
"""Chunk manager for streaming large JSONL datasets."""

import os
import json
import hashlib
from typing import Tuple, Optional, Dict, Any
from pathlib import Path
from tqdm import tqdm


class ChunkManager:
    """

    Manages chunked reading of large JSONL files.

    

    This class handles:

    - File scanning to count total lines without loading all text

    - Estimating chunk boundaries based on file size

    - Tracking which line ranges belong to each chunk

    """
    
    def __init__(self, jsonl_path: str, chunk_size_gb: float = 5.0, 

                 samples_per_chunk: Optional[int] = None,

                 enable_metadata_cache: bool = True, chunk_cache_dir: str = ".cache/chunks",

                 max_samples: Optional[int] = None):
        """

        Initialize ChunkManager.

        

        Args:

            jsonl_path: Path to JSONL file

            chunk_size_gb: Approximate chunk size in GB (ignored if samples_per_chunk is set)

            samples_per_chunk: Number of samples per chunk (takes precedence over chunk_size_gb)

            enable_metadata_cache: Enable caching of file scan metadata

            chunk_cache_dir: Directory to store cache files

            max_samples: Limit total samples to at most this many (if total_lines > max_samples)

        

        Raises:

            FileNotFoundError: If JSONL file doesn't exist

            ValueError: If file is empty

        """
        self.jsonl_path = Path(jsonl_path)
        self.chunk_size_bytes = int(chunk_size_gb * 1024 ** 3)  # Convert GB to bytes
        self.max_samples = max_samples  # Limit total samples if specified
        print (f"Initializing ChunkManager for {self.jsonl_path} with target chunk size {chunk_size_gb} GB")
        if samples_per_chunk is not None:
            print(f"   Overriding chunk size with {samples_per_chunk} samples per chunk")
        if max_samples is not None:
            print(f"   Limiting dataset to {max_samples} samples")
        self.samples_per_chunk = samples_per_chunk  # If set, overrides GB-based chunking
        self.enable_metadata_cache = enable_metadata_cache
        self.chunk_cache_dir = Path(chunk_cache_dir)
        
        if not self.jsonl_path.exists():
            raise FileNotFoundError(f"JSONL file not found: {self.jsonl_path}")
        
        self.file_size_bytes = os.path.getsize(self.jsonl_path)
        self.file_mtime = os.path.getmtime(self.jsonl_path)
        
        if self.file_size_bytes == 0:
            raise ValueError("JSONL file is empty")
        
        # Will be populated by _scan_file()
        self.total_lines = 0
        self.effective_lines = 0
        self.line_sizes = []  # bytes per line
        self.valid_line_offsets = []  # byte offset of each VALID JSON line (for seeking)
        self.chunk_line_ranges = []  # [(start_line, end_line), ...]
        
        # Try to load from cache first
        cache_loaded = False
        if self.enable_metadata_cache:
            cache_loaded = self._load_metadata_cache()

        # If cache not used, scan the file
        if not cache_loaded:
            self._scan_file()
            self._compute_chunk_ranges()
            
            # Save metadata cache for future runs
            if self.enable_metadata_cache:
                self._save_metadata_cache()
        else:
            # Cache stores file scan metadata. Recompute chunk ranges for the
            # current training config so samples_per_chunk/max_samples changes
            # are honored without rescanning the large JSONL file.
            self._compute_chunk_ranges()
    
    def _get_cache_path(self) -> Path:
        """Get the metadata cache file path for this JSONL file."""
        # Create a hash of the file path to use as cache filename
        file_hash = hashlib.md5(str(self.jsonl_path.absolute()).encode()).hexdigest()[:8]
        cache_file = self.chunk_cache_dir / f"{file_hash}.metadata.json"
        return cache_file
    
    def _load_metadata_cache(self) -> bool:
        """

        Load metadata from cache if it exists and is valid.

        

        Returns:

            True if cache was loaded successfully, False otherwise

        """
        cache_file = self._get_cache_path()
        
        if not cache_file.exists():
            return False
        
        try:
            with open(cache_file, 'r', encoding='utf-8') as f:
                cache_data = json.load(f)
            
            # Validate cache: check file hasn't changed
            if (cache_data.get('file_size') != self.file_size_bytes or
                cache_data.get('file_mtime') != self.file_mtime or
                cache_data.get('jsonl_path') != str(self.jsonl_path.absolute())):
                return False
            
            # Load cached data
            self.total_lines = cache_data.get('total_lines', 0)
            self.line_sizes = cache_data.get('line_sizes', [])
            self.valid_line_offsets = cache_data.get('valid_line_offsets', [])
            # Convert loaded lists back to tuples for chunk_line_ranges
            chunk_ranges = cache_data.get('chunk_line_ranges', [])
            self.chunk_line_ranges = [tuple(r) for r in chunk_ranges]
            self.chunk_size_bytes = cache_data.get('chunk_size_bytes', self.chunk_size_bytes)
            
            print(f"โœ“ Loaded scan metadata from cache: {cache_file.name}")
            print(f"   Found {self.total_lines:,} valid JSON lines in {len(self.chunk_line_ranges)} chunks")
            return True
            
        except Exception as e:
            # If cache loading fails, fall back to scanning
            return False
    
    def _save_metadata_cache(self) -> None:
        """Save metadata cache to file."""
        cache_file = self._get_cache_path()
        cache_file.parent.mkdir(parents=True, exist_ok=True)
        
        cache_data = {
            'jsonl_path': str(self.jsonl_path.absolute()),
            'file_size': self.file_size_bytes,
            'file_mtime': self.file_mtime,
            'total_lines': self.total_lines,
            'line_sizes': self.line_sizes,
            'valid_line_offsets': self.valid_line_offsets,
            'chunk_line_ranges': self.chunk_line_ranges,
            'chunk_size_bytes': self.chunk_size_bytes,
        }
        
        try:
            # Write atomically using a temp file + rename
            temp_file = cache_file.with_suffix('.tmp')
            with open(temp_file, 'w', encoding='utf-8') as f:
                json.dump(cache_data, f, indent=2)
            temp_file.replace(cache_file)
            print(f"   Saved scan metadata to cache: {cache_file.name}")
        except Exception as e:
            print(f"   โš  Warning: failed to save cache: {e}")
    
    def _get_chunk_cache_dir(self) -> Path:
        """Get the directory for storing cached chunk data for this JSONL file."""
        file_hash = hashlib.md5(str(self.jsonl_path.absolute()).encode()).hexdigest()[:8]
        chunk_dir = self.chunk_cache_dir / "chunks" / file_hash
        return chunk_dir
    
    def _get_chunk_cache_file(self, chunk_num: int) -> Path:
        """Get the cache file path for a specific chunk."""
        chunk_dir = self._get_chunk_cache_dir()
        return chunk_dir / f"chunk_{chunk_num:06d}.jsonl"
    
    def _get_chunk_index_file(self) -> Path:
        """Get the index file that lists all cached chunks."""
        chunk_dir = self._get_chunk_cache_dir()
        return chunk_dir / "index.json"
    
    def extract_and_cache_chunks(self) -> Dict[str, Any]:
        """

        Extract chunks from the original JSONL file and save them as separate cached files.

        

        This is optional and should be called manually if you want to pre-cache chunks

        for faster repeated access. It can significantly speed up training but uses more disk space.

        

        Returns:

            Dictionary with cache information:

            - 'cache_dir': path to cache directory

            - 'num_chunks': number of chunks cached

            - 'total_size_gb': total size of cached chunks

        """
        chunk_dir = self._get_chunk_cache_dir()
        chunk_dir.mkdir(parents=True, exist_ok=True)
        
        print(f"๐Ÿ’พ Extracting {len(self.chunk_line_ranges)} chunks to cache...")
        total_size = 0
        
        for chunk_num in range(len(self.chunk_line_ranges)):
            cache_file = self._get_chunk_cache_file(chunk_num)
            
            # Skip if already cached
            if cache_file.exists():
                total_size += os.path.getsize(cache_file)
                continue
            
            # Read chunk and save to cache file
            chunk_examples = self.read_chunk(chunk_num, _from_cache=False)
            
            with open(cache_file, 'w', encoding='utf-8') as f:
                for obj in chunk_examples:
                    f.write(json.dumps(obj) + '\n')
            
            total_size += os.path.getsize(cache_file)
            if (chunk_num + 1) % max(1, len(self.chunk_line_ranges) // 10) == 0:
                print(f"   - Cached {chunk_num + 1}/{len(self.chunk_line_ranges)} chunks...")
        
        # Write index file
        index_data = {
            'jsonl_path': str(self.jsonl_path.absolute()),
            'num_chunks': len(self.chunk_line_ranges),
            'chunk_ranges': self.chunk_line_ranges,
        }
        with open(self._get_chunk_index_file(), 'w', encoding='utf-8') as f:
            json.dump(index_data, f, indent=2)
        
        print(f"โœ“ Cached {len(self.chunk_line_ranges)} chunks ({total_size / (1024**3):.2f} GB)")
        
        return {
            'cache_dir': str(chunk_dir),
            'num_chunks': len(self.chunk_line_ranges),
            'total_size_gb': total_size / (1024**3),
        }
    
    def clear_chunk_cache(self, keep_metadata: bool = False) -> None:
        """

        Clear cached chunk data.

        

        Args:

            keep_metadata: If True, only remove chunk files, keep the metadata cache

        """
        chunk_dir = self._get_chunk_cache_dir()
        
        if chunk_dir.exists():
            import shutil
            shutil.rmtree(chunk_dir)
            print(f"โœ“ Cleared chunk cache: {chunk_dir}")
        
        if not keep_metadata:
            cache_file = self._get_cache_path()
            if cache_file.exists():
                cache_file.unlink()
                print(f"โœ“ Cleared metadata cache: {cache_file}")
    
    def _scan_file(self) -> None:
        """

        Scan JSONL file to count lines and track offsets.

        

        This reads the file once to:

        - Count total valid JSON lines

        - Record byte offset of each VALID line for seeking

        - Estimate size per line

        """
        print(f"๐Ÿ“– Scanning JSONL file: {self.jsonl_path}")
        print(f"   File size: {self.file_size_bytes / (1024**3):.2f} GB")
        
        self.valid_line_offsets = []
        current_offset = 0
        valid_lines = 0
        
        try:
            with open(self.jsonl_path, 'r', encoding='utf-8') as f:
                for line in tqdm(f, desc="Scanning JSONL", unit=" lines"):
                    # Skip empty lines - don't count toward line numbers
                    if not line.strip():
                        current_offset += len(line.encode('utf-8'))
                        continue
                    
                    try:
                        json.loads(line)
                        # Valid JSON line - record its starting byte offset
                        self.valid_line_offsets.append(current_offset)
                        valid_lines += 1
                        
                        line_bytes = len(line.encode('utf-8'))
                        self.line_sizes.append(line_bytes)
                        
                    except json.JSONDecodeError:
                        # Skip invalid JSON lines - don't count toward line numbers
                        pass
                    
                    current_offset += len(line.encode('utf-8'))
        
        except Exception as e:
            raise ValueError(f"Error scanning JSONL file: {e}")
        
        self.total_lines = valid_lines
        
        if self.total_lines == 0:
            raise ValueError("No valid JSON lines found in JSONL file")
        
        print(f"โœ“ Found {self.total_lines:,} valid JSON lines")
        
        # Calculate average line size
        avg_line_size = sum(self.line_sizes) / len(self.line_sizes) if self.line_sizes else 0
        print(f"   Average line size: {avg_line_size:.2f} bytes")
        print(f"   Chunk size target: {self.chunk_size_bytes / (1024**3):.2f} GB")
    
    def _compute_chunk_ranges(self) -> None:
        """

        Compute line ranges for each chunk based on target chunk size.

        

        If samples_per_chunk is set, uses that. Otherwise, divides file

        based on chunk_size_bytes. If max_samples is set, limits chunks to cover

        at most max_samples lines.

        """
        if self.total_lines == 0:
            self.chunk_line_ranges = []
            return
        
        # Apply max_samples limit to effective line count
        self.effective_lines = self.total_lines
        if self.max_samples is not None:
            self.effective_lines = min(self.total_lines, self.max_samples)
        
        # Determine lines per chunk
        if self.samples_per_chunk is not None:
            # Use explicit sample count
            lines_per_chunk = self.samples_per_chunk
        else:
            # Use GB-based calculation
            avg_line_size = sum(self.line_sizes) / len(self.line_sizes) if self.line_sizes else 1
            lines_per_chunk = max(1, int(self.chunk_size_bytes / avg_line_size))
        
        chunk_ranges = []
        start_line = 0
        
        # Create chunks up to self.effective_lines (honors max_samples)
        while start_line < self.effective_lines:
            end_line = min(start_line + lines_per_chunk, self.effective_lines)
            chunk_ranges.append((start_line, end_line))
            start_line = end_line
        
        self.chunk_line_ranges = chunk_ranges
        self.num_chunks = len(chunk_ranges)
        
        print(f"   Divided into {self.num_chunks} chunks (covering {self.effective_lines:,} lines)")
    
    def get_chunk_indices(self, chunk_num: int) -> Tuple[int, int]:
        """

        Get (start_line, end_line) for a given chunk number.

        

        Args:

            chunk_num: Chunk number (0-indexed)

        

        Returns:

            Tuple of (start_line, end_line) where end_line is exclusive

        

        Raises:

            IndexError: If chunk_num is out of range

        """
        if chunk_num < 0 or chunk_num >= len(self.chunk_line_ranges):
            raise IndexError(f"Chunk {chunk_num} out of range [0, {len(self.chunk_line_ranges)-1}]")
        
        return self.chunk_line_ranges[chunk_num]
    
    def read_chunk(self, chunk_num: int, _from_cache: bool = True) -> list[dict]:
        """

        Read a specific chunk and return parsed JSON objects.

        

        If chunk cache is available, reads from cache. Otherwise reads from original JSONL

        using file.seek() for O(1) lookup instead of O(n) scanning.

        

        Args:

            chunk_num: Chunk number (0-indexed)

            _from_cache: Internal parameter to force reading from original (used during cache extraction)

        

        Returns:

            List of parsed JSON objects from that chunk

        

        Raises:

            IndexError: If chunk_num is out of range

            ValueError: If JSON parsing fails

        """
        # Try to read from cache first (if it exists)
        if _from_cache:
            cache_file = self._get_chunk_cache_file(chunk_num)
            if cache_file.exists():
                examples = []
                try:
                    with open(cache_file, 'r', encoding='utf-8') as f:
                        for line in f:
                            if line.strip():
                                try:
                                    obj = json.loads(line)
                                    examples.append(obj)
                                except json.JSONDecodeError:
                                    pass
                    return examples
                except Exception as e:
                    print(f"   โš  Warning: failed to read chunk from cache, falling back to original: {e}")
        
        # Read from original JSONL file using seek optimization
        start_line, end_line = self.get_chunk_indices(chunk_num)
        
        examples = []
        
        with open(self.jsonl_path, 'r', encoding='utf-8') as f:
            # Seek to the byte offset of the start line
            # This is O(1) instead of O(start_line) iteration
            if start_line < len(self.valid_line_offsets):
                f.seek(self.valid_line_offsets[start_line])
            else:
                # Fallback if valid_line_offsets not available (shouldn't happen)
                f.seek(0)
            
            current_line = start_line
            
            # Read lines from start_line to end_line
            for line in f:
                # Skip empty lines
                if not line.strip():
                    continue
                
                # Stop when we've read enough lines
                if current_line >= end_line:
                    break
                
                try:
                    obj = json.loads(line)
                    examples.append(obj)
                    current_line += 1
                except json.JSONDecodeError:
                    # Skip invalid JSON lines, but don't increment line counter
                    # This maintains alignment with line numbering from scan
                    pass
        
        return examples
    
    @property
    def num_chunks(self) -> int:
        """Return number of chunks."""
        return len(self.chunk_line_ranges)
    
    @num_chunks.setter
    def num_chunks(self, value: int) -> None:
        """Set number of chunks (internal use)."""
        self._num_chunks = value
    
    def __repr__(self) -> str:
        """String representation."""
        return (
            f"ChunkManager(file={self.jsonl_path.name}, "
            f"size={self.file_size_bytes/(1024**3):.2f}GB, "
            f"lines={self.effective_lines:,}, "
            f"chunks={self.num_chunks})"
        )