LogicGoInfotechSpaces commited on
Commit
bfac43d
·
1 Parent(s): 372a1fe

Add error handling for missing checkpoint files in load_weights methods

Browse files
Files changed (1) hide show
  1. models/methods.py +40 -18
models/methods.py CHANGED
@@ -88,29 +88,44 @@ class FSEFull(nn.Module):
88
  def load_weights(self):
89
  if self.opts.checkpoint_path != "":
90
  print(f"Loading from checkpoint: {self.opts.checkpoint_path}")
91
- ckpt = torch.load(self.opts.checkpoint_path, map_location="cpu")
92
- self.load_disc_from_ckpt(ckpt)
93
- self.encoder.load_state_dict(get_keys(ckpt, "encoder"), strict=True)
94
- self.inverter.load_state_dict(get_keys(ckpt, "inverter"), strict=True)
 
 
 
95
  else:
96
  print(f"Loading Discriminator and Inverter from Inverter checkpoint: {self.inverter_pth}")
97
- ckpt = torch.load(self.inverter_pth, map_location="cpu")
98
- self.load_disc_from_ckpt(ckpt)
99
- self.inverter.load_state_dict(get_keys(ckpt, "encoder"), strict=True)
 
 
 
100
 
101
  self.inverter = self.inverter.eval().to(self.device)
102
  toogle_grad(self.inverter, False)
103
 
104
  print("Loading Decoder from", self.opts.stylegan_weights)
105
- ckpt = torch.load(self.opts.stylegan_weights)
106
- self.decoder.load_state_dict(ckpt["g_ema"], strict=False)
107
- self.latent_avg = ckpt['latent_avg'].to(self.device)
 
 
 
 
 
108
  self.decoder = self.decoder.eval().to(self.device)
109
  toogle_grad(self.decoder, False)
110
 
111
  print("Loading E4E from", self.opts.e4e_path)
112
- ckpt = torch.load(self.opts.e4e_path, map_location="cpu")
113
- self.e4e_encoder.load_state_dict(get_keys(ckpt, "encoder"), strict=True)
 
 
 
 
114
  self.e4e_encoder = self.e4e_encoder.eval().to(self.device)
115
  toogle_grad(self.e4e_encoder, False)
116
 
@@ -227,14 +242,21 @@ class FSEInverter(nn.Module):
227
  def load_weights(self):
228
  if self.opts.checkpoint_path != "":
229
  print("Loading from checkpoint: {}".format(self.opts.checkpoint_path))
230
- ckpt = torch.load(self.opts.checkpoint_path, map_location="cpu")
231
- self.load_disc_from_ckpt(ckpt)
232
- self.encoder.load_state_dict(get_keys(ckpt, "encoder"), strict=True)
 
 
 
233
 
234
  print("Loading decoder from", self.opts.stylegan_weights)
235
- ckpt = torch.load(self.opts.stylegan_weights)
236
- self.decoder.load_state_dict(ckpt["g_ema"], strict=False)
237
- self.latent_avg = ckpt['latent_avg'].to(self.device)
 
 
 
 
238
 
239
  def set_encoder(self):
240
  inverter = psp_encoders.Inverter(opts=self.opts, n_styles=18)
 
88
  def load_weights(self):
89
  if self.opts.checkpoint_path != "":
90
  print(f"Loading from checkpoint: {self.opts.checkpoint_path}")
91
+ try:
92
+ ckpt = torch.load(self.opts.checkpoint_path, map_location="cpu")
93
+ self.load_disc_from_ckpt(ckpt)
94
+ self.encoder.load_state_dict(get_keys(ckpt, "encoder"), strict=True)
95
+ self.inverter.load_state_dict(get_keys(ckpt, "inverter"), strict=True)
96
+ except FileNotFoundError:
97
+ print(f"Warning: {self.opts.checkpoint_path} not found, using uninitialized weights")
98
  else:
99
  print(f"Loading Discriminator and Inverter from Inverter checkpoint: {self.inverter_pth}")
100
+ try:
101
+ ckpt = torch.load(self.inverter_pth, map_location="cpu")
102
+ self.load_disc_from_ckpt(ckpt)
103
+ self.inverter.load_state_dict(get_keys(ckpt, "encoder"), strict=True)
104
+ except FileNotFoundError:
105
+ print(f"Warning: {self.inverter_pth} not found, using uninitialized weights")
106
 
107
  self.inverter = self.inverter.eval().to(self.device)
108
  toogle_grad(self.inverter, False)
109
 
110
  print("Loading Decoder from", self.opts.stylegan_weights)
111
+ try:
112
+ ckpt = torch.load(self.opts.stylegan_weights)
113
+ self.decoder.load_state_dict(ckpt["g_ema"], strict=False)
114
+ self.latent_avg = ckpt['latent_avg'].to(self.device)
115
+ except FileNotFoundError:
116
+ print(f"Warning: {self.opts.stylegan_weights} not found, using uninitialized decoder")
117
+ self.latent_avg = torch.zeros(18, 512).to(self.device)
118
+
119
  self.decoder = self.decoder.eval().to(self.device)
120
  toogle_grad(self.decoder, False)
121
 
122
  print("Loading E4E from", self.opts.e4e_path)
123
+ try:
124
+ ckpt = torch.load(self.opts.e4e_path, map_location="cpu")
125
+ self.e4e_encoder.load_state_dict(get_keys(ckpt, "encoder"), strict=True)
126
+ except FileNotFoundError:
127
+ print(f"Warning: {self.opts.e4e_path} not found, using uninitialized E4E encoder")
128
+
129
  self.e4e_encoder = self.e4e_encoder.eval().to(self.device)
130
  toogle_grad(self.e4e_encoder, False)
131
 
 
242
  def load_weights(self):
243
  if self.opts.checkpoint_path != "":
244
  print("Loading from checkpoint: {}".format(self.opts.checkpoint_path))
245
+ try:
246
+ ckpt = torch.load(self.opts.checkpoint_path, map_location="cpu")
247
+ self.load_disc_from_ckpt(ckpt)
248
+ self.encoder.load_state_dict(get_keys(ckpt, "encoder"), strict=True)
249
+ except FileNotFoundError:
250
+ print(f"Warning: {self.opts.checkpoint_path} not found, using uninitialized weights")
251
 
252
  print("Loading decoder from", self.opts.stylegan_weights)
253
+ try:
254
+ ckpt = torch.load(self.opts.stylegan_weights)
255
+ self.decoder.load_state_dict(ckpt["g_ema"], strict=False)
256
+ self.latent_avg = ckpt['latent_avg'].to(self.device)
257
+ except FileNotFoundError:
258
+ print(f"Warning: {self.opts.stylegan_weights} not found, using uninitialized decoder")
259
+ self.latent_avg = torch.zeros(18, 512).to(self.device)
260
 
261
  def set_encoder(self):
262
  inverter = psp_encoders.Inverter(opts=self.opts, n_styles=18)