""" finetune/check_data.py Smoke-test: loads 5 rows from OpenHermes-2.5, runs them through the same format_and_tokenize() logic used by prepare_data.py, and prints a full visual audit so you can confirm everything lines up. Checks: 1. Raw conversation structure from the dataset 2. ChatML text that gets fed to the tokenizer 3. Token IDs and decoded tokens (side-by-side) 4. Label mask — ✓ (labeled) vs (masked -100) for every token 5. Label ratio (should be ~30-60% assistant tokens) Run from project root: python finetune/check_data.py python finetune/check_data.py --row 3 # inspect a specific row index """ import sys import argparse from pathlib import Path # ------------------------------------------------------------------ # # Paths # ------------------------------------------------------------------ # SCRIPT_DIR = Path(__file__).resolve().parent PROJECT_ROOT = SCRIPT_DIR.parent TOKENIZER_DIR = PROJECT_ROOT / "tokenizer" / "fineweb_edu_tokenizer" sys.path.insert(0, str(PROJECT_ROOT)) from transformers import PreTrainedTokenizerFast from datasets import load_dataset SPECIAL_TOKENS = ["<|im_start|>", "<|im_end|>"] MAX_LENGTH = 1024 ROLE_MAP = { "system": "system", "human": "user", "gpt": "assistant", "user": "user", "assistant": "assistant", } # ------------------------------------------------------------------ # # Replicated from prepare_data.py (no import to keep this self-contained) # ------------------------------------------------------------------ # def load_tokenizer() -> PreTrainedTokenizerFast: tok = PreTrainedTokenizerFast.from_pretrained(str(TOKENIZER_DIR)) new = [t for t in SPECIAL_TOKENS if t not in tok.get_vocab()] if new: tok.add_special_tokens({"additional_special_tokens": new}) return tok def format_and_tokenize(conversations, tokenizer): """Identical logic to prepare_data.py — returns (input_ids, labels) or None.""" input_ids, labels = [], [] for turn in conversations: role_raw = turn.get("from", turn.get("role", "")).strip().lower() content = turn.get("value", turn.get("content", "")).strip() role = ROLE_MAP.get(role_raw, role_raw) if not content or not role: continue header_text = f"<|im_start|>{role}\n" header_ids = tokenizer.encode(header_text, add_special_tokens=False) body_text = f"{content}<|im_end|>\n" body_ids = tokenizer.encode(body_text, add_special_tokens=False) turn_input = header_ids + body_ids if role == "assistant": turn_labels = [-100] * len(header_ids) + body_ids else: turn_labels = [-100] * len(turn_input) input_ids.extend(turn_input) labels.extend(turn_labels) if not any(l != -100 for l in labels): return None input_ids = input_ids[:MAX_LENGTH] labels = labels[:MAX_LENGTH] if len(input_ids) < 8: return None return input_ids, labels # ------------------------------------------------------------------ # # Pretty-print helpers # ------------------------------------------------------------------ # def print_section(title: str): print(f"\n{'─'*60}") print(f" {title}") print(f"{'─'*60}") def print_token_table(input_ids, labels, tokenizer, max_rows: int = 80): """ Prints a table: idx | token_str | label (✓ or ✗) Green ✓ = labeled (assistant) — model learns this Red ✗ = masked -100 — model ignores this """ GREEN = "\033[92m" RED = "\033[91m" RESET = "\033[0m" print(f"\n {'IDX':>5} {'TOKEN':<22} {'ID':>6} {'LABEL':>8} {'LEARN?'}") print(f" {'─'*5} {'─'*22} {'─'*6} {'─'*8} {'─'*6}") shown = 0 for i, (tok_id, lbl) in enumerate(zip(input_ids, labels)): tok_str = repr(tokenizer.decode([tok_id]))[:22] if lbl == -100: learn_str = f"{RED}✗ masked{RESET}" lbl_str = " -100" else: learn_str = f"{GREEN}✓ learn {RESET}" lbl_str = f"{lbl:>8}" print(f" {i:>5} {tok_str:<22} {tok_id:>6} {lbl_str} {learn_str}") shown += 1 if shown >= max_rows: remaining = len(input_ids) - max_rows print(f" ... ({remaining} more tokens not shown)") break # Summary n_labeled = sum(1 for l in labels if l != -100) n_total = len(labels) print(f"\n Total tokens : {n_total}") print(f" Labeled : {n_labeled} ({n_labeled/n_total:.1%}) ← assistant tokens") print(f" Masked : {n_total - n_labeled} ({(n_total-n_labeled)/n_total:.1%}) ← user/system tokens") # ------------------------------------------------------------------ # # MAIN # ------------------------------------------------------------------ # def parse_args(): p = argparse.ArgumentParser(description="Check one OpenHermes row through the SFT pipeline") p.add_argument("--row", type=int, default=0, help="Which row to inspect in detail (0-indexed, from the first 20 fetched)") p.add_argument("--n_fetch", type=int, default=20, help="How many rows to fetch from HuggingFace (default: 20)") return p.parse_args() def main(): args = parse_args() print("\n" + "=" * 60) print(" SFT Pipeline — Data Alignment Check") print("=" * 60) # ---- 1. Tokenizer ---------------------------------------------- # print_section("1. Tokenizer") tokenizer = load_tokenizer() im_start_id = tokenizer.convert_tokens_to_ids("<|im_start|>") im_end_id = tokenizer.convert_tokens_to_ids("<|im_end|>") print(f" Vocab size : {len(tokenizer):,}") print(f" <|im_start|> : token ID {im_start_id}") print(f" <|im_end|> : token ID {im_end_id}") assert im_start_id != tokenizer.unk_token_id, "ERROR: <|im_start|> not in vocab!" assert im_end_id != tokenizer.unk_token_id, "ERROR: <|im_end|> not in vocab!" print(" ✓ Special tokens present in vocab") # ---- 2. Load one row ------------------------------------------- # print_section(f"2. Loading row {args.row} from OpenHermes-2.5") print(f" Loading first {args.n_fetch} rows from local cache (Arrow format)...") ds = load_dataset("teknium/OpenHermes-2.5", split="train") row = ds[args.row] convs = row.get("conversations", []) print(f" Row index : {args.row}") print(f" Turns in conv : {len(convs)}") # ---- 3. Raw conversation --------------------------------------- # print_section("3. Raw conversation (from dataset)") for i, turn in enumerate(convs): role = turn.get("from", "?") content = turn.get("value", "").strip() preview = content[:120].replace("\n", "↵") print(f" [{i}] from={role!r:12s} | {preview!r}") # ---- 4. ChatML formatted text ---------------------------------- # print_section("4. ChatML text (what tokenizer sees)") chatml = "" for turn in convs: role_raw = turn.get("from", "").strip().lower() content = turn.get("value", "").strip() role = ROLE_MAP.get(role_raw, role_raw) if content and role: chatml += f"<|im_start|>{role}\n{content}<|im_end|>\n" print(chatml[:800]) if len(chatml) > 800: print(f" ... ({len(chatml) - 800} more chars)") # ---- 5. Run through format_and_tokenize ----------------------- # print_section("5. format_and_tokenize() output") result = format_and_tokenize(convs, tokenizer) if result is None: print(" ✗ RETURNED None — no assistant turn or too short.") print(" Try a different --row index.") return input_ids, labels = result print(f" input_ids length : {len(input_ids)}") print(f" labels length : {len(labels)}") assert len(input_ids) == len(labels), "MISMATCH: input_ids and labels have different lengths!" print(" ✓ Lengths match") # ---- 6. Verify label alignment --------------------------------- # print_section("6. Label alignment sanity checks") # Every im_start should be masked im_start_positions = [i for i, t in enumerate(input_ids) if t == im_start_id] im_end_positions = [i for i, t in enumerate(input_ids) if t == im_end_id] print(f" <|im_start|> positions : {im_start_positions}") print(f" <|im_end|> positions : {im_end_positions}") im_start_masked = all(labels[i] == -100 for i in im_start_positions) print(f" All <|im_start|> tokens are masked (-100) : {'✓' if im_start_masked else '✗ FAIL'}") # Decode the labeled span to confirm it's the assistant content labeled_ids = [t for t, l in zip(input_ids, labels) if l != -100] labeled_text = tokenizer.decode(labeled_ids, skip_special_tokens=False) print(f"\n Labeled (assistant) text preview:") print(f" {labeled_text[:300].replace(chr(10), '↵')!r}") # Check that labeled text doesn't contain user/system markers if "user\n" in labeled_text or "system\n" in labeled_text: print(" ✗ WARNING: user/system content found in labeled tokens!") else: print(" ✓ Labeled tokens contain only assistant content") # ---- 7. Token-by-token table ----------------------------------- # print_section("7. Token-by-token table (first 80 tokens)") print_token_table(input_ids, labels, tokenizer, max_rows=80) # ---- 8. Decode round-trip ------------------------------------- # print_section("8. Full decode round-trip (skip_special_tokens=False)") decoded = tokenizer.decode(input_ids, skip_special_tokens=False) print(decoded[:600]) print("\n" + "=" * 60) print(" CHECK COMPLETE — pipeline looks aligned ✓") print("=" * 60) print(f"\nWhen ready, run the full data prep:") print(f" python finetune/prepare_data.py") if __name__ == "__main__": main()