Commit ·
2aee25d
1
Parent(s): 8e86dac
Revert batch size default to 256
Browse files- scripts/train_all.py +1 -1
scripts/train_all.py
CHANGED
|
@@ -301,7 +301,7 @@ def parse_args():
|
|
| 301 |
p = argparse.ArgumentParser(description="Train small/base/large PAWN models simultaneously")
|
| 302 |
p.add_argument("--device", type=str, default=None, help="Device (cuda/cpu)")
|
| 303 |
p.add_argument("--total-steps", type=int, default=100_000, help="Total training steps")
|
| 304 |
-
p.add_argument("--batch-size", type=int, default=
|
| 305 |
p.add_argument("--num-workers", type=int, default=4, help="DataLoader workers")
|
| 306 |
p.add_argument("--log-dir", type=str, default="logs", help="Log directory")
|
| 307 |
p.add_argument("--log-interval", type=int, default=10)
|
|
|
|
| 301 |
p = argparse.ArgumentParser(description="Train small/base/large PAWN models simultaneously")
|
| 302 |
p.add_argument("--device", type=str, default=None, help="Device (cuda/cpu)")
|
| 303 |
p.add_argument("--total-steps", type=int, default=100_000, help="Total training steps")
|
| 304 |
+
p.add_argument("--batch-size", type=int, default=256, help="Batch size (shared across models)")
|
| 305 |
p.add_argument("--num-workers", type=int, default=4, help="DataLoader workers")
|
| 306 |
p.add_argument("--log-dir", type=str, default="logs", help="Log directory")
|
| 307 |
p.add_argument("--log-interval", type=int, default=10)
|