sllm / finetune /check_data.py
geeteshcodes's picture
Initial commit
7f974df verified
"""
finetune/check_data.py
Smoke-test: loads 5 rows from OpenHermes-2.5, runs them through the
same format_and_tokenize() logic used by prepare_data.py, and prints
a full visual audit so you can confirm everything lines up.
Checks:
1. Raw conversation structure from the dataset
2. ChatML text that gets fed to the tokenizer
3. Token IDs and decoded tokens (side-by-side)
4. Label mask β€” βœ“ (labeled) vs (masked -100) for every token
5. Label ratio (should be ~30-60% assistant tokens)
Run from project root:
python finetune/check_data.py
python finetune/check_data.py --row 3 # inspect a specific row index
"""
import sys
import argparse
from pathlib import Path
# ------------------------------------------------------------------ #
# Paths
# ------------------------------------------------------------------ #
SCRIPT_DIR = Path(__file__).resolve().parent
PROJECT_ROOT = SCRIPT_DIR.parent
TOKENIZER_DIR = PROJECT_ROOT / "tokenizer" / "fineweb_edu_tokenizer"
sys.path.insert(0, str(PROJECT_ROOT))
from transformers import PreTrainedTokenizerFast
from datasets import load_dataset
SPECIAL_TOKENS = ["<|im_start|>", "<|im_end|>"]
MAX_LENGTH = 1024
ROLE_MAP = {
"system": "system",
"human": "user",
"gpt": "assistant",
"user": "user",
"assistant": "assistant",
}
# ------------------------------------------------------------------ #
# Replicated from prepare_data.py (no import to keep this self-contained)
# ------------------------------------------------------------------ #
def load_tokenizer() -> PreTrainedTokenizerFast:
tok = PreTrainedTokenizerFast.from_pretrained(str(TOKENIZER_DIR))
new = [t for t in SPECIAL_TOKENS if t not in tok.get_vocab()]
if new:
tok.add_special_tokens({"additional_special_tokens": new})
return tok
def format_and_tokenize(conversations, tokenizer):
"""Identical logic to prepare_data.py β€” returns (input_ids, labels) or None."""
input_ids, labels = [], []
for turn in conversations:
role_raw = turn.get("from", turn.get("role", "")).strip().lower()
content = turn.get("value", turn.get("content", "")).strip()
role = ROLE_MAP.get(role_raw, role_raw)
if not content or not role:
continue
header_text = f"<|im_start|>{role}\n"
header_ids = tokenizer.encode(header_text, add_special_tokens=False)
body_text = f"{content}<|im_end|>\n"
body_ids = tokenizer.encode(body_text, add_special_tokens=False)
turn_input = header_ids + body_ids
if role == "assistant":
turn_labels = [-100] * len(header_ids) + body_ids
else:
turn_labels = [-100] * len(turn_input)
input_ids.extend(turn_input)
labels.extend(turn_labels)
if not any(l != -100 for l in labels):
return None
input_ids = input_ids[:MAX_LENGTH]
labels = labels[:MAX_LENGTH]
if len(input_ids) < 8:
return None
return input_ids, labels
# ------------------------------------------------------------------ #
# Pretty-print helpers
# ------------------------------------------------------------------ #
def print_section(title: str):
print(f"\n{'─'*60}")
print(f" {title}")
print(f"{'─'*60}")
def print_token_table(input_ids, labels, tokenizer, max_rows: int = 80):
"""
Prints a table: idx | token_str | label (βœ“ or βœ—)
Green βœ“ = labeled (assistant) β€” model learns this
Red βœ— = masked -100 β€” model ignores this
"""
GREEN = "\033[92m"
RED = "\033[91m"
RESET = "\033[0m"
print(f"\n {'IDX':>5} {'TOKEN':<22} {'ID':>6} {'LABEL':>8} {'LEARN?'}")
print(f" {'─'*5} {'─'*22} {'─'*6} {'─'*8} {'─'*6}")
shown = 0
for i, (tok_id, lbl) in enumerate(zip(input_ids, labels)):
tok_str = repr(tokenizer.decode([tok_id]))[:22]
if lbl == -100:
learn_str = f"{RED}βœ— masked{RESET}"
lbl_str = " -100"
else:
learn_str = f"{GREEN}βœ“ learn {RESET}"
lbl_str = f"{lbl:>8}"
print(f" {i:>5} {tok_str:<22} {tok_id:>6} {lbl_str} {learn_str}")
shown += 1
if shown >= max_rows:
remaining = len(input_ids) - max_rows
print(f" ... ({remaining} more tokens not shown)")
break
# Summary
n_labeled = sum(1 for l in labels if l != -100)
n_total = len(labels)
print(f"\n Total tokens : {n_total}")
print(f" Labeled : {n_labeled} ({n_labeled/n_total:.1%}) ← assistant tokens")
print(f" Masked : {n_total - n_labeled} ({(n_total-n_labeled)/n_total:.1%}) ← user/system tokens")
# ------------------------------------------------------------------ #
# MAIN
# ------------------------------------------------------------------ #
def parse_args():
p = argparse.ArgumentParser(description="Check one OpenHermes row through the SFT pipeline")
p.add_argument("--row", type=int, default=0,
help="Which row to inspect in detail (0-indexed, from the first 20 fetched)")
p.add_argument("--n_fetch", type=int, default=20,
help="How many rows to fetch from HuggingFace (default: 20)")
return p.parse_args()
def main():
args = parse_args()
print("\n" + "=" * 60)
print(" SFT Pipeline β€” Data Alignment Check")
print("=" * 60)
# ---- 1. Tokenizer ---------------------------------------------- #
print_section("1. Tokenizer")
tokenizer = load_tokenizer()
im_start_id = tokenizer.convert_tokens_to_ids("<|im_start|>")
im_end_id = tokenizer.convert_tokens_to_ids("<|im_end|>")
print(f" Vocab size : {len(tokenizer):,}")
print(f" <|im_start|> : token ID {im_start_id}")
print(f" <|im_end|> : token ID {im_end_id}")
assert im_start_id != tokenizer.unk_token_id, "ERROR: <|im_start|> not in vocab!"
assert im_end_id != tokenizer.unk_token_id, "ERROR: <|im_end|> not in vocab!"
print(" βœ“ Special tokens present in vocab")
# ---- 2. Load one row ------------------------------------------- #
print_section(f"2. Loading row {args.row} from OpenHermes-2.5")
print(f" Loading first {args.n_fetch} rows from local cache (Arrow format)...")
ds = load_dataset("teknium/OpenHermes-2.5", split="train")
row = ds[args.row]
convs = row.get("conversations", [])
print(f" Row index : {args.row}")
print(f" Turns in conv : {len(convs)}")
# ---- 3. Raw conversation --------------------------------------- #
print_section("3. Raw conversation (from dataset)")
for i, turn in enumerate(convs):
role = turn.get("from", "?")
content = turn.get("value", "").strip()
preview = content[:120].replace("\n", "↡")
print(f" [{i}] from={role!r:12s} | {preview!r}")
# ---- 4. ChatML formatted text ---------------------------------- #
print_section("4. ChatML text (what tokenizer sees)")
chatml = ""
for turn in convs:
role_raw = turn.get("from", "").strip().lower()
content = turn.get("value", "").strip()
role = ROLE_MAP.get(role_raw, role_raw)
if content and role:
chatml += f"<|im_start|>{role}\n{content}<|im_end|>\n"
print(chatml[:800])
if len(chatml) > 800:
print(f" ... ({len(chatml) - 800} more chars)")
# ---- 5. Run through format_and_tokenize ----------------------- #
print_section("5. format_and_tokenize() output")
result = format_and_tokenize(convs, tokenizer)
if result is None:
print(" βœ— RETURNED None β€” no assistant turn or too short.")
print(" Try a different --row index.")
return
input_ids, labels = result
print(f" input_ids length : {len(input_ids)}")
print(f" labels length : {len(labels)}")
assert len(input_ids) == len(labels), "MISMATCH: input_ids and labels have different lengths!"
print(" βœ“ Lengths match")
# ---- 6. Verify label alignment --------------------------------- #
print_section("6. Label alignment sanity checks")
# Every im_start should be masked
im_start_positions = [i for i, t in enumerate(input_ids) if t == im_start_id]
im_end_positions = [i for i, t in enumerate(input_ids) if t == im_end_id]
print(f" <|im_start|> positions : {im_start_positions}")
print(f" <|im_end|> positions : {im_end_positions}")
im_start_masked = all(labels[i] == -100 for i in im_start_positions)
print(f" All <|im_start|> tokens are masked (-100) : {'βœ“' if im_start_masked else 'βœ— FAIL'}")
# Decode the labeled span to confirm it's the assistant content
labeled_ids = [t for t, l in zip(input_ids, labels) if l != -100]
labeled_text = tokenizer.decode(labeled_ids, skip_special_tokens=False)
print(f"\n Labeled (assistant) text preview:")
print(f" {labeled_text[:300].replace(chr(10), '↡')!r}")
# Check that labeled text doesn't contain user/system markers
if "user\n" in labeled_text or "system\n" in labeled_text:
print(" βœ— WARNING: user/system content found in labeled tokens!")
else:
print(" βœ“ Labeled tokens contain only assistant content")
# ---- 7. Token-by-token table ----------------------------------- #
print_section("7. Token-by-token table (first 80 tokens)")
print_token_table(input_ids, labels, tokenizer, max_rows=80)
# ---- 8. Decode round-trip ------------------------------------- #
print_section("8. Full decode round-trip (skip_special_tokens=False)")
decoded = tokenizer.decode(input_ids, skip_special_tokens=False)
print(decoded[:600])
print("\n" + "=" * 60)
print(" CHECK COMPLETE β€” pipeline looks aligned βœ“")
print("=" * 60)
print(f"\nWhen ready, run the full data prep:")
print(f" python finetune/prepare_data.py")
if __name__ == "__main__":
main()