Update train_nsf_sim_cache_sid_load_pretrain.py
Browse files
train_nsf_sim_cache_sid_load_pretrain.py
CHANGED
|
@@ -174,10 +174,10 @@ def run(rank, n_gpus, hps):
|
|
| 174 |
net_g = DDP(net_g)
|
| 175 |
net_d = DDP(net_d)
|
| 176 |
|
| 177 |
-
try: #
|
| 178 |
_, _, _, epoch_str = utils.load_checkpoint(
|
| 179 |
utils.latest_checkpoint_path(hps.model_dir, "D_*.pth"), net_d, optim_d
|
| 180 |
-
) # D
|
| 181 |
if rank == 0:
|
| 182 |
logger.info("loaded D")
|
| 183 |
# _, _, _, epoch_str = utils.load_checkpoint(utils.latest_checkpoint_path(hps.model_dir, "G_*.pth"), net_g, optim_g,load_opt=0)
|
|
@@ -187,7 +187,7 @@ def run(rank, n_gpus, hps):
|
|
| 187 |
global_step = (epoch_str - 1) * len(train_loader)
|
| 188 |
# epoch_str = 1
|
| 189 |
# global_step = 0
|
| 190 |
-
except: #
|
| 191 |
# traceback.print_exc()
|
| 192 |
epoch_str = 1
|
| 193 |
global_step = 0
|
|
@@ -198,7 +198,7 @@ def run(rank, n_gpus, hps):
|
|
| 198 |
net_g.module.load_state_dict(
|
| 199 |
torch.load(hps.pretrainG, map_location="cpu")["model"]
|
| 200 |
)
|
| 201 |
-
) ##
|
| 202 |
if hps.pretrainD != "":
|
| 203 |
if rank == 0:
|
| 204 |
logger.info("loaded pretrained %s" % (hps.pretrainD))
|
|
|
|
| 174 |
net_g = DDP(net_g)
|
| 175 |
net_d = DDP(net_d)
|
| 176 |
|
| 177 |
+
try: # If you can load automatically resume
|
| 178 |
_, _, _, epoch_str = utils.load_checkpoint(
|
| 179 |
utils.latest_checkpoint_path(hps.model_dir, "D_*.pth"), net_d, optim_d
|
| 180 |
+
) # D Most of the time, it is OK to load
|
| 181 |
if rank == 0:
|
| 182 |
logger.info("loaded D")
|
| 183 |
# _, _, _, epoch_str = utils.load_checkpoint(utils.latest_checkpoint_path(hps.model_dir, "G_*.pth"), net_g, optim_g,load_opt=0)
|
|
|
|
| 187 |
global_step = (epoch_str - 1) * len(train_loader)
|
| 188 |
# epoch_str = 1
|
| 189 |
# global_step = 0
|
| 190 |
+
except: # If you cannot load it for the first time, load pretrain
|
| 191 |
# traceback.print_exc()
|
| 192 |
epoch_str = 1
|
| 193 |
global_step = 0
|
|
|
|
| 198 |
net_g.module.load_state_dict(
|
| 199 |
torch.load(hps.pretrainG, map_location="cpu")["model"]
|
| 200 |
)
|
| 201 |
+
) ##Test without loading the optimizer
|
| 202 |
if hps.pretrainD != "":
|
| 203 |
if rank == 0:
|
| 204 |
logger.info("loaded pretrained %s" % (hps.pretrainD))
|