"""Data pipeline diagnostic tools.""" import time from typing import Dict import torch from torch.utils.data import DataLoader from llm_lab.config import DataConfig from .tokenizer import Tokenizer class DataPipelineDiagnostics: """Diagnoses the performance and quality of the data pipeline. Items to verify before training: 1) Tokenizer quality: average tokens/document, unknown token ratio 2) Packing efficiency: actual token ratio vs. padding ratio 3) Throughput: tokens/sec (check for data loading bottlenecks) 4) Batch shape: correctness of shape and dtype """ @staticmethod def check_tokenizer_quality( tokenizer: Tokenizer, config: DataConfig, num_samples: int = 1000, ): """Diagnoses tokenizer quality.""" from datasets import load_dataset print("\n" + "=" * 60) print("Tokenizer Quality Diagnostics") print("=" * 60) ds = load_dataset( config.dataset_name, name=config.dataset_subset, split=config.dataset_split, streaming=True, trust_remote_code=True, ) token_counts = [] char_counts = [] sample_count = 0 for example in ds: if sample_count >= num_samples: break text = example[config.text_column] if not text or not text.strip(): continue tokens = tokenizer.encode(text) token_counts.append(len(tokens)) char_counts.append(len(text)) sample_count += 1 avg_tokens = sum(token_counts) / len(token_counts) avg_chars = sum(char_counts) / len(char_counts) compression_ratio = avg_chars / avg_tokens # Characters per token ratio print(f" Documents analyzed: {len(token_counts):,}") print(f" Average tokens/document: {avg_tokens:.1f}") print(f" Average chars/document: {avg_chars:.1f}") print(f" Compression ratio (chars/token): {compression_ratio:.2f}") print(f" -> 3.5~4.5 is normal for English") print(f" Min tokens: {min(token_counts)}, Max: {max(token_counts)}") # Round-trip decode test test_text = "The quick brown fox jumps over the lazy dog." encoded = tokenizer.encode(test_text) decoded = tokenizer.decode(encoded) roundtrip_ok = test_text.strip() in decoded.strip() print(f"\n Round-trip test: {'PASSED' if roundtrip_ok else 'FAILED'}") print(f" Original: {test_text}") print(f" Encoded: {encoded[:20]}{'...' if len(encoded) > 20 else ''}") print(f" Decoded: {decoded}") @staticmethod def benchmark_throughput( dataloader: DataLoader, num_batches: int = 50, seq_len: int = 2048, ): """Measures data loading throughput. A key diagnostic to determine whether data loading is the bottleneck in GPU training. Goal: data loading should be faster than GPU computation (data loading != bottleneck). """ print("\n" + "=" * 60) print("Data Loading Throughput Benchmark") print("=" * 60) total_tokens = 0 start_time = time.time() for i, batch in enumerate(dataloader): if i >= num_batches: break batch_tokens = batch["input_ids"].numel() total_tokens += batch_tokens if (i + 1) % 10 == 0: elapsed = time.time() - start_time tps = total_tokens / elapsed print(f" Batch {i+1}: {tps:,.0f} tokens/sec") elapsed = time.time() - start_time tps = total_tokens / elapsed print(f"\n Total batches: {num_batches}") print(f" Total tokens: {total_tokens:,}") print(f" Elapsed time: {elapsed:.2f}s") print(f" Average throughput: {tps:,.0f} tokens/sec") print(f"\n A100 training throughput reference ~50-80K tokens/sec:") if tps > 80_000: print(f" Data loading is not the bottleneck") elif tps > 30_000: print(f" Borderline - consider increasing num_workers") else: print(f" Data loading is the bottleneck! Adjust num_workers/prefetch") @staticmethod def inspect_batch(batch: Dict[str, torch.Tensor], tokenizer: Tokenizer): """Inspects a single batch in detail.""" print("\n" + "=" * 60) print("Batch Detailed Inspection") print("=" * 60) input_ids = batch["input_ids"] targets = batch["targets"] print(f" input_ids shape: {input_ids.shape}") print(f" targets shape: {targets.shape}") print(f" dtype: {input_ids.dtype}") print(f" value range: [{input_ids.min().item()}, {input_ids.max().item()}]") # Verify shift relationship: targets[i] == input_ids[i+1] shift_correct = (input_ids[:, 1:] == targets[:, :-1]).float().mean().item() print(f" Shift consistency: {shift_correct*100:.1f}% (should be 100%)") # EOS token distribution (document boundaries) eos_count = (input_ids == tokenizer.eos_id).sum().item() total_tokens = input_ids.numel() print(f" EOS token count: {eos_count} / {total_tokens} ({eos_count/total_tokens*100:.2f}%)") # Decode preview of the first sample first_sample = input_ids[0][:100].tolist() decoded_preview = tokenizer.decode(first_sample) print(f"\n First sample decoded (first 100 tokens):") print(f" {decoded_preview[:300]}...")