Spaces:
Sleeping
Sleeping
Commit
·
bfac43d
1
Parent(s):
372a1fe
Add error handling for missing checkpoint files in load_weights methods
Browse files- models/methods.py +40 -18
models/methods.py
CHANGED
|
@@ -88,29 +88,44 @@ class FSEFull(nn.Module):
|
|
| 88 |
def load_weights(self):
|
| 89 |
if self.opts.checkpoint_path != "":
|
| 90 |
print(f"Loading from checkpoint: {self.opts.checkpoint_path}")
|
| 91 |
-
|
| 92 |
-
|
| 93 |
-
|
| 94 |
-
|
|
|
|
|
|
|
|
|
|
| 95 |
else:
|
| 96 |
print(f"Loading Discriminator and Inverter from Inverter checkpoint: {self.inverter_pth}")
|
| 97 |
-
|
| 98 |
-
|
| 99 |
-
|
|
|
|
|
|
|
|
|
|
| 100 |
|
| 101 |
self.inverter = self.inverter.eval().to(self.device)
|
| 102 |
toogle_grad(self.inverter, False)
|
| 103 |
|
| 104 |
print("Loading Decoder from", self.opts.stylegan_weights)
|
| 105 |
-
|
| 106 |
-
|
| 107 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 108 |
self.decoder = self.decoder.eval().to(self.device)
|
| 109 |
toogle_grad(self.decoder, False)
|
| 110 |
|
| 111 |
print("Loading E4E from", self.opts.e4e_path)
|
| 112 |
-
|
| 113 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
| 114 |
self.e4e_encoder = self.e4e_encoder.eval().to(self.device)
|
| 115 |
toogle_grad(self.e4e_encoder, False)
|
| 116 |
|
|
@@ -227,14 +242,21 @@ class FSEInverter(nn.Module):
|
|
| 227 |
def load_weights(self):
|
| 228 |
if self.opts.checkpoint_path != "":
|
| 229 |
print("Loading from checkpoint: {}".format(self.opts.checkpoint_path))
|
| 230 |
-
|
| 231 |
-
|
| 232 |
-
|
|
|
|
|
|
|
|
|
|
| 233 |
|
| 234 |
print("Loading decoder from", self.opts.stylegan_weights)
|
| 235 |
-
|
| 236 |
-
|
| 237 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
| 238 |
|
| 239 |
def set_encoder(self):
|
| 240 |
inverter = psp_encoders.Inverter(opts=self.opts, n_styles=18)
|
|
|
|
| 88 |
def load_weights(self):
|
| 89 |
if self.opts.checkpoint_path != "":
|
| 90 |
print(f"Loading from checkpoint: {self.opts.checkpoint_path}")
|
| 91 |
+
try:
|
| 92 |
+
ckpt = torch.load(self.opts.checkpoint_path, map_location="cpu")
|
| 93 |
+
self.load_disc_from_ckpt(ckpt)
|
| 94 |
+
self.encoder.load_state_dict(get_keys(ckpt, "encoder"), strict=True)
|
| 95 |
+
self.inverter.load_state_dict(get_keys(ckpt, "inverter"), strict=True)
|
| 96 |
+
except FileNotFoundError:
|
| 97 |
+
print(f"Warning: {self.opts.checkpoint_path} not found, using uninitialized weights")
|
| 98 |
else:
|
| 99 |
print(f"Loading Discriminator and Inverter from Inverter checkpoint: {self.inverter_pth}")
|
| 100 |
+
try:
|
| 101 |
+
ckpt = torch.load(self.inverter_pth, map_location="cpu")
|
| 102 |
+
self.load_disc_from_ckpt(ckpt)
|
| 103 |
+
self.inverter.load_state_dict(get_keys(ckpt, "encoder"), strict=True)
|
| 104 |
+
except FileNotFoundError:
|
| 105 |
+
print(f"Warning: {self.inverter_pth} not found, using uninitialized weights")
|
| 106 |
|
| 107 |
self.inverter = self.inverter.eval().to(self.device)
|
| 108 |
toogle_grad(self.inverter, False)
|
| 109 |
|
| 110 |
print("Loading Decoder from", self.opts.stylegan_weights)
|
| 111 |
+
try:
|
| 112 |
+
ckpt = torch.load(self.opts.stylegan_weights)
|
| 113 |
+
self.decoder.load_state_dict(ckpt["g_ema"], strict=False)
|
| 114 |
+
self.latent_avg = ckpt['latent_avg'].to(self.device)
|
| 115 |
+
except FileNotFoundError:
|
| 116 |
+
print(f"Warning: {self.opts.stylegan_weights} not found, using uninitialized decoder")
|
| 117 |
+
self.latent_avg = torch.zeros(18, 512).to(self.device)
|
| 118 |
+
|
| 119 |
self.decoder = self.decoder.eval().to(self.device)
|
| 120 |
toogle_grad(self.decoder, False)
|
| 121 |
|
| 122 |
print("Loading E4E from", self.opts.e4e_path)
|
| 123 |
+
try:
|
| 124 |
+
ckpt = torch.load(self.opts.e4e_path, map_location="cpu")
|
| 125 |
+
self.e4e_encoder.load_state_dict(get_keys(ckpt, "encoder"), strict=True)
|
| 126 |
+
except FileNotFoundError:
|
| 127 |
+
print(f"Warning: {self.opts.e4e_path} not found, using uninitialized E4E encoder")
|
| 128 |
+
|
| 129 |
self.e4e_encoder = self.e4e_encoder.eval().to(self.device)
|
| 130 |
toogle_grad(self.e4e_encoder, False)
|
| 131 |
|
|
|
|
| 242 |
def load_weights(self):
|
| 243 |
if self.opts.checkpoint_path != "":
|
| 244 |
print("Loading from checkpoint: {}".format(self.opts.checkpoint_path))
|
| 245 |
+
try:
|
| 246 |
+
ckpt = torch.load(self.opts.checkpoint_path, map_location="cpu")
|
| 247 |
+
self.load_disc_from_ckpt(ckpt)
|
| 248 |
+
self.encoder.load_state_dict(get_keys(ckpt, "encoder"), strict=True)
|
| 249 |
+
except FileNotFoundError:
|
| 250 |
+
print(f"Warning: {self.opts.checkpoint_path} not found, using uninitialized weights")
|
| 251 |
|
| 252 |
print("Loading decoder from", self.opts.stylegan_weights)
|
| 253 |
+
try:
|
| 254 |
+
ckpt = torch.load(self.opts.stylegan_weights)
|
| 255 |
+
self.decoder.load_state_dict(ckpt["g_ema"], strict=False)
|
| 256 |
+
self.latent_avg = ckpt['latent_avg'].to(self.device)
|
| 257 |
+
except FileNotFoundError:
|
| 258 |
+
print(f"Warning: {self.opts.stylegan_weights} not found, using uninitialized decoder")
|
| 259 |
+
self.latent_avg = torch.zeros(18, 512).to(self.device)
|
| 260 |
|
| 261 |
def set_encoder(self):
|
| 262 |
inverter = psp_encoders.Inverter(opts=self.opts, n_styles=18)
|