File size: 4,838 Bytes
9639af0 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 | import time
import jax
import jax.numpy as jnp
from LaughLM.config.loader import load_config
from LaughLM.model.gpt import GPTModel
from LaughLM.training.scheduler import build_scheduler
from LaughLM.training.optimizer import build_optimizer
from LaughLM.training.train_step import create_train_step
from LaughLM.model.parameter_utils import estimate_parameters
from LaughLM.utils.rng import create_rng
# ------------------------------------------------------------
# TPU v5e peak FLOPs
# ------------------------------------------------------------
TPU_V5E_PEAK_FLOPS = 1.576e15
# ------------------------------------------------------------
# Benchmark
# ------------------------------------------------------------
def benchmark(config_path: str, steps: int = 200, warmup: int = 20):
print("\nLoading config...")
config = load_config(config_path)
rng = create_rng(42)
# --------------------------------------------------------
# Build model
# --------------------------------------------------------
print("Initializing model...")
model = GPTModel(config=config)
batch_size = config.runtime.micro_batch_per_device
seq_len = config.runtime.seq_len
dummy_batch = jnp.zeros(
(batch_size, seq_len),
dtype=jnp.int32,
)
params = model.init(rng.next_key(), dummy_batch)["params"]
# --------------------------------------------------------
# Build optimizer
# --------------------------------------------------------
schedule = build_scheduler(config)
optimizer = build_optimizer(config, schedule)
opt_state = optimizer.init(params)
# --------------------------------------------------------
# Build train step
# --------------------------------------------------------
train_step = create_train_step(model, optimizer)
# --------------------------------------------------------
# Warmup (compile)
# --------------------------------------------------------
print("\nCompiling train step...")
start_compile = time.time()
params, opt_state, metrics = train_step(
params,
opt_state,
dummy_batch,
)
jax.block_until_ready(params)
compile_time = time.time() - start_compile
print(f"Compilation time: {compile_time:.2f}s")
# --------------------------------------------------------
# Warmup iterations
# --------------------------------------------------------
print(f"\nRunning {warmup} warmup steps...")
for _ in range(warmup):
params, opt_state, metrics = train_step(
params,
opt_state,
dummy_batch,
)
jax.block_until_ready(params)
# --------------------------------------------------------
# Benchmark loop
# --------------------------------------------------------
print(f"\nRunning benchmark ({steps} steps)...")
start = time.time()
for _ in range(steps):
params, opt_state, metrics = train_step(
params,
opt_state,
dummy_batch,
)
jax.block_until_ready(params)
end = time.time()
total_time = end - start
step_time = total_time / steps
# --------------------------------------------------------
# Throughput
# --------------------------------------------------------
tokens_per_step = (
config.runtime.seq_len
* config.runtime.micro_batch_per_device
* config.parallelism.data_parallel
* config.runtime.gradient_accumulation
)
tokens_per_sec = tokens_per_step / step_time
# --------------------------------------------------------
# MFU calculation
# --------------------------------------------------------
params_count = estimate_parameters(config)["total_params"]
mfu = (
6 * params_count * tokens_per_sec
/ TPU_V5E_PEAK_FLOPS
) * 100
# --------------------------------------------------------
# Report
# --------------------------------------------------------
print("\nBenchmark Results")
print("ββββββββββββββββββββββββ")
print(f"Steps: {steps}")
print(f"Step time: {step_time:.4f} s")
print(f"Tokens / step: {tokens_per_step:,}")
print(f"Tokens / sec: {tokens_per_sec:,.0f}")
print(f"Total parameters: {params_count:,}")
print(f"MFU: {mfu:.2f}%")
print("ββββββββββββββββββββββββ\n")
# ------------------------------------------------------------
# Entry
# ------------------------------------------------------------
if __name__ == "__main__":
benchmark(
config_path="configs/gpu_test.yaml",
steps=200,
warmup=20,
)
|