Upload scripts/sweep_batch_sizes.py with huggingface_hub
Browse files- scripts/sweep_batch_sizes.py +264 -0
scripts/sweep_batch_sizes.py
ADDED
|
@@ -0,0 +1,264 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#!/usr/bin/env python3
|
| 2 |
+
"""
|
| 3 |
+
Sweep per-GPU batch sizes and report throughput + 50-epoch ETA.
|
| 4 |
+
"""
|
| 5 |
+
|
| 6 |
+
import argparse
|
| 7 |
+
import json
|
| 8 |
+
import re
|
| 9 |
+
import subprocess
|
| 10 |
+
import sys
|
| 11 |
+
from dataclasses import dataclass, asdict
|
| 12 |
+
from pathlib import Path
|
| 13 |
+
from typing import List, Optional
|
| 14 |
+
|
| 15 |
+
|
| 16 |
+
RESULT_RE = re.compile(
|
| 17 |
+
r"\[result\]\s+status=(?P<status>\w+)\s+batch_size=(?P<batch_size>\d+)\s+"
|
| 18 |
+
r"world_size=(?P<world_size>\d+)\s+global_batch=(?P<global_batch>\d+)"
|
| 19 |
+
r"(?:\s+tok_s=(?P<tok_s>[0-9.]+))?"
|
| 20 |
+
r"(?:\s+elapsed_s=(?P<elapsed_s>[0-9.]+))?"
|
| 21 |
+
r"(?:\s+measured_steps=(?P<measured_steps>\d+))?"
|
| 22 |
+
r"(?:\s+mean_loss=(?P<mean_loss>[0-9.]+))?"
|
| 23 |
+
r"(?:\s+max_mem_gib=(?P<max_mem_gib>[0-9.]+))?"
|
| 24 |
+
)
|
| 25 |
+
|
| 26 |
+
|
| 27 |
+
@dataclass
|
| 28 |
+
class SweepResult:
|
| 29 |
+
status: str
|
| 30 |
+
batch_size: int
|
| 31 |
+
world_size: int
|
| 32 |
+
global_batch: int
|
| 33 |
+
tok_s: float = 0.0
|
| 34 |
+
elapsed_s: float = 0.0
|
| 35 |
+
measured_steps: int = 0
|
| 36 |
+
mean_loss: float = 0.0
|
| 37 |
+
max_mem_gib: float = 0.0
|
| 38 |
+
returncode: int = 0
|
| 39 |
+
stderr_tail: str = ""
|
| 40 |
+
|
| 41 |
+
|
| 42 |
+
def _parse_batch_sizes(text: str) -> List[int]:
|
| 43 |
+
values = []
|
| 44 |
+
for part in text.split(","):
|
| 45 |
+
p = part.strip()
|
| 46 |
+
if not p:
|
| 47 |
+
continue
|
| 48 |
+
values.append(int(p))
|
| 49 |
+
if not values:
|
| 50 |
+
raise ValueError("No batch sizes were provided.")
|
| 51 |
+
return values
|
| 52 |
+
|
| 53 |
+
|
| 54 |
+
def _parse_result(stdout: str, returncode: int, batch_size: int, nproc: int, stderr: str) -> SweepResult:
|
| 55 |
+
matches = RESULT_RE.findall(stdout)
|
| 56 |
+
if not matches:
|
| 57 |
+
tail = "\n".join((stderr or "").strip().splitlines()[-8:])
|
| 58 |
+
return SweepResult(
|
| 59 |
+
status="error",
|
| 60 |
+
batch_size=batch_size,
|
| 61 |
+
world_size=nproc,
|
| 62 |
+
global_batch=batch_size * nproc,
|
| 63 |
+
returncode=returncode,
|
| 64 |
+
stderr_tail=tail,
|
| 65 |
+
)
|
| 66 |
+
|
| 67 |
+
groups = RESULT_RE.search([m.group(0) for m in RESULT_RE.finditer(stdout)][-1])
|
| 68 |
+
assert groups is not None
|
| 69 |
+
d = groups.groupdict()
|
| 70 |
+
|
| 71 |
+
return SweepResult(
|
| 72 |
+
status=d["status"],
|
| 73 |
+
batch_size=int(d["batch_size"]),
|
| 74 |
+
world_size=int(d["world_size"]),
|
| 75 |
+
global_batch=int(d["global_batch"]),
|
| 76 |
+
tok_s=float(d["tok_s"] or 0.0),
|
| 77 |
+
elapsed_s=float(d["elapsed_s"] or 0.0),
|
| 78 |
+
measured_steps=int(d["measured_steps"] or 0),
|
| 79 |
+
mean_loss=float(d["mean_loss"] or 0.0),
|
| 80 |
+
max_mem_gib=float(d["max_mem_gib"] or 0.0),
|
| 81 |
+
returncode=returncode,
|
| 82 |
+
stderr_tail="\n".join((stderr or "").strip().splitlines()[-8:]),
|
| 83 |
+
)
|
| 84 |
+
|
| 85 |
+
|
| 86 |
+
def _run_once(
|
| 87 |
+
config: str,
|
| 88 |
+
batch_size: int,
|
| 89 |
+
warmup_steps: int,
|
| 90 |
+
steps: int,
|
| 91 |
+
nproc_per_node: int,
|
| 92 |
+
nnodes: int,
|
| 93 |
+
node_rank: int,
|
| 94 |
+
master_addr: str,
|
| 95 |
+
master_port: int,
|
| 96 |
+
num_workers: Optional[int],
|
| 97 |
+
disable_compile: bool,
|
| 98 |
+
) -> SweepResult:
|
| 99 |
+
cmd = [
|
| 100 |
+
sys.executable,
|
| 101 |
+
"-m",
|
| 102 |
+
"torch.distributed.run",
|
| 103 |
+
"--nnodes",
|
| 104 |
+
str(nnodes),
|
| 105 |
+
"--node_rank",
|
| 106 |
+
str(node_rank),
|
| 107 |
+
"--nproc_per_node",
|
| 108 |
+
str(nproc_per_node),
|
| 109 |
+
"--master_addr",
|
| 110 |
+
str(master_addr),
|
| 111 |
+
"--master_port",
|
| 112 |
+
str(master_port),
|
| 113 |
+
"training/benchmark_throughput.py",
|
| 114 |
+
"--config",
|
| 115 |
+
config,
|
| 116 |
+
"--batch-size",
|
| 117 |
+
str(batch_size),
|
| 118 |
+
"--warmup-steps",
|
| 119 |
+
str(warmup_steps),
|
| 120 |
+
"--steps",
|
| 121 |
+
str(steps),
|
| 122 |
+
]
|
| 123 |
+
if num_workers is not None:
|
| 124 |
+
cmd.extend(["--num-workers", str(num_workers)])
|
| 125 |
+
if disable_compile:
|
| 126 |
+
cmd.append("--disable-compile")
|
| 127 |
+
|
| 128 |
+
proc = subprocess.run(cmd, capture_output=True, text=True)
|
| 129 |
+
combined_stdout = (proc.stdout or "") + "\n" + (proc.stderr or "")
|
| 130 |
+
return _parse_result(
|
| 131 |
+
stdout=combined_stdout,
|
| 132 |
+
returncode=proc.returncode,
|
| 133 |
+
batch_size=batch_size,
|
| 134 |
+
nproc=nproc_per_node,
|
| 135 |
+
stderr=proc.stderr or "",
|
| 136 |
+
)
|
| 137 |
+
|
| 138 |
+
|
| 139 |
+
def _format_eta_hours(hours: float) -> str:
|
| 140 |
+
if hours >= 1.0:
|
| 141 |
+
return f"{hours:.2f}h"
|
| 142 |
+
return f"{hours * 60.0:.1f}m"
|
| 143 |
+
|
| 144 |
+
|
| 145 |
+
def main() -> int:
|
| 146 |
+
parser = argparse.ArgumentParser(description="Batch-size throughput sweep (DDP)")
|
| 147 |
+
parser.add_argument("--config", type=str, default="configs/real_config_8gpu_100m.yaml")
|
| 148 |
+
parser.add_argument("--batch-sizes", type=str, default="24,32,40,48,56,64,72,80,96")
|
| 149 |
+
parser.add_argument("--warmup-steps", type=int, default=20)
|
| 150 |
+
parser.add_argument("--steps", type=int, default=80)
|
| 151 |
+
parser.add_argument("--nproc-per-node", type=int, default=8)
|
| 152 |
+
parser.add_argument("--nnodes", type=int, default=1)
|
| 153 |
+
parser.add_argument("--node-rank", type=int, default=0)
|
| 154 |
+
parser.add_argument("--master-addr", type=str, default="127.0.0.1")
|
| 155 |
+
parser.add_argument("--master-port", type=int, default=29517)
|
| 156 |
+
parser.add_argument("--num-workers", type=int, default=None)
|
| 157 |
+
parser.add_argument("--disable-compile", action="store_true")
|
| 158 |
+
parser.add_argument("--stop-on-oom", dest="stop_on_oom", action="store_true")
|
| 159 |
+
parser.add_argument("--no-stop-on-oom", dest="stop_on_oom", action="store_false")
|
| 160 |
+
parser.add_argument("--tokens-per-epoch", type=float, default=30342999.0)
|
| 161 |
+
parser.add_argument("--epochs", type=int, default=50)
|
| 162 |
+
parser.add_argument("--save-json", type=str, default="sweep_results_8gpu.json")
|
| 163 |
+
parser.set_defaults(stop_on_oom=True)
|
| 164 |
+
args = parser.parse_args()
|
| 165 |
+
|
| 166 |
+
config_path = Path(args.config)
|
| 167 |
+
if not config_path.exists():
|
| 168 |
+
raise FileNotFoundError(f"Config not found: {config_path}")
|
| 169 |
+
|
| 170 |
+
batch_sizes = _parse_batch_sizes(args.batch_sizes)
|
| 171 |
+
results: List[SweepResult] = []
|
| 172 |
+
|
| 173 |
+
print(f"[sweep] config={config_path}")
|
| 174 |
+
print(f"[sweep] batch_sizes={batch_sizes}")
|
| 175 |
+
print(
|
| 176 |
+
"[sweep] launch "
|
| 177 |
+
f"nnodes={args.nnodes} node_rank={args.node_rank} nproc_per_node={args.nproc_per_node} "
|
| 178 |
+
f"master={args.master_addr}:{args.master_port}"
|
| 179 |
+
)
|
| 180 |
+
print(f"[sweep] warmup_steps={args.warmup_steps} measured_steps={args.steps}")
|
| 181 |
+
|
| 182 |
+
for idx, batch_size in enumerate(batch_sizes, start=1):
|
| 183 |
+
print(f"[sweep] ({idx}/{len(batch_sizes)}) batch_size={batch_size} ...")
|
| 184 |
+
result = _run_once(
|
| 185 |
+
config=str(config_path),
|
| 186 |
+
batch_size=batch_size,
|
| 187 |
+
warmup_steps=int(args.warmup_steps),
|
| 188 |
+
steps=int(args.steps),
|
| 189 |
+
nproc_per_node=int(args.nproc_per_node),
|
| 190 |
+
nnodes=int(args.nnodes),
|
| 191 |
+
node_rank=int(args.node_rank),
|
| 192 |
+
master_addr=str(args.master_addr),
|
| 193 |
+
master_port=int(args.master_port),
|
| 194 |
+
num_workers=args.num_workers,
|
| 195 |
+
disable_compile=bool(args.disable_compile),
|
| 196 |
+
)
|
| 197 |
+
results.append(result)
|
| 198 |
+
|
| 199 |
+
if result.status == "ok":
|
| 200 |
+
eta_hours = (args.tokens_per_epoch * args.epochs) / max(result.tok_s, 1e-9) / 3600.0
|
| 201 |
+
print(
|
| 202 |
+
"[sweep] ok "
|
| 203 |
+
f"global_batch={result.global_batch} tok_s={result.tok_s:.1f} "
|
| 204 |
+
f"max_mem_gib={result.max_mem_gib:.2f} eta_{args.epochs}ep={_format_eta_hours(eta_hours)}"
|
| 205 |
+
)
|
| 206 |
+
elif result.status == "oom":
|
| 207 |
+
print(f"[sweep] oom at batch_size={batch_size} (global_batch={result.global_batch})")
|
| 208 |
+
if args.stop_on_oom:
|
| 209 |
+
break
|
| 210 |
+
else:
|
| 211 |
+
print(
|
| 212 |
+
"[sweep] error "
|
| 213 |
+
f"batch_size={batch_size} returncode={result.returncode} "
|
| 214 |
+
f"stderr_tail={result.stderr_tail!r}"
|
| 215 |
+
)
|
| 216 |
+
|
| 217 |
+
ok_results = [r for r in results if r.status == "ok"]
|
| 218 |
+
best = max(ok_results, key=lambda r: r.tok_s) if ok_results else None
|
| 219 |
+
|
| 220 |
+
print("\n[sweep] summary")
|
| 221 |
+
for r in results:
|
| 222 |
+
if r.status == "ok":
|
| 223 |
+
eta_hours = (args.tokens_per_epoch * args.epochs) / max(r.tok_s, 1e-9) / 3600.0
|
| 224 |
+
print(
|
| 225 |
+
f" batch={r.batch_size:>4} global_batch={r.global_batch:>5} "
|
| 226 |
+
f"tok_s={r.tok_s:>10.1f} mem_gib={r.max_mem_gib:>7.2f} "
|
| 227 |
+
f"eta_{args.epochs}ep={_format_eta_hours(eta_hours)}"
|
| 228 |
+
)
|
| 229 |
+
else:
|
| 230 |
+
print(
|
| 231 |
+
f" batch={r.batch_size:>4} global_batch={r.global_batch:>5} "
|
| 232 |
+
f"status={r.status} returncode={r.returncode}"
|
| 233 |
+
)
|
| 234 |
+
|
| 235 |
+
if best is not None:
|
| 236 |
+
best_eta_hours = (args.tokens_per_epoch * args.epochs) / max(best.tok_s, 1e-9) / 3600.0
|
| 237 |
+
print("\n[sweep] best")
|
| 238 |
+
print(
|
| 239 |
+
f" batch_size={best.batch_size} global_batch={best.global_batch} "
|
| 240 |
+
f"tok_s={best.tok_s:.1f} max_mem_gib={best.max_mem_gib:.2f} "
|
| 241 |
+
f"eta_{args.epochs}ep={_format_eta_hours(best_eta_hours)}"
|
| 242 |
+
)
|
| 243 |
+
else:
|
| 244 |
+
print("\n[sweep] no successful runs")
|
| 245 |
+
|
| 246 |
+
save_path = Path(args.save_json)
|
| 247 |
+
payload = {
|
| 248 |
+
"config": str(config_path),
|
| 249 |
+
"epochs": int(args.epochs),
|
| 250 |
+
"tokens_per_epoch": float(args.tokens_per_epoch),
|
| 251 |
+
"results": [asdict(r) for r in results],
|
| 252 |
+
"best": asdict(best) if best else None,
|
| 253 |
+
"best_eta_hours": (
|
| 254 |
+
(args.tokens_per_epoch * args.epochs) / max(best.tok_s, 1e-9) / 3600.0 if best else None
|
| 255 |
+
),
|
| 256 |
+
}
|
| 257 |
+
save_path.write_text(json.dumps(payload, indent=2), encoding="utf-8")
|
| 258 |
+
print(f"[sweep] wrote {save_path}")
|
| 259 |
+
|
| 260 |
+
return 0 if best is not None else 2
|
| 261 |
+
|
| 262 |
+
|
| 263 |
+
if __name__ == "__main__":
|
| 264 |
+
raise SystemExit(main())
|