"""Chunk manager for streaming large JSONL datasets.""" import os import json import hashlib from typing import Tuple, Optional, Dict, Any from pathlib import Path from tqdm import tqdm class ChunkManager: """ Manages chunked reading of large JSONL files. This class handles: - File scanning to count total lines without loading all text - Estimating chunk boundaries based on file size - Tracking which line ranges belong to each chunk """ def __init__(self, jsonl_path: str, chunk_size_gb: float = 5.0, samples_per_chunk: Optional[int] = None, enable_metadata_cache: bool = True, chunk_cache_dir: str = ".cache/chunks", max_samples: Optional[int] = None): """ Initialize ChunkManager. Args: jsonl_path: Path to JSONL file chunk_size_gb: Approximate chunk size in GB (ignored if samples_per_chunk is set) samples_per_chunk: Number of samples per chunk (takes precedence over chunk_size_gb) enable_metadata_cache: Enable caching of file scan metadata chunk_cache_dir: Directory to store cache files max_samples: Limit total samples to at most this many (if total_lines > max_samples) Raises: FileNotFoundError: If JSONL file doesn't exist ValueError: If file is empty """ self.jsonl_path = Path(jsonl_path) self.chunk_size_bytes = int(chunk_size_gb * 1024 ** 3) # Convert GB to bytes self.max_samples = max_samples # Limit total samples if specified print (f"Initializing ChunkManager for {self.jsonl_path} with target chunk size {chunk_size_gb} GB") if samples_per_chunk is not None: print(f" Overriding chunk size with {samples_per_chunk} samples per chunk") if max_samples is not None: print(f" Limiting dataset to {max_samples} samples") self.samples_per_chunk = samples_per_chunk # If set, overrides GB-based chunking self.enable_metadata_cache = enable_metadata_cache self.chunk_cache_dir = Path(chunk_cache_dir) if not self.jsonl_path.exists(): raise FileNotFoundError(f"JSONL file not found: {self.jsonl_path}") self.file_size_bytes = os.path.getsize(self.jsonl_path) self.file_mtime = os.path.getmtime(self.jsonl_path) if self.file_size_bytes == 0: raise ValueError("JSONL file is empty") # Will be populated by _scan_file() self.total_lines = 0 self.effective_lines = 0 self.line_sizes = [] # bytes per line self.valid_line_offsets = [] # byte offset of each VALID JSON line (for seeking) self.chunk_line_ranges = [] # [(start_line, end_line), ...] # Try to load from cache first cache_loaded = False if self.enable_metadata_cache: cache_loaded = self._load_metadata_cache() # If cache not used, scan the file if not cache_loaded: self._scan_file() self._compute_chunk_ranges() # Save metadata cache for future runs if self.enable_metadata_cache: self._save_metadata_cache() else: # Cache stores file scan metadata. Recompute chunk ranges for the # current training config so samples_per_chunk/max_samples changes # are honored without rescanning the large JSONL file. self._compute_chunk_ranges() def _get_cache_path(self) -> Path: """Get the metadata cache file path for this JSONL file.""" # Create a hash of the file path to use as cache filename file_hash = hashlib.md5(str(self.jsonl_path.absolute()).encode()).hexdigest()[:8] cache_file = self.chunk_cache_dir / f"{file_hash}.metadata.json" return cache_file def _load_metadata_cache(self) -> bool: """ Load metadata from cache if it exists and is valid. Returns: True if cache was loaded successfully, False otherwise """ cache_file = self._get_cache_path() if not cache_file.exists(): return False try: with open(cache_file, 'r', encoding='utf-8') as f: cache_data = json.load(f) # Validate cache: check file hasn't changed if (cache_data.get('file_size') != self.file_size_bytes or cache_data.get('file_mtime') != self.file_mtime or cache_data.get('jsonl_path') != str(self.jsonl_path.absolute())): return False # Load cached data self.total_lines = cache_data.get('total_lines', 0) self.line_sizes = cache_data.get('line_sizes', []) self.valid_line_offsets = cache_data.get('valid_line_offsets', []) # Convert loaded lists back to tuples for chunk_line_ranges chunk_ranges = cache_data.get('chunk_line_ranges', []) self.chunk_line_ranges = [tuple(r) for r in chunk_ranges] self.chunk_size_bytes = cache_data.get('chunk_size_bytes', self.chunk_size_bytes) print(f"✓ Loaded scan metadata from cache: {cache_file.name}") print(f" Found {self.total_lines:,} valid JSON lines in {len(self.chunk_line_ranges)} chunks") return True except Exception as e: # If cache loading fails, fall back to scanning return False def _save_metadata_cache(self) -> None: """Save metadata cache to file.""" cache_file = self._get_cache_path() cache_file.parent.mkdir(parents=True, exist_ok=True) cache_data = { 'jsonl_path': str(self.jsonl_path.absolute()), 'file_size': self.file_size_bytes, 'file_mtime': self.file_mtime, 'total_lines': self.total_lines, 'line_sizes': self.line_sizes, 'valid_line_offsets': self.valid_line_offsets, 'chunk_line_ranges': self.chunk_line_ranges, 'chunk_size_bytes': self.chunk_size_bytes, } try: # Write atomically using a temp file + rename temp_file = cache_file.with_suffix('.tmp') with open(temp_file, 'w', encoding='utf-8') as f: json.dump(cache_data, f, indent=2) temp_file.replace(cache_file) print(f" Saved scan metadata to cache: {cache_file.name}") except Exception as e: print(f" ⚠ Warning: failed to save cache: {e}") def _get_chunk_cache_dir(self) -> Path: """Get the directory for storing cached chunk data for this JSONL file.""" file_hash = hashlib.md5(str(self.jsonl_path.absolute()).encode()).hexdigest()[:8] chunk_dir = self.chunk_cache_dir / "chunks" / file_hash return chunk_dir def _get_chunk_cache_file(self, chunk_num: int) -> Path: """Get the cache file path for a specific chunk.""" chunk_dir = self._get_chunk_cache_dir() return chunk_dir / f"chunk_{chunk_num:06d}.jsonl" def _get_chunk_index_file(self) -> Path: """Get the index file that lists all cached chunks.""" chunk_dir = self._get_chunk_cache_dir() return chunk_dir / "index.json" def extract_and_cache_chunks(self) -> Dict[str, Any]: """ Extract chunks from the original JSONL file and save them as separate cached files. This is optional and should be called manually if you want to pre-cache chunks for faster repeated access. It can significantly speed up training but uses more disk space. Returns: Dictionary with cache information: - 'cache_dir': path to cache directory - 'num_chunks': number of chunks cached - 'total_size_gb': total size of cached chunks """ chunk_dir = self._get_chunk_cache_dir() chunk_dir.mkdir(parents=True, exist_ok=True) print(f"💾 Extracting {len(self.chunk_line_ranges)} chunks to cache...") total_size = 0 for chunk_num in range(len(self.chunk_line_ranges)): cache_file = self._get_chunk_cache_file(chunk_num) # Skip if already cached if cache_file.exists(): total_size += os.path.getsize(cache_file) continue # Read chunk and save to cache file chunk_examples = self.read_chunk(chunk_num, _from_cache=False) with open(cache_file, 'w', encoding='utf-8') as f: for obj in chunk_examples: f.write(json.dumps(obj) + '\n') total_size += os.path.getsize(cache_file) if (chunk_num + 1) % max(1, len(self.chunk_line_ranges) // 10) == 0: print(f" - Cached {chunk_num + 1}/{len(self.chunk_line_ranges)} chunks...") # Write index file index_data = { 'jsonl_path': str(self.jsonl_path.absolute()), 'num_chunks': len(self.chunk_line_ranges), 'chunk_ranges': self.chunk_line_ranges, } with open(self._get_chunk_index_file(), 'w', encoding='utf-8') as f: json.dump(index_data, f, indent=2) print(f"✓ Cached {len(self.chunk_line_ranges)} chunks ({total_size / (1024**3):.2f} GB)") return { 'cache_dir': str(chunk_dir), 'num_chunks': len(self.chunk_line_ranges), 'total_size_gb': total_size / (1024**3), } def clear_chunk_cache(self, keep_metadata: bool = False) -> None: """ Clear cached chunk data. Args: keep_metadata: If True, only remove chunk files, keep the metadata cache """ chunk_dir = self._get_chunk_cache_dir() if chunk_dir.exists(): import shutil shutil.rmtree(chunk_dir) print(f"✓ Cleared chunk cache: {chunk_dir}") if not keep_metadata: cache_file = self._get_cache_path() if cache_file.exists(): cache_file.unlink() print(f"✓ Cleared metadata cache: {cache_file}") def _scan_file(self) -> None: """ Scan JSONL file to count lines and track offsets. This reads the file once to: - Count total valid JSON lines - Record byte offset of each VALID line for seeking - Estimate size per line """ print(f"📖 Scanning JSONL file: {self.jsonl_path}") print(f" File size: {self.file_size_bytes / (1024**3):.2f} GB") self.valid_line_offsets = [] current_offset = 0 valid_lines = 0 try: with open(self.jsonl_path, 'r', encoding='utf-8') as f: for line in tqdm(f, desc="Scanning JSONL", unit=" lines"): # Skip empty lines - don't count toward line numbers if not line.strip(): current_offset += len(line.encode('utf-8')) continue try: json.loads(line) # Valid JSON line - record its starting byte offset self.valid_line_offsets.append(current_offset) valid_lines += 1 line_bytes = len(line.encode('utf-8')) self.line_sizes.append(line_bytes) except json.JSONDecodeError: # Skip invalid JSON lines - don't count toward line numbers pass current_offset += len(line.encode('utf-8')) except Exception as e: raise ValueError(f"Error scanning JSONL file: {e}") self.total_lines = valid_lines if self.total_lines == 0: raise ValueError("No valid JSON lines found in JSONL file") print(f"✓ Found {self.total_lines:,} valid JSON lines") # Calculate average line size avg_line_size = sum(self.line_sizes) / len(self.line_sizes) if self.line_sizes else 0 print(f" Average line size: {avg_line_size:.2f} bytes") print(f" Chunk size target: {self.chunk_size_bytes / (1024**3):.2f} GB") def _compute_chunk_ranges(self) -> None: """ Compute line ranges for each chunk based on target chunk size. If samples_per_chunk is set, uses that. Otherwise, divides file based on chunk_size_bytes. If max_samples is set, limits chunks to cover at most max_samples lines. """ if self.total_lines == 0: self.chunk_line_ranges = [] return # Apply max_samples limit to effective line count self.effective_lines = self.total_lines if self.max_samples is not None: self.effective_lines = min(self.total_lines, self.max_samples) # Determine lines per chunk if self.samples_per_chunk is not None: # Use explicit sample count lines_per_chunk = self.samples_per_chunk else: # Use GB-based calculation avg_line_size = sum(self.line_sizes) / len(self.line_sizes) if self.line_sizes else 1 lines_per_chunk = max(1, int(self.chunk_size_bytes / avg_line_size)) chunk_ranges = [] start_line = 0 # Create chunks up to self.effective_lines (honors max_samples) while start_line < self.effective_lines: end_line = min(start_line + lines_per_chunk, self.effective_lines) chunk_ranges.append((start_line, end_line)) start_line = end_line self.chunk_line_ranges = chunk_ranges self.num_chunks = len(chunk_ranges) print(f" Divided into {self.num_chunks} chunks (covering {self.effective_lines:,} lines)") def get_chunk_indices(self, chunk_num: int) -> Tuple[int, int]: """ Get (start_line, end_line) for a given chunk number. Args: chunk_num: Chunk number (0-indexed) Returns: Tuple of (start_line, end_line) where end_line is exclusive Raises: IndexError: If chunk_num is out of range """ if chunk_num < 0 or chunk_num >= len(self.chunk_line_ranges): raise IndexError(f"Chunk {chunk_num} out of range [0, {len(self.chunk_line_ranges)-1}]") return self.chunk_line_ranges[chunk_num] def read_chunk(self, chunk_num: int, _from_cache: bool = True) -> list[dict]: """ Read a specific chunk and return parsed JSON objects. If chunk cache is available, reads from cache. Otherwise reads from original JSONL using file.seek() for O(1) lookup instead of O(n) scanning. Args: chunk_num: Chunk number (0-indexed) _from_cache: Internal parameter to force reading from original (used during cache extraction) Returns: List of parsed JSON objects from that chunk Raises: IndexError: If chunk_num is out of range ValueError: If JSON parsing fails """ # Try to read from cache first (if it exists) if _from_cache: cache_file = self._get_chunk_cache_file(chunk_num) if cache_file.exists(): examples = [] try: with open(cache_file, 'r', encoding='utf-8') as f: for line in f: if line.strip(): try: obj = json.loads(line) examples.append(obj) except json.JSONDecodeError: pass return examples except Exception as e: print(f" ⚠ Warning: failed to read chunk from cache, falling back to original: {e}") # Read from original JSONL file using seek optimization start_line, end_line = self.get_chunk_indices(chunk_num) examples = [] with open(self.jsonl_path, 'r', encoding='utf-8') as f: # Seek to the byte offset of the start line # This is O(1) instead of O(start_line) iteration if start_line < len(self.valid_line_offsets): f.seek(self.valid_line_offsets[start_line]) else: # Fallback if valid_line_offsets not available (shouldn't happen) f.seek(0) current_line = start_line # Read lines from start_line to end_line for line in f: # Skip empty lines if not line.strip(): continue # Stop when we've read enough lines if current_line >= end_line: break try: obj = json.loads(line) examples.append(obj) current_line += 1 except json.JSONDecodeError: # Skip invalid JSON lines, but don't increment line counter # This maintains alignment with line numbering from scan pass return examples @property def num_chunks(self) -> int: """Return number of chunks.""" return len(self.chunk_line_ranges) @num_chunks.setter def num_chunks(self, value: int) -> None: """Set number of chunks (internal use).""" self._num_chunks = value def __repr__(self) -> str: """String representation.""" return ( f"ChunkManager(file={self.jsonl_path.name}, " f"size={self.file_size_bytes/(1024**3):.2f}GB, " f"lines={self.effective_lines:,}, " f"chunks={self.num_chunks})" )