File size: 5,645 Bytes
858e8b2
8a58ffe
 
 
 
 
 
 
 
 
 
 
 
858e8b2
8a58ffe
858e8b2
 
 
 
 
8a58ffe
 
 
 
 
 
 
 
858e8b2
8a58ffe
 
 
858e8b2
8a58ffe
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
858e8b2
8a58ffe
858e8b2
 
 
 
 
 
8a58ffe
858e8b2
8a58ffe
 
 
 
858e8b2
 
 
 
8a58ffe
 
 
 
 
 
 
858e8b2
8a58ffe
858e8b2
 
8a58ffe
 
858e8b2
8a58ffe
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
858e8b2
 
 
 
 
8a58ffe
858e8b2
8a58ffe
858e8b2
8a58ffe
858e8b2
8a58ffe
 
 
858e8b2
8a58ffe
858e8b2
8a58ffe
 
 
 
 
 
 
 
858e8b2
8a58ffe
858e8b2
8a58ffe
858e8b2
8a58ffe
858e8b2
8a58ffe
 
858e8b2
8a58ffe
858e8b2
8a58ffe
 
858e8b2
8a58ffe
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
"""Data pipeline diagnostic tools."""

import time
from typing import Dict

import torch
from torch.utils.data import DataLoader

from llm_lab.config import DataConfig
from .tokenizer import Tokenizer


class DataPipelineDiagnostics:
    """Diagnoses the performance and quality of the data pipeline.

    Items to verify before training:
      1) Tokenizer quality: average tokens/document, unknown token ratio
      2) Packing efficiency: actual token ratio vs. padding ratio
      3) Throughput: tokens/sec (check for data loading bottlenecks)
      4) Batch shape: correctness of shape and dtype
    """

    @staticmethod
    def check_tokenizer_quality(
        tokenizer: Tokenizer,
        config: DataConfig,
        num_samples: int = 1000,
    ):
        """Diagnoses tokenizer quality."""
        from datasets import load_dataset

        print("\n" + "=" * 60)
        print("Tokenizer Quality Diagnostics")
        print("=" * 60)

        ds = load_dataset(
            config.dataset_name,
            name=config.dataset_subset,
            split=config.dataset_split,
            streaming=True,
            trust_remote_code=True,
        )

        token_counts = []
        char_counts = []
        sample_count = 0

        for example in ds:
            if sample_count >= num_samples:
                break
            text = example[config.text_column]
            if not text or not text.strip():
                continue

            tokens = tokenizer.encode(text)
            token_counts.append(len(tokens))
            char_counts.append(len(text))
            sample_count += 1

        avg_tokens = sum(token_counts) / len(token_counts)
        avg_chars = sum(char_counts) / len(char_counts)
        compression_ratio = avg_chars / avg_tokens  # Characters per token ratio

        print(f"  Documents analyzed: {len(token_counts):,}")
        print(f"  Average tokens/document: {avg_tokens:.1f}")
        print(f"  Average chars/document: {avg_chars:.1f}")
        print(f"  Compression ratio (chars/token): {compression_ratio:.2f}")
        print(f"    -> 3.5~4.5 is normal for English")
        print(f"  Min tokens: {min(token_counts)}, Max: {max(token_counts)}")

        # Round-trip decode test
        test_text = "The quick brown fox jumps over the lazy dog."
        encoded = tokenizer.encode(test_text)
        decoded = tokenizer.decode(encoded)
        roundtrip_ok = test_text.strip() in decoded.strip()
        print(f"\n  Round-trip test: {'PASSED' if roundtrip_ok else 'FAILED'}")
        print(f"    Original:  {test_text}")
        print(f"    Encoded: {encoded[:20]}{'...' if len(encoded) > 20 else ''}")
        print(f"    Decoded: {decoded}")

    @staticmethod
    def benchmark_throughput(
        dataloader: DataLoader,
        num_batches: int = 50,
        seq_len: int = 2048,
    ):
        """Measures data loading throughput.

        A key diagnostic to determine whether data loading is the bottleneck in GPU training.
        Goal: data loading should be faster than GPU computation (data loading != bottleneck).
        """
        print("\n" + "=" * 60)
        print("Data Loading Throughput Benchmark")
        print("=" * 60)

        total_tokens = 0
        start_time = time.time()

        for i, batch in enumerate(dataloader):
            if i >= num_batches:
                break
            batch_tokens = batch["input_ids"].numel()
            total_tokens += batch_tokens

            if (i + 1) % 10 == 0:
                elapsed = time.time() - start_time
                tps = total_tokens / elapsed
                print(f"  Batch {i+1}: {tps:,.0f} tokens/sec")

        elapsed = time.time() - start_time
        tps = total_tokens / elapsed

        print(f"\n  Total batches: {num_batches}")
        print(f"  Total tokens: {total_tokens:,}")
        print(f"  Elapsed time: {elapsed:.2f}s")
        print(f"  Average throughput: {tps:,.0f} tokens/sec")
        print(f"\n  A100 training throughput reference ~50-80K tokens/sec:")
        if tps > 80_000:
            print(f"     Data loading is not the bottleneck")
        elif tps > 30_000:
            print(f"     Borderline - consider increasing num_workers")
        else:
            print(f"     Data loading is the bottleneck! Adjust num_workers/prefetch")

    @staticmethod
    def inspect_batch(batch: Dict[str, torch.Tensor], tokenizer: Tokenizer):
        """Inspects a single batch in detail."""
        print("\n" + "=" * 60)
        print("Batch Detailed Inspection")
        print("=" * 60)

        input_ids = batch["input_ids"]
        targets = batch["targets"]

        print(f"  input_ids shape: {input_ids.shape}")
        print(f"  targets shape:   {targets.shape}")
        print(f"  dtype:           {input_ids.dtype}")
        print(f"  value range:     [{input_ids.min().item()}, {input_ids.max().item()}]")

        # Verify shift relationship: targets[i] == input_ids[i+1]
        shift_correct = (input_ids[:, 1:] == targets[:, :-1]).float().mean().item()
        print(f"  Shift consistency: {shift_correct*100:.1f}% (should be 100%)")

        # EOS token distribution (document boundaries)
        eos_count = (input_ids == tokenizer.eos_id).sum().item()
        total_tokens = input_ids.numel()
        print(f"  EOS token count: {eos_count} / {total_tokens} ({eos_count/total_tokens*100:.2f}%)")

        # Decode preview of the first sample
        first_sample = input_ids[0][:100].tolist()
        decoded_preview = tokenizer.decode(first_sample)
        print(f"\n  First sample decoded (first 100 tokens):")
        print(f"  {decoded_preview[:300]}...")