Prisma / data.py
y3i12's picture
Initial commit
56e82ec
"""
Data loading utilities for Circuit Transformer.
Supports:
- Single text file: --data path/to/file.txt
- Directory of text files: --data path/to/dir/
- HuggingFace dataset: --data hf:dataset_name
Caching:
- HF datasets: memory-mapped binary files (.bin) — O(1) RAM
- Text files: torch .pt files (legacy, in-memory)
- Cache location: ./circuits/.cache/ (or custom via cache_dir)
Parallelism:
- HF datasets tokenized via dataset.map(num_proc=N) — multiprocessing, bypasses GIL
- Fast tokenizer uses Rust internally — additional parallelism within each worker
"""
import os
import struct
import hashlib
import multiprocessing
from pathlib import Path
import numpy as np
import torch
from torch.utils.data import Dataset, DataLoader
DEFAULT_CACHE_DIR = "./circuits/.cache"
# Memmap binary format:
# Header: 8 bytes = [uint32 n_chunks, uint32 max_seq_len]
# Data: n_chunks * max_seq_len * 4 bytes (int32, row-major)
HEADER_SIZE = 8
# ---------------------------------------------------------------------------
# Cache utilities
# ---------------------------------------------------------------------------
def _cache_key(data_source: str, max_seq_len: int, num_samples: int | None) -> str:
"""Generate cache filename from parameters."""
key_str = f"{data_source}|{max_seq_len}|{num_samples}"
hash_val = hashlib.md5(key_str.encode()).hexdigest()[:12]
name = data_source.replace("/", "_").replace(":", "_").replace(".", "_")[-30:]
return f"{name}_{max_seq_len}_{hash_val}.bin"
# ---------------------------------------------------------------------------
# Dataset classes
# ---------------------------------------------------------------------------
class MemmapDataset(Dataset):
"""Dataset backed by memory-mapped binary file. O(1) RAM regardless of size."""
def __init__(self, path, start=None, end=None):
self.path = str(path)
with open(self.path, 'rb') as f:
total, self.max_seq_len = struct.unpack('II', f.read(HEADER_SIZE))
self._total = total
self.data = np.memmap(
self.path, dtype=np.int32, mode='r',
offset=HEADER_SIZE, shape=(total, self.max_seq_len),
)
self.start = start if start is not None else 0
self.end = end if end is not None else total
def __len__(self):
return self.end - self.start
def __getitem__(self, idx):
tokens = torch.from_numpy(self.data[self.start + idx].copy()).long()
return {"input_ids": tokens, "labels": tokens.clone()}
def split(self, val_fraction=0.1):
"""Split into (train, val) datasets. Both share the same memmap file."""
total = self.end - self.start
n_val = max(1, int(total * val_fraction))
train = MemmapDataset(self.path, self.start, self.end - n_val)
val = MemmapDataset(self.path, self.end - n_val, self.end)
return train, val
class TextDataset(Dataset):
"""Simple in-memory dataset from tokenized chunks. For small datasets."""
def __init__(self, token_chunks: list[list[int]], max_seq_len: int):
self.chunks = token_chunks
self.max_seq_len = max_seq_len
def __len__(self) -> int:
return len(self.chunks)
def __getitem__(self, idx: int) -> dict[str, torch.Tensor]:
tokens = self.chunks[idx]
if len(tokens) < self.max_seq_len:
tokens = tokens + [0] * (self.max_seq_len - len(tokens))
else:
tokens = tokens[: self.max_seq_len]
input_ids = torch.tensor(tokens, dtype=torch.long)
return {"input_ids": input_ids, "labels": input_ids.clone()}
def split(self, val_fraction=0.1):
"""Split into (train, val) datasets with shuffle."""
import random
random.shuffle(self.chunks)
n_val = max(1, int(len(self.chunks) * val_fraction))
val = TextDataset(self.chunks[:n_val], self.max_seq_len)
train = TextDataset(self.chunks[n_val:], self.max_seq_len)
return train, val
# ---------------------------------------------------------------------------
# Tokenizer
# ---------------------------------------------------------------------------
class _SentencePieceTokenizer:
"""Minimal tokenizer wrapper using sentencepiece directly.
Bypasses transformers tokenizer bugs across versions."""
def __init__(self, model_path, name):
import sentencepiece as spm
self.sp = spm.SentencePieceProcessor()
self.sp.Load(model_path)
self._vocab_size = self.sp.GetPieceSize()
self.eos_token_id = self.sp.eos_id()
self.bos_token_id = self.sp.bos_id()
self.eos_token = self.sp.IdToPiece(self.eos_token_id)
self.bos_token = self.sp.IdToPiece(self.bos_token_id)
self.pad_token = None
self.pad_token_id = None
self.name_or_path = name
def __len__(self):
return self._vocab_size
@property
def vocab_size(self):
return self._vocab_size
def encode(self, text, add_special_tokens=False, return_tensors=None):
ids = self.sp.Encode(text)
if return_tensors == "pt":
import torch
return torch.tensor([ids])
return ids
def decode(self, ids, skip_special_tokens=False):
if hasattr(ids, 'tolist'):
ids = ids.tolist()
return self.sp.Decode(list(ids))
def __call__(self, texts, add_special_tokens=False, **kwargs):
if isinstance(texts, str):
texts = [texts]
return {"input_ids": [self.sp.Encode(t) for t in texts]}
def get_tokenizer(name: str = "gpt2"):
"""Get tokenizer from HuggingFace, with sentencepiece fallback.
Args:
name: Tokenizer name or path. Default "gpt2" (50257 vocab).
Use e.g. "facebook/MobileLLM-125M" for 32K vocab.
"""
from transformers import AutoTokenizer
# Try AutoTokenizer (fast then slow)
for use_fast in (True, False):
try:
tokenizer = AutoTokenizer.from_pretrained(name, use_fast=use_fast,
trust_remote_code=True)
if isinstance(tokenizer, bool):
continue
if tokenizer.pad_token is None:
tokenizer.pad_token = tokenizer.eos_token
return tokenizer
except Exception:
continue
# Fallback: load sentencepiece model directly (bypasses transformers bugs)
print(f"AutoTokenizer failed for {name}, falling back to sentencepiece")
from huggingface_hub import hf_hub_download
model_path = hf_hub_download(name, "tokenizer.model")
tokenizer = _SentencePieceTokenizer(model_path, name)
if tokenizer.pad_token is None:
tokenizer.pad_token = tokenizer.eos_token
tokenizer.pad_token_id = tokenizer.eos_token_id
return tokenizer
# ---------------------------------------------------------------------------
# Streaming memmap writer
# ---------------------------------------------------------------------------
def _stream_chunks_to_memmap(tokenized, total_examples, max_seq_len, output_path,
num_workers=1, read_batch=10_000):
"""Stream tokenized examples into a memory-mapped binary file.
Single-process, numpy-batch approach. Reads batches from Arrow dataset,
flattens to numpy int32, writes complete chunks to disk.
Memory: O(read_batch * avg_seq_len * 4 bytes).
No fork, no multiprocessing, no OOM.
"""
from itertools import chain
from tqdm import tqdm
temp_path = str(output_path) + ".tmp"
n_chunks = 0
total_tokens = 0
carryover = np.array([], dtype=np.int32)
n_batches = (total_examples + read_batch - 1) // read_batch
with open(temp_path, 'wb') as f:
f.write(struct.pack('II', 0, max_seq_len)) # placeholder header
for batch_start in tqdm(range(0, total_examples, read_batch),
total=n_batches, desc="Chunking",
mininterval=1.0):
batch_end = min(batch_start + read_batch, total_examples)
batch_ids = tokenized[batch_start:batch_end]["input_ids"]
# Count tokens, flatten Arrow→numpy without intermediate Python list
n_tok = sum(len(ids) for ids in batch_ids if ids)
if n_tok == 0:
del batch_ids
continue
flat = np.fromiter(
chain.from_iterable(ids for ids in batch_ids if ids),
dtype=np.int32, count=n_tok,
)
del batch_ids
total_tokens += n_tok
# Prepend carryover from previous batch
if len(carryover) > 0:
flat = np.concatenate([carryover, flat])
# Write complete chunks
n_complete = len(flat) // max_seq_len
if n_complete > 0:
f.write(flat[:n_complete * max_seq_len].tobytes())
n_chunks += n_complete
carryover = flat[n_complete * max_seq_len:].copy()
del flat
# Handle remaining tokens
if len(carryover) >= 32:
padded = np.zeros(max_seq_len, dtype=np.int32)
padded[:len(carryover)] = carryover
f.write(padded.tobytes())
n_chunks += 1
# Write actual count into header
f.seek(0)
f.write(struct.pack('II', n_chunks, max_seq_len))
os.rename(temp_path, str(output_path))
size_gb = os.path.getsize(output_path) / 1e9
print(f"Total tokens: {total_tokens:,} → {n_chunks:,} chunks ({size_gb:.1f} GB)")
return n_chunks
# ---------------------------------------------------------------------------
# HuggingFace dataset loader (parallel + memmap)
# ---------------------------------------------------------------------------
def _flatten_chat(example):
"""Convert chat format (system + conversations list) to plain text.
Handles datasets like Bespoke-Stratos-17k and OpenThoughts-114k
which store data as: system (str) + conversations (list of {from, value}).
"""
parts = []
if example.get("system"):
parts.append(example["system"].strip())
for msg in example.get("conversations", []):
value = msg.get("value", "")
if value:
parts.append(value.strip())
return {"text": "\n\n".join(parts)}
def _estimate_avg_chars(dataset, text_column: str, n_sample: int = 200) -> float:
"""Estimate average text length from a sample of the dataset."""
n = min(n_sample, len(dataset))
total = sum(len(dataset[i][text_column] or "") for i in range(n))
return total / max(n, 1)
def _adaptive_params(avg_chars: float, n_examples: int):
"""Scale worker count, batch sizes based on average example length.
Long examples (chain-of-thought reasoning) need smaller batches and fewer
workers to avoid OOM on memory-constrained systems (especially WSL).
"""
cpu_count = max(1, multiprocessing.cpu_count() - 1)
if avg_chars > 20_000: # very long (OpenThoughts-style, ~7K+ tokens)
num_proc = min(cpu_count, 4)
tok_batch = 64
read_batch = 500
elif avg_chars > 5_000: # long (detailed SFT, ~1.5K+ tokens)
num_proc = min(cpu_count, 8)
tok_batch = 256
read_batch = 2_000
elif avg_chars > 1_000: # medium (typical SFT)
num_proc = min(cpu_count, 16)
tok_batch = 500
read_batch = 5_000
else: # short (web text, wiki)
num_proc = min(cpu_count, 32)
tok_batch = 1000
read_batch = 10_000
return num_proc, tok_batch, read_batch
def load_hf_dataset(
name: str,
split: str,
text_column: str,
tokenizer,
max_seq_len: int,
num_samples: int | None = None,
hf_config: str | None = None,
cache_path: Path | None = None,
data_format: str = "text",
) -> MemmapDataset:
"""Load HF dataset with parallel tokenization and streaming to memmap.
Parallelism:
- dataset.map(num_proc=N) uses multiprocessing — bypasses GIL
- GPT2TokenizerFast runs Rust tokenization — bypasses GIL
- batched=True enables efficient batch processing
Memory:
- Adaptive batch sizes based on avg example length — prevents OOM on long sequences
- Tokenized data in Arrow format (memory-mapped by HuggingFace)
- Chunks streamed to binary memmap file — never in RAM
"""
from datasets import load_dataset
config_str = f", config={hf_config}" if hf_config else ""
print(f"Loading HF dataset: {name} (split={split}{config_str})")
dataset = load_dataset(name, hf_config, split=split)
if num_samples is not None:
dataset = dataset.select(range(min(num_samples, len(dataset))))
# Flatten chat format to plain text
if data_format == "chat":
# Use conservative parallelism for flattening — light operation
flat_proc = min(max(1, multiprocessing.cpu_count() - 1), 8)
print(f"Flattening {len(dataset):,} chat examples to plain text...")
dataset = dataset.map(
_flatten_chat,
num_proc=flat_proc,
remove_columns=dataset.column_names,
desc="Flattening chat",
)
text_column = "text"
# Estimate avg example length and adapt parameters
avg_chars = _estimate_avg_chars(dataset, text_column)
num_proc, tok_batch, read_batch = _adaptive_params(avg_chars, len(dataset))
print(f" Avg example length: ~{avg_chars:,.0f} chars → "
f"{num_proc} workers, tok_batch={tok_batch}, read_batch={read_batch}")
# Filter empty examples
print(f"Filtering empty examples from {len(dataset):,}...")
dataset = dataset.filter(
lambda x: bool(x[text_column] and x[text_column].strip()),
num_proc=num_proc,
desc="Filtering",
)
print(f" {len(dataset):,} non-empty examples")
# Parallel tokenization
print(f"Tokenizing {len(dataset):,} examples with {num_proc} workers...")
def tokenize_batch(examples):
return tokenizer(examples[text_column], add_special_tokens=False)
tokenized = dataset.map(
tokenize_batch,
batched=True,
batch_size=tok_batch,
num_proc=num_proc,
remove_columns=dataset.column_names,
desc="Tokenizing",
)
# Stream to memmap — use temp path if no cache configured
if cache_path is None:
import tempfile
cache_path = Path(tempfile.mktemp(suffix='.bin'))
_stream_chunks_to_memmap(tokenized, len(tokenized), max_seq_len, cache_path,
read_batch=read_batch)
return MemmapDataset(cache_path)
# ---------------------------------------------------------------------------
# Text file loaders (unchanged — small datasets, in-memory is fine)
# ---------------------------------------------------------------------------
def tokenize_text(text: str, tokenizer, max_seq_len: int) -> list[list[int]]:
"""Tokenize text into chunks of max_seq_len."""
tokens = tokenizer.encode(text)
chunks = []
for i in range(0, len(tokens), max_seq_len):
chunk = tokens[i : i + max_seq_len]
if len(chunk) >= 32:
chunks.append(chunk)
return chunks
def load_text_file(path: str, tokenizer, max_seq_len: int) -> list[list[int]]:
"""Load and tokenize a single text file."""
with open(path, "r", encoding="utf-8") as f:
text = f.read()
return tokenize_text(text, tokenizer, max_seq_len)
def load_text_directory(path: str, tokenizer, max_seq_len: int) -> list[list[int]]:
"""Load and tokenize all .txt files from a directory."""
all_chunks = []
path = Path(path)
for txt_file in sorted(path.glob("**/*.txt")):
chunks = load_text_file(str(txt_file), tokenizer, max_seq_len)
all_chunks.extend(chunks)
return all_chunks
# ---------------------------------------------------------------------------
# Main entry point
# ---------------------------------------------------------------------------
def load_data(
data_source: str,
tokenizer,
max_seq_len: int,
text_column: str = "text",
num_samples: int | None = None,
cache_dir: str | None = DEFAULT_CACHE_DIR,
data_format: str = "text",
) -> Dataset:
"""
Load data from various sources. Returns a Dataset with .split() support.
Args:
data_source: Path or HF dataset identifier
- "path/to/file.txt" — single file
- "path/to/dir/" — directory of .txt files
- "hf:dataset_name" — HuggingFace dataset (train split)
- "hf:dataset:split" — HuggingFace with specific split
- "hf:dataset:config:split" — with config and split
tokenizer: Tokenizer to use
max_seq_len: Maximum sequence length
text_column: Column name for HF datasets
num_samples: Limit samples from HF dataset
cache_dir: Directory for cache files (None to disable)
Returns:
Dataset object supporting len(), __getitem__(), and split(fraction)
"""
cache_path = None
if cache_dir is not None:
cache_path = Path(cache_dir) / _cache_key(data_source, max_seq_len, num_samples)
cache_path.parent.mkdir(parents=True, exist_ok=True)
# Check for memmap cache (.bin)
if cache_path.exists():
print(f"Loading from cache: {cache_path}")
ds = MemmapDataset(cache_path)
print(f" Loaded {len(ds):,} chunks")
return ds
# Check for legacy cache (.pt)
legacy_path = cache_path.with_suffix('.pt')
if legacy_path.exists():
print(f"Loading from legacy cache: {legacy_path}")
data = torch.load(legacy_path, weights_only=False)
chunks = data["chunks"]
print(f" Loaded {len(chunks):,} chunks")
return TextDataset(chunks, max_seq_len)
# Load and tokenize
if data_source.startswith("hf:"):
parts = data_source[3:].split(":")
name = parts[0]
hf_config = None
split = "train"
if len(parts) == 2:
split = parts[1]
elif len(parts) == 3:
hf_config = parts[1]
split = parts[2]
return load_hf_dataset(
name, split, text_column, tokenizer, max_seq_len,
num_samples, hf_config=hf_config, cache_path=cache_path,
data_format=data_format,
)
elif os.path.isfile(data_source):
chunks = load_text_file(data_source, tokenizer, max_seq_len)
elif os.path.isdir(data_source):
chunks = load_text_directory(data_source, tokenizer, max_seq_len)
else:
raise ValueError(f"Unknown data source: {data_source}")
# For text files: save legacy cache
if cache_dir is not None:
legacy_path = cache_path.with_suffix('.pt')
torch.save({"chunks": chunks, "data_source": data_source,
"max_seq_len": max_seq_len, "num_samples": num_samples}, legacy_path)
print(f"Saved to cache: {legacy_path}")
return TextDataset(chunks, max_seq_len)
# ---------------------------------------------------------------------------
# DataLoader factory
# ---------------------------------------------------------------------------
def create_dataloader(
dataset,
batch_size: int,
max_seq_len: int = None,
shuffle: bool = True,
num_workers: int = 0,
) -> DataLoader:
"""Create a DataLoader from a Dataset or list of chunks."""
if not isinstance(dataset, Dataset):
# Legacy compatibility: list of token chunks
dataset = TextDataset(dataset, max_seq_len)
return DataLoader(
dataset,
batch_size=batch_size,
shuffle=shuffle,
num_workers=num_workers,
pin_memory=True,
)