| """ |
| finetune/check_data.py |
| |
| Smoke-test: loads 5 rows from OpenHermes-2.5, runs them through the |
| same format_and_tokenize() logic used by prepare_data.py, and prints |
| a full visual audit so you can confirm everything lines up. |
| |
| Checks: |
| 1. Raw conversation structure from the dataset |
| 2. ChatML text that gets fed to the tokenizer |
| 3. Token IDs and decoded tokens (side-by-side) |
| 4. Label mask β β (labeled) vs (masked -100) for every token |
| 5. Label ratio (should be ~30-60% assistant tokens) |
| |
| Run from project root: |
| python finetune/check_data.py |
| python finetune/check_data.py --row 3 # inspect a specific row index |
| """ |
|
|
| import sys |
| import argparse |
| from pathlib import Path |
|
|
| |
| |
| |
|
|
| SCRIPT_DIR = Path(__file__).resolve().parent |
| PROJECT_ROOT = SCRIPT_DIR.parent |
| TOKENIZER_DIR = PROJECT_ROOT / "tokenizer" / "fineweb_edu_tokenizer" |
|
|
| sys.path.insert(0, str(PROJECT_ROOT)) |
|
|
| from transformers import PreTrainedTokenizerFast |
| from datasets import load_dataset |
|
|
| SPECIAL_TOKENS = ["<|im_start|>", "<|im_end|>"] |
| MAX_LENGTH = 1024 |
|
|
| ROLE_MAP = { |
| "system": "system", |
| "human": "user", |
| "gpt": "assistant", |
| "user": "user", |
| "assistant": "assistant", |
| } |
|
|
|
|
| |
| |
| |
|
|
| def load_tokenizer() -> PreTrainedTokenizerFast: |
| tok = PreTrainedTokenizerFast.from_pretrained(str(TOKENIZER_DIR)) |
| new = [t for t in SPECIAL_TOKENS if t not in tok.get_vocab()] |
| if new: |
| tok.add_special_tokens({"additional_special_tokens": new}) |
| return tok |
|
|
|
|
| def format_and_tokenize(conversations, tokenizer): |
| """Identical logic to prepare_data.py β returns (input_ids, labels) or None.""" |
| input_ids, labels = [], [] |
|
|
| for turn in conversations: |
| role_raw = turn.get("from", turn.get("role", "")).strip().lower() |
| content = turn.get("value", turn.get("content", "")).strip() |
| role = ROLE_MAP.get(role_raw, role_raw) |
|
|
| if not content or not role: |
| continue |
|
|
| header_text = f"<|im_start|>{role}\n" |
| header_ids = tokenizer.encode(header_text, add_special_tokens=False) |
|
|
| body_text = f"{content}<|im_end|>\n" |
| body_ids = tokenizer.encode(body_text, add_special_tokens=False) |
|
|
| turn_input = header_ids + body_ids |
|
|
| if role == "assistant": |
| turn_labels = [-100] * len(header_ids) + body_ids |
| else: |
| turn_labels = [-100] * len(turn_input) |
|
|
| input_ids.extend(turn_input) |
| labels.extend(turn_labels) |
|
|
| if not any(l != -100 for l in labels): |
| return None |
|
|
| input_ids = input_ids[:MAX_LENGTH] |
| labels = labels[:MAX_LENGTH] |
|
|
| if len(input_ids) < 8: |
| return None |
|
|
| return input_ids, labels |
|
|
|
|
| |
| |
| |
|
|
| def print_section(title: str): |
| print(f"\n{'β'*60}") |
| print(f" {title}") |
| print(f"{'β'*60}") |
|
|
|
|
| def print_token_table(input_ids, labels, tokenizer, max_rows: int = 80): |
| """ |
| Prints a table: idx | token_str | label (β or β) |
| Green β = labeled (assistant) β model learns this |
| Red β = masked -100 β model ignores this |
| """ |
| GREEN = "\033[92m" |
| RED = "\033[91m" |
| RESET = "\033[0m" |
|
|
| print(f"\n {'IDX':>5} {'TOKEN':<22} {'ID':>6} {'LABEL':>8} {'LEARN?'}") |
| print(f" {'β'*5} {'β'*22} {'β'*6} {'β'*8} {'β'*6}") |
|
|
| shown = 0 |
| for i, (tok_id, lbl) in enumerate(zip(input_ids, labels)): |
| tok_str = repr(tokenizer.decode([tok_id]))[:22] |
| if lbl == -100: |
| learn_str = f"{RED}β masked{RESET}" |
| lbl_str = " -100" |
| else: |
| learn_str = f"{GREEN}β learn {RESET}" |
| lbl_str = f"{lbl:>8}" |
|
|
| print(f" {i:>5} {tok_str:<22} {tok_id:>6} {lbl_str} {learn_str}") |
| shown += 1 |
| if shown >= max_rows: |
| remaining = len(input_ids) - max_rows |
| print(f" ... ({remaining} more tokens not shown)") |
| break |
|
|
| |
| n_labeled = sum(1 for l in labels if l != -100) |
| n_total = len(labels) |
| print(f"\n Total tokens : {n_total}") |
| print(f" Labeled : {n_labeled} ({n_labeled/n_total:.1%}) β assistant tokens") |
| print(f" Masked : {n_total - n_labeled} ({(n_total-n_labeled)/n_total:.1%}) β user/system tokens") |
|
|
|
|
| |
| |
| |
|
|
| def parse_args(): |
| p = argparse.ArgumentParser(description="Check one OpenHermes row through the SFT pipeline") |
| p.add_argument("--row", type=int, default=0, |
| help="Which row to inspect in detail (0-indexed, from the first 20 fetched)") |
| p.add_argument("--n_fetch", type=int, default=20, |
| help="How many rows to fetch from HuggingFace (default: 20)") |
| return p.parse_args() |
|
|
|
|
| def main(): |
| args = parse_args() |
|
|
| print("\n" + "=" * 60) |
| print(" SFT Pipeline β Data Alignment Check") |
| print("=" * 60) |
|
|
| |
| print_section("1. Tokenizer") |
| tokenizer = load_tokenizer() |
| im_start_id = tokenizer.convert_tokens_to_ids("<|im_start|>") |
| im_end_id = tokenizer.convert_tokens_to_ids("<|im_end|>") |
| print(f" Vocab size : {len(tokenizer):,}") |
| print(f" <|im_start|> : token ID {im_start_id}") |
| print(f" <|im_end|> : token ID {im_end_id}") |
| assert im_start_id != tokenizer.unk_token_id, "ERROR: <|im_start|> not in vocab!" |
| assert im_end_id != tokenizer.unk_token_id, "ERROR: <|im_end|> not in vocab!" |
| print(" β Special tokens present in vocab") |
|
|
| |
| print_section(f"2. Loading row {args.row} from OpenHermes-2.5") |
| print(f" Loading first {args.n_fetch} rows from local cache (Arrow format)...") |
| ds = load_dataset("teknium/OpenHermes-2.5", split="train") |
| row = ds[args.row] |
| convs = row.get("conversations", []) |
|
|
| print(f" Row index : {args.row}") |
| print(f" Turns in conv : {len(convs)}") |
|
|
| |
| print_section("3. Raw conversation (from dataset)") |
| for i, turn in enumerate(convs): |
| role = turn.get("from", "?") |
| content = turn.get("value", "").strip() |
| preview = content[:120].replace("\n", "β΅") |
| print(f" [{i}] from={role!r:12s} | {preview!r}") |
|
|
| |
| print_section("4. ChatML text (what tokenizer sees)") |
| chatml = "" |
| for turn in convs: |
| role_raw = turn.get("from", "").strip().lower() |
| content = turn.get("value", "").strip() |
| role = ROLE_MAP.get(role_raw, role_raw) |
| if content and role: |
| chatml += f"<|im_start|>{role}\n{content}<|im_end|>\n" |
| print(chatml[:800]) |
| if len(chatml) > 800: |
| print(f" ... ({len(chatml) - 800} more chars)") |
|
|
| |
| print_section("5. format_and_tokenize() output") |
| result = format_and_tokenize(convs, tokenizer) |
|
|
| if result is None: |
| print(" β RETURNED None β no assistant turn or too short.") |
| print(" Try a different --row index.") |
| return |
|
|
| input_ids, labels = result |
| print(f" input_ids length : {len(input_ids)}") |
| print(f" labels length : {len(labels)}") |
| assert len(input_ids) == len(labels), "MISMATCH: input_ids and labels have different lengths!" |
| print(" β Lengths match") |
|
|
| |
| print_section("6. Label alignment sanity checks") |
|
|
| |
| im_start_positions = [i for i, t in enumerate(input_ids) if t == im_start_id] |
| im_end_positions = [i for i, t in enumerate(input_ids) if t == im_end_id] |
|
|
| print(f" <|im_start|> positions : {im_start_positions}") |
| print(f" <|im_end|> positions : {im_end_positions}") |
|
|
| im_start_masked = all(labels[i] == -100 for i in im_start_positions) |
| print(f" All <|im_start|> tokens are masked (-100) : {'β' if im_start_masked else 'β FAIL'}") |
|
|
| |
| labeled_ids = [t for t, l in zip(input_ids, labels) if l != -100] |
| labeled_text = tokenizer.decode(labeled_ids, skip_special_tokens=False) |
| print(f"\n Labeled (assistant) text preview:") |
| print(f" {labeled_text[:300].replace(chr(10), 'β΅')!r}") |
|
|
| |
| if "user\n" in labeled_text or "system\n" in labeled_text: |
| print(" β WARNING: user/system content found in labeled tokens!") |
| else: |
| print(" β Labeled tokens contain only assistant content") |
|
|
| |
| print_section("7. Token-by-token table (first 80 tokens)") |
| print_token_table(input_ids, labels, tokenizer, max_rows=80) |
|
|
| |
| print_section("8. Full decode round-trip (skip_special_tokens=False)") |
| decoded = tokenizer.decode(input_ids, skip_special_tokens=False) |
| print(decoded[:600]) |
|
|
| print("\n" + "=" * 60) |
| print(" CHECK COMPLETE β pipeline looks aligned β") |
| print("=" * 60) |
| print(f"\nWhen ready, run the full data prep:") |
| print(f" python finetune/prepare_data.py") |
|
|
|
|
| if __name__ == "__main__": |
| main() |
|
|