feat: add cli option for model name (log)
Browse files
train.py
CHANGED
|
@@ -77,6 +77,13 @@ parser.add_argument(
|
|
| 77 |
default=["./dataset/font_img"],
|
| 78 |
help="Datasets paths, seperated by space (default: ['./dataset/font_img'])",
|
| 79 |
)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 80 |
|
| 81 |
args = parser.parse_args()
|
| 82 |
|
|
@@ -124,7 +131,7 @@ data_module = FontDataModule(
|
|
| 124 |
num_iters = data_module.get_train_num_iter(num_device) * num_epochs
|
| 125 |
num_warmup_iter = data_module.get_train_num_iter(num_device) * num_warmup_epochs
|
| 126 |
|
| 127 |
-
model_name =
|
| 128 |
|
| 129 |
logger_unconditioned = TensorBoardLogger(
|
| 130 |
save_dir=os.getcwd(), name="tensorboard", version=model_name
|
|
|
|
| 77 |
default=["./dataset/font_img"],
|
| 78 |
help="Datasets paths, seperated by space (default: ['./dataset/font_img'])",
|
| 79 |
)
|
| 80 |
+
parser.add_argument(
|
| 81 |
+
"-n",
|
| 82 |
+
"--model-name",
|
| 83 |
+
type=str,
|
| 84 |
+
default=get_current_tag(),
|
| 85 |
+
help="Model name (default: current tag)",
|
| 86 |
+
)
|
| 87 |
|
| 88 |
args = parser.parse_args()
|
| 89 |
|
|
|
|
| 131 |
num_iters = data_module.get_train_num_iter(num_device) * num_epochs
|
| 132 |
num_warmup_iter = data_module.get_train_num_iter(num_device) * num_warmup_epochs
|
| 133 |
|
| 134 |
+
model_name = args.model_name
|
| 135 |
|
| 136 |
logger_unconditioned = TensorBoardLogger(
|
| 137 |
save_dir=os.getcwd(), name="tensorboard", version=model_name
|