artydemo / scripts /train_cnn_safe.py
Pablo Dejuan
Add HF trainer utilities (dataset+OOM-safe training)
cf79e95
"""
Run `scripts/train_cnn.py` with a simple batch-size fallback strategy.
On some GPU setups, `--batch-size` can cause OOM and crash early.
This wrapper retries once with a smaller batch size if the first run
exits non-zero.
"""
from __future__ import annotations
import argparse
import subprocess
import sys
from pathlib import Path
def run_train(
*,
arch: str,
epochs: int,
batch_size: int,
cpu: bool,
extra_args: list[str] | None = None,
) -> int:
ROOT = Path(__file__).resolve().parent.parent
train_script = ROOT / "scripts" / "train_cnn.py"
cmd = [sys.executable, str(train_script), "--arch", arch, "--epochs", str(epochs), "--batch-size", str(batch_size)]
if cpu:
cmd.append("--cpu")
if extra_args:
cmd += extra_args
r = subprocess.run(cmd, cwd=str(ROOT))
return int(r.returncode)
def main() -> None:
p = argparse.ArgumentParser(description="Train with a batch-size fallback strategy.")
p.add_argument("--arch", choices=["cnn", "cnnrnn"], required=True)
p.add_argument("--epochs", type=int, required=True)
p.add_argument("--batch-size-primary", type=int, default=32)
p.add_argument("--batch-size-fallback", type=int, default=16)
p.add_argument("--cpu", action="store_true", help="Force CPU (mainly for local debugging)")
p.add_argument("--extra-args", nargs="*", default=None, help="Extra args passed through to train_cnn.py")
args = p.parse_args()
rc = run_train(
arch=args.arch,
epochs=args.epochs,
batch_size=args.batch_size_primary,
cpu=args.cpu,
extra_args=args.extra_args,
)
if rc == 0:
return
if args.batch_size_fallback and args.batch_size_fallback != args.batch_size_primary:
rc2 = run_train(
arch=args.arch,
epochs=args.epochs,
batch_size=args.batch_size_fallback,
cpu=args.cpu,
extra_args=args.extra_args,
)
if rc2 != 0:
sys.exit(rc2)
return
sys.exit(rc)
if __name__ == "__main__":
main()