feat: add device config into cli arg
Browse files
train.py
CHANGED
|
@@ -1,3 +1,4 @@
|
|
|
|
|
| 1 |
import os
|
| 2 |
import torch
|
| 3 |
import pytorch_lightning as ptl
|
|
@@ -10,7 +11,12 @@ from utils import get_current_tag
|
|
| 10 |
|
| 11 |
torch.set_float32_matmul_precision('high')
|
| 12 |
|
| 13 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 14 |
|
| 15 |
final_batch_size = 128
|
| 16 |
single_device_num_workers = 24
|
|
|
|
| 1 |
+
import argparse
|
| 2 |
import os
|
| 3 |
import torch
|
| 4 |
import pytorch_lightning as ptl
|
|
|
|
| 11 |
|
| 12 |
torch.set_float32_matmul_precision('high')
|
| 13 |
|
| 14 |
+
parser = argparse.ArgumentParser()
|
| 15 |
+
parser.add_argument('-d', '--devices', nargs='*', type=int, default=[0])
|
| 16 |
+
|
| 17 |
+
args = parser.parse_args()
|
| 18 |
+
|
| 19 |
+
devices = args.devices
|
| 20 |
|
| 21 |
final_batch_size = 128
|
| 22 |
single_device_num_workers = 24
|