Smile_Changer / models /methods.py
LogicGoInfotechSpaces's picture
Add error handling for missing checkpoint files in load_weights methods
bfac43d
import math
import sys
import pickle
import torch
import argparse
import numpy as np
import torch.nn.functional as F
from torch import nn
from models.psp.encoders import psp_encoders
from models.psp.stylegan2.model import Generator
from models.hyperinverter.stylegan2_ada import Discriminator
from utils.class_registry import ClassRegistry
from utils.common_utils import get_keys
from utils.model_utils import toogle_grad
from configs.paths import DefaultPaths
from argparse import Namespace
from training.loggers import BaseTimer
sys.path.append("./utils")
methods_registry = ClassRegistry()
@methods_registry.add_to_registry("fse_full", stop_args=("self", "checkpoint_path"))
class FSEFull(nn.Module):
def __init__(self,
device="cuda:0",
paths=DefaultPaths,
checkpoint_path=None,
inverter_pth=None):
super(FSEFull, self).__init__()
self.opts = {
"device": device,
"checkpoint_path": checkpoint_path,
"stylegan_size": 1024
}
self.opts.update(paths)
self.opts = Namespace(**self.opts)
# Handle device detection and fallback to CPU if CUDA is not available
try:
torch.randn(1).to(device)
print("Device: {}".format(device))
except Exception as e:
print("Could not use device {}, {}".format(device, e))
print("Set device to CPU")
device = "cpu"
self.device = torch.device(device)
self.inverter_pth = inverter_pth
self.encoder = self.set_encoder()
self.decoder = Generator(self.opts.stylegan_size, 512, 8)
self.latent_avg = None
self.load_disc()
self.pool = torch.nn.AdaptiveAvgPool2d((256, 256))
self.load_weights()
def load_disc(self):
# We used the hyperinverter discriminator since it has a cars checkpoint
print("Loading default Discriminator from ", self.opts.stylegan_weights_pkl)
try:
with open(self.opts.stylegan_weights_pkl, "rb") as f:
ckpt = pickle.load(f)
D_original = ckpt["D"]
D_original = D_original.float()
self.discriminator = Discriminator(**D_original.init_kwargs)
self.discriminator.load_state_dict(D_original.state_dict())
self.discriminator.to(self.device)
except FileNotFoundError:
print(f"Warning: {self.opts.stylegan_weights_pkl} not found, using uninitialized discriminator")
# Create a dummy discriminator
self.discriminator = Discriminator(c_dim=0, img_resolution=1024, img_channels=3)
self.discriminator.to(self.device)
def load_disc_from_ckpt(self, ckpt):
unique_keys = set(key.split(".")[0] for key in ckpt["state_dict"].keys())
if "discriminator" in unique_keys:
self.discriminator.load_state_dict(get_keys(ckpt, "discriminator"), strict=True)
else:
print("Can not find Discriminator weights in checkpoint, leave default weights.")
def load_weights(self):
if self.opts.checkpoint_path != "":
print(f"Loading from checkpoint: {self.opts.checkpoint_path}")
try:
ckpt = torch.load(self.opts.checkpoint_path, map_location="cpu")
self.load_disc_from_ckpt(ckpt)
self.encoder.load_state_dict(get_keys(ckpt, "encoder"), strict=True)
self.inverter.load_state_dict(get_keys(ckpt, "inverter"), strict=True)
except FileNotFoundError:
print(f"Warning: {self.opts.checkpoint_path} not found, using uninitialized weights")
else:
print(f"Loading Discriminator and Inverter from Inverter checkpoint: {self.inverter_pth}")
try:
ckpt = torch.load(self.inverter_pth, map_location="cpu")
self.load_disc_from_ckpt(ckpt)
self.inverter.load_state_dict(get_keys(ckpt, "encoder"), strict=True)
except FileNotFoundError:
print(f"Warning: {self.inverter_pth} not found, using uninitialized weights")
self.inverter = self.inverter.eval().to(self.device)
toogle_grad(self.inverter, False)
print("Loading Decoder from", self.opts.stylegan_weights)
try:
ckpt = torch.load(self.opts.stylegan_weights)
self.decoder.load_state_dict(ckpt["g_ema"], strict=False)
self.latent_avg = ckpt['latent_avg'].to(self.device)
except FileNotFoundError:
print(f"Warning: {self.opts.stylegan_weights} not found, using uninitialized decoder")
self.latent_avg = torch.zeros(18, 512).to(self.device)
self.decoder = self.decoder.eval().to(self.device)
toogle_grad(self.decoder, False)
print("Loading E4E from", self.opts.e4e_path)
try:
ckpt = torch.load(self.opts.e4e_path, map_location="cpu")
self.e4e_encoder.load_state_dict(get_keys(ckpt, "encoder"), strict=True)
except FileNotFoundError:
print(f"Warning: {self.opts.e4e_path} not found, using uninitialized E4E encoder")
self.e4e_encoder = self.e4e_encoder.eval().to(self.device)
toogle_grad(self.e4e_encoder, False)
def set_encoder(self):
self.inverter = psp_encoders.Inverter(opts=self.opts, n_styles=18)
self.e4e_encoder = psp_encoders.Encoder4Editing(50, "ir_se", self.opts)
feat_editor = psp_encoders.ContentLayerDeepFast(6, 1024, 512)
return feat_editor # trainable part
def forward(self, x, return_latents=False, n_iter=1e5):
x = F.interpolate(x, size=(256, 256), mode="bilinear", align_corners=False)
with torch.no_grad():
w_recon, predicted_feat = self.inverter.fs_backbone(x)
w_recon = w_recon + self.latent_avg
_, w_feats = self.decoder(
[w_recon],
input_is_latent=True,
return_features=True,
is_stylespace=False,
randomize_noise=False,
early_stop=64
)
w_feat = w_feats[9] # bs x 512 x 64 x 64
fused_feat = self.inverter.fuser(torch.cat([predicted_feat, w_feat], dim=1))
delta = torch.zeros_like(fused_feat) # inversion case
edited_feat = self.encoder(torch.cat([fused_feat, delta], dim=1))
feats = [None] * 9 + [edited_feat] + [None] * (17 - 9)
images, _ = self.decoder(
[w_recon],
input_is_latent=True,
return_features=True,
new_features=feats,
feature_scale=min(1.0, 0.0001 * n_iter),
is_stylespace=False,
randomize_noise=False
)
if return_latents:
if not self.encoder.training:
fused_feat = fused_feat.cpu()
predicted_feat = predicted_feat.cpu()
return images, w_recon, fused_feat, predicted_feat
return images
@methods_registry.add_to_registry("fse_inverter", stop_args=("self", "checkpoint_path"))
class FSEInverter(nn.Module):
def __init__(self,
device="cuda:0",
paths=DefaultPaths,
checkpoint_path=None):
super(FSEInverter, self).__init__()
self.opts = {
"device": device,
"checkpoint_path": checkpoint_path,
"stylegan_size": 1024
}
self.opts.update(paths)
self.opts = Namespace(**self.opts)
# Handle device detection and fallback to CPU if CUDA is not available
try:
torch.randn(1).to(device)
print("Device: {}".format(device))
except Exception as e:
print("Could not use device {}, {}".format(device, e))
print("Set device to CPU")
device = "cpu"
self.device = torch.device(device)
self.encoder = self.set_encoder()
self.decoder = Generator(self.opts.stylegan_size, 512, 8)
self.latent_avg = None
self.load_disc()
self.pool = torch.nn.AdaptiveAvgPool2d((256, 256))
self.load_weights()
def load_disc(self):
print("Loading default Discriminator from ", self.opts.stylegan_weights_pkl)
# We used the hyperinverter discriminator since it has a cars checkpoint
try:
with open(self.opts.stylegan_weights_pkl, "rb") as f:
ckpt = pickle.load(f)
D_original = ckpt["D"]
D_original = D_original.float()
self.discriminator = Discriminator(**D_original.init_kwargs)
self.discriminator.load_state_dict(D_original.state_dict())
self.discriminator.to(self.device)
except FileNotFoundError:
print(f"Warning: {self.opts.stylegan_weights_pkl} not found, using uninitialized discriminator")
# Create a dummy discriminator
self.discriminator = Discriminator(c_dim=0, img_resolution=1024, img_channels=3)
self.discriminator.to(self.device)
def load_disc_from_ckpt(self, ckpt):
unique_keys = set(key.split(".")[0] for key in ckpt["state_dict"].keys())
if "discriminator" in unique_keys:
self.discriminator.load_state_dict(get_keys(ckpt, "discriminator"), strict=True)
else:
print("Can not find Discriminator weights in checkpoint, leave default weights.")
def load_weights(self):
if self.opts.checkpoint_path != "":
print("Loading from checkpoint: {}".format(self.opts.checkpoint_path))
try:
ckpt = torch.load(self.opts.checkpoint_path, map_location="cpu")
self.load_disc_from_ckpt(ckpt)
self.encoder.load_state_dict(get_keys(ckpt, "encoder"), strict=True)
except FileNotFoundError:
print(f"Warning: {self.opts.checkpoint_path} not found, using uninitialized weights")
print("Loading decoder from", self.opts.stylegan_weights)
try:
ckpt = torch.load(self.opts.stylegan_weights)
self.decoder.load_state_dict(ckpt["g_ema"], strict=False)
self.latent_avg = ckpt['latent_avg'].to(self.device)
except FileNotFoundError:
print(f"Warning: {self.opts.stylegan_weights} not found, using uninitialized decoder")
self.latent_avg = torch.zeros(18, 512).to(self.device)
def set_encoder(self):
inverter = psp_encoders.Inverter(opts=self.opts, n_styles=18)
return inverter # trainable part
def forward(self, x, return_latents=False, n_iter=1e5):
x = F.interpolate(x, size=(256, 256), mode="bilinear", align_corners=False)
w_recon, predicted_feat = self.encoder.fs_backbone(x)
w_recon = w_recon + self.latent_avg
_, w_feats = self.decoder(
[w_recon],
input_is_latent=True,
return_features=True,
is_stylespace=False,
randomize_noise=False,
early_stop=64
)
w_feat = w_feats[9] # bs x 512 x 64 x 64
fused_feat = self.encoder.fuser(torch.cat([predicted_feat, w_feat], dim=1))
feats = [None] * 9 + [fused_feat] + [None] * (17 - 9)
images, _ = self.decoder(
[w_recon],
input_is_latent=True,
return_features=True,
new_features=feats,
feature_scale=min(1.0, 0.0001 * n_iter),
is_stylespace=False,
randomize_noise=False
)
if return_latents:
if not self.encoder.training:
fused_feat = fused_feat.cpu()
w_feat = w_feat.cpu()
return images, w_recon, fused_feat, w_feat
return images