Update train.py
Browse files
train.py
CHANGED
|
@@ -24,21 +24,17 @@ print("Torchvision Version: ",torchvision.__version__)
|
|
| 24 |
# Much of the code is a modified version of the code available at
|
| 25 |
# https://pytorch.org/tutorials/beginner/finetuning_torchvision_models_tutorial.html
|
| 26 |
|
| 27 |
-
|
| 28 |
-
# nohup python train.py > logs/empty_cell_aug_28032024.txt 2>&1 &
|
| 29 |
-
# echo $! > logs/save_pid.txt
|
| 30 |
-
|
| 31 |
parser = argparse.ArgumentParser('arguments for training')
|
| 32 |
|
| 33 |
-
parser.add_argument('--tr_empty_folder', type=str, default="/
|
| 34 |
help='path to training data with empty images')
|
| 35 |
-
parser.add_argument('--val_empty_folder', type=str, default="/
|
| 36 |
help='path to validation data with empty images')
|
| 37 |
-
parser.add_argument('--tr_ok_folder', type=str, default="/
|
| 38 |
help='path to training data with ok images')
|
| 39 |
-
parser.add_argument('--val_ok_folder', type=str, default="/
|
| 40 |
help='path to validation data with ok images')
|
| 41 |
-
parser.add_argument('--results_folder', type=str, default="results/
|
| 42 |
help='Folder for saving training results.')
|
| 43 |
parser.add_argument('--save_model_path', type=str, default="./models/",
|
| 44 |
help='Path for saving model file.')
|
|
@@ -56,11 +52,11 @@ parser.add_argument('--random_seed', type=int, default=8765,
|
|
| 56 |
help='Number used for initializing random number generation.')
|
| 57 |
parser.add_argument('--early_stop_threshold', type=int, default=3,
|
| 58 |
help='Threshold value of epochs after which training stops if validation accuracy does not improve.')
|
| 59 |
-
parser.add_argument('--save_model_format', type=str, default='
|
| 60 |
help='Defines the format for saving the model.')
|
| 61 |
parser.add_argument('--augment_choice', type=str, default=None,
|
| 62 |
help='Defines which image augmentation(s) are used. Defaults to randomly selected augmentations.')
|
| 63 |
-
parser.add_argument('--model_name', type=str, default='
|
| 64 |
help='Current date.')
|
| 65 |
parser.add_argument('--date', type=str, default=time.strftime("%d%m%Y"),
|
| 66 |
help='Current date.')
|
|
|
|
| 24 |
# Much of the code is a modified version of the code available at
|
| 25 |
# https://pytorch.org/tutorials/beginner/finetuning_torchvision_models_tutorial.html
|
| 26 |
|
|
|
|
|
|
|
|
|
|
|
|
|
| 27 |
parser = argparse.ArgumentParser('arguments for training')
|
| 28 |
|
| 29 |
+
parser.add_argument('--tr_empty_folder', type=str, default="/path/to/empty/train/images",
|
| 30 |
help='path to training data with empty images')
|
| 31 |
+
parser.add_argument('--val_empty_folder', type=str, default="/path/to/empty/validation/images",
|
| 32 |
help='path to validation data with empty images')
|
| 33 |
+
parser.add_argument('--tr_ok_folder', type=str, default="/path/to/non-empty/train/images",
|
| 34 |
help='path to training data with ok images')
|
| 35 |
+
parser.add_argument('--val_ok_folder', type=str, default="/path/to/non-empty/validation/images",
|
| 36 |
help='path to validation data with ok images')
|
| 37 |
+
parser.add_argument('--results_folder', type=str, default="./results/",
|
| 38 |
help='Folder for saving training results.')
|
| 39 |
parser.add_argument('--save_model_path', type=str, default="./models/",
|
| 40 |
help='Path for saving model file.')
|
|
|
|
| 52 |
help='Number used for initializing random number generation.')
|
| 53 |
parser.add_argument('--early_stop_threshold', type=int, default=3,
|
| 54 |
help='Threshold value of epochs after which training stops if validation accuracy does not improve.')
|
| 55 |
+
parser.add_argument('--save_model_format', type=str, default='onnx',
|
| 56 |
help='Defines the format for saving the model.')
|
| 57 |
parser.add_argument('--augment_choice', type=str, default=None,
|
| 58 |
help='Defines which image augmentation(s) are used. Defaults to randomly selected augmentations.')
|
| 59 |
+
parser.add_argument('--model_name', type=str, default='test_model',
|
| 60 |
help='Current date.')
|
| 61 |
parser.add_argument('--date', type=str, default=time.strftime("%d%m%Y"),
|
| 62 |
help='Current date.')
|