Spaces:
Sleeping
Sleeping
Commit
·
51cbd8f
1
Parent(s):
b885690
Add error handling for missing StyleGAN discriminator weights in both FSEFull and FSEInverter classes
Browse files- models/methods.py +30 -18
models/methods.py
CHANGED
|
@@ -53,15 +53,21 @@ class FSEFull(nn.Module):
|
|
| 53 |
def load_disc(self):
|
| 54 |
# We used the hyperinverter discriminator since it has a cars checkpoint
|
| 55 |
print("Loading default Discriminator from ", self.opts.stylegan_weights_pkl)
|
| 56 |
-
|
| 57 |
-
|
| 58 |
-
|
| 59 |
-
|
| 60 |
-
|
| 61 |
-
|
| 62 |
-
|
| 63 |
-
|
| 64 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 65 |
|
| 66 |
def load_disc_from_ckpt(self, ckpt):
|
| 67 |
unique_keys = set(key.split(".")[0] for key in ckpt["state_dict"].keys())
|
|
@@ -177,15 +183,21 @@ class FSEInverter(nn.Module):
|
|
| 177 |
def load_disc(self):
|
| 178 |
print("Loading default Discriminator from ", self.opts.stylegan_weights_pkl)
|
| 179 |
# We used the hyperinverter discriminator since it has a cars checkpoint
|
| 180 |
-
|
| 181 |
-
|
| 182 |
-
|
| 183 |
-
|
| 184 |
-
|
| 185 |
-
|
| 186 |
-
|
| 187 |
-
|
| 188 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 189 |
|
| 190 |
def load_disc_from_ckpt(self, ckpt):
|
| 191 |
unique_keys = set(key.split(".")[0] for key in ckpt["state_dict"].keys())
|
|
|
|
| 53 |
def load_disc(self):
|
| 54 |
# We used the hyperinverter discriminator since it has a cars checkpoint
|
| 55 |
print("Loading default Discriminator from ", self.opts.stylegan_weights_pkl)
|
| 56 |
+
try:
|
| 57 |
+
with open(self.opts.stylegan_weights_pkl, "rb") as f:
|
| 58 |
+
ckpt = pickle.load(f)
|
| 59 |
+
|
| 60 |
+
D_original = ckpt["D"]
|
| 61 |
+
D_original = D_original.float()
|
| 62 |
+
|
| 63 |
+
self.discriminator = Discriminator(**D_original.init_kwargs)
|
| 64 |
+
self.discriminator.load_state_dict(D_original.state_dict())
|
| 65 |
+
self.discriminator.to(self.device)
|
| 66 |
+
except FileNotFoundError:
|
| 67 |
+
print(f"Warning: {self.opts.stylegan_weights_pkl} not found, using uninitialized discriminator")
|
| 68 |
+
# Create a dummy discriminator
|
| 69 |
+
self.discriminator = Discriminator(c_dim=0, img_resolution=1024, img_channels=3)
|
| 70 |
+
self.discriminator.to(self.device)
|
| 71 |
|
| 72 |
def load_disc_from_ckpt(self, ckpt):
|
| 73 |
unique_keys = set(key.split(".")[0] for key in ckpt["state_dict"].keys())
|
|
|
|
| 183 |
def load_disc(self):
|
| 184 |
print("Loading default Discriminator from ", self.opts.stylegan_weights_pkl)
|
| 185 |
# We used the hyperinverter discriminator since it has a cars checkpoint
|
| 186 |
+
try:
|
| 187 |
+
with open(self.opts.stylegan_weights_pkl, "rb") as f:
|
| 188 |
+
ckpt = pickle.load(f)
|
| 189 |
+
|
| 190 |
+
D_original = ckpt["D"]
|
| 191 |
+
D_original = D_original.float()
|
| 192 |
+
|
| 193 |
+
self.discriminator = Discriminator(**D_original.init_kwargs)
|
| 194 |
+
self.discriminator.load_state_dict(D_original.state_dict())
|
| 195 |
+
self.discriminator.to(self.device)
|
| 196 |
+
except FileNotFoundError:
|
| 197 |
+
print(f"Warning: {self.opts.stylegan_weights_pkl} not found, using uninitialized discriminator")
|
| 198 |
+
# Create a dummy discriminator
|
| 199 |
+
self.discriminator = Discriminator(c_dim=0, img_resolution=1024, img_channels=3)
|
| 200 |
+
self.discriminator.to(self.device)
|
| 201 |
|
| 202 |
def load_disc_from_ckpt(self, ckpt):
|
| 203 |
unique_keys = set(key.split(".")[0] for key in ckpt["state_dict"].keys())
|