| """Data pipeline diagnostic tools.""" |
|
|
| import time |
| from typing import Dict |
|
|
| import torch |
| from torch.utils.data import DataLoader |
|
|
| from llm_lab.config import DataConfig |
| from .tokenizer import Tokenizer |
|
|
|
|
| class DataPipelineDiagnostics: |
| """Diagnoses the performance and quality of the data pipeline. |
| |
| Items to verify before training: |
| 1) Tokenizer quality: average tokens/document, unknown token ratio |
| 2) Packing efficiency: actual token ratio vs. padding ratio |
| 3) Throughput: tokens/sec (check for data loading bottlenecks) |
| 4) Batch shape: correctness of shape and dtype |
| """ |
|
|
| @staticmethod |
| def check_tokenizer_quality( |
| tokenizer: Tokenizer, |
| config: DataConfig, |
| num_samples: int = 1000, |
| ): |
| """Diagnoses tokenizer quality.""" |
| from datasets import load_dataset |
|
|
| print("\n" + "=" * 60) |
| print("Tokenizer Quality Diagnostics") |
| print("=" * 60) |
|
|
| ds = load_dataset( |
| config.dataset_name, |
| name=config.dataset_subset, |
| split=config.dataset_split, |
| streaming=True, |
| trust_remote_code=True, |
| ) |
|
|
| token_counts = [] |
| char_counts = [] |
| sample_count = 0 |
|
|
| for example in ds: |
| if sample_count >= num_samples: |
| break |
| text = example[config.text_column] |
| if not text or not text.strip(): |
| continue |
|
|
| tokens = tokenizer.encode(text) |
| token_counts.append(len(tokens)) |
| char_counts.append(len(text)) |
| sample_count += 1 |
|
|
| avg_tokens = sum(token_counts) / len(token_counts) |
| avg_chars = sum(char_counts) / len(char_counts) |
| compression_ratio = avg_chars / avg_tokens |
|
|
| print(f" Documents analyzed: {len(token_counts):,}") |
| print(f" Average tokens/document: {avg_tokens:.1f}") |
| print(f" Average chars/document: {avg_chars:.1f}") |
| print(f" Compression ratio (chars/token): {compression_ratio:.2f}") |
| print(f" -> 3.5~4.5 is normal for English") |
| print(f" Min tokens: {min(token_counts)}, Max: {max(token_counts)}") |
|
|
| |
| test_text = "The quick brown fox jumps over the lazy dog." |
| encoded = tokenizer.encode(test_text) |
| decoded = tokenizer.decode(encoded) |
| roundtrip_ok = test_text.strip() in decoded.strip() |
| print(f"\n Round-trip test: {'PASSED' if roundtrip_ok else 'FAILED'}") |
| print(f" Original: {test_text}") |
| print(f" Encoded: {encoded[:20]}{'...' if len(encoded) > 20 else ''}") |
| print(f" Decoded: {decoded}") |
|
|
| @staticmethod |
| def benchmark_throughput( |
| dataloader: DataLoader, |
| num_batches: int = 50, |
| seq_len: int = 2048, |
| ): |
| """Measures data loading throughput. |
| |
| A key diagnostic to determine whether data loading is the bottleneck in GPU training. |
| Goal: data loading should be faster than GPU computation (data loading != bottleneck). |
| """ |
| print("\n" + "=" * 60) |
| print("Data Loading Throughput Benchmark") |
| print("=" * 60) |
|
|
| total_tokens = 0 |
| start_time = time.time() |
|
|
| for i, batch in enumerate(dataloader): |
| if i >= num_batches: |
| break |
| batch_tokens = batch["input_ids"].numel() |
| total_tokens += batch_tokens |
|
|
| if (i + 1) % 10 == 0: |
| elapsed = time.time() - start_time |
| tps = total_tokens / elapsed |
| print(f" Batch {i+1}: {tps:,.0f} tokens/sec") |
|
|
| elapsed = time.time() - start_time |
| tps = total_tokens / elapsed |
|
|
| print(f"\n Total batches: {num_batches}") |
| print(f" Total tokens: {total_tokens:,}") |
| print(f" Elapsed time: {elapsed:.2f}s") |
| print(f" Average throughput: {tps:,.0f} tokens/sec") |
| print(f"\n A100 training throughput reference ~50-80K tokens/sec:") |
| if tps > 80_000: |
| print(f" Data loading is not the bottleneck") |
| elif tps > 30_000: |
| print(f" Borderline - consider increasing num_workers") |
| else: |
| print(f" Data loading is the bottleneck! Adjust num_workers/prefetch") |
|
|
| @staticmethod |
| def inspect_batch(batch: Dict[str, torch.Tensor], tokenizer: Tokenizer): |
| """Inspects a single batch in detail.""" |
| print("\n" + "=" * 60) |
| print("Batch Detailed Inspection") |
| print("=" * 60) |
|
|
| input_ids = batch["input_ids"] |
| targets = batch["targets"] |
|
|
| print(f" input_ids shape: {input_ids.shape}") |
| print(f" targets shape: {targets.shape}") |
| print(f" dtype: {input_ids.dtype}") |
| print(f" value range: [{input_ids.min().item()}, {input_ids.max().item()}]") |
|
|
| |
| shift_correct = (input_ids[:, 1:] == targets[:, :-1]).float().mean().item() |
| print(f" Shift consistency: {shift_correct*100:.1f}% (should be 100%)") |
|
|
| |
| eos_count = (input_ids == tokenizer.eos_id).sum().item() |
| total_tokens = input_ids.numel() |
| print(f" EOS token count: {eos_count} / {total_tokens} ({eos_count/total_tokens*100:.2f}%)") |
|
|
| |
| first_sample = input_ids[0][:100].tolist() |
| decoded_preview = tokenizer.decode(first_sample) |
| print(f"\n First sample decoded (first 100 tokens):") |
| print(f" {decoded_preview[:300]}...") |
|
|