""" Data loading utilities for Circuit Transformer. Supports: - Single text file: --data path/to/file.txt - Directory of text files: --data path/to/dir/ - HuggingFace dataset: --data hf:dataset_name Caching: - HF datasets: memory-mapped binary files (.bin) — O(1) RAM - Text files: torch .pt files (legacy, in-memory) - Cache location: ./circuits/.cache/ (or custom via cache_dir) Parallelism: - HF datasets tokenized via dataset.map(num_proc=N) — multiprocessing, bypasses GIL - Fast tokenizer uses Rust internally — additional parallelism within each worker """ import os import struct import hashlib import multiprocessing from pathlib import Path import numpy as np import torch from torch.utils.data import Dataset, DataLoader DEFAULT_CACHE_DIR = "./circuits/.cache" # Memmap binary format: # Header: 8 bytes = [uint32 n_chunks, uint32 max_seq_len] # Data: n_chunks * max_seq_len * 4 bytes (int32, row-major) HEADER_SIZE = 8 # --------------------------------------------------------------------------- # Cache utilities # --------------------------------------------------------------------------- def _cache_key(data_source: str, max_seq_len: int, num_samples: int | None) -> str: """Generate cache filename from parameters.""" key_str = f"{data_source}|{max_seq_len}|{num_samples}" hash_val = hashlib.md5(key_str.encode()).hexdigest()[:12] name = data_source.replace("/", "_").replace(":", "_").replace(".", "_")[-30:] return f"{name}_{max_seq_len}_{hash_val}.bin" # --------------------------------------------------------------------------- # Dataset classes # --------------------------------------------------------------------------- class MemmapDataset(Dataset): """Dataset backed by memory-mapped binary file. O(1) RAM regardless of size.""" def __init__(self, path, start=None, end=None): self.path = str(path) with open(self.path, 'rb') as f: total, self.max_seq_len = struct.unpack('II', f.read(HEADER_SIZE)) self._total = total self.data = np.memmap( self.path, dtype=np.int32, mode='r', offset=HEADER_SIZE, shape=(total, self.max_seq_len), ) self.start = start if start is not None else 0 self.end = end if end is not None else total def __len__(self): return self.end - self.start def __getitem__(self, idx): tokens = torch.from_numpy(self.data[self.start + idx].copy()).long() return {"input_ids": tokens, "labels": tokens.clone()} def split(self, val_fraction=0.1): """Split into (train, val) datasets. Both share the same memmap file.""" total = self.end - self.start n_val = max(1, int(total * val_fraction)) train = MemmapDataset(self.path, self.start, self.end - n_val) val = MemmapDataset(self.path, self.end - n_val, self.end) return train, val class TextDataset(Dataset): """Simple in-memory dataset from tokenized chunks. For small datasets.""" def __init__(self, token_chunks: list[list[int]], max_seq_len: int): self.chunks = token_chunks self.max_seq_len = max_seq_len def __len__(self) -> int: return len(self.chunks) def __getitem__(self, idx: int) -> dict[str, torch.Tensor]: tokens = self.chunks[idx] if len(tokens) < self.max_seq_len: tokens = tokens + [0] * (self.max_seq_len - len(tokens)) else: tokens = tokens[: self.max_seq_len] input_ids = torch.tensor(tokens, dtype=torch.long) return {"input_ids": input_ids, "labels": input_ids.clone()} def split(self, val_fraction=0.1): """Split into (train, val) datasets with shuffle.""" import random random.shuffle(self.chunks) n_val = max(1, int(len(self.chunks) * val_fraction)) val = TextDataset(self.chunks[:n_val], self.max_seq_len) train = TextDataset(self.chunks[n_val:], self.max_seq_len) return train, val # --------------------------------------------------------------------------- # Tokenizer # --------------------------------------------------------------------------- class _SentencePieceTokenizer: """Minimal tokenizer wrapper using sentencepiece directly. Bypasses transformers tokenizer bugs across versions.""" def __init__(self, model_path, name): import sentencepiece as spm self.sp = spm.SentencePieceProcessor() self.sp.Load(model_path) self._vocab_size = self.sp.GetPieceSize() self.eos_token_id = self.sp.eos_id() self.bos_token_id = self.sp.bos_id() self.eos_token = self.sp.IdToPiece(self.eos_token_id) self.bos_token = self.sp.IdToPiece(self.bos_token_id) self.pad_token = None self.pad_token_id = None self.name_or_path = name def __len__(self): return self._vocab_size @property def vocab_size(self): return self._vocab_size def encode(self, text, add_special_tokens=False, return_tensors=None): ids = self.sp.Encode(text) if return_tensors == "pt": import torch return torch.tensor([ids]) return ids def decode(self, ids, skip_special_tokens=False): if hasattr(ids, 'tolist'): ids = ids.tolist() return self.sp.Decode(list(ids)) def __call__(self, texts, add_special_tokens=False, **kwargs): if isinstance(texts, str): texts = [texts] return {"input_ids": [self.sp.Encode(t) for t in texts]} def get_tokenizer(name: str = "gpt2"): """Get tokenizer from HuggingFace, with sentencepiece fallback. Args: name: Tokenizer name or path. Default "gpt2" (50257 vocab). Use e.g. "facebook/MobileLLM-125M" for 32K vocab. """ from transformers import AutoTokenizer # Try AutoTokenizer (fast then slow) for use_fast in (True, False): try: tokenizer = AutoTokenizer.from_pretrained(name, use_fast=use_fast, trust_remote_code=True) if isinstance(tokenizer, bool): continue if tokenizer.pad_token is None: tokenizer.pad_token = tokenizer.eos_token return tokenizer except Exception: continue # Fallback: load sentencepiece model directly (bypasses transformers bugs) print(f"AutoTokenizer failed for {name}, falling back to sentencepiece") from huggingface_hub import hf_hub_download model_path = hf_hub_download(name, "tokenizer.model") tokenizer = _SentencePieceTokenizer(model_path, name) if tokenizer.pad_token is None: tokenizer.pad_token = tokenizer.eos_token tokenizer.pad_token_id = tokenizer.eos_token_id return tokenizer # --------------------------------------------------------------------------- # Streaming memmap writer # --------------------------------------------------------------------------- def _stream_chunks_to_memmap(tokenized, total_examples, max_seq_len, output_path, num_workers=1, read_batch=10_000): """Stream tokenized examples into a memory-mapped binary file. Single-process, numpy-batch approach. Reads batches from Arrow dataset, flattens to numpy int32, writes complete chunks to disk. Memory: O(read_batch * avg_seq_len * 4 bytes). No fork, no multiprocessing, no OOM. """ from itertools import chain from tqdm import tqdm temp_path = str(output_path) + ".tmp" n_chunks = 0 total_tokens = 0 carryover = np.array([], dtype=np.int32) n_batches = (total_examples + read_batch - 1) // read_batch with open(temp_path, 'wb') as f: f.write(struct.pack('II', 0, max_seq_len)) # placeholder header for batch_start in tqdm(range(0, total_examples, read_batch), total=n_batches, desc="Chunking", mininterval=1.0): batch_end = min(batch_start + read_batch, total_examples) batch_ids = tokenized[batch_start:batch_end]["input_ids"] # Count tokens, flatten Arrow→numpy without intermediate Python list n_tok = sum(len(ids) for ids in batch_ids if ids) if n_tok == 0: del batch_ids continue flat = np.fromiter( chain.from_iterable(ids for ids in batch_ids if ids), dtype=np.int32, count=n_tok, ) del batch_ids total_tokens += n_tok # Prepend carryover from previous batch if len(carryover) > 0: flat = np.concatenate([carryover, flat]) # Write complete chunks n_complete = len(flat) // max_seq_len if n_complete > 0: f.write(flat[:n_complete * max_seq_len].tobytes()) n_chunks += n_complete carryover = flat[n_complete * max_seq_len:].copy() del flat # Handle remaining tokens if len(carryover) >= 32: padded = np.zeros(max_seq_len, dtype=np.int32) padded[:len(carryover)] = carryover f.write(padded.tobytes()) n_chunks += 1 # Write actual count into header f.seek(0) f.write(struct.pack('II', n_chunks, max_seq_len)) os.rename(temp_path, str(output_path)) size_gb = os.path.getsize(output_path) / 1e9 print(f"Total tokens: {total_tokens:,} → {n_chunks:,} chunks ({size_gb:.1f} GB)") return n_chunks # --------------------------------------------------------------------------- # HuggingFace dataset loader (parallel + memmap) # --------------------------------------------------------------------------- def _flatten_chat(example): """Convert chat format (system + conversations list) to plain text. Handles datasets like Bespoke-Stratos-17k and OpenThoughts-114k which store data as: system (str) + conversations (list of {from, value}). """ parts = [] if example.get("system"): parts.append(example["system"].strip()) for msg in example.get("conversations", []): value = msg.get("value", "") if value: parts.append(value.strip()) return {"text": "\n\n".join(parts)} def _estimate_avg_chars(dataset, text_column: str, n_sample: int = 200) -> float: """Estimate average text length from a sample of the dataset.""" n = min(n_sample, len(dataset)) total = sum(len(dataset[i][text_column] or "") for i in range(n)) return total / max(n, 1) def _adaptive_params(avg_chars: float, n_examples: int): """Scale worker count, batch sizes based on average example length. Long examples (chain-of-thought reasoning) need smaller batches and fewer workers to avoid OOM on memory-constrained systems (especially WSL). """ cpu_count = max(1, multiprocessing.cpu_count() - 1) if avg_chars > 20_000: # very long (OpenThoughts-style, ~7K+ tokens) num_proc = min(cpu_count, 4) tok_batch = 64 read_batch = 500 elif avg_chars > 5_000: # long (detailed SFT, ~1.5K+ tokens) num_proc = min(cpu_count, 8) tok_batch = 256 read_batch = 2_000 elif avg_chars > 1_000: # medium (typical SFT) num_proc = min(cpu_count, 16) tok_batch = 500 read_batch = 5_000 else: # short (web text, wiki) num_proc = min(cpu_count, 32) tok_batch = 1000 read_batch = 10_000 return num_proc, tok_batch, read_batch def load_hf_dataset( name: str, split: str, text_column: str, tokenizer, max_seq_len: int, num_samples: int | None = None, hf_config: str | None = None, cache_path: Path | None = None, data_format: str = "text", ) -> MemmapDataset: """Load HF dataset with parallel tokenization and streaming to memmap. Parallelism: - dataset.map(num_proc=N) uses multiprocessing — bypasses GIL - GPT2TokenizerFast runs Rust tokenization — bypasses GIL - batched=True enables efficient batch processing Memory: - Adaptive batch sizes based on avg example length — prevents OOM on long sequences - Tokenized data in Arrow format (memory-mapped by HuggingFace) - Chunks streamed to binary memmap file — never in RAM """ from datasets import load_dataset config_str = f", config={hf_config}" if hf_config else "" print(f"Loading HF dataset: {name} (split={split}{config_str})") dataset = load_dataset(name, hf_config, split=split) if num_samples is not None: dataset = dataset.select(range(min(num_samples, len(dataset)))) # Flatten chat format to plain text if data_format == "chat": # Use conservative parallelism for flattening — light operation flat_proc = min(max(1, multiprocessing.cpu_count() - 1), 8) print(f"Flattening {len(dataset):,} chat examples to plain text...") dataset = dataset.map( _flatten_chat, num_proc=flat_proc, remove_columns=dataset.column_names, desc="Flattening chat", ) text_column = "text" # Estimate avg example length and adapt parameters avg_chars = _estimate_avg_chars(dataset, text_column) num_proc, tok_batch, read_batch = _adaptive_params(avg_chars, len(dataset)) print(f" Avg example length: ~{avg_chars:,.0f} chars → " f"{num_proc} workers, tok_batch={tok_batch}, read_batch={read_batch}") # Filter empty examples print(f"Filtering empty examples from {len(dataset):,}...") dataset = dataset.filter( lambda x: bool(x[text_column] and x[text_column].strip()), num_proc=num_proc, desc="Filtering", ) print(f" {len(dataset):,} non-empty examples") # Parallel tokenization print(f"Tokenizing {len(dataset):,} examples with {num_proc} workers...") def tokenize_batch(examples): return tokenizer(examples[text_column], add_special_tokens=False) tokenized = dataset.map( tokenize_batch, batched=True, batch_size=tok_batch, num_proc=num_proc, remove_columns=dataset.column_names, desc="Tokenizing", ) # Stream to memmap — use temp path if no cache configured if cache_path is None: import tempfile cache_path = Path(tempfile.mktemp(suffix='.bin')) _stream_chunks_to_memmap(tokenized, len(tokenized), max_seq_len, cache_path, read_batch=read_batch) return MemmapDataset(cache_path) # --------------------------------------------------------------------------- # Text file loaders (unchanged — small datasets, in-memory is fine) # --------------------------------------------------------------------------- def tokenize_text(text: str, tokenizer, max_seq_len: int) -> list[list[int]]: """Tokenize text into chunks of max_seq_len.""" tokens = tokenizer.encode(text) chunks = [] for i in range(0, len(tokens), max_seq_len): chunk = tokens[i : i + max_seq_len] if len(chunk) >= 32: chunks.append(chunk) return chunks def load_text_file(path: str, tokenizer, max_seq_len: int) -> list[list[int]]: """Load and tokenize a single text file.""" with open(path, "r", encoding="utf-8") as f: text = f.read() return tokenize_text(text, tokenizer, max_seq_len) def load_text_directory(path: str, tokenizer, max_seq_len: int) -> list[list[int]]: """Load and tokenize all .txt files from a directory.""" all_chunks = [] path = Path(path) for txt_file in sorted(path.glob("**/*.txt")): chunks = load_text_file(str(txt_file), tokenizer, max_seq_len) all_chunks.extend(chunks) return all_chunks # --------------------------------------------------------------------------- # Main entry point # --------------------------------------------------------------------------- def load_data( data_source: str, tokenizer, max_seq_len: int, text_column: str = "text", num_samples: int | None = None, cache_dir: str | None = DEFAULT_CACHE_DIR, data_format: str = "text", ) -> Dataset: """ Load data from various sources. Returns a Dataset with .split() support. Args: data_source: Path or HF dataset identifier - "path/to/file.txt" — single file - "path/to/dir/" — directory of .txt files - "hf:dataset_name" — HuggingFace dataset (train split) - "hf:dataset:split" — HuggingFace with specific split - "hf:dataset:config:split" — with config and split tokenizer: Tokenizer to use max_seq_len: Maximum sequence length text_column: Column name for HF datasets num_samples: Limit samples from HF dataset cache_dir: Directory for cache files (None to disable) Returns: Dataset object supporting len(), __getitem__(), and split(fraction) """ cache_path = None if cache_dir is not None: cache_path = Path(cache_dir) / _cache_key(data_source, max_seq_len, num_samples) cache_path.parent.mkdir(parents=True, exist_ok=True) # Check for memmap cache (.bin) if cache_path.exists(): print(f"Loading from cache: {cache_path}") ds = MemmapDataset(cache_path) print(f" Loaded {len(ds):,} chunks") return ds # Check for legacy cache (.pt) legacy_path = cache_path.with_suffix('.pt') if legacy_path.exists(): print(f"Loading from legacy cache: {legacy_path}") data = torch.load(legacy_path, weights_only=False) chunks = data["chunks"] print(f" Loaded {len(chunks):,} chunks") return TextDataset(chunks, max_seq_len) # Load and tokenize if data_source.startswith("hf:"): parts = data_source[3:].split(":") name = parts[0] hf_config = None split = "train" if len(parts) == 2: split = parts[1] elif len(parts) == 3: hf_config = parts[1] split = parts[2] return load_hf_dataset( name, split, text_column, tokenizer, max_seq_len, num_samples, hf_config=hf_config, cache_path=cache_path, data_format=data_format, ) elif os.path.isfile(data_source): chunks = load_text_file(data_source, tokenizer, max_seq_len) elif os.path.isdir(data_source): chunks = load_text_directory(data_source, tokenizer, max_seq_len) else: raise ValueError(f"Unknown data source: {data_source}") # For text files: save legacy cache if cache_dir is not None: legacy_path = cache_path.with_suffix('.pt') torch.save({"chunks": chunks, "data_source": data_source, "max_seq_len": max_seq_len, "num_samples": num_samples}, legacy_path) print(f"Saved to cache: {legacy_path}") return TextDataset(chunks, max_seq_len) # --------------------------------------------------------------------------- # DataLoader factory # --------------------------------------------------------------------------- def create_dataloader( dataset, batch_size: int, max_seq_len: int = None, shuffle: bool = True, num_workers: int = 0, ) -> DataLoader: """Create a DataLoader from a Dataset or list of chunks.""" if not isinstance(dataset, Dataset): # Legacy compatibility: list of token chunks dataset = TextDataset(dataset, max_seq_len) return DataLoader( dataset, batch_size=batch_size, shuffle=shuffle, num_workers=num_workers, pin_memory=True, )