Spaces:
Sleeping
Sleeping
| import math | |
| import sys | |
| import pickle | |
| import torch | |
| import argparse | |
| import numpy as np | |
| import torch.nn.functional as F | |
| from torch import nn | |
| from models.psp.encoders import psp_encoders | |
| from models.psp.stylegan2.model import Generator | |
| from models.hyperinverter.stylegan2_ada import Discriminator | |
| from utils.class_registry import ClassRegistry | |
| from utils.common_utils import get_keys | |
| from utils.model_utils import toogle_grad | |
| from configs.paths import DefaultPaths | |
| from argparse import Namespace | |
| from training.loggers import BaseTimer | |
| sys.path.append("./utils") | |
| methods_registry = ClassRegistry() | |
| class FSEFull(nn.Module): | |
| def __init__(self, | |
| device="cuda:0", | |
| paths=DefaultPaths, | |
| checkpoint_path=None, | |
| inverter_pth=None): | |
| super(FSEFull, self).__init__() | |
| self.opts = { | |
| "device": device, | |
| "checkpoint_path": checkpoint_path, | |
| "stylegan_size": 1024 | |
| } | |
| self.opts.update(paths) | |
| self.opts = Namespace(**self.opts) | |
| # Handle device detection and fallback to CPU if CUDA is not available | |
| try: | |
| torch.randn(1).to(device) | |
| print("Device: {}".format(device)) | |
| except Exception as e: | |
| print("Could not use device {}, {}".format(device, e)) | |
| print("Set device to CPU") | |
| device = "cpu" | |
| self.device = torch.device(device) | |
| self.inverter_pth = inverter_pth | |
| self.encoder = self.set_encoder() | |
| self.decoder = Generator(self.opts.stylegan_size, 512, 8) | |
| self.latent_avg = None | |
| self.load_disc() | |
| self.pool = torch.nn.AdaptiveAvgPool2d((256, 256)) | |
| self.load_weights() | |
| def load_disc(self): | |
| # We used the hyperinverter discriminator since it has a cars checkpoint | |
| print("Loading default Discriminator from ", self.opts.stylegan_weights_pkl) | |
| try: | |
| with open(self.opts.stylegan_weights_pkl, "rb") as f: | |
| ckpt = pickle.load(f) | |
| D_original = ckpt["D"] | |
| D_original = D_original.float() | |
| self.discriminator = Discriminator(**D_original.init_kwargs) | |
| self.discriminator.load_state_dict(D_original.state_dict()) | |
| self.discriminator.to(self.device) | |
| except FileNotFoundError: | |
| print(f"Warning: {self.opts.stylegan_weights_pkl} not found, using uninitialized discriminator") | |
| # Create a dummy discriminator | |
| self.discriminator = Discriminator(c_dim=0, img_resolution=1024, img_channels=3) | |
| self.discriminator.to(self.device) | |
| def load_disc_from_ckpt(self, ckpt): | |
| unique_keys = set(key.split(".")[0] for key in ckpt["state_dict"].keys()) | |
| if "discriminator" in unique_keys: | |
| self.discriminator.load_state_dict(get_keys(ckpt, "discriminator"), strict=True) | |
| else: | |
| print("Can not find Discriminator weights in checkpoint, leave default weights.") | |
| def load_weights(self): | |
| if self.opts.checkpoint_path != "": | |
| print(f"Loading from checkpoint: {self.opts.checkpoint_path}") | |
| try: | |
| ckpt = torch.load(self.opts.checkpoint_path, map_location="cpu") | |
| self.load_disc_from_ckpt(ckpt) | |
| self.encoder.load_state_dict(get_keys(ckpt, "encoder"), strict=True) | |
| self.inverter.load_state_dict(get_keys(ckpt, "inverter"), strict=True) | |
| except FileNotFoundError: | |
| print(f"Warning: {self.opts.checkpoint_path} not found, using uninitialized weights") | |
| else: | |
| print(f"Loading Discriminator and Inverter from Inverter checkpoint: {self.inverter_pth}") | |
| try: | |
| ckpt = torch.load(self.inverter_pth, map_location="cpu") | |
| self.load_disc_from_ckpt(ckpt) | |
| self.inverter.load_state_dict(get_keys(ckpt, "encoder"), strict=True) | |
| except FileNotFoundError: | |
| print(f"Warning: {self.inverter_pth} not found, using uninitialized weights") | |
| self.inverter = self.inverter.eval().to(self.device) | |
| toogle_grad(self.inverter, False) | |
| print("Loading Decoder from", self.opts.stylegan_weights) | |
| try: | |
| ckpt = torch.load(self.opts.stylegan_weights) | |
| self.decoder.load_state_dict(ckpt["g_ema"], strict=False) | |
| self.latent_avg = ckpt['latent_avg'].to(self.device) | |
| except FileNotFoundError: | |
| print(f"Warning: {self.opts.stylegan_weights} not found, using uninitialized decoder") | |
| self.latent_avg = torch.zeros(18, 512).to(self.device) | |
| self.decoder = self.decoder.eval().to(self.device) | |
| toogle_grad(self.decoder, False) | |
| print("Loading E4E from", self.opts.e4e_path) | |
| try: | |
| ckpt = torch.load(self.opts.e4e_path, map_location="cpu") | |
| self.e4e_encoder.load_state_dict(get_keys(ckpt, "encoder"), strict=True) | |
| except FileNotFoundError: | |
| print(f"Warning: {self.opts.e4e_path} not found, using uninitialized E4E encoder") | |
| self.e4e_encoder = self.e4e_encoder.eval().to(self.device) | |
| toogle_grad(self.e4e_encoder, False) | |
| def set_encoder(self): | |
| self.inverter = psp_encoders.Inverter(opts=self.opts, n_styles=18) | |
| self.e4e_encoder = psp_encoders.Encoder4Editing(50, "ir_se", self.opts) | |
| feat_editor = psp_encoders.ContentLayerDeepFast(6, 1024, 512) | |
| return feat_editor # trainable part | |
| def forward(self, x, return_latents=False, n_iter=1e5): | |
| x = F.interpolate(x, size=(256, 256), mode="bilinear", align_corners=False) | |
| with torch.no_grad(): | |
| w_recon, predicted_feat = self.inverter.fs_backbone(x) | |
| w_recon = w_recon + self.latent_avg | |
| _, w_feats = self.decoder( | |
| [w_recon], | |
| input_is_latent=True, | |
| return_features=True, | |
| is_stylespace=False, | |
| randomize_noise=False, | |
| early_stop=64 | |
| ) | |
| w_feat = w_feats[9] # bs x 512 x 64 x 64 | |
| fused_feat = self.inverter.fuser(torch.cat([predicted_feat, w_feat], dim=1)) | |
| delta = torch.zeros_like(fused_feat) # inversion case | |
| edited_feat = self.encoder(torch.cat([fused_feat, delta], dim=1)) | |
| feats = [None] * 9 + [edited_feat] + [None] * (17 - 9) | |
| images, _ = self.decoder( | |
| [w_recon], | |
| input_is_latent=True, | |
| return_features=True, | |
| new_features=feats, | |
| feature_scale=min(1.0, 0.0001 * n_iter), | |
| is_stylespace=False, | |
| randomize_noise=False | |
| ) | |
| if return_latents: | |
| if not self.encoder.training: | |
| fused_feat = fused_feat.cpu() | |
| predicted_feat = predicted_feat.cpu() | |
| return images, w_recon, fused_feat, predicted_feat | |
| return images | |
| class FSEInverter(nn.Module): | |
| def __init__(self, | |
| device="cuda:0", | |
| paths=DefaultPaths, | |
| checkpoint_path=None): | |
| super(FSEInverter, self).__init__() | |
| self.opts = { | |
| "device": device, | |
| "checkpoint_path": checkpoint_path, | |
| "stylegan_size": 1024 | |
| } | |
| self.opts.update(paths) | |
| self.opts = Namespace(**self.opts) | |
| # Handle device detection and fallback to CPU if CUDA is not available | |
| try: | |
| torch.randn(1).to(device) | |
| print("Device: {}".format(device)) | |
| except Exception as e: | |
| print("Could not use device {}, {}".format(device, e)) | |
| print("Set device to CPU") | |
| device = "cpu" | |
| self.device = torch.device(device) | |
| self.encoder = self.set_encoder() | |
| self.decoder = Generator(self.opts.stylegan_size, 512, 8) | |
| self.latent_avg = None | |
| self.load_disc() | |
| self.pool = torch.nn.AdaptiveAvgPool2d((256, 256)) | |
| self.load_weights() | |
| def load_disc(self): | |
| print("Loading default Discriminator from ", self.opts.stylegan_weights_pkl) | |
| # We used the hyperinverter discriminator since it has a cars checkpoint | |
| try: | |
| with open(self.opts.stylegan_weights_pkl, "rb") as f: | |
| ckpt = pickle.load(f) | |
| D_original = ckpt["D"] | |
| D_original = D_original.float() | |
| self.discriminator = Discriminator(**D_original.init_kwargs) | |
| self.discriminator.load_state_dict(D_original.state_dict()) | |
| self.discriminator.to(self.device) | |
| except FileNotFoundError: | |
| print(f"Warning: {self.opts.stylegan_weights_pkl} not found, using uninitialized discriminator") | |
| # Create a dummy discriminator | |
| self.discriminator = Discriminator(c_dim=0, img_resolution=1024, img_channels=3) | |
| self.discriminator.to(self.device) | |
| def load_disc_from_ckpt(self, ckpt): | |
| unique_keys = set(key.split(".")[0] for key in ckpt["state_dict"].keys()) | |
| if "discriminator" in unique_keys: | |
| self.discriminator.load_state_dict(get_keys(ckpt, "discriminator"), strict=True) | |
| else: | |
| print("Can not find Discriminator weights in checkpoint, leave default weights.") | |
| def load_weights(self): | |
| if self.opts.checkpoint_path != "": | |
| print("Loading from checkpoint: {}".format(self.opts.checkpoint_path)) | |
| try: | |
| ckpt = torch.load(self.opts.checkpoint_path, map_location="cpu") | |
| self.load_disc_from_ckpt(ckpt) | |
| self.encoder.load_state_dict(get_keys(ckpt, "encoder"), strict=True) | |
| except FileNotFoundError: | |
| print(f"Warning: {self.opts.checkpoint_path} not found, using uninitialized weights") | |
| print("Loading decoder from", self.opts.stylegan_weights) | |
| try: | |
| ckpt = torch.load(self.opts.stylegan_weights) | |
| self.decoder.load_state_dict(ckpt["g_ema"], strict=False) | |
| self.latent_avg = ckpt['latent_avg'].to(self.device) | |
| except FileNotFoundError: | |
| print(f"Warning: {self.opts.stylegan_weights} not found, using uninitialized decoder") | |
| self.latent_avg = torch.zeros(18, 512).to(self.device) | |
| def set_encoder(self): | |
| inverter = psp_encoders.Inverter(opts=self.opts, n_styles=18) | |
| return inverter # trainable part | |
| def forward(self, x, return_latents=False, n_iter=1e5): | |
| x = F.interpolate(x, size=(256, 256), mode="bilinear", align_corners=False) | |
| w_recon, predicted_feat = self.encoder.fs_backbone(x) | |
| w_recon = w_recon + self.latent_avg | |
| _, w_feats = self.decoder( | |
| [w_recon], | |
| input_is_latent=True, | |
| return_features=True, | |
| is_stylespace=False, | |
| randomize_noise=False, | |
| early_stop=64 | |
| ) | |
| w_feat = w_feats[9] # bs x 512 x 64 x 64 | |
| fused_feat = self.encoder.fuser(torch.cat([predicted_feat, w_feat], dim=1)) | |
| feats = [None] * 9 + [fused_feat] + [None] * (17 - 9) | |
| images, _ = self.decoder( | |
| [w_recon], | |
| input_is_latent=True, | |
| return_features=True, | |
| new_features=feats, | |
| feature_scale=min(1.0, 0.0001 * n_iter), | |
| is_stylespace=False, | |
| randomize_noise=False | |
| ) | |
| if return_latents: | |
| if not self.encoder.training: | |
| fused_feat = fused_feat.cpu() | |
| w_feat = w_feat.cpu() | |
| return images, w_recon, fused_feat, w_feat | |
| return images | |