FreqFormer-32M / freqformer /sft_data.py
cturan's picture
Upload folder using huggingface_hub
01fcd60 verified
"""
SFT Data Pipeline
=================
Supervised Fine-Tuning dataset with ChatML template and label masking.
Supports:
- Dolci-Think-SFT (chat with <think> blocks)
- Dolci-Instruct-SFT-Tool-Use (function calling with <functions>, <function_calls>)
- Any ChatML-format conversation dataset
Label mask: loss computed ONLY on assistant response tokens.
System, user, and environment tokens are masked out (ignore_index=-100).
Tokenizer: freqformer/tokenizer (Mistral-7B BPE 32K + ChatML special tokens)
Usage:
# Tokenize HuggingFace SFT dataset to binary
python -m freqformer.sft_data --dataset allenai/Dolci-Think-SFT-7B \
--out data/sft_think --max_seq_len 4096
# Tokenize tool-use dataset
python -m freqformer.sft_data --dataset allenai/Dolci-Instruct-SFT-Tool-Use \
--out data/sft_tools --max_seq_len 4096
# Merge multiple SFT datasets
python -m freqformer.sft_data --dataset allenai/Dolci-Think-SFT-7B \
--out data/sft_mixed --max_seq_len 4096
python -m freqformer.sft_data --dataset allenai/Dolci-Instruct-SFT-Tool-Use \
--out data/sft_mixed --max_seq_len 4096 --append
"""
import argparse
import os
import struct
import json
import numpy as np
import torch
import torch.distributed as dist
from torch.utils.data import Dataset, DataLoader, DistributedSampler
from transformers import AutoTokenizer
from typing import Optional
# ---------------------------------------------------------------------------
# Binary format: same FREQTOK1 header + parallel label_mask file
# ---------------------------------------------------------------------------
HEADER_MAGIC = b"FREQSFT1"
HEADER_SIZE = 48 # magic(8) + version(4) + num_examples(8) + max_seq_len(4) + vocab_size(4) + pad_id(4) + reserved(16)
def _is_rank0() -> bool:
return not dist.is_initialized() or dist.get_rank() == 0
# ---------------------------------------------------------------------------
# ChatML Template
# ---------------------------------------------------------------------------
class ChatMLTemplate:
"""
Applies ChatML template to conversations and builds label masks.
Template format (matching OLMo 3.1):
<|im_start|>system\n{content}<|im_end|>\n
<|im_start|>user\n{content}<|im_end|>\n
<|im_start|>assistant\n{content}<|im_end|>\n (or eos for last)
<|im_start|>environment\n{content}<|im_end|>\n
Tool-use extensions:
- system: content + " <functions>{functions}</functions>"
- assistant: content + "<function_calls>{function_calls}</function_calls>"
- user: content + "\n<functions>{functions}</functions>" (if user has functions)
Label mask: True ONLY for assistant response tokens (body + closing token).
"""
def __init__(self, tokenizer):
self.tokenizer = tokenizer
self.im_start_id = tokenizer.convert_tokens_to_ids("<|im_start|>")
self.im_end_id = tokenizer.convert_tokens_to_ids("<|im_end|>")
self.eos_id = tokenizer.eos_token_id
self.pad_id = tokenizer.pad_token_id if tokenizer.pad_token_id is not None else self.eos_id
def apply(self, messages: list[dict], add_generation_prompt: bool = False) -> dict:
"""
Convert a list of messages to token_ids + label_mask.
Returns:
dict with "input_ids" (list[int]) and "label_mask" (list[bool])
"""
all_token_ids = []
all_label_mask = []
for i, msg in enumerate(messages):
role = msg["role"]
content = msg.get("content") or ""
function_calls = msg.get("function_calls")
functions = msg.get("functions")
# --- Build header: <|im_start|>role\n ---
header_text = f"<|im_start|>{role}\n"
header_ids = self.tokenizer.encode(header_text, add_special_tokens=False)
# --- Build body based on role ---
if role == "system":
body_text = content
if functions:
body_text += f" <functions>{functions}</functions>"
elif role == "user":
body_text = content
if functions:
body_text += f"\n<functions>{functions}</functions>"
elif role == "assistant":
parts = []
if content:
parts.append(content)
if function_calls:
parts.append(f"<function_calls>{function_calls}</function_calls>")
body_text = "".join(parts)
elif role == "environment":
body_text = content
else:
body_text = content
body_ids = self.tokenizer.encode(body_text, add_special_tokens=False) if body_text else []
# --- Build footer ---
is_last_msg = (i == len(messages) - 1)
if role == "assistant" and is_last_msg:
# Last assistant turn ends with eos_token
footer_ids = [self.eos_id]
else:
# <|im_end|>\n
footer_text = "<|im_end|>\n"
footer_ids = self.tokenizer.encode(footer_text, add_special_tokens=False)
# --- Assemble token_ids and label_mask ---
is_assistant = (role == "assistant")
# Header: always masked (not trained on)
all_token_ids.extend(header_ids)
all_label_mask.extend([False] * len(header_ids))
# Body: only assistant body is trained on
all_token_ids.extend(body_ids)
all_label_mask.extend([is_assistant] * len(body_ids))
# Footer: assistant footer is trained on (model learns to produce eos/<|im_end|>)
all_token_ids.extend(footer_ids)
all_label_mask.extend([is_assistant] * len(footer_ids))
# Optionally add generation prompt for inference
if add_generation_prompt:
gen_text = "<|im_start|>assistant\n"
gen_ids = self.tokenizer.encode(gen_text, add_special_tokens=False)
all_token_ids.extend(gen_ids)
all_label_mask.extend([False] * len(gen_ids))
return {
"input_ids": all_token_ids,
"label_mask": all_label_mask,
}
# ---------------------------------------------------------------------------
# SFT Dataset (mmap)
# ---------------------------------------------------------------------------
class SFTDataset(Dataset):
"""
Memory-mapped SFT dataset. Serves (input_ids, targets, label_mask) triples.
input_ids: tokens[:-1]
targets: tokens[1:] (shifted by 1)
label_mask: mask[1:] (aligned with targets)
"""
def __init__(self, path_prefix: str, max_seq_len: Optional[int] = None):
tok_path = path_prefix + ".bin"
mask_path = path_prefix + ".mask.bin"
assert os.path.exists(tok_path), f"SFT data not found: {tok_path}"
assert os.path.exists(mask_path), f"SFT mask not found: {mask_path}"
# Read header
with open(tok_path, "rb") as f:
magic = f.read(8)
assert magic == HEADER_MAGIC, f"Invalid SFT file: {tok_path}"
_version = struct.unpack("<I", f.read(4))[0]
self.num_examples = struct.unpack("<Q", f.read(8))[0]
self.file_seq_len = struct.unpack("<I", f.read(4))[0]
self.vocab_size = struct.unpack("<I", f.read(4))[0]
self.pad_id = struct.unpack("<I", f.read(4))[0]
self.seq_len = max_seq_len if max_seq_len and max_seq_len < self.file_seq_len else self.file_seq_len
# Mmap data
self.tokens = np.memmap(tok_path, dtype=np.int32, mode="r", offset=HEADER_SIZE,
shape=(self.num_examples, self.file_seq_len))
self.masks = np.memmap(mask_path, dtype=np.uint8, mode="r", offset=HEADER_SIZE,
shape=(self.num_examples, self.file_seq_len))
def __len__(self):
return self.num_examples
def __getitem__(self, idx):
seq_len = self.seq_len
tokens = self.tokens[idx, :seq_len + 1].astype(np.int64) # +1 for shift
mask = self.masks[idx, :seq_len + 1].astype(np.bool_)
# Clamp if file doesn't have seq_len+1 tokens
if len(tokens) < seq_len + 1:
tokens = np.pad(tokens, (0, seq_len + 1 - len(tokens)), constant_values=self.pad_id)
mask = np.pad(mask, (0, seq_len + 1 - len(mask)), constant_values=False)
input_ids = torch.from_numpy(tokens[:-1].copy()) # (seq_len,)
targets = torch.from_numpy(tokens[1:].copy()) # (seq_len,)
label_mask = torch.from_numpy(mask[1:].copy()) # (seq_len,) aligned with targets
return input_ids, targets, label_mask
def get_sft_dataloaders(
data_dir: str,
max_seq_len: int,
batch_size: int,
train_file: str = "train",
val_file: str = "val",
num_workers: int = 4,
distributed: bool = False,
):
"""Create train and val SFT dataloaders."""
train_prefix = os.path.join(data_dir, train_file)
val_prefix = os.path.join(data_dir, val_file)
train_ds = SFTDataset(train_prefix, max_seq_len)
if _is_rank0():
print(f"SFT Train: {train_ds.num_examples:,} examples, seq_len={train_ds.seq_len}")
val_ds = None
val_loader = None
if os.path.exists(val_prefix + ".bin"):
val_ds = SFTDataset(val_prefix, max_seq_len)
if _is_rank0():
print(f"SFT Val: {val_ds.num_examples:,} examples, seq_len={val_ds.seq_len}")
val_sampler = DistributedSampler(val_ds, shuffle=False) if distributed else None
val_loader = DataLoader(
val_ds, batch_size=batch_size, shuffle=False,
sampler=val_sampler, num_workers=num_workers,
pin_memory=True, drop_last=True,
persistent_workers=num_workers > 0,
prefetch_factor=4 if num_workers > 0 else None,
)
train_sampler = DistributedSampler(train_ds, shuffle=True) if distributed else None
train_loader = DataLoader(
train_ds, batch_size=batch_size,
shuffle=(train_sampler is None),
sampler=train_sampler, num_workers=num_workers,
pin_memory=True, drop_last=True,
persistent_workers=num_workers > 0,
prefetch_factor=4 if num_workers > 0 else None,
)
return train_loader, val_loader
# ---------------------------------------------------------------------------
# Tokenize HuggingFace SFT Dataset
# ---------------------------------------------------------------------------
def _detect_messages_column(ds, streaming: bool) -> str:
"""Auto-detect the column containing conversation messages."""
if streaming:
first = next(iter(ds))
cols = list(first.keys())
else:
cols = ds.column_names
# Common column names for chat datasets
for candidate in ["messages", "conversations", "chat", "dialogue"]:
if candidate in cols:
return candidate
# If dataset has a single column that looks like a list of dicts
for col in cols:
if streaming:
sample = first[col]
else:
sample = ds[0][col]
if isinstance(sample, list) and len(sample) > 0 and isinstance(sample[0], dict):
return col
raise ValueError(f"Cannot detect messages column from: {cols}")
def _sft_worker_init(tokenizer_name):
"""Initialize per-worker tokenizer + template for multiprocessing."""
global _worker_template
tok = AutoTokenizer.from_pretrained(tokenizer_name)
_worker_template = ChatMLTemplate(tok)
def _sft_worker_process(messages):
"""Process a single conversation in a worker. Returns (result_dict, skip_reason)."""
global _worker_template
if not messages or len(messages) < 2:
return None, "short"
if isinstance(messages, str):
try:
messages = json.loads(messages)
except json.JSONDecodeError:
return None, "json_error"
result = _worker_template.apply(messages)
token_ids = result["input_ids"]
label_mask = result["label_mask"]
if not any(label_mask):
return None, "no_assistant"
if len(token_ids) < 2:
return None, "too_short"
return {"input_ids": token_ids, "label_mask": label_mask}, None
def tokenize_sft_dataset(
dataset_name: str,
tokenizer_name: str,
out_dir: str,
subset: Optional[str] = None,
max_seq_len: int = 4096,
max_samples: int = 0,
split: str = "train",
val_ratio: float = 0.02,
streaming: bool = False,
append: bool = False,
num_proc: int = 0,
):
"""Tokenize a HuggingFace SFT dataset with ChatML template + label masks.
Uses multiprocessing for non-streaming datasets to utilize all CPU cores.
"""
from datasets import load_dataset
import multiprocessing
label = f"{dataset_name}/{subset}" if subset else dataset_name
print(f"Loading dataset: {label} (split={split}, streaming={streaming})")
ds_kwargs = {}
if subset:
ds_kwargs["name"] = subset
if streaming:
ds_kwargs["streaming"] = True
ds = load_dataset(dataset_name, split=split, **ds_kwargs)
print(f"Loading tokenizer: {tokenizer_name}")
tokenizer = AutoTokenizer.from_pretrained(tokenizer_name)
vocab_size = len(tokenizer) # includes added tokens
pad_id = tokenizer.pad_token_id if tokenizer.pad_token_id is not None else tokenizer.eos_token_id
template = ChatMLTemplate(tokenizer)
# Detect messages column
msg_col = _detect_messages_column(ds, streaming)
print(f"Messages column: '{msg_col}'")
print(f"Vocab size: {vocab_size}, pad_id: {pad_id}, max_seq_len: {max_seq_len}")
if num_proc <= 0:
num_proc = min(multiprocessing.cpu_count(), 16)
# ---- Shared disk-write helpers (used by both modes) ----
os.makedirs(out_dir, exist_ok=True)
rng = np.random.default_rng(42)
def _init_sft_file(prefix: str):
"""Create SFT binary files with placeholder header."""
tok_path = prefix + ".bin"
mask_path = prefix + ".mask.bin"
for p in (tok_path, mask_path):
with open(p, "wb") as f:
f.write(HEADER_MAGIC)
f.write(struct.pack("<I", 1)) # version
f.write(struct.pack("<Q", 0)) # num_examples (placeholder)
f.write(struct.pack("<I", max_seq_len))
f.write(struct.pack("<I", vocab_size if p == tok_path else 0))
f.write(struct.pack("<I", pad_id if p == tok_path else 0))
f.write(b"\x00" * 16) # reserved
def _read_existing_count(prefix: str) -> int:
"""Read num_examples from existing SFT binary header."""
tok_path = prefix + ".bin"
if not os.path.exists(tok_path):
return 0
with open(tok_path, "rb") as f:
magic = f.read(8)
if magic != HEADER_MAGIC:
return 0
f.read(4) # version
return struct.unpack("<Q", f.read(8))[0]
def _update_header_count(f, n):
f.seek(12) # offset to num_examples
f.write(struct.pack("<Q", n))
train_prefix = os.path.join(out_dir, "train")
val_prefix = os.path.join(out_dir, "val")
if not append:
_init_sft_file(train_prefix)
if val_ratio > 0:
_init_sft_file(val_prefix)
else:
if not os.path.exists(train_prefix + ".bin"):
_init_sft_file(train_prefix)
if val_ratio > 0 and not os.path.exists(val_prefix + ".bin"):
_init_sft_file(val_prefix)
existing_train = _read_existing_count(train_prefix) if append else 0
existing_val = _read_existing_count(val_prefix) if (append and val_ratio > 0) else 0
train_tok_f = open(train_prefix + ".bin", "r+b")
train_mask_f = open(train_prefix + ".mask.bin", "r+b")
train_tok_f.seek(0, 2)
train_mask_f.seek(0, 2)
train_count = 0
val_tok_f = val_mask_f = None
val_count = 0
if val_ratio > 0:
val_tok_f = open(val_prefix + ".bin", "r+b")
val_mask_f = open(val_prefix + ".mask.bin", "r+b")
val_tok_f.seek(0, 2)
val_mask_f.seek(0, 2)
# Reusable padded row buffers
pad_tokens = np.full(max_seq_len, pad_id, dtype=np.int32)
pad_mask = np.zeros(max_seq_len, dtype=np.uint8)
def _write_example(token_ids, label_mask_list):
"""Pad/truncate one example and route to train or val file."""
nonlocal train_count, val_count
ids = token_ids[:max_seq_len]
msk = label_mask_list[:max_seq_len]
row_t = pad_tokens.copy()
row_m = pad_mask.copy()
row_t[:len(ids)] = ids
row_m[:len(msk)] = msk
data_t = row_t.tobytes()
data_m = row_m.tobytes()
if val_ratio > 0 and rng.random() < val_ratio:
val_tok_f.write(data_t)
val_mask_f.write(data_m)
val_count += 1
else:
train_tok_f.write(data_t)
train_mask_f.write(data_m)
train_count += 1
try:
# ---- Non-streaming + multiprocessing ----
if not streaming and num_proc > 1:
if max_samples:
ds = ds.select(range(min(max_samples, len(ds))))
total_messages = len(ds)
print(f"\nTokenizing {total_messages:,} conversations with {num_proc} processes...")
def _iter_messages():
for row in ds:
yield row[msg_col]
skipped = 0
count = 0
log_every = 10000
recent_lens = []
recent_ratios = []
def _iter_chunk_results(chunk):
if isinstance(chunk, list):
return chunk
return [chunk]
with multiprocessing.Pool(
num_proc,
initializer=_sft_worker_init,
initargs=(tokenizer_name,),
maxtasksperchild=1000,
) as pool:
for chunk in pool.imap_unordered(_sft_worker_process, _iter_messages(), chunksize=64):
for result, skip_reason in _iter_chunk_results(chunk):
if result is None:
skipped += 1
continue
token_ids = result["input_ids"]
lm = result["label_mask"]
_write_example(token_ids, lm)
count += 1
recent_lens.append(len(token_ids))
recent_ratios.append(sum(lm) / max(len(lm), 1))
if count % log_every == 0:
avg_len = sum(recent_lens) / len(recent_lens)
avg_ratio = sum(recent_ratios) / len(recent_ratios)
print(f" {count:>10,} examples | avg_len={avg_len:.0f} | assistant_ratio={avg_ratio:.1%}")
recent_lens.clear()
recent_ratios.clear()
print(f" Total: {count:,} examples ({skipped:,} skipped)")
# ---- Streaming / single-proc fallback (direct disk writes) ----
else:
skipped = 0
count = 0
log_every = 10000
recent_lens = []
recent_ratios = []
for example in ds:
if max_samples and count >= max_samples:
break
messages = example[msg_col]
if not messages or len(messages) < 2:
skipped += 1
continue
if isinstance(messages, str):
try:
messages = json.loads(messages)
except json.JSONDecodeError:
skipped += 1
continue
result = template.apply(messages)
token_ids = result["input_ids"]
label_mask = result["label_mask"]
if not any(label_mask):
skipped += 1
continue
if len(token_ids) < 2:
skipped += 1
continue
_write_example(token_ids, label_mask)
count += 1
recent_lens.append(len(token_ids))
recent_ratios.append(sum(label_mask) / max(len(label_mask), 1))
if count % log_every == 0:
avg_len = sum(recent_lens) / len(recent_lens)
avg_ratio = sum(recent_ratios) / len(recent_ratios)
print(f" {count:>10,} examples | avg_len={avg_len:.0f} | assistant_ratio={avg_ratio:.1%}")
recent_lens.clear()
recent_ratios.clear()
print(f" Total: {count:,} examples ({skipped:,} skipped)")
# ---- Finalize: update headers ----
_update_header_count(train_tok_f, existing_train + train_count)
_update_header_count(train_mask_f, existing_train + train_count)
if val_ratio > 0:
_update_header_count(val_tok_f, existing_val + val_count)
_update_header_count(val_mask_f, existing_val + val_count)
finally:
train_tok_f.close()
train_mask_f.close()
if val_tok_f is not None:
val_tok_f.close()
if val_mask_f is not None:
val_mask_f.close()
print(f" Train: {train_count:,}, Val: {val_count:,}")
train_mb = os.path.getsize(train_prefix + ".bin") / 1024 / 1024
print(f" Saved to {out_dir}/ ({train_mb:.0f} MB train)")
# ---------------------------------------------------------------------------
# CLI
# ---------------------------------------------------------------------------
def main():
parser = argparse.ArgumentParser(description="Tokenize SFT datasets for FreqFormer")
parser.add_argument("--dataset", type=str, required=True, help="HuggingFace dataset name")
parser.add_argument("--subset", type=str, default=None, help="Dataset subset/config")
parser.add_argument("--tokenizer", type=str, default="freqformer/tokenizer",
help="Tokenizer name/path")
parser.add_argument("--out", type=str, required=True, help="Output directory")
parser.add_argument("--max_seq_len", type=int, default=4096, help="Max sequence length")
parser.add_argument("--max_samples", type=int, default=0, help="Max samples (0=all)")
parser.add_argument("--split", type=str, default="train", help="Dataset split")
parser.add_argument("--val_ratio", type=float, default=0.02, help="Validation split ratio")
parser.add_argument("--streaming", action="store_true", help="Stream large datasets")
parser.add_argument("--append", action="store_true", help="Append to existing files")
parser.add_argument("--num_proc", type=int, default=0, help="Number of CPU processes (0=auto)")
args = parser.parse_args()
if args.append:
print("Mode: APPEND")
tokenize_sft_dataset(
dataset_name=args.dataset,
tokenizer_name=args.tokenizer,
out_dir=args.out,
subset=args.subset,
max_seq_len=args.max_seq_len,
max_samples=args.max_samples,
split=args.split,
val_ratio=args.val_ratio,
streaming=args.streaming,
append=args.append,
num_proc=args.num_proc,
)
print(f"\nDone! SFT data saved to {args.out}/")
print(f"Use --sft_data_dir {args.out} when running SFT training.")
if __name__ == "__main__":
main()