sllm / finetune /prepare_data.py
geeteshcodes's picture
Initial commit
7f974df verified
"""
finetune/prepare_data.py
Downloads teknium/OpenHermes-2.5 from HuggingFace, formats conversations
as ChatML, tokenizes with our custom tokenizer + 2 new special tokens,
and saves train_sft.pt / val_sft.pt to finetune/data/.
Also saves the tokenizer (with special tokens baked in) to finetune/data/
so sft_train.py and chat.py can load it without re-adding tokens.
Usage:
python finetune/prepare_data.py
python finetune/prepare_data.py --n_samples 50000
Dataset structure (OpenHermes-2.5):
Each row has a "conversations" key:
[
{"from": "system", "value": "..."}, # optional
{"from": "human", "value": "..."},
{"from": "gpt", "value": "..."},
... # may have more turns
]
"""
import os
import sys
import json
import random
import argparse
from pathlib import Path
import torch
from transformers import PreTrainedTokenizerFast
from datasets import load_dataset
from tqdm import tqdm
# ------------------------------------------------------------------ #
# Paths (relative to project root, not this script)
# ------------------------------------------------------------------ #
SCRIPT_DIR = Path(__file__).resolve().parent
PROJECT_ROOT = SCRIPT_DIR.parent
sys.path.insert(0, str(PROJECT_ROOT))
TOKENIZER_DIR = PROJECT_ROOT / "tokenizer" / "fineweb_edu_tokenizer"
# The two new tokens that define ChatML structure
SPECIAL_TOKENS = ["<|im_start|>", "<|im_end|>"]
MAX_LENGTH = 1024 # model context_length — truncate anything longer
# Map OpenHermes role names → ChatML role names
ROLE_MAP = {
"system": "system",
"human": "user",
"gpt": "assistant",
"user": "user",
"assistant": "assistant",
}
# ------------------------------------------------------------------ #
# TOKENIZER
# ------------------------------------------------------------------ #
def load_and_extend_tokenizer() -> PreTrainedTokenizerFast:
"""
Loads our pretrained BPE tokenizer and adds the two ChatML tokens.
Returns the extended tokenizer (vocab 32,000 → 32,002).
"""
tokenizer = PreTrainedTokenizerFast.from_pretrained(str(TOKENIZER_DIR))
new_tokens = [t for t in SPECIAL_TOKENS if t not in tokenizer.get_vocab()]
if new_tokens:
added = tokenizer.add_special_tokens({"additional_special_tokens": new_tokens})
print(f" Added {added} special token(s): {new_tokens}")
else:
print(" Special tokens already present — skipping add.")
print(f" Final vocab size: {len(tokenizer):,}")
return tokenizer
# ------------------------------------------------------------------ #
# FORMAT + TOKENIZE ONE CONVERSATION
# ------------------------------------------------------------------ #
def format_and_tokenize(
conversations: list[dict],
tokenizer: PreTrainedTokenizerFast,
) -> tuple[list[int], list[int]] | None:
"""
Converts a list of chat turns into (input_ids, labels).
ChatML format per turn:
<|im_start|>{role}\\n{content}<|im_end|>\\n
Labels:
- User / system turns → all -100 (not learned)
- Assistant turns → header (-100) + content (actual token ids)
i.e. we learn the response but not the "<|im_start|>assistant\\n" prefix
Returns None for:
- Conversations with no assistant turns (nothing to learn)
- Conversations that tokenize to fewer than 8 tokens
"""
input_ids: list[int] = []
labels: list[int] = []
for turn in conversations:
role_raw = turn.get("from", turn.get("role", "")).strip().lower()
content = turn.get("value", turn.get("content", "")).strip()
role = ROLE_MAP.get(role_raw, role_raw)
if not content or not role:
continue
# ---- header: <|im_start|>role\n — never labeled ----------- #
header_text = f"<|im_start|>{role}\n"
header_ids = tokenizer.encode(header_text, add_special_tokens=False)
# ---- body: content<|im_end|>\n ------------------------------ #
body_text = f"{content}<|im_end|>\n"
body_ids = tokenizer.encode(body_text, add_special_tokens=False)
turn_input = header_ids + body_ids
if role == "assistant":
# Teach the model the body (response + im_end), not the header
turn_labels = [-100] * len(header_ids) + body_ids
else:
# User / system: no learning signal
turn_labels = [-100] * len(turn_input)
input_ids.extend(turn_input)
labels.extend(turn_labels)
# Must have at least one labeled token to be a valid training example
if not any(l != -100 for l in labels):
return None
# Truncate to context window
input_ids = input_ids[:MAX_LENGTH]
labels = labels[:MAX_LENGTH]
# Skip micro-sequences (likely malformed)
if len(input_ids) < 8:
return None
return input_ids, labels
# ------------------------------------------------------------------ #
# ARG PARSING
# ------------------------------------------------------------------ #
def parse_args():
p = argparse.ArgumentParser(description="Prepare SFT data from OpenHermes-2.5")
p.add_argument("--n_samples", type=int, default=80_000,
help="Number of conversations to sample (default: 80000)")
p.add_argument("--val_ratio", type=float, default=0.05,
help="Fraction held out for validation (default: 0.05)")
p.add_argument("--output_dir", type=str, default=str(SCRIPT_DIR / "data"),
help="Where to save train_sft.pt, val_sft.pt, and tokenizer")
p.add_argument("--seed", type=int, default=42)
return p.parse_args()
# ------------------------------------------------------------------ #
# MAIN
# ------------------------------------------------------------------ #
def main():
args = parse_args()
random.seed(args.seed)
os.makedirs(args.output_dir, exist_ok=True)
print("\n" + "=" * 60)
print(" SLLM-150M SFT — Data Preparation")
print("=" * 60)
# ---------------------------------------------------------------- #
# 1. Tokenizer
# ---------------------------------------------------------------- #
print("\n[1/4] Loading tokenizer + adding ChatML special tokens...")
tokenizer = load_and_extend_tokenizer()
# Save the extended tokenizer to data dir so training/chat can load it
tokenizer.save_pretrained(args.output_dir)
print(f" Extended tokenizer saved → {args.output_dir}/")
# ---------------------------------------------------------------- #
# 2. Dataset download
# ---------------------------------------------------------------- #
print(f"\n[2/4] Loading teknium/OpenHermes-2.5 from HuggingFace...")
ds = load_dataset("teknium/OpenHermes-2.5")
full = ds["train"] # only split in this dataset
print(f" Full dataset size: {len(full):,} examples")
# Sample a subset
n = min(args.n_samples, len(full))
indices = random.sample(range(len(full)), n)
subset = full.select(indices)
print(f" Sampled: {n:,} examples (seed={args.seed})")
# ---------------------------------------------------------------- #
# 3. Tokenize
# ---------------------------------------------------------------- #
print(f"\n[3/4] Formatting and tokenizing conversations...")
all_input_ids: list[torch.Tensor] = []
all_labels: list[torch.Tensor] = []
skipped = 0
for example in tqdm(subset, desc="Tokenizing", unit="conv"):
conversations = example.get("conversations", [])
result = format_and_tokenize(conversations, tokenizer)
if result is None:
skipped += 1
continue
ids, lbls = result
all_input_ids.append(torch.tensor(ids, dtype=torch.long))
all_labels.append( torch.tensor(lbls, dtype=torch.long))
total = len(all_input_ids)
print(f"\n Kept : {total:,}")
print(f" Skipped: {skipped:,} (no assistant turn or too short)")
if total == 0:
raise RuntimeError("No valid examples produced — check dataset structure.")
# Print a sample so we can visually verify
print("\n ── Sample (first conversation, first 400 chars) ──")
sample_decoded = tokenizer.decode(all_input_ids[0].tolist(), skip_special_tokens=False)
print(" " + sample_decoded[:400].replace("\n", "\n "))
print()
# ---------------------------------------------------------------- #
# 4. Split + save
# ---------------------------------------------------------------- #
print(f"[4/4] Splitting and saving...")
perm = list(range(total))
random.shuffle(perm)
val_n = max(1, int(total * args.val_ratio))
train_n = total - val_n
train_ids = [all_input_ids[i] for i in perm[:train_n]]
train_lbl = [all_labels[i] for i in perm[:train_n]]
val_ids = [all_input_ids[i] for i in perm[train_n:]]
val_lbl = [all_labels[i] for i in perm[train_n:]]
train_path = os.path.join(args.output_dir, "train_sft.pt")
val_path = os.path.join(args.output_dir, "val_sft.pt")
torch.save({"input_ids": train_ids, "labels": train_lbl}, train_path)
torch.save({"input_ids": val_ids, "labels": val_lbl}, val_path)
# Stats
lengths = [len(x) for x in all_input_ids]
label_ratios = [(t != -100).float().mean().item() for t in all_labels]
avg_len = sum(lengths) / len(lengths)
avg_lbl_ratio = sum(label_ratios) / len(label_ratios)
print(f"\n train_sft.pt : {train_n:,} examples")
print(f" val_sft.pt : {val_n:,} examples")
print(f"\n Avg seq length : {avg_len:.0f} tokens (max={max(lengths)})")
print(f" Avg assistant ratio : {avg_lbl_ratio:.1%} of tokens are labeled")
# Save metadata for reference
meta = {
"dataset": "teknium/OpenHermes-2.5",
"n_sampled": n,
"n_train": train_n,
"n_val": val_n,
"vocab_size": len(tokenizer),
"special_tokens": SPECIAL_TOKENS,
"max_length": MAX_LENGTH,
"seed": args.seed,
}
with open(os.path.join(args.output_dir, "meta.json"), "w") as f:
json.dump(meta, f, indent=2)
print(f"\n meta.json saved → {args.output_dir}/meta.json")
print("\n" + "=" * 60)
print(" Data preparation complete!")
print("=" * 60)
print(f"""
Next step:
python finetune/sft_train.py \\
--base_ckpt runs/sllm_150m/ckpt_0011500.pt \\
--run_dir runs/sllm_150m_chat \\
--max_steps 2000 \\
--batch_size 4 --grad_accum 8 \\
--grad_checkpoint
""")
if __name__ == "__main__":
main()