PhoenixStormJr commited on
Commit
aedec4c
·
verified ·
1 Parent(s): 7819a94

Update train_nsf_sim_cache_sid_load_pretrain.py

Browse files
train_nsf_sim_cache_sid_load_pretrain.py CHANGED
@@ -174,10 +174,10 @@ def run(rank, n_gpus, hps):
174
  net_g = DDP(net_g)
175
  net_d = DDP(net_d)
176
 
177
- try: # 如果能加载自动resume
178
  _, _, _, epoch_str = utils.load_checkpoint(
179
  utils.latest_checkpoint_path(hps.model_dir, "D_*.pth"), net_d, optim_d
180
- ) # D多半加载没事
181
  if rank == 0:
182
  logger.info("loaded D")
183
  # _, _, _, epoch_str = utils.load_checkpoint(utils.latest_checkpoint_path(hps.model_dir, "G_*.pth"), net_g, optim_g,load_opt=0)
@@ -187,7 +187,7 @@ def run(rank, n_gpus, hps):
187
  global_step = (epoch_str - 1) * len(train_loader)
188
  # epoch_str = 1
189
  # global_step = 0
190
- except: # 如果首次不能加载,加载pretrain
191
  # traceback.print_exc()
192
  epoch_str = 1
193
  global_step = 0
@@ -198,7 +198,7 @@ def run(rank, n_gpus, hps):
198
  net_g.module.load_state_dict(
199
  torch.load(hps.pretrainG, map_location="cpu")["model"]
200
  )
201
- ) ##测试不加载优化器
202
  if hps.pretrainD != "":
203
  if rank == 0:
204
  logger.info("loaded pretrained %s" % (hps.pretrainD))
 
174
  net_g = DDP(net_g)
175
  net_d = DDP(net_d)
176
 
177
+ try: # If you can load automatically resume
178
  _, _, _, epoch_str = utils.load_checkpoint(
179
  utils.latest_checkpoint_path(hps.model_dir, "D_*.pth"), net_d, optim_d
180
+ ) # D Most of the time, it is OK to load
181
  if rank == 0:
182
  logger.info("loaded D")
183
  # _, _, _, epoch_str = utils.load_checkpoint(utils.latest_checkpoint_path(hps.model_dir, "G_*.pth"), net_g, optim_g,load_opt=0)
 
187
  global_step = (epoch_str - 1) * len(train_loader)
188
  # epoch_str = 1
189
  # global_step = 0
190
+ except: # If you cannot load it for the first time, load pretrain
191
  # traceback.print_exc()
192
  epoch_str = 1
193
  global_step = 0
 
198
  net_g.module.load_state_dict(
199
  torch.load(hps.pretrainG, map_location="cpu")["model"]
200
  )
201
+ ) ##Test without loading the optimizer
202
  if hps.pretrainD != "":
203
  if rank == 0:
204
  logger.info("loaded pretrained %s" % (hps.pretrainD))