LLM-1B-Lab / llm_lab /data /diagnostics.py
Vjeong's picture
docs: translate all Korean comments and docstrings to English
858e8b2
"""Data pipeline diagnostic tools."""
import time
from typing import Dict
import torch
from torch.utils.data import DataLoader
from llm_lab.config import DataConfig
from .tokenizer import Tokenizer
class DataPipelineDiagnostics:
"""Diagnoses the performance and quality of the data pipeline.
Items to verify before training:
1) Tokenizer quality: average tokens/document, unknown token ratio
2) Packing efficiency: actual token ratio vs. padding ratio
3) Throughput: tokens/sec (check for data loading bottlenecks)
4) Batch shape: correctness of shape and dtype
"""
@staticmethod
def check_tokenizer_quality(
tokenizer: Tokenizer,
config: DataConfig,
num_samples: int = 1000,
):
"""Diagnoses tokenizer quality."""
from datasets import load_dataset
print("\n" + "=" * 60)
print("Tokenizer Quality Diagnostics")
print("=" * 60)
ds = load_dataset(
config.dataset_name,
name=config.dataset_subset,
split=config.dataset_split,
streaming=True,
trust_remote_code=True,
)
token_counts = []
char_counts = []
sample_count = 0
for example in ds:
if sample_count >= num_samples:
break
text = example[config.text_column]
if not text or not text.strip():
continue
tokens = tokenizer.encode(text)
token_counts.append(len(tokens))
char_counts.append(len(text))
sample_count += 1
avg_tokens = sum(token_counts) / len(token_counts)
avg_chars = sum(char_counts) / len(char_counts)
compression_ratio = avg_chars / avg_tokens # Characters per token ratio
print(f" Documents analyzed: {len(token_counts):,}")
print(f" Average tokens/document: {avg_tokens:.1f}")
print(f" Average chars/document: {avg_chars:.1f}")
print(f" Compression ratio (chars/token): {compression_ratio:.2f}")
print(f" -> 3.5~4.5 is normal for English")
print(f" Min tokens: {min(token_counts)}, Max: {max(token_counts)}")
# Round-trip decode test
test_text = "The quick brown fox jumps over the lazy dog."
encoded = tokenizer.encode(test_text)
decoded = tokenizer.decode(encoded)
roundtrip_ok = test_text.strip() in decoded.strip()
print(f"\n Round-trip test: {'PASSED' if roundtrip_ok else 'FAILED'}")
print(f" Original: {test_text}")
print(f" Encoded: {encoded[:20]}{'...' if len(encoded) > 20 else ''}")
print(f" Decoded: {decoded}")
@staticmethod
def benchmark_throughput(
dataloader: DataLoader,
num_batches: int = 50,
seq_len: int = 2048,
):
"""Measures data loading throughput.
A key diagnostic to determine whether data loading is the bottleneck in GPU training.
Goal: data loading should be faster than GPU computation (data loading != bottleneck).
"""
print("\n" + "=" * 60)
print("Data Loading Throughput Benchmark")
print("=" * 60)
total_tokens = 0
start_time = time.time()
for i, batch in enumerate(dataloader):
if i >= num_batches:
break
batch_tokens = batch["input_ids"].numel()
total_tokens += batch_tokens
if (i + 1) % 10 == 0:
elapsed = time.time() - start_time
tps = total_tokens / elapsed
print(f" Batch {i+1}: {tps:,.0f} tokens/sec")
elapsed = time.time() - start_time
tps = total_tokens / elapsed
print(f"\n Total batches: {num_batches}")
print(f" Total tokens: {total_tokens:,}")
print(f" Elapsed time: {elapsed:.2f}s")
print(f" Average throughput: {tps:,.0f} tokens/sec")
print(f"\n A100 training throughput reference ~50-80K tokens/sec:")
if tps > 80_000:
print(f" Data loading is not the bottleneck")
elif tps > 30_000:
print(f" Borderline - consider increasing num_workers")
else:
print(f" Data loading is the bottleneck! Adjust num_workers/prefetch")
@staticmethod
def inspect_batch(batch: Dict[str, torch.Tensor], tokenizer: Tokenizer):
"""Inspects a single batch in detail."""
print("\n" + "=" * 60)
print("Batch Detailed Inspection")
print("=" * 60)
input_ids = batch["input_ids"]
targets = batch["targets"]
print(f" input_ids shape: {input_ids.shape}")
print(f" targets shape: {targets.shape}")
print(f" dtype: {input_ids.dtype}")
print(f" value range: [{input_ids.min().item()}, {input_ids.max().item()}]")
# Verify shift relationship: targets[i] == input_ids[i+1]
shift_correct = (input_ids[:, 1:] == targets[:, :-1]).float().mean().item()
print(f" Shift consistency: {shift_correct*100:.1f}% (should be 100%)")
# EOS token distribution (document boundaries)
eos_count = (input_ids == tokenizer.eos_id).sum().item()
total_tokens = input_ids.numel()
print(f" EOS token count: {eos_count} / {total_tokens} ({eos_count/total_tokens*100:.2f}%)")
# Decode preview of the first sample
first_sample = input_ids[0][:100].tolist()
decoded_preview = tokenizer.decode(first_sample)
print(f"\n First sample decoded (first 100 tokens):")
print(f" {decoded_preview[:300]}...")