File size: 5,645 Bytes
858e8b2 8a58ffe 858e8b2 8a58ffe 858e8b2 8a58ffe 858e8b2 8a58ffe 858e8b2 8a58ffe 858e8b2 8a58ffe 858e8b2 8a58ffe 858e8b2 8a58ffe 858e8b2 8a58ffe 858e8b2 8a58ffe 858e8b2 8a58ffe 858e8b2 8a58ffe 858e8b2 8a58ffe 858e8b2 8a58ffe 858e8b2 8a58ffe 858e8b2 8a58ffe 858e8b2 8a58ffe 858e8b2 8a58ffe 858e8b2 8a58ffe 858e8b2 8a58ffe 858e8b2 8a58ffe 858e8b2 8a58ffe 858e8b2 8a58ffe 858e8b2 8a58ffe 858e8b2 8a58ffe | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 | """Data pipeline diagnostic tools."""
import time
from typing import Dict
import torch
from torch.utils.data import DataLoader
from llm_lab.config import DataConfig
from .tokenizer import Tokenizer
class DataPipelineDiagnostics:
"""Diagnoses the performance and quality of the data pipeline.
Items to verify before training:
1) Tokenizer quality: average tokens/document, unknown token ratio
2) Packing efficiency: actual token ratio vs. padding ratio
3) Throughput: tokens/sec (check for data loading bottlenecks)
4) Batch shape: correctness of shape and dtype
"""
@staticmethod
def check_tokenizer_quality(
tokenizer: Tokenizer,
config: DataConfig,
num_samples: int = 1000,
):
"""Diagnoses tokenizer quality."""
from datasets import load_dataset
print("\n" + "=" * 60)
print("Tokenizer Quality Diagnostics")
print("=" * 60)
ds = load_dataset(
config.dataset_name,
name=config.dataset_subset,
split=config.dataset_split,
streaming=True,
trust_remote_code=True,
)
token_counts = []
char_counts = []
sample_count = 0
for example in ds:
if sample_count >= num_samples:
break
text = example[config.text_column]
if not text or not text.strip():
continue
tokens = tokenizer.encode(text)
token_counts.append(len(tokens))
char_counts.append(len(text))
sample_count += 1
avg_tokens = sum(token_counts) / len(token_counts)
avg_chars = sum(char_counts) / len(char_counts)
compression_ratio = avg_chars / avg_tokens # Characters per token ratio
print(f" Documents analyzed: {len(token_counts):,}")
print(f" Average tokens/document: {avg_tokens:.1f}")
print(f" Average chars/document: {avg_chars:.1f}")
print(f" Compression ratio (chars/token): {compression_ratio:.2f}")
print(f" -> 3.5~4.5 is normal for English")
print(f" Min tokens: {min(token_counts)}, Max: {max(token_counts)}")
# Round-trip decode test
test_text = "The quick brown fox jumps over the lazy dog."
encoded = tokenizer.encode(test_text)
decoded = tokenizer.decode(encoded)
roundtrip_ok = test_text.strip() in decoded.strip()
print(f"\n Round-trip test: {'PASSED' if roundtrip_ok else 'FAILED'}")
print(f" Original: {test_text}")
print(f" Encoded: {encoded[:20]}{'...' if len(encoded) > 20 else ''}")
print(f" Decoded: {decoded}")
@staticmethod
def benchmark_throughput(
dataloader: DataLoader,
num_batches: int = 50,
seq_len: int = 2048,
):
"""Measures data loading throughput.
A key diagnostic to determine whether data loading is the bottleneck in GPU training.
Goal: data loading should be faster than GPU computation (data loading != bottleneck).
"""
print("\n" + "=" * 60)
print("Data Loading Throughput Benchmark")
print("=" * 60)
total_tokens = 0
start_time = time.time()
for i, batch in enumerate(dataloader):
if i >= num_batches:
break
batch_tokens = batch["input_ids"].numel()
total_tokens += batch_tokens
if (i + 1) % 10 == 0:
elapsed = time.time() - start_time
tps = total_tokens / elapsed
print(f" Batch {i+1}: {tps:,.0f} tokens/sec")
elapsed = time.time() - start_time
tps = total_tokens / elapsed
print(f"\n Total batches: {num_batches}")
print(f" Total tokens: {total_tokens:,}")
print(f" Elapsed time: {elapsed:.2f}s")
print(f" Average throughput: {tps:,.0f} tokens/sec")
print(f"\n A100 training throughput reference ~50-80K tokens/sec:")
if tps > 80_000:
print(f" Data loading is not the bottleneck")
elif tps > 30_000:
print(f" Borderline - consider increasing num_workers")
else:
print(f" Data loading is the bottleneck! Adjust num_workers/prefetch")
@staticmethod
def inspect_batch(batch: Dict[str, torch.Tensor], tokenizer: Tokenizer):
"""Inspects a single batch in detail."""
print("\n" + "=" * 60)
print("Batch Detailed Inspection")
print("=" * 60)
input_ids = batch["input_ids"]
targets = batch["targets"]
print(f" input_ids shape: {input_ids.shape}")
print(f" targets shape: {targets.shape}")
print(f" dtype: {input_ids.dtype}")
print(f" value range: [{input_ids.min().item()}, {input_ids.max().item()}]")
# Verify shift relationship: targets[i] == input_ids[i+1]
shift_correct = (input_ids[:, 1:] == targets[:, :-1]).float().mean().item()
print(f" Shift consistency: {shift_correct*100:.1f}% (should be 100%)")
# EOS token distribution (document boundaries)
eos_count = (input_ids == tokenizer.eos_id).sum().item()
total_tokens = input_ids.numel()
print(f" EOS token count: {eos_count} / {total_tokens} ({eos_count/total_tokens*100:.2f}%)")
# Decode preview of the first sample
first_sample = input_ids[0][:100].tolist()
decoded_preview = tokenizer.decode(first_sample)
print(f"\n First sample decoded (first 100 tokens):")
print(f" {decoded_preview[:300]}...")
|