Tomoqt commited on
Commit
5fb8a28
·
verified ·
1 Parent(s): fd88777

Upload training/benchmark_throughput.py with huggingface_hub

Browse files
Files changed (1) hide show
  1. training/benchmark_throughput.py +276 -0
training/benchmark_throughput.py ADDED
@@ -0,0 +1,276 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python3
2
+ """
3
+ Measure steady-state training throughput (non-pad target tokens/sec).
4
+
5
+ Supports single process and DDP launch via torchrun.
6
+ """
7
+
8
+ import argparse
9
+ import json
10
+ import os
11
+ import sys
12
+ import time
13
+ from pathlib import Path
14
+ from typing import Dict, Optional
15
+
16
+ import torch
17
+ import torch.distributed as dist
18
+ from torch.nn.parallel import DistributedDataParallel as DDP
19
+
20
+ sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
21
+
22
+ from models.smiles_tokenizer import SmilesTokenizer
23
+ from training.core_train import compute_next_token_loss
24
+ from training.train_autoregressive import (
25
+ DistContext,
26
+ _autocast_context,
27
+ _cleanup_dist,
28
+ _infer_dist_context,
29
+ _reduce_pair,
30
+ _seed_everything,
31
+ _set_perf_flags,
32
+ _build_model,
33
+ create_loaders,
34
+ load_config,
35
+ load_nmr_tokenizer,
36
+ )
37
+
38
+
39
+ def _parse_args() -> argparse.Namespace:
40
+ parser = argparse.ArgumentParser(description="Throughput micro-benchmark for autoregressive pretraining")
41
+ parser.add_argument("--config", type=str, required=True, help="YAML config path")
42
+ parser.add_argument("--batch-size", type=int, required=True, help="Per-rank train batch size override")
43
+ parser.add_argument("--steps", type=int, default=80, help="Measured optimizer steps")
44
+ parser.add_argument("--warmup-steps", type=int, default=20, help="Warmup steps before timing")
45
+ parser.add_argument("--device", type=str, default=None, help="Override config training.device")
46
+ parser.add_argument("--num-workers", type=int, default=None, help="Override config training.num_workers")
47
+ parser.add_argument(
48
+ "--disable-compile",
49
+ action="store_true",
50
+ help="Disable torch.compile even if enabled in config",
51
+ )
52
+ return parser.parse_args()
53
+
54
+
55
+ def _max_reduce(value: float, device: torch.device, dist_ctx: DistContext) -> float:
56
+ tensor = torch.tensor([value], device=device, dtype=torch.float32)
57
+ if dist_ctx.enabled:
58
+ dist.all_reduce(tensor, op=dist.ReduceOp.MAX)
59
+ return float(tensor.item())
60
+
61
+
62
+ def _sync_if_needed(device: torch.device) -> None:
63
+ if device.type == "cuda":
64
+ torch.cuda.synchronize(device=device)
65
+
66
+
67
+ def _maybe_enable_compile(model: torch.nn.Module, training_cfg: Dict, disable_compile: bool) -> torch.nn.Module:
68
+ compile_enabled = bool(training_cfg.get("compile", False)) and not disable_compile
69
+ if not compile_enabled:
70
+ return model
71
+ if not hasattr(torch, "compile"):
72
+ return model
73
+ return torch.compile(
74
+ model,
75
+ mode=str(training_cfg.get("compile_mode", "max-autotune")),
76
+ dynamic=bool(training_cfg.get("compile_dynamic", False)),
77
+ fullgraph=bool(training_cfg.get("compile_fullgraph", False)),
78
+ )
79
+
80
+
81
+ def main() -> int:
82
+ args = _parse_args()
83
+ cfg = load_config(args.config)
84
+ training_cfg = cfg["training"]
85
+
86
+ training_cfg["batch_size"] = int(args.batch_size)
87
+ training_cfg["test_batch_size"] = int(args.batch_size)
88
+ training_cfg["num_epochs"] = 1
89
+ training_cfg["log_every_steps"] = 0
90
+ training_cfg["drop_last"] = True
91
+
92
+ if args.device is not None:
93
+ training_cfg["device"] = str(args.device).lower()
94
+ if args.num_workers is not None:
95
+ training_cfg["num_workers"] = int(args.num_workers)
96
+
97
+ requested_device_name = str(training_cfg.get("device", "cpu")).lower()
98
+ dist_ctx = DistContext(enabled=False)
99
+
100
+ try:
101
+ dist_ctx = _infer_dist_context(training_cfg, requested_device_name)
102
+
103
+ if requested_device_name == "cuda":
104
+ cuda_index = dist_ctx.local_rank if dist_ctx.enabled else 0
105
+ device = torch.device(f"cuda:{cuda_index}")
106
+ elif requested_device_name == "cpu":
107
+ device = torch.device("cpu")
108
+ elif requested_device_name == "mps":
109
+ if dist_ctx.enabled:
110
+ raise RuntimeError("MPS DDP is unsupported for this benchmark. Use CUDA + NCCL for multi-GPU.")
111
+ device = torch.device("mps")
112
+ else:
113
+ raise ValueError(f"Unsupported device '{requested_device_name}'.")
114
+
115
+ _set_perf_flags(training_cfg)
116
+ base_seed = int(training_cfg.get("seed", 1337))
117
+ _seed_everything(base_seed + dist_ctx.rank)
118
+
119
+ tokenized_dir = Path(cfg["data"]["tokenized_dir"])
120
+ if not tokenized_dir.exists():
121
+ raise FileNotFoundError(f"Tokenized directory not found: {tokenized_dir}")
122
+
123
+ smiles_tokenizer = SmilesTokenizer(vocab_file=str(Path(__file__).with_name("vocab.txt")))
124
+ nmr_tokenizer = load_nmr_tokenizer(tokenized_dir)
125
+
126
+ train_loader, _, _ = create_loaders(
127
+ tokenized_dir,
128
+ smiles_tokenizer,
129
+ nmr_tokenizer,
130
+ cfg,
131
+ dist_ctx=dist_ctx,
132
+ device=device,
133
+ )
134
+
135
+ model = _build_model(cfg, smiles_tokenizer, nmr_tokenizer, device)
136
+ model = _maybe_enable_compile(model, training_cfg, disable_compile=bool(args.disable_compile))
137
+
138
+ if dist_ctx.enabled:
139
+ model = DDP(
140
+ model,
141
+ device_ids=[dist_ctx.local_rank] if device.type == "cuda" else None,
142
+ output_device=dist_ctx.local_rank if device.type == "cuda" else None,
143
+ find_unused_parameters=bool(training_cfg.get("ddp_find_unused_parameters", False)),
144
+ gradient_as_bucket_view=bool(training_cfg.get("ddp_gradient_as_bucket_view", True)),
145
+ static_graph=bool(training_cfg.get("ddp_static_graph", True)),
146
+ )
147
+
148
+ optimizer = torch.optim.AdamW(
149
+ model.parameters(),
150
+ lr=float(training_cfg["learning_rate"]),
151
+ weight_decay=float(training_cfg.get("weight_decay", 0.01)),
152
+ )
153
+
154
+ precision = str(training_cfg.get("precision", "fp32")).lower()
155
+ if precision not in {"fp32", "bf16", "fp16"}:
156
+ raise ValueError(f"Unsupported precision '{precision}'. Use one of: fp32, bf16, fp16.")
157
+
158
+ use_grad_scaler = device.type == "cuda" and precision == "fp16"
159
+ scaler = torch.cuda.amp.GradScaler(enabled=use_grad_scaler)
160
+
161
+ pad_token_id = smiles_tokenizer.pad_token_id
162
+ non_blocking = device.type == "cuda"
163
+ warmup_steps = max(0, int(args.warmup_steps))
164
+ measured_steps = max(1, int(args.steps))
165
+ total_steps = warmup_steps + measured_steps
166
+
167
+ iterator = iter(train_loader)
168
+ local_measured_tokens = 0.0
169
+ local_loss_sum = 0.0
170
+ timed_start = None
171
+ model.train()
172
+ optimizer.zero_grad(set_to_none=True)
173
+
174
+ oom_happened = False
175
+
176
+ for step_idx in range(total_steps):
177
+ try:
178
+ batch = next(iterator)
179
+ except StopIteration:
180
+ iterator = iter(train_loader)
181
+ batch = next(iterator)
182
+
183
+ target_tokens, ir_data, nmr_tokens = batch
184
+ target_tokens = target_tokens.to(device, non_blocking=non_blocking)
185
+ nmr_tokens = nmr_tokens.to(device, non_blocking=non_blocking)
186
+ if ir_data is not None:
187
+ ir_data = ir_data.to(device, non_blocking=non_blocking)
188
+
189
+ if step_idx == warmup_steps:
190
+ _sync_if_needed(device)
191
+ timed_start = time.perf_counter()
192
+
193
+ try:
194
+ with _autocast_context(device, precision):
195
+ logits = model(
196
+ nmr_tokens=nmr_tokens,
197
+ ir_data=ir_data,
198
+ target_seq=target_tokens[:, :-1],
199
+ )
200
+ loss = compute_next_token_loss(logits, target_tokens, pad_token_id)
201
+
202
+ if scaler.is_enabled():
203
+ scaler.scale(loss).backward()
204
+ scaler.step(optimizer)
205
+ scaler.update()
206
+ else:
207
+ loss.backward()
208
+ optimizer.step()
209
+ optimizer.zero_grad(set_to_none=True)
210
+ except RuntimeError as exc:
211
+ if "out of memory" in str(exc).lower():
212
+ oom_happened = True
213
+ if device.type == "cuda":
214
+ torch.cuda.empty_cache()
215
+ break
216
+ raise
217
+
218
+ if step_idx >= warmup_steps:
219
+ local_loss_sum += float(loss.item())
220
+ local_measured_tokens += float((target_tokens[:, 1:] != pad_token_id).sum().item())
221
+
222
+ local_oom = 1.0 if oom_happened else 0.0
223
+ global_oom = _max_reduce(local_oom, device, dist_ctx)
224
+ global_batch = int(args.batch_size) * dist_ctx.world_size
225
+
226
+ if global_oom > 0:
227
+ if dist_ctx.is_main:
228
+ print(
229
+ f"[result] status=oom batch_size={int(args.batch_size)} "
230
+ f"world_size={dist_ctx.world_size} global_batch={global_batch}"
231
+ )
232
+ return 3
233
+
234
+ _sync_if_needed(device)
235
+ if timed_start is None:
236
+ timed_start = time.perf_counter()
237
+ elapsed = max(time.perf_counter() - timed_start, 1e-9)
238
+ elapsed_max = _max_reduce(elapsed, device, dist_ctx)
239
+
240
+ global_tokens, _ = _reduce_pair(local_measured_tokens, 0.0, device, dist_ctx)
241
+ global_loss_sum, global_count = _reduce_pair(local_loss_sum, float(measured_steps), device, dist_ctx)
242
+ tok_s = global_tokens / elapsed_max
243
+ mean_loss = global_loss_sum / max(global_count, 1.0)
244
+
245
+ max_mem_gib = 0.0
246
+ if device.type == "cuda":
247
+ max_mem_gib = torch.cuda.max_memory_allocated(device=device) / (1024**3)
248
+ max_mem_gib = _max_reduce(max_mem_gib, device, dist_ctx)
249
+
250
+ if dist_ctx.is_main:
251
+ payload = {
252
+ "status": "ok",
253
+ "batch_size": int(args.batch_size),
254
+ "world_size": dist_ctx.world_size,
255
+ "global_batch": global_batch,
256
+ "tok_s": round(tok_s, 2),
257
+ "elapsed_s": round(elapsed_max, 3),
258
+ "measured_steps": measured_steps,
259
+ "mean_loss": round(mean_loss, 6),
260
+ "max_mem_gib": round(max_mem_gib, 3),
261
+ }
262
+ print(
263
+ "[result] "
264
+ f"status=ok batch_size={payload['batch_size']} world_size={payload['world_size']} "
265
+ f"global_batch={payload['global_batch']} tok_s={payload['tok_s']:.2f} "
266
+ f"elapsed_s={payload['elapsed_s']:.3f} measured_steps={payload['measured_steps']} "
267
+ f"mean_loss={payload['mean_loss']:.6f} max_mem_gib={payload['max_mem_gib']:.3f}"
268
+ )
269
+ print("[result_json] " + json.dumps(payload, sort_keys=True))
270
+ return 0
271
+ finally:
272
+ _cleanup_dist(dist_ctx)
273
+
274
+
275
+ if __name__ == "__main__":
276
+ raise SystemExit(main())