| """ |
| finetune/prepare_data.py |
| |
| Downloads teknium/OpenHermes-2.5 from HuggingFace, formats conversations |
| as ChatML, tokenizes with our custom tokenizer + 2 new special tokens, |
| and saves train_sft.pt / val_sft.pt to finetune/data/. |
| |
| Also saves the tokenizer (with special tokens baked in) to finetune/data/ |
| so sft_train.py and chat.py can load it without re-adding tokens. |
| |
| Usage: |
| python finetune/prepare_data.py |
| python finetune/prepare_data.py --n_samples 50000 |
| |
| Dataset structure (OpenHermes-2.5): |
| Each row has a "conversations" key: |
| [ |
| {"from": "system", "value": "..."}, # optional |
| {"from": "human", "value": "..."}, |
| {"from": "gpt", "value": "..."}, |
| ... # may have more turns |
| ] |
| """ |
|
|
| import os |
| import sys |
| import json |
| import random |
| import argparse |
| from pathlib import Path |
|
|
| import torch |
| from transformers import PreTrainedTokenizerFast |
| from datasets import load_dataset |
| from tqdm import tqdm |
|
|
| |
| |
| |
|
|
| SCRIPT_DIR = Path(__file__).resolve().parent |
| PROJECT_ROOT = SCRIPT_DIR.parent |
|
|
| sys.path.insert(0, str(PROJECT_ROOT)) |
|
|
| TOKENIZER_DIR = PROJECT_ROOT / "tokenizer" / "fineweb_edu_tokenizer" |
|
|
| |
| SPECIAL_TOKENS = ["<|im_start|>", "<|im_end|>"] |
|
|
| MAX_LENGTH = 1024 |
|
|
| |
| ROLE_MAP = { |
| "system": "system", |
| "human": "user", |
| "gpt": "assistant", |
| "user": "user", |
| "assistant": "assistant", |
| } |
|
|
|
|
| |
| |
| |
|
|
| def load_and_extend_tokenizer() -> PreTrainedTokenizerFast: |
| """ |
| Loads our pretrained BPE tokenizer and adds the two ChatML tokens. |
| Returns the extended tokenizer (vocab 32,000 → 32,002). |
| """ |
| tokenizer = PreTrainedTokenizerFast.from_pretrained(str(TOKENIZER_DIR)) |
|
|
| new_tokens = [t for t in SPECIAL_TOKENS if t not in tokenizer.get_vocab()] |
| if new_tokens: |
| added = tokenizer.add_special_tokens({"additional_special_tokens": new_tokens}) |
| print(f" Added {added} special token(s): {new_tokens}") |
| else: |
| print(" Special tokens already present — skipping add.") |
|
|
| print(f" Final vocab size: {len(tokenizer):,}") |
| return tokenizer |
|
|
|
|
| |
| |
| |
|
|
| def format_and_tokenize( |
| conversations: list[dict], |
| tokenizer: PreTrainedTokenizerFast, |
| ) -> tuple[list[int], list[int]] | None: |
| """ |
| Converts a list of chat turns into (input_ids, labels). |
| |
| ChatML format per turn: |
| <|im_start|>{role}\\n{content}<|im_end|>\\n |
| |
| Labels: |
| - User / system turns → all -100 (not learned) |
| - Assistant turns → header (-100) + content (actual token ids) |
| i.e. we learn the response but not the "<|im_start|>assistant\\n" prefix |
| |
| Returns None for: |
| - Conversations with no assistant turns (nothing to learn) |
| - Conversations that tokenize to fewer than 8 tokens |
| """ |
| input_ids: list[int] = [] |
| labels: list[int] = [] |
|
|
| for turn in conversations: |
| role_raw = turn.get("from", turn.get("role", "")).strip().lower() |
| content = turn.get("value", turn.get("content", "")).strip() |
| role = ROLE_MAP.get(role_raw, role_raw) |
|
|
| if not content or not role: |
| continue |
|
|
| |
| header_text = f"<|im_start|>{role}\n" |
| header_ids = tokenizer.encode(header_text, add_special_tokens=False) |
|
|
| |
| body_text = f"{content}<|im_end|>\n" |
| body_ids = tokenizer.encode(body_text, add_special_tokens=False) |
|
|
| turn_input = header_ids + body_ids |
|
|
| if role == "assistant": |
| |
| turn_labels = [-100] * len(header_ids) + body_ids |
| else: |
| |
| turn_labels = [-100] * len(turn_input) |
|
|
| input_ids.extend(turn_input) |
| labels.extend(turn_labels) |
|
|
| |
| if not any(l != -100 for l in labels): |
| return None |
|
|
| |
| input_ids = input_ids[:MAX_LENGTH] |
| labels = labels[:MAX_LENGTH] |
|
|
| |
| if len(input_ids) < 8: |
| return None |
|
|
| return input_ids, labels |
|
|
|
|
| |
| |
| |
|
|
| def parse_args(): |
| p = argparse.ArgumentParser(description="Prepare SFT data from OpenHermes-2.5") |
| p.add_argument("--n_samples", type=int, default=80_000, |
| help="Number of conversations to sample (default: 80000)") |
| p.add_argument("--val_ratio", type=float, default=0.05, |
| help="Fraction held out for validation (default: 0.05)") |
| p.add_argument("--output_dir", type=str, default=str(SCRIPT_DIR / "data"), |
| help="Where to save train_sft.pt, val_sft.pt, and tokenizer") |
| p.add_argument("--seed", type=int, default=42) |
| return p.parse_args() |
|
|
|
|
| |
| |
| |
|
|
| def main(): |
| args = parse_args() |
| random.seed(args.seed) |
| os.makedirs(args.output_dir, exist_ok=True) |
|
|
| print("\n" + "=" * 60) |
| print(" SLLM-150M SFT — Data Preparation") |
| print("=" * 60) |
|
|
| |
| |
| |
| print("\n[1/4] Loading tokenizer + adding ChatML special tokens...") |
| tokenizer = load_and_extend_tokenizer() |
|
|
| |
| tokenizer.save_pretrained(args.output_dir) |
| print(f" Extended tokenizer saved → {args.output_dir}/") |
|
|
| |
| |
| |
| print(f"\n[2/4] Loading teknium/OpenHermes-2.5 from HuggingFace...") |
| ds = load_dataset("teknium/OpenHermes-2.5") |
| full = ds["train"] |
| print(f" Full dataset size: {len(full):,} examples") |
|
|
| |
| n = min(args.n_samples, len(full)) |
| indices = random.sample(range(len(full)), n) |
| subset = full.select(indices) |
| print(f" Sampled: {n:,} examples (seed={args.seed})") |
|
|
| |
| |
| |
| print(f"\n[3/4] Formatting and tokenizing conversations...") |
|
|
| all_input_ids: list[torch.Tensor] = [] |
| all_labels: list[torch.Tensor] = [] |
| skipped = 0 |
|
|
| for example in tqdm(subset, desc="Tokenizing", unit="conv"): |
| conversations = example.get("conversations", []) |
| result = format_and_tokenize(conversations, tokenizer) |
|
|
| if result is None: |
| skipped += 1 |
| continue |
|
|
| ids, lbls = result |
| all_input_ids.append(torch.tensor(ids, dtype=torch.long)) |
| all_labels.append( torch.tensor(lbls, dtype=torch.long)) |
|
|
| total = len(all_input_ids) |
| print(f"\n Kept : {total:,}") |
| print(f" Skipped: {skipped:,} (no assistant turn or too short)") |
|
|
| if total == 0: |
| raise RuntimeError("No valid examples produced — check dataset structure.") |
|
|
| |
| print("\n ── Sample (first conversation, first 400 chars) ──") |
| sample_decoded = tokenizer.decode(all_input_ids[0].tolist(), skip_special_tokens=False) |
| print(" " + sample_decoded[:400].replace("\n", "\n ")) |
| print() |
|
|
| |
| |
| |
| print(f"[4/4] Splitting and saving...") |
|
|
| perm = list(range(total)) |
| random.shuffle(perm) |
| val_n = max(1, int(total * args.val_ratio)) |
| train_n = total - val_n |
|
|
| train_ids = [all_input_ids[i] for i in perm[:train_n]] |
| train_lbl = [all_labels[i] for i in perm[:train_n]] |
| val_ids = [all_input_ids[i] for i in perm[train_n:]] |
| val_lbl = [all_labels[i] for i in perm[train_n:]] |
|
|
| train_path = os.path.join(args.output_dir, "train_sft.pt") |
| val_path = os.path.join(args.output_dir, "val_sft.pt") |
|
|
| torch.save({"input_ids": train_ids, "labels": train_lbl}, train_path) |
| torch.save({"input_ids": val_ids, "labels": val_lbl}, val_path) |
|
|
| |
| lengths = [len(x) for x in all_input_ids] |
| label_ratios = [(t != -100).float().mean().item() for t in all_labels] |
| avg_len = sum(lengths) / len(lengths) |
| avg_lbl_ratio = sum(label_ratios) / len(label_ratios) |
|
|
| print(f"\n train_sft.pt : {train_n:,} examples") |
| print(f" val_sft.pt : {val_n:,} examples") |
| print(f"\n Avg seq length : {avg_len:.0f} tokens (max={max(lengths)})") |
| print(f" Avg assistant ratio : {avg_lbl_ratio:.1%} of tokens are labeled") |
|
|
| |
| meta = { |
| "dataset": "teknium/OpenHermes-2.5", |
| "n_sampled": n, |
| "n_train": train_n, |
| "n_val": val_n, |
| "vocab_size": len(tokenizer), |
| "special_tokens": SPECIAL_TOKENS, |
| "max_length": MAX_LENGTH, |
| "seed": args.seed, |
| } |
| with open(os.path.join(args.output_dir, "meta.json"), "w") as f: |
| json.dump(meta, f, indent=2) |
| print(f"\n meta.json saved → {args.output_dir}/meta.json") |
|
|
| print("\n" + "=" * 60) |
| print(" Data preparation complete!") |
| print("=" * 60) |
| print(f""" |
| Next step: |
| python finetune/sft_train.py \\ |
| --base_ckpt runs/sllm_150m/ckpt_0011500.pt \\ |
| --run_dir runs/sllm_150m_chat \\ |
| --max_steps 2000 \\ |
| --batch_size 4 --grad_accum 8 \\ |
| --grad_checkpoint |
| """) |
|
|
|
|
| if __name__ == "__main__": |
| main() |
|
|