LogicGoInfotechSpaces commited on
Commit
51cbd8f
·
1 Parent(s): b885690

Add error handling for missing StyleGAN discriminator weights in both FSEFull and FSEInverter classes

Browse files
Files changed (1) hide show
  1. models/methods.py +30 -18
models/methods.py CHANGED
@@ -53,15 +53,21 @@ class FSEFull(nn.Module):
53
  def load_disc(self):
54
  # We used the hyperinverter discriminator since it has a cars checkpoint
55
  print("Loading default Discriminator from ", self.opts.stylegan_weights_pkl)
56
- with open(self.opts.stylegan_weights_pkl, "rb") as f:
57
- ckpt = pickle.load(f)
58
-
59
- D_original = ckpt["D"]
60
- D_original = D_original.float()
61
-
62
- self.discriminator = Discriminator(**D_original.init_kwargs)
63
- self.discriminator.load_state_dict(D_original.state_dict())
64
- self.discriminator.to(self.device)
 
 
 
 
 
 
65
 
66
  def load_disc_from_ckpt(self, ckpt):
67
  unique_keys = set(key.split(".")[0] for key in ckpt["state_dict"].keys())
@@ -177,15 +183,21 @@ class FSEInverter(nn.Module):
177
  def load_disc(self):
178
  print("Loading default Discriminator from ", self.opts.stylegan_weights_pkl)
179
  # We used the hyperinverter discriminator since it has a cars checkpoint
180
- with open(self.opts.stylegan_weights_pkl, "rb") as f:
181
- ckpt = pickle.load(f)
182
-
183
- D_original = ckpt["D"]
184
- D_original = D_original.float()
185
-
186
- self.discriminator = Discriminator(**D_original.init_kwargs)
187
- self.discriminator.load_state_dict(D_original.state_dict())
188
- self.discriminator.to(self.device)
 
 
 
 
 
 
189
 
190
  def load_disc_from_ckpt(self, ckpt):
191
  unique_keys = set(key.split(".")[0] for key in ckpt["state_dict"].keys())
 
53
  def load_disc(self):
54
  # We used the hyperinverter discriminator since it has a cars checkpoint
55
  print("Loading default Discriminator from ", self.opts.stylegan_weights_pkl)
56
+ try:
57
+ with open(self.opts.stylegan_weights_pkl, "rb") as f:
58
+ ckpt = pickle.load(f)
59
+
60
+ D_original = ckpt["D"]
61
+ D_original = D_original.float()
62
+
63
+ self.discriminator = Discriminator(**D_original.init_kwargs)
64
+ self.discriminator.load_state_dict(D_original.state_dict())
65
+ self.discriminator.to(self.device)
66
+ except FileNotFoundError:
67
+ print(f"Warning: {self.opts.stylegan_weights_pkl} not found, using uninitialized discriminator")
68
+ # Create a dummy discriminator
69
+ self.discriminator = Discriminator(c_dim=0, img_resolution=1024, img_channels=3)
70
+ self.discriminator.to(self.device)
71
 
72
  def load_disc_from_ckpt(self, ckpt):
73
  unique_keys = set(key.split(".")[0] for key in ckpt["state_dict"].keys())
 
183
  def load_disc(self):
184
  print("Loading default Discriminator from ", self.opts.stylegan_weights_pkl)
185
  # We used the hyperinverter discriminator since it has a cars checkpoint
186
+ try:
187
+ with open(self.opts.stylegan_weights_pkl, "rb") as f:
188
+ ckpt = pickle.load(f)
189
+
190
+ D_original = ckpt["D"]
191
+ D_original = D_original.float()
192
+
193
+ self.discriminator = Discriminator(**D_original.init_kwargs)
194
+ self.discriminator.load_state_dict(D_original.state_dict())
195
+ self.discriminator.to(self.device)
196
+ except FileNotFoundError:
197
+ print(f"Warning: {self.opts.stylegan_weights_pkl} not found, using uninitialized discriminator")
198
+ # Create a dummy discriminator
199
+ self.discriminator = Discriminator(c_dim=0, img_resolution=1024, img_channels=3)
200
+ self.discriminator.to(self.device)
201
 
202
  def load_disc_from_ckpt(self, ckpt):
203
  unique_keys = set(key.split(".")[0] for key in ckpt["state_dict"].keys())