Tomoqt commited on
Commit
fd88777
·
verified ·
1 Parent(s): 267a5b4

Upload scripts/sweep_batch_sizes.py with huggingface_hub

Browse files
Files changed (1) hide show
  1. scripts/sweep_batch_sizes.py +264 -0
scripts/sweep_batch_sizes.py ADDED
@@ -0,0 +1,264 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python3
2
+ """
3
+ Sweep per-GPU batch sizes and report throughput + 50-epoch ETA.
4
+ """
5
+
6
+ import argparse
7
+ import json
8
+ import re
9
+ import subprocess
10
+ import sys
11
+ from dataclasses import dataclass, asdict
12
+ from pathlib import Path
13
+ from typing import List, Optional
14
+
15
+
16
+ RESULT_RE = re.compile(
17
+ r"\[result\]\s+status=(?P<status>\w+)\s+batch_size=(?P<batch_size>\d+)\s+"
18
+ r"world_size=(?P<world_size>\d+)\s+global_batch=(?P<global_batch>\d+)"
19
+ r"(?:\s+tok_s=(?P<tok_s>[0-9.]+))?"
20
+ r"(?:\s+elapsed_s=(?P<elapsed_s>[0-9.]+))?"
21
+ r"(?:\s+measured_steps=(?P<measured_steps>\d+))?"
22
+ r"(?:\s+mean_loss=(?P<mean_loss>[0-9.]+))?"
23
+ r"(?:\s+max_mem_gib=(?P<max_mem_gib>[0-9.]+))?"
24
+ )
25
+
26
+
27
+ @dataclass
28
+ class SweepResult:
29
+ status: str
30
+ batch_size: int
31
+ world_size: int
32
+ global_batch: int
33
+ tok_s: float = 0.0
34
+ elapsed_s: float = 0.0
35
+ measured_steps: int = 0
36
+ mean_loss: float = 0.0
37
+ max_mem_gib: float = 0.0
38
+ returncode: int = 0
39
+ stderr_tail: str = ""
40
+
41
+
42
+ def _parse_batch_sizes(text: str) -> List[int]:
43
+ values = []
44
+ for part in text.split(","):
45
+ p = part.strip()
46
+ if not p:
47
+ continue
48
+ values.append(int(p))
49
+ if not values:
50
+ raise ValueError("No batch sizes were provided.")
51
+ return values
52
+
53
+
54
+ def _parse_result(stdout: str, returncode: int, batch_size: int, nproc: int, stderr: str) -> SweepResult:
55
+ matches = RESULT_RE.findall(stdout)
56
+ if not matches:
57
+ tail = "\n".join((stderr or "").strip().splitlines()[-8:])
58
+ return SweepResult(
59
+ status="error",
60
+ batch_size=batch_size,
61
+ world_size=nproc,
62
+ global_batch=batch_size * nproc,
63
+ returncode=returncode,
64
+ stderr_tail=tail,
65
+ )
66
+
67
+ groups = RESULT_RE.search([m.group(0) for m in RESULT_RE.finditer(stdout)][-1])
68
+ assert groups is not None
69
+ d = groups.groupdict()
70
+
71
+ return SweepResult(
72
+ status=d["status"],
73
+ batch_size=int(d["batch_size"]),
74
+ world_size=int(d["world_size"]),
75
+ global_batch=int(d["global_batch"]),
76
+ tok_s=float(d["tok_s"] or 0.0),
77
+ elapsed_s=float(d["elapsed_s"] or 0.0),
78
+ measured_steps=int(d["measured_steps"] or 0),
79
+ mean_loss=float(d["mean_loss"] or 0.0),
80
+ max_mem_gib=float(d["max_mem_gib"] or 0.0),
81
+ returncode=returncode,
82
+ stderr_tail="\n".join((stderr or "").strip().splitlines()[-8:]),
83
+ )
84
+
85
+
86
+ def _run_once(
87
+ config: str,
88
+ batch_size: int,
89
+ warmup_steps: int,
90
+ steps: int,
91
+ nproc_per_node: int,
92
+ nnodes: int,
93
+ node_rank: int,
94
+ master_addr: str,
95
+ master_port: int,
96
+ num_workers: Optional[int],
97
+ disable_compile: bool,
98
+ ) -> SweepResult:
99
+ cmd = [
100
+ sys.executable,
101
+ "-m",
102
+ "torch.distributed.run",
103
+ "--nnodes",
104
+ str(nnodes),
105
+ "--node_rank",
106
+ str(node_rank),
107
+ "--nproc_per_node",
108
+ str(nproc_per_node),
109
+ "--master_addr",
110
+ str(master_addr),
111
+ "--master_port",
112
+ str(master_port),
113
+ "training/benchmark_throughput.py",
114
+ "--config",
115
+ config,
116
+ "--batch-size",
117
+ str(batch_size),
118
+ "--warmup-steps",
119
+ str(warmup_steps),
120
+ "--steps",
121
+ str(steps),
122
+ ]
123
+ if num_workers is not None:
124
+ cmd.extend(["--num-workers", str(num_workers)])
125
+ if disable_compile:
126
+ cmd.append("--disable-compile")
127
+
128
+ proc = subprocess.run(cmd, capture_output=True, text=True)
129
+ combined_stdout = (proc.stdout or "") + "\n" + (proc.stderr or "")
130
+ return _parse_result(
131
+ stdout=combined_stdout,
132
+ returncode=proc.returncode,
133
+ batch_size=batch_size,
134
+ nproc=nproc_per_node,
135
+ stderr=proc.stderr or "",
136
+ )
137
+
138
+
139
+ def _format_eta_hours(hours: float) -> str:
140
+ if hours >= 1.0:
141
+ return f"{hours:.2f}h"
142
+ return f"{hours * 60.0:.1f}m"
143
+
144
+
145
+ def main() -> int:
146
+ parser = argparse.ArgumentParser(description="Batch-size throughput sweep (DDP)")
147
+ parser.add_argument("--config", type=str, default="configs/real_config_8gpu_100m.yaml")
148
+ parser.add_argument("--batch-sizes", type=str, default="24,32,40,48,56,64,72,80,96")
149
+ parser.add_argument("--warmup-steps", type=int, default=20)
150
+ parser.add_argument("--steps", type=int, default=80)
151
+ parser.add_argument("--nproc-per-node", type=int, default=8)
152
+ parser.add_argument("--nnodes", type=int, default=1)
153
+ parser.add_argument("--node-rank", type=int, default=0)
154
+ parser.add_argument("--master-addr", type=str, default="127.0.0.1")
155
+ parser.add_argument("--master-port", type=int, default=29517)
156
+ parser.add_argument("--num-workers", type=int, default=None)
157
+ parser.add_argument("--disable-compile", action="store_true")
158
+ parser.add_argument("--stop-on-oom", dest="stop_on_oom", action="store_true")
159
+ parser.add_argument("--no-stop-on-oom", dest="stop_on_oom", action="store_false")
160
+ parser.add_argument("--tokens-per-epoch", type=float, default=30342999.0)
161
+ parser.add_argument("--epochs", type=int, default=50)
162
+ parser.add_argument("--save-json", type=str, default="sweep_results_8gpu.json")
163
+ parser.set_defaults(stop_on_oom=True)
164
+ args = parser.parse_args()
165
+
166
+ config_path = Path(args.config)
167
+ if not config_path.exists():
168
+ raise FileNotFoundError(f"Config not found: {config_path}")
169
+
170
+ batch_sizes = _parse_batch_sizes(args.batch_sizes)
171
+ results: List[SweepResult] = []
172
+
173
+ print(f"[sweep] config={config_path}")
174
+ print(f"[sweep] batch_sizes={batch_sizes}")
175
+ print(
176
+ "[sweep] launch "
177
+ f"nnodes={args.nnodes} node_rank={args.node_rank} nproc_per_node={args.nproc_per_node} "
178
+ f"master={args.master_addr}:{args.master_port}"
179
+ )
180
+ print(f"[sweep] warmup_steps={args.warmup_steps} measured_steps={args.steps}")
181
+
182
+ for idx, batch_size in enumerate(batch_sizes, start=1):
183
+ print(f"[sweep] ({idx}/{len(batch_sizes)}) batch_size={batch_size} ...")
184
+ result = _run_once(
185
+ config=str(config_path),
186
+ batch_size=batch_size,
187
+ warmup_steps=int(args.warmup_steps),
188
+ steps=int(args.steps),
189
+ nproc_per_node=int(args.nproc_per_node),
190
+ nnodes=int(args.nnodes),
191
+ node_rank=int(args.node_rank),
192
+ master_addr=str(args.master_addr),
193
+ master_port=int(args.master_port),
194
+ num_workers=args.num_workers,
195
+ disable_compile=bool(args.disable_compile),
196
+ )
197
+ results.append(result)
198
+
199
+ if result.status == "ok":
200
+ eta_hours = (args.tokens_per_epoch * args.epochs) / max(result.tok_s, 1e-9) / 3600.0
201
+ print(
202
+ "[sweep] ok "
203
+ f"global_batch={result.global_batch} tok_s={result.tok_s:.1f} "
204
+ f"max_mem_gib={result.max_mem_gib:.2f} eta_{args.epochs}ep={_format_eta_hours(eta_hours)}"
205
+ )
206
+ elif result.status == "oom":
207
+ print(f"[sweep] oom at batch_size={batch_size} (global_batch={result.global_batch})")
208
+ if args.stop_on_oom:
209
+ break
210
+ else:
211
+ print(
212
+ "[sweep] error "
213
+ f"batch_size={batch_size} returncode={result.returncode} "
214
+ f"stderr_tail={result.stderr_tail!r}"
215
+ )
216
+
217
+ ok_results = [r for r in results if r.status == "ok"]
218
+ best = max(ok_results, key=lambda r: r.tok_s) if ok_results else None
219
+
220
+ print("\n[sweep] summary")
221
+ for r in results:
222
+ if r.status == "ok":
223
+ eta_hours = (args.tokens_per_epoch * args.epochs) / max(r.tok_s, 1e-9) / 3600.0
224
+ print(
225
+ f" batch={r.batch_size:>4} global_batch={r.global_batch:>5} "
226
+ f"tok_s={r.tok_s:>10.1f} mem_gib={r.max_mem_gib:>7.2f} "
227
+ f"eta_{args.epochs}ep={_format_eta_hours(eta_hours)}"
228
+ )
229
+ else:
230
+ print(
231
+ f" batch={r.batch_size:>4} global_batch={r.global_batch:>5} "
232
+ f"status={r.status} returncode={r.returncode}"
233
+ )
234
+
235
+ if best is not None:
236
+ best_eta_hours = (args.tokens_per_epoch * args.epochs) / max(best.tok_s, 1e-9) / 3600.0
237
+ print("\n[sweep] best")
238
+ print(
239
+ f" batch_size={best.batch_size} global_batch={best.global_batch} "
240
+ f"tok_s={best.tok_s:.1f} max_mem_gib={best.max_mem_gib:.2f} "
241
+ f"eta_{args.epochs}ep={_format_eta_hours(best_eta_hours)}"
242
+ )
243
+ else:
244
+ print("\n[sweep] no successful runs")
245
+
246
+ save_path = Path(args.save_json)
247
+ payload = {
248
+ "config": str(config_path),
249
+ "epochs": int(args.epochs),
250
+ "tokens_per_epoch": float(args.tokens_per_epoch),
251
+ "results": [asdict(r) for r in results],
252
+ "best": asdict(best) if best else None,
253
+ "best_eta_hours": (
254
+ (args.tokens_per_epoch * args.epochs) / max(best.tok_s, 1e-9) / 3600.0 if best else None
255
+ ),
256
+ }
257
+ save_path.write_text(json.dumps(payload, indent=2), encoding="utf-8")
258
+ print(f"[sweep] wrote {save_path}")
259
+
260
+ return 0 if best is not None else 2
261
+
262
+
263
+ if __name__ == "__main__":
264
+ raise SystemExit(main())