| """ |
| SFT Data Pipeline |
| ================= |
| Supervised Fine-Tuning dataset with ChatML template and label masking. |
| |
| Supports: |
| - Dolci-Think-SFT (chat with <think> blocks) |
| - Dolci-Instruct-SFT-Tool-Use (function calling with <functions>, <function_calls>) |
| - Any ChatML-format conversation dataset |
| |
| Label mask: loss computed ONLY on assistant response tokens. |
| System, user, and environment tokens are masked out (ignore_index=-100). |
| |
| Tokenizer: freqformer/tokenizer (Mistral-7B BPE 32K + ChatML special tokens) |
| |
| Usage: |
| # Tokenize HuggingFace SFT dataset to binary |
| python -m freqformer.sft_data --dataset allenai/Dolci-Think-SFT-7B \ |
| --out data/sft_think --max_seq_len 4096 |
| |
| # Tokenize tool-use dataset |
| python -m freqformer.sft_data --dataset allenai/Dolci-Instruct-SFT-Tool-Use \ |
| --out data/sft_tools --max_seq_len 4096 |
| |
| # Merge multiple SFT datasets |
| python -m freqformer.sft_data --dataset allenai/Dolci-Think-SFT-7B \ |
| --out data/sft_mixed --max_seq_len 4096 |
| python -m freqformer.sft_data --dataset allenai/Dolci-Instruct-SFT-Tool-Use \ |
| --out data/sft_mixed --max_seq_len 4096 --append |
| """ |
|
|
| import argparse |
| import os |
| import struct |
| import json |
| import numpy as np |
| import torch |
| import torch.distributed as dist |
| from torch.utils.data import Dataset, DataLoader, DistributedSampler |
| from transformers import AutoTokenizer |
| from typing import Optional |
|
|
|
|
| |
| |
| |
| HEADER_MAGIC = b"FREQSFT1" |
| HEADER_SIZE = 48 |
|
|
|
|
| def _is_rank0() -> bool: |
| return not dist.is_initialized() or dist.get_rank() == 0 |
|
|
|
|
| |
| |
| |
| class ChatMLTemplate: |
| """ |
| Applies ChatML template to conversations and builds label masks. |
| |
| Template format (matching OLMo 3.1): |
| <|im_start|>system\n{content}<|im_end|>\n |
| <|im_start|>user\n{content}<|im_end|>\n |
| <|im_start|>assistant\n{content}<|im_end|>\n (or eos for last) |
| <|im_start|>environment\n{content}<|im_end|>\n |
| |
| Tool-use extensions: |
| - system: content + " <functions>{functions}</functions>" |
| - assistant: content + "<function_calls>{function_calls}</function_calls>" |
| - user: content + "\n<functions>{functions}</functions>" (if user has functions) |
| |
| Label mask: True ONLY for assistant response tokens (body + closing token). |
| """ |
|
|
| def __init__(self, tokenizer): |
| self.tokenizer = tokenizer |
| self.im_start_id = tokenizer.convert_tokens_to_ids("<|im_start|>") |
| self.im_end_id = tokenizer.convert_tokens_to_ids("<|im_end|>") |
| self.eos_id = tokenizer.eos_token_id |
| self.pad_id = tokenizer.pad_token_id if tokenizer.pad_token_id is not None else self.eos_id |
|
|
| def apply(self, messages: list[dict], add_generation_prompt: bool = False) -> dict: |
| """ |
| Convert a list of messages to token_ids + label_mask. |
| |
| Returns: |
| dict with "input_ids" (list[int]) and "label_mask" (list[bool]) |
| """ |
| all_token_ids = [] |
| all_label_mask = [] |
|
|
| for i, msg in enumerate(messages): |
| role = msg["role"] |
| content = msg.get("content") or "" |
| function_calls = msg.get("function_calls") |
| functions = msg.get("functions") |
|
|
| |
| header_text = f"<|im_start|>{role}\n" |
| header_ids = self.tokenizer.encode(header_text, add_special_tokens=False) |
|
|
| |
| if role == "system": |
| body_text = content |
| if functions: |
| body_text += f" <functions>{functions}</functions>" |
| elif role == "user": |
| body_text = content |
| if functions: |
| body_text += f"\n<functions>{functions}</functions>" |
| elif role == "assistant": |
| parts = [] |
| if content: |
| parts.append(content) |
| if function_calls: |
| parts.append(f"<function_calls>{function_calls}</function_calls>") |
| body_text = "".join(parts) |
| elif role == "environment": |
| body_text = content |
| else: |
| body_text = content |
|
|
| body_ids = self.tokenizer.encode(body_text, add_special_tokens=False) if body_text else [] |
|
|
| |
| is_last_msg = (i == len(messages) - 1) |
| if role == "assistant" and is_last_msg: |
| |
| footer_ids = [self.eos_id] |
| else: |
| |
| footer_text = "<|im_end|>\n" |
| footer_ids = self.tokenizer.encode(footer_text, add_special_tokens=False) |
|
|
| |
| is_assistant = (role == "assistant") |
|
|
| |
| all_token_ids.extend(header_ids) |
| all_label_mask.extend([False] * len(header_ids)) |
|
|
| |
| all_token_ids.extend(body_ids) |
| all_label_mask.extend([is_assistant] * len(body_ids)) |
|
|
| |
| all_token_ids.extend(footer_ids) |
| all_label_mask.extend([is_assistant] * len(footer_ids)) |
|
|
| |
| if add_generation_prompt: |
| gen_text = "<|im_start|>assistant\n" |
| gen_ids = self.tokenizer.encode(gen_text, add_special_tokens=False) |
| all_token_ids.extend(gen_ids) |
| all_label_mask.extend([False] * len(gen_ids)) |
|
|
| return { |
| "input_ids": all_token_ids, |
| "label_mask": all_label_mask, |
| } |
|
|
|
|
| |
| |
| |
| class SFTDataset(Dataset): |
| """ |
| Memory-mapped SFT dataset. Serves (input_ids, targets, label_mask) triples. |
| |
| input_ids: tokens[:-1] |
| targets: tokens[1:] (shifted by 1) |
| label_mask: mask[1:] (aligned with targets) |
| """ |
|
|
| def __init__(self, path_prefix: str, max_seq_len: Optional[int] = None): |
| tok_path = path_prefix + ".bin" |
| mask_path = path_prefix + ".mask.bin" |
| assert os.path.exists(tok_path), f"SFT data not found: {tok_path}" |
| assert os.path.exists(mask_path), f"SFT mask not found: {mask_path}" |
|
|
| |
| with open(tok_path, "rb") as f: |
| magic = f.read(8) |
| assert magic == HEADER_MAGIC, f"Invalid SFT file: {tok_path}" |
| _version = struct.unpack("<I", f.read(4))[0] |
| self.num_examples = struct.unpack("<Q", f.read(8))[0] |
| self.file_seq_len = struct.unpack("<I", f.read(4))[0] |
| self.vocab_size = struct.unpack("<I", f.read(4))[0] |
| self.pad_id = struct.unpack("<I", f.read(4))[0] |
|
|
| self.seq_len = max_seq_len if max_seq_len and max_seq_len < self.file_seq_len else self.file_seq_len |
|
|
| |
| self.tokens = np.memmap(tok_path, dtype=np.int32, mode="r", offset=HEADER_SIZE, |
| shape=(self.num_examples, self.file_seq_len)) |
| self.masks = np.memmap(mask_path, dtype=np.uint8, mode="r", offset=HEADER_SIZE, |
| shape=(self.num_examples, self.file_seq_len)) |
|
|
| def __len__(self): |
| return self.num_examples |
|
|
| def __getitem__(self, idx): |
| seq_len = self.seq_len |
| tokens = self.tokens[idx, :seq_len + 1].astype(np.int64) |
| mask = self.masks[idx, :seq_len + 1].astype(np.bool_) |
|
|
| |
| if len(tokens) < seq_len + 1: |
| tokens = np.pad(tokens, (0, seq_len + 1 - len(tokens)), constant_values=self.pad_id) |
| mask = np.pad(mask, (0, seq_len + 1 - len(mask)), constant_values=False) |
|
|
| input_ids = torch.from_numpy(tokens[:-1].copy()) |
| targets = torch.from_numpy(tokens[1:].copy()) |
| label_mask = torch.from_numpy(mask[1:].copy()) |
| return input_ids, targets, label_mask |
|
|
|
|
| def get_sft_dataloaders( |
| data_dir: str, |
| max_seq_len: int, |
| batch_size: int, |
| train_file: str = "train", |
| val_file: str = "val", |
| num_workers: int = 4, |
| distributed: bool = False, |
| ): |
| """Create train and val SFT dataloaders.""" |
| train_prefix = os.path.join(data_dir, train_file) |
| val_prefix = os.path.join(data_dir, val_file) |
|
|
| train_ds = SFTDataset(train_prefix, max_seq_len) |
| if _is_rank0(): |
| print(f"SFT Train: {train_ds.num_examples:,} examples, seq_len={train_ds.seq_len}") |
|
|
| val_ds = None |
| val_loader = None |
| if os.path.exists(val_prefix + ".bin"): |
| val_ds = SFTDataset(val_prefix, max_seq_len) |
| if _is_rank0(): |
| print(f"SFT Val: {val_ds.num_examples:,} examples, seq_len={val_ds.seq_len}") |
| val_sampler = DistributedSampler(val_ds, shuffle=False) if distributed else None |
| val_loader = DataLoader( |
| val_ds, batch_size=batch_size, shuffle=False, |
| sampler=val_sampler, num_workers=num_workers, |
| pin_memory=True, drop_last=True, |
| persistent_workers=num_workers > 0, |
| prefetch_factor=4 if num_workers > 0 else None, |
| ) |
|
|
| train_sampler = DistributedSampler(train_ds, shuffle=True) if distributed else None |
| train_loader = DataLoader( |
| train_ds, batch_size=batch_size, |
| shuffle=(train_sampler is None), |
| sampler=train_sampler, num_workers=num_workers, |
| pin_memory=True, drop_last=True, |
| persistent_workers=num_workers > 0, |
| prefetch_factor=4 if num_workers > 0 else None, |
| ) |
|
|
| return train_loader, val_loader |
|
|
|
|
| |
| |
| |
| def _detect_messages_column(ds, streaming: bool) -> str: |
| """Auto-detect the column containing conversation messages.""" |
| if streaming: |
| first = next(iter(ds)) |
| cols = list(first.keys()) |
| else: |
| cols = ds.column_names |
|
|
| |
| for candidate in ["messages", "conversations", "chat", "dialogue"]: |
| if candidate in cols: |
| return candidate |
|
|
| |
| for col in cols: |
| if streaming: |
| sample = first[col] |
| else: |
| sample = ds[0][col] |
| if isinstance(sample, list) and len(sample) > 0 and isinstance(sample[0], dict): |
| return col |
|
|
| raise ValueError(f"Cannot detect messages column from: {cols}") |
|
|
|
|
| def _sft_worker_init(tokenizer_name): |
| """Initialize per-worker tokenizer + template for multiprocessing.""" |
| global _worker_template |
| tok = AutoTokenizer.from_pretrained(tokenizer_name) |
| _worker_template = ChatMLTemplate(tok) |
|
|
|
|
| def _sft_worker_process(messages): |
| """Process a single conversation in a worker. Returns (result_dict, skip_reason).""" |
| global _worker_template |
| if not messages or len(messages) < 2: |
| return None, "short" |
| if isinstance(messages, str): |
| try: |
| messages = json.loads(messages) |
| except json.JSONDecodeError: |
| return None, "json_error" |
| result = _worker_template.apply(messages) |
| token_ids = result["input_ids"] |
| label_mask = result["label_mask"] |
| if not any(label_mask): |
| return None, "no_assistant" |
| if len(token_ids) < 2: |
| return None, "too_short" |
| return {"input_ids": token_ids, "label_mask": label_mask}, None |
|
|
|
|
| def tokenize_sft_dataset( |
| dataset_name: str, |
| tokenizer_name: str, |
| out_dir: str, |
| subset: Optional[str] = None, |
| max_seq_len: int = 4096, |
| max_samples: int = 0, |
| split: str = "train", |
| val_ratio: float = 0.02, |
| streaming: bool = False, |
| append: bool = False, |
| num_proc: int = 0, |
| ): |
| """Tokenize a HuggingFace SFT dataset with ChatML template + label masks. |
| |
| Uses multiprocessing for non-streaming datasets to utilize all CPU cores. |
| """ |
| from datasets import load_dataset |
| import multiprocessing |
|
|
| label = f"{dataset_name}/{subset}" if subset else dataset_name |
| print(f"Loading dataset: {label} (split={split}, streaming={streaming})") |
|
|
| ds_kwargs = {} |
| if subset: |
| ds_kwargs["name"] = subset |
| if streaming: |
| ds_kwargs["streaming"] = True |
|
|
| ds = load_dataset(dataset_name, split=split, **ds_kwargs) |
|
|
| print(f"Loading tokenizer: {tokenizer_name}") |
| tokenizer = AutoTokenizer.from_pretrained(tokenizer_name) |
| vocab_size = len(tokenizer) |
| pad_id = tokenizer.pad_token_id if tokenizer.pad_token_id is not None else tokenizer.eos_token_id |
| template = ChatMLTemplate(tokenizer) |
|
|
| |
| msg_col = _detect_messages_column(ds, streaming) |
| print(f"Messages column: '{msg_col}'") |
| print(f"Vocab size: {vocab_size}, pad_id: {pad_id}, max_seq_len: {max_seq_len}") |
|
|
| if num_proc <= 0: |
| num_proc = min(multiprocessing.cpu_count(), 16) |
|
|
| |
| os.makedirs(out_dir, exist_ok=True) |
| rng = np.random.default_rng(42) |
|
|
| def _init_sft_file(prefix: str): |
| """Create SFT binary files with placeholder header.""" |
| tok_path = prefix + ".bin" |
| mask_path = prefix + ".mask.bin" |
| for p in (tok_path, mask_path): |
| with open(p, "wb") as f: |
| f.write(HEADER_MAGIC) |
| f.write(struct.pack("<I", 1)) |
| f.write(struct.pack("<Q", 0)) |
| f.write(struct.pack("<I", max_seq_len)) |
| f.write(struct.pack("<I", vocab_size if p == tok_path else 0)) |
| f.write(struct.pack("<I", pad_id if p == tok_path else 0)) |
| f.write(b"\x00" * 16) |
|
|
| def _read_existing_count(prefix: str) -> int: |
| """Read num_examples from existing SFT binary header.""" |
| tok_path = prefix + ".bin" |
| if not os.path.exists(tok_path): |
| return 0 |
| with open(tok_path, "rb") as f: |
| magic = f.read(8) |
| if magic != HEADER_MAGIC: |
| return 0 |
| f.read(4) |
| return struct.unpack("<Q", f.read(8))[0] |
|
|
| def _update_header_count(f, n): |
| f.seek(12) |
| f.write(struct.pack("<Q", n)) |
|
|
| train_prefix = os.path.join(out_dir, "train") |
| val_prefix = os.path.join(out_dir, "val") |
|
|
| if not append: |
| _init_sft_file(train_prefix) |
| if val_ratio > 0: |
| _init_sft_file(val_prefix) |
| else: |
| if not os.path.exists(train_prefix + ".bin"): |
| _init_sft_file(train_prefix) |
| if val_ratio > 0 and not os.path.exists(val_prefix + ".bin"): |
| _init_sft_file(val_prefix) |
|
|
| existing_train = _read_existing_count(train_prefix) if append else 0 |
| existing_val = _read_existing_count(val_prefix) if (append and val_ratio > 0) else 0 |
|
|
| train_tok_f = open(train_prefix + ".bin", "r+b") |
| train_mask_f = open(train_prefix + ".mask.bin", "r+b") |
| train_tok_f.seek(0, 2) |
| train_mask_f.seek(0, 2) |
| train_count = 0 |
|
|
| val_tok_f = val_mask_f = None |
| val_count = 0 |
| if val_ratio > 0: |
| val_tok_f = open(val_prefix + ".bin", "r+b") |
| val_mask_f = open(val_prefix + ".mask.bin", "r+b") |
| val_tok_f.seek(0, 2) |
| val_mask_f.seek(0, 2) |
|
|
| |
| pad_tokens = np.full(max_seq_len, pad_id, dtype=np.int32) |
| pad_mask = np.zeros(max_seq_len, dtype=np.uint8) |
|
|
| def _write_example(token_ids, label_mask_list): |
| """Pad/truncate one example and route to train or val file.""" |
| nonlocal train_count, val_count |
| ids = token_ids[:max_seq_len] |
| msk = label_mask_list[:max_seq_len] |
| row_t = pad_tokens.copy() |
| row_m = pad_mask.copy() |
| row_t[:len(ids)] = ids |
| row_m[:len(msk)] = msk |
| data_t = row_t.tobytes() |
| data_m = row_m.tobytes() |
| if val_ratio > 0 and rng.random() < val_ratio: |
| val_tok_f.write(data_t) |
| val_mask_f.write(data_m) |
| val_count += 1 |
| else: |
| train_tok_f.write(data_t) |
| train_mask_f.write(data_m) |
| train_count += 1 |
|
|
| try: |
| |
| if not streaming and num_proc > 1: |
| if max_samples: |
| ds = ds.select(range(min(max_samples, len(ds)))) |
|
|
| total_messages = len(ds) |
| print(f"\nTokenizing {total_messages:,} conversations with {num_proc} processes...") |
|
|
| def _iter_messages(): |
| for row in ds: |
| yield row[msg_col] |
|
|
| skipped = 0 |
| count = 0 |
| log_every = 10000 |
| recent_lens = [] |
| recent_ratios = [] |
| def _iter_chunk_results(chunk): |
| if isinstance(chunk, list): |
| return chunk |
| return [chunk] |
|
|
| with multiprocessing.Pool( |
| num_proc, |
| initializer=_sft_worker_init, |
| initargs=(tokenizer_name,), |
| maxtasksperchild=1000, |
| ) as pool: |
| for chunk in pool.imap_unordered(_sft_worker_process, _iter_messages(), chunksize=64): |
| for result, skip_reason in _iter_chunk_results(chunk): |
| if result is None: |
| skipped += 1 |
| continue |
| token_ids = result["input_ids"] |
| lm = result["label_mask"] |
| _write_example(token_ids, lm) |
| count += 1 |
| recent_lens.append(len(token_ids)) |
| recent_ratios.append(sum(lm) / max(len(lm), 1)) |
| if count % log_every == 0: |
| avg_len = sum(recent_lens) / len(recent_lens) |
| avg_ratio = sum(recent_ratios) / len(recent_ratios) |
| print(f" {count:>10,} examples | avg_len={avg_len:.0f} | assistant_ratio={avg_ratio:.1%}") |
| recent_lens.clear() |
| recent_ratios.clear() |
|
|
| print(f" Total: {count:,} examples ({skipped:,} skipped)") |
|
|
| |
| else: |
| skipped = 0 |
| count = 0 |
| log_every = 10000 |
| recent_lens = [] |
| recent_ratios = [] |
|
|
| for example in ds: |
| if max_samples and count >= max_samples: |
| break |
|
|
| messages = example[msg_col] |
| if not messages or len(messages) < 2: |
| skipped += 1 |
| continue |
|
|
| if isinstance(messages, str): |
| try: |
| messages = json.loads(messages) |
| except json.JSONDecodeError: |
| skipped += 1 |
| continue |
|
|
| result = template.apply(messages) |
| token_ids = result["input_ids"] |
| label_mask = result["label_mask"] |
|
|
| if not any(label_mask): |
| skipped += 1 |
| continue |
| if len(token_ids) < 2: |
| skipped += 1 |
| continue |
|
|
| _write_example(token_ids, label_mask) |
| count += 1 |
|
|
| recent_lens.append(len(token_ids)) |
| recent_ratios.append(sum(label_mask) / max(len(label_mask), 1)) |
| if count % log_every == 0: |
| avg_len = sum(recent_lens) / len(recent_lens) |
| avg_ratio = sum(recent_ratios) / len(recent_ratios) |
| print(f" {count:>10,} examples | avg_len={avg_len:.0f} | assistant_ratio={avg_ratio:.1%}") |
| recent_lens.clear() |
| recent_ratios.clear() |
|
|
| print(f" Total: {count:,} examples ({skipped:,} skipped)") |
|
|
| |
| _update_header_count(train_tok_f, existing_train + train_count) |
| _update_header_count(train_mask_f, existing_train + train_count) |
| if val_ratio > 0: |
| _update_header_count(val_tok_f, existing_val + val_count) |
| _update_header_count(val_mask_f, existing_val + val_count) |
| finally: |
| train_tok_f.close() |
| train_mask_f.close() |
| if val_tok_f is not None: |
| val_tok_f.close() |
| if val_mask_f is not None: |
| val_mask_f.close() |
|
|
| print(f" Train: {train_count:,}, Val: {val_count:,}") |
| train_mb = os.path.getsize(train_prefix + ".bin") / 1024 / 1024 |
| print(f" Saved to {out_dir}/ ({train_mb:.0f} MB train)") |
|
|
|
|
| |
| |
| |
| def main(): |
| parser = argparse.ArgumentParser(description="Tokenize SFT datasets for FreqFormer") |
| parser.add_argument("--dataset", type=str, required=True, help="HuggingFace dataset name") |
| parser.add_argument("--subset", type=str, default=None, help="Dataset subset/config") |
| parser.add_argument("--tokenizer", type=str, default="freqformer/tokenizer", |
| help="Tokenizer name/path") |
| parser.add_argument("--out", type=str, required=True, help="Output directory") |
| parser.add_argument("--max_seq_len", type=int, default=4096, help="Max sequence length") |
| parser.add_argument("--max_samples", type=int, default=0, help="Max samples (0=all)") |
| parser.add_argument("--split", type=str, default="train", help="Dataset split") |
| parser.add_argument("--val_ratio", type=float, default=0.02, help="Validation split ratio") |
| parser.add_argument("--streaming", action="store_true", help="Stream large datasets") |
| parser.add_argument("--append", action="store_true", help="Append to existing files") |
| parser.add_argument("--num_proc", type=int, default=0, help="Number of CPU processes (0=auto)") |
| args = parser.parse_args() |
|
|
| if args.append: |
| print("Mode: APPEND") |
|
|
| tokenize_sft_dataset( |
| dataset_name=args.dataset, |
| tokenizer_name=args.tokenizer, |
| out_dir=args.out, |
| subset=args.subset, |
| max_seq_len=args.max_seq_len, |
| max_samples=args.max_samples, |
| split=args.split, |
| val_ratio=args.val_ratio, |
| streaming=args.streaming, |
| append=args.append, |
| num_proc=args.num_proc, |
| ) |
|
|
| print(f"\nDone! SFT data saved to {args.out}/") |
| print(f"Use --sft_data_dir {args.out} when running SFT training.") |
|
|
|
|
| if __name__ == "__main__": |
| main() |
|
|