| | """ |
| | Data loading utilities for Circuit Transformer. |
| | |
| | Supports: |
| | - Single text file: --data path/to/file.txt |
| | - Directory of text files: --data path/to/dir/ |
| | - HuggingFace dataset: --data hf:dataset_name |
| | |
| | Caching: |
| | - HF datasets: memory-mapped binary files (.bin) — O(1) RAM |
| | - Text files: torch .pt files (legacy, in-memory) |
| | - Cache location: ./circuits/.cache/ (or custom via cache_dir) |
| | |
| | Parallelism: |
| | - HF datasets tokenized via dataset.map(num_proc=N) — multiprocessing, bypasses GIL |
| | - Fast tokenizer uses Rust internally — additional parallelism within each worker |
| | """ |
| |
|
| | import os |
| | import struct |
| | import hashlib |
| | import multiprocessing |
| | from pathlib import Path |
| |
|
| | import numpy as np |
| | import torch |
| | from torch.utils.data import Dataset, DataLoader |
| |
|
| | DEFAULT_CACHE_DIR = "./circuits/.cache" |
| |
|
| | |
| | |
| | |
| | HEADER_SIZE = 8 |
| |
|
| |
|
| | |
| | |
| | |
| |
|
| | def _cache_key(data_source: str, max_seq_len: int, num_samples: int | None) -> str: |
| | """Generate cache filename from parameters.""" |
| | key_str = f"{data_source}|{max_seq_len}|{num_samples}" |
| | hash_val = hashlib.md5(key_str.encode()).hexdigest()[:12] |
| | name = data_source.replace("/", "_").replace(":", "_").replace(".", "_")[-30:] |
| | return f"{name}_{max_seq_len}_{hash_val}.bin" |
| |
|
| |
|
| | |
| | |
| | |
| |
|
| | class MemmapDataset(Dataset): |
| | """Dataset backed by memory-mapped binary file. O(1) RAM regardless of size.""" |
| |
|
| | def __init__(self, path, start=None, end=None): |
| | self.path = str(path) |
| | with open(self.path, 'rb') as f: |
| | total, self.max_seq_len = struct.unpack('II', f.read(HEADER_SIZE)) |
| | self._total = total |
| | self.data = np.memmap( |
| | self.path, dtype=np.int32, mode='r', |
| | offset=HEADER_SIZE, shape=(total, self.max_seq_len), |
| | ) |
| | self.start = start if start is not None else 0 |
| | self.end = end if end is not None else total |
| |
|
| | def __len__(self): |
| | return self.end - self.start |
| |
|
| | def __getitem__(self, idx): |
| | tokens = torch.from_numpy(self.data[self.start + idx].copy()).long() |
| | return {"input_ids": tokens, "labels": tokens.clone()} |
| |
|
| | def split(self, val_fraction=0.1): |
| | """Split into (train, val) datasets. Both share the same memmap file.""" |
| | total = self.end - self.start |
| | n_val = max(1, int(total * val_fraction)) |
| | train = MemmapDataset(self.path, self.start, self.end - n_val) |
| | val = MemmapDataset(self.path, self.end - n_val, self.end) |
| | return train, val |
| |
|
| |
|
| | class TextDataset(Dataset): |
| | """Simple in-memory dataset from tokenized chunks. For small datasets.""" |
| |
|
| | def __init__(self, token_chunks: list[list[int]], max_seq_len: int): |
| | self.chunks = token_chunks |
| | self.max_seq_len = max_seq_len |
| |
|
| | def __len__(self) -> int: |
| | return len(self.chunks) |
| |
|
| | def __getitem__(self, idx: int) -> dict[str, torch.Tensor]: |
| | tokens = self.chunks[idx] |
| | if len(tokens) < self.max_seq_len: |
| | tokens = tokens + [0] * (self.max_seq_len - len(tokens)) |
| | else: |
| | tokens = tokens[: self.max_seq_len] |
| | input_ids = torch.tensor(tokens, dtype=torch.long) |
| | return {"input_ids": input_ids, "labels": input_ids.clone()} |
| |
|
| | def split(self, val_fraction=0.1): |
| | """Split into (train, val) datasets with shuffle.""" |
| | import random |
| | random.shuffle(self.chunks) |
| | n_val = max(1, int(len(self.chunks) * val_fraction)) |
| | val = TextDataset(self.chunks[:n_val], self.max_seq_len) |
| | train = TextDataset(self.chunks[n_val:], self.max_seq_len) |
| | return train, val |
| |
|
| |
|
| | |
| | |
| | |
| |
|
| | class _SentencePieceTokenizer: |
| | """Minimal tokenizer wrapper using sentencepiece directly. |
| | Bypasses transformers tokenizer bugs across versions.""" |
| |
|
| | def __init__(self, model_path, name): |
| | import sentencepiece as spm |
| | self.sp = spm.SentencePieceProcessor() |
| | self.sp.Load(model_path) |
| | self._vocab_size = self.sp.GetPieceSize() |
| | self.eos_token_id = self.sp.eos_id() |
| | self.bos_token_id = self.sp.bos_id() |
| | self.eos_token = self.sp.IdToPiece(self.eos_token_id) |
| | self.bos_token = self.sp.IdToPiece(self.bos_token_id) |
| | self.pad_token = None |
| | self.pad_token_id = None |
| | self.name_or_path = name |
| |
|
| | def __len__(self): |
| | return self._vocab_size |
| |
|
| | @property |
| | def vocab_size(self): |
| | return self._vocab_size |
| |
|
| | def encode(self, text, add_special_tokens=False, return_tensors=None): |
| | ids = self.sp.Encode(text) |
| | if return_tensors == "pt": |
| | import torch |
| | return torch.tensor([ids]) |
| | return ids |
| |
|
| | def decode(self, ids, skip_special_tokens=False): |
| | if hasattr(ids, 'tolist'): |
| | ids = ids.tolist() |
| | return self.sp.Decode(list(ids)) |
| |
|
| | def __call__(self, texts, add_special_tokens=False, **kwargs): |
| | if isinstance(texts, str): |
| | texts = [texts] |
| | return {"input_ids": [self.sp.Encode(t) for t in texts]} |
| |
|
| |
|
| | def get_tokenizer(name: str = "gpt2"): |
| | """Get tokenizer from HuggingFace, with sentencepiece fallback. |
| | |
| | Args: |
| | name: Tokenizer name or path. Default "gpt2" (50257 vocab). |
| | Use e.g. "facebook/MobileLLM-125M" for 32K vocab. |
| | """ |
| | from transformers import AutoTokenizer |
| |
|
| | |
| | for use_fast in (True, False): |
| | try: |
| | tokenizer = AutoTokenizer.from_pretrained(name, use_fast=use_fast, |
| | trust_remote_code=True) |
| | if isinstance(tokenizer, bool): |
| | continue |
| | if tokenizer.pad_token is None: |
| | tokenizer.pad_token = tokenizer.eos_token |
| | return tokenizer |
| | except Exception: |
| | continue |
| |
|
| | |
| | print(f"AutoTokenizer failed for {name}, falling back to sentencepiece") |
| | from huggingface_hub import hf_hub_download |
| | model_path = hf_hub_download(name, "tokenizer.model") |
| | tokenizer = _SentencePieceTokenizer(model_path, name) |
| | if tokenizer.pad_token is None: |
| | tokenizer.pad_token = tokenizer.eos_token |
| | tokenizer.pad_token_id = tokenizer.eos_token_id |
| | return tokenizer |
| |
|
| |
|
| | |
| | |
| | |
| |
|
| | def _stream_chunks_to_memmap(tokenized, total_examples, max_seq_len, output_path, |
| | num_workers=1, read_batch=10_000): |
| | """Stream tokenized examples into a memory-mapped binary file. |
| | |
| | Single-process, numpy-batch approach. Reads batches from Arrow dataset, |
| | flattens to numpy int32, writes complete chunks to disk. |
| | Memory: O(read_batch * avg_seq_len * 4 bytes). |
| | No fork, no multiprocessing, no OOM. |
| | """ |
| | from itertools import chain |
| | from tqdm import tqdm |
| |
|
| | temp_path = str(output_path) + ".tmp" |
| | n_chunks = 0 |
| | total_tokens = 0 |
| | carryover = np.array([], dtype=np.int32) |
| |
|
| | n_batches = (total_examples + read_batch - 1) // read_batch |
| |
|
| | with open(temp_path, 'wb') as f: |
| | f.write(struct.pack('II', 0, max_seq_len)) |
| |
|
| | for batch_start in tqdm(range(0, total_examples, read_batch), |
| | total=n_batches, desc="Chunking", |
| | mininterval=1.0): |
| | batch_end = min(batch_start + read_batch, total_examples) |
| | batch_ids = tokenized[batch_start:batch_end]["input_ids"] |
| |
|
| | |
| | n_tok = sum(len(ids) for ids in batch_ids if ids) |
| | if n_tok == 0: |
| | del batch_ids |
| | continue |
| |
|
| | flat = np.fromiter( |
| | chain.from_iterable(ids for ids in batch_ids if ids), |
| | dtype=np.int32, count=n_tok, |
| | ) |
| | del batch_ids |
| | total_tokens += n_tok |
| |
|
| | |
| | if len(carryover) > 0: |
| | flat = np.concatenate([carryover, flat]) |
| |
|
| | |
| | n_complete = len(flat) // max_seq_len |
| | if n_complete > 0: |
| | f.write(flat[:n_complete * max_seq_len].tobytes()) |
| | n_chunks += n_complete |
| |
|
| | carryover = flat[n_complete * max_seq_len:].copy() |
| | del flat |
| |
|
| | |
| | if len(carryover) >= 32: |
| | padded = np.zeros(max_seq_len, dtype=np.int32) |
| | padded[:len(carryover)] = carryover |
| | f.write(padded.tobytes()) |
| | n_chunks += 1 |
| |
|
| | |
| | f.seek(0) |
| | f.write(struct.pack('II', n_chunks, max_seq_len)) |
| |
|
| | os.rename(temp_path, str(output_path)) |
| | size_gb = os.path.getsize(output_path) / 1e9 |
| | print(f"Total tokens: {total_tokens:,} → {n_chunks:,} chunks ({size_gb:.1f} GB)") |
| | return n_chunks |
| |
|
| |
|
| | |
| | |
| | |
| |
|
| | def _flatten_chat(example): |
| | """Convert chat format (system + conversations list) to plain text. |
| | |
| | Handles datasets like Bespoke-Stratos-17k and OpenThoughts-114k |
| | which store data as: system (str) + conversations (list of {from, value}). |
| | """ |
| | parts = [] |
| | if example.get("system"): |
| | parts.append(example["system"].strip()) |
| | for msg in example.get("conversations", []): |
| | value = msg.get("value", "") |
| | if value: |
| | parts.append(value.strip()) |
| | return {"text": "\n\n".join(parts)} |
| |
|
| |
|
| | def _estimate_avg_chars(dataset, text_column: str, n_sample: int = 200) -> float: |
| | """Estimate average text length from a sample of the dataset.""" |
| | n = min(n_sample, len(dataset)) |
| | total = sum(len(dataset[i][text_column] or "") for i in range(n)) |
| | return total / max(n, 1) |
| |
|
| |
|
| | def _adaptive_params(avg_chars: float, n_examples: int): |
| | """Scale worker count, batch sizes based on average example length. |
| | |
| | Long examples (chain-of-thought reasoning) need smaller batches and fewer |
| | workers to avoid OOM on memory-constrained systems (especially WSL). |
| | """ |
| | cpu_count = max(1, multiprocessing.cpu_count() - 1) |
| |
|
| | if avg_chars > 20_000: |
| | num_proc = min(cpu_count, 4) |
| | tok_batch = 64 |
| | read_batch = 500 |
| | elif avg_chars > 5_000: |
| | num_proc = min(cpu_count, 8) |
| | tok_batch = 256 |
| | read_batch = 2_000 |
| | elif avg_chars > 1_000: |
| | num_proc = min(cpu_count, 16) |
| | tok_batch = 500 |
| | read_batch = 5_000 |
| | else: |
| | num_proc = min(cpu_count, 32) |
| | tok_batch = 1000 |
| | read_batch = 10_000 |
| |
|
| | return num_proc, tok_batch, read_batch |
| |
|
| |
|
| | def load_hf_dataset( |
| | name: str, |
| | split: str, |
| | text_column: str, |
| | tokenizer, |
| | max_seq_len: int, |
| | num_samples: int | None = None, |
| | hf_config: str | None = None, |
| | cache_path: Path | None = None, |
| | data_format: str = "text", |
| | ) -> MemmapDataset: |
| | """Load HF dataset with parallel tokenization and streaming to memmap. |
| | |
| | Parallelism: |
| | - dataset.map(num_proc=N) uses multiprocessing — bypasses GIL |
| | - GPT2TokenizerFast runs Rust tokenization — bypasses GIL |
| | - batched=True enables efficient batch processing |
| | |
| | Memory: |
| | - Adaptive batch sizes based on avg example length — prevents OOM on long sequences |
| | - Tokenized data in Arrow format (memory-mapped by HuggingFace) |
| | - Chunks streamed to binary memmap file — never in RAM |
| | """ |
| | from datasets import load_dataset |
| |
|
| | config_str = f", config={hf_config}" if hf_config else "" |
| | print(f"Loading HF dataset: {name} (split={split}{config_str})") |
| | dataset = load_dataset(name, hf_config, split=split) |
| |
|
| | if num_samples is not None: |
| | dataset = dataset.select(range(min(num_samples, len(dataset)))) |
| |
|
| | |
| | if data_format == "chat": |
| | |
| | flat_proc = min(max(1, multiprocessing.cpu_count() - 1), 8) |
| | print(f"Flattening {len(dataset):,} chat examples to plain text...") |
| | dataset = dataset.map( |
| | _flatten_chat, |
| | num_proc=flat_proc, |
| | remove_columns=dataset.column_names, |
| | desc="Flattening chat", |
| | ) |
| | text_column = "text" |
| |
|
| | |
| | avg_chars = _estimate_avg_chars(dataset, text_column) |
| | num_proc, tok_batch, read_batch = _adaptive_params(avg_chars, len(dataset)) |
| | print(f" Avg example length: ~{avg_chars:,.0f} chars → " |
| | f"{num_proc} workers, tok_batch={tok_batch}, read_batch={read_batch}") |
| |
|
| | |
| | print(f"Filtering empty examples from {len(dataset):,}...") |
| | dataset = dataset.filter( |
| | lambda x: bool(x[text_column] and x[text_column].strip()), |
| | num_proc=num_proc, |
| | desc="Filtering", |
| | ) |
| | print(f" {len(dataset):,} non-empty examples") |
| |
|
| | |
| | print(f"Tokenizing {len(dataset):,} examples with {num_proc} workers...") |
| |
|
| | def tokenize_batch(examples): |
| | return tokenizer(examples[text_column], add_special_tokens=False) |
| |
|
| | tokenized = dataset.map( |
| | tokenize_batch, |
| | batched=True, |
| | batch_size=tok_batch, |
| | num_proc=num_proc, |
| | remove_columns=dataset.column_names, |
| | desc="Tokenizing", |
| | ) |
| |
|
| | |
| | if cache_path is None: |
| | import tempfile |
| | cache_path = Path(tempfile.mktemp(suffix='.bin')) |
| |
|
| | _stream_chunks_to_memmap(tokenized, len(tokenized), max_seq_len, cache_path, |
| | read_batch=read_batch) |
| | return MemmapDataset(cache_path) |
| |
|
| |
|
| | |
| | |
| | |
| |
|
| | def tokenize_text(text: str, tokenizer, max_seq_len: int) -> list[list[int]]: |
| | """Tokenize text into chunks of max_seq_len.""" |
| | tokens = tokenizer.encode(text) |
| | chunks = [] |
| | for i in range(0, len(tokens), max_seq_len): |
| | chunk = tokens[i : i + max_seq_len] |
| | if len(chunk) >= 32: |
| | chunks.append(chunk) |
| | return chunks |
| |
|
| |
|
| | def load_text_file(path: str, tokenizer, max_seq_len: int) -> list[list[int]]: |
| | """Load and tokenize a single text file.""" |
| | with open(path, "r", encoding="utf-8") as f: |
| | text = f.read() |
| | return tokenize_text(text, tokenizer, max_seq_len) |
| |
|
| |
|
| | def load_text_directory(path: str, tokenizer, max_seq_len: int) -> list[list[int]]: |
| | """Load and tokenize all .txt files from a directory.""" |
| | all_chunks = [] |
| | path = Path(path) |
| | for txt_file in sorted(path.glob("**/*.txt")): |
| | chunks = load_text_file(str(txt_file), tokenizer, max_seq_len) |
| | all_chunks.extend(chunks) |
| | return all_chunks |
| |
|
| |
|
| | |
| | |
| | |
| |
|
| | def load_data( |
| | data_source: str, |
| | tokenizer, |
| | max_seq_len: int, |
| | text_column: str = "text", |
| | num_samples: int | None = None, |
| | cache_dir: str | None = DEFAULT_CACHE_DIR, |
| | data_format: str = "text", |
| | ) -> Dataset: |
| | """ |
| | Load data from various sources. Returns a Dataset with .split() support. |
| | |
| | Args: |
| | data_source: Path or HF dataset identifier |
| | - "path/to/file.txt" — single file |
| | - "path/to/dir/" — directory of .txt files |
| | - "hf:dataset_name" — HuggingFace dataset (train split) |
| | - "hf:dataset:split" — HuggingFace with specific split |
| | - "hf:dataset:config:split" — with config and split |
| | tokenizer: Tokenizer to use |
| | max_seq_len: Maximum sequence length |
| | text_column: Column name for HF datasets |
| | num_samples: Limit samples from HF dataset |
| | cache_dir: Directory for cache files (None to disable) |
| | |
| | Returns: |
| | Dataset object supporting len(), __getitem__(), and split(fraction) |
| | """ |
| | cache_path = None |
| | if cache_dir is not None: |
| | cache_path = Path(cache_dir) / _cache_key(data_source, max_seq_len, num_samples) |
| | cache_path.parent.mkdir(parents=True, exist_ok=True) |
| |
|
| | |
| | if cache_path.exists(): |
| | print(f"Loading from cache: {cache_path}") |
| | ds = MemmapDataset(cache_path) |
| | print(f" Loaded {len(ds):,} chunks") |
| | return ds |
| |
|
| | |
| | legacy_path = cache_path.with_suffix('.pt') |
| | if legacy_path.exists(): |
| | print(f"Loading from legacy cache: {legacy_path}") |
| | data = torch.load(legacy_path, weights_only=False) |
| | chunks = data["chunks"] |
| | print(f" Loaded {len(chunks):,} chunks") |
| | return TextDataset(chunks, max_seq_len) |
| |
|
| | |
| | if data_source.startswith("hf:"): |
| | parts = data_source[3:].split(":") |
| | name = parts[0] |
| | hf_config = None |
| | split = "train" |
| | if len(parts) == 2: |
| | split = parts[1] |
| | elif len(parts) == 3: |
| | hf_config = parts[1] |
| | split = parts[2] |
| | return load_hf_dataset( |
| | name, split, text_column, tokenizer, max_seq_len, |
| | num_samples, hf_config=hf_config, cache_path=cache_path, |
| | data_format=data_format, |
| | ) |
| | elif os.path.isfile(data_source): |
| | chunks = load_text_file(data_source, tokenizer, max_seq_len) |
| | elif os.path.isdir(data_source): |
| | chunks = load_text_directory(data_source, tokenizer, max_seq_len) |
| | else: |
| | raise ValueError(f"Unknown data source: {data_source}") |
| |
|
| | |
| | if cache_dir is not None: |
| | legacy_path = cache_path.with_suffix('.pt') |
| | torch.save({"chunks": chunks, "data_source": data_source, |
| | "max_seq_len": max_seq_len, "num_samples": num_samples}, legacy_path) |
| | print(f"Saved to cache: {legacy_path}") |
| |
|
| | return TextDataset(chunks, max_seq_len) |
| |
|
| |
|
| | |
| | |
| | |
| |
|
| | def create_dataloader( |
| | dataset, |
| | batch_size: int, |
| | max_seq_len: int = None, |
| | shuffle: bool = True, |
| | num_workers: int = 0, |
| | ) -> DataLoader: |
| | """Create a DataLoader from a Dataset or list of chunks.""" |
| | if not isinstance(dataset, Dataset): |
| | |
| | dataset = TextDataset(dataset, max_seq_len) |
| | return DataLoader( |
| | dataset, |
| | batch_size=batch_size, |
| | shuffle=shuffle, |
| | num_workers=num_workers, |
| | pin_memory=True, |
| | ) |
| |
|