25ep
Browse files
train.py
CHANGED
|
@@ -29,8 +29,8 @@ ds_path = "datasets/576"
|
|
| 29 |
project = "unet"
|
| 30 |
batch_size = 25
|
| 31 |
base_learning_rate = 9.5e-6
|
| 32 |
-
min_learning_rate =
|
| 33 |
-
num_epochs =
|
| 34 |
# samples/save per epoch
|
| 35 |
sample_interval_share = 10
|
| 36 |
use_wandb = True
|
|
@@ -904,7 +904,7 @@ for epoch in range(start_epoch, start_epoch + num_epochs):
|
|
| 904 |
if (global_step % 100 == 0) or (global_step % sample_interval == 0):
|
| 905 |
accelerator.wait_for_everyone()
|
| 906 |
|
| 907 |
-
grad = torch.tensor(0.0, device=device)
|
| 908 |
if not fbp:
|
| 909 |
if accelerator.sync_gradients:
|
| 910 |
with torch.amp.autocast('cuda', enabled=False):
|
|
|
|
| 29 |
project = "unet"
|
| 30 |
batch_size = 25
|
| 31 |
base_learning_rate = 9.5e-6
|
| 32 |
+
min_learning_rate = 7e-6
|
| 33 |
+
num_epochs = 24
|
| 34 |
# samples/save per epoch
|
| 35 |
sample_interval_share = 10
|
| 36 |
use_wandb = True
|
|
|
|
| 904 |
if (global_step % 100 == 0) or (global_step % sample_interval == 0):
|
| 905 |
accelerator.wait_for_everyone()
|
| 906 |
|
| 907 |
+
grad = torch.tensor(0.0, device=device)
|
| 908 |
if not fbp:
|
| 909 |
if accelerator.sync_gradients:
|
| 910 |
with torch.amp.autocast('cuda', enabled=False):
|