|
|
import os |
|
|
import sys |
|
|
import json |
|
|
import wandb |
|
|
import datetime |
|
|
import omegaconf |
|
|
|
|
|
import torch |
|
|
import numpy as np |
|
|
import torch.nn.functional as F |
|
|
from collections import defaultdict |
|
|
from tqdm.auto import tqdm |
|
|
from io import BytesIO |
|
|
from PIL import Image |
|
|
from time import time |
|
|
from pathlib import Path |
|
|
from abc import abstractmethod |
|
|
|
|
|
from runners.base_runner import BaseRunner |
|
|
from utils.class_registry import ClassRegistry |
|
|
from datasets.transforms import transforms_registry |
|
|
from datasets.datasets import ImageDataset |
|
|
from datasets.loaders import InfiniteLoader |
|
|
from training.losses import disc_losses, LossBuilder |
|
|
from training.optimizers import optimizers |
|
|
from metrics.metrics import metrics_registry |
|
|
|
|
|
from training.loggers import Timer, StreamingMeans, TrainigLogger |
|
|
from utils.common_utils import tensor2im, get_keys |
|
|
from models.methods import methods_registry |
|
|
|
|
|
from models.psp.encoders.psp_encoders import ProgressiveStage |
|
|
from utils.model_utils import toogle_grad |
|
|
|
|
|
|
|
|
training_runners = ClassRegistry() |
|
|
|
|
|
|
|
|
FACE_DIRECTIONS = { |
|
|
"age": [-7, -5, 5, 7, 10], |
|
|
"fs_makeup": [5, 8, 12], |
|
|
"afro": [0.03, 0.07], |
|
|
"angry": [0.06, 0.1], |
|
|
"purple_hair": [0.07, 0.1, 0.12], |
|
|
"glasses": [-10, -7], |
|
|
"face_roundness": [-13, -7, 7, 13], |
|
|
"rotation": [-5.0, -3.0, -1.0, 1.0, 3.0, 5.0], |
|
|
"bobcut": [0.07, 0.12, 0.18], |
|
|
"bowlcut": [0.07, 0.14], |
|
|
"mohawk": [0.07, 0.10], |
|
|
"blond hair": [-8, -4, 4, 8], |
|
|
"fs_smiling": [-6, -3, 3, 6, 9] |
|
|
} |
|
|
|
|
|
|
|
|
def get_random_edit(): |
|
|
direction = np.random.choice(list(FACE_DIRECTIONS.keys())) |
|
|
strenght = np.random.choice(FACE_DIRECTIONS[direction]) |
|
|
return direction, strenght |
|
|
|
|
|
|
|
|
@training_runners.add_to_registry(name="base_training_runner") |
|
|
class BaseTrainingRunner(BaseRunner): |
|
|
def setup(self): |
|
|
self.start_step = self.config.train.start_step |
|
|
self._setup_device() |
|
|
self._setup_experiment_dir() |
|
|
|
|
|
self._setup_method() |
|
|
self._setup_logger() |
|
|
|
|
|
self._setup_metrics() |
|
|
self._setup_datasets() |
|
|
|
|
|
start_batch_size = ( |
|
|
self.config.train.bs_used_before_adv_loss |
|
|
if self.config.train.train_dis |
|
|
else self.config.model.batch_size |
|
|
) |
|
|
|
|
|
self._setup_dataloaders(start_batch_size) |
|
|
|
|
|
self._setup_latent_editor() |
|
|
self._setup_optimizers() |
|
|
self._setup_loss() |
|
|
|
|
|
def _setup_logger(self): |
|
|
self.logger = TrainigLogger(self.config) |
|
|
|
|
|
def _setup_datasets(self): |
|
|
print("Loading dataset") |
|
|
transform_dict = transforms_registry[self.config.data.transform]().get_transforms() |
|
|
self.train_dataset = ImageDataset( |
|
|
self.config.data.input_train_dir, transform_dict["train"] |
|
|
) |
|
|
|
|
|
self.test_dataset = ImageDataset( |
|
|
self.config.data.input_val_dir, transform_dict["test"] |
|
|
) |
|
|
self.paths = self.test_dataset.paths |
|
|
|
|
|
self.special_dataset = ImageDataset( |
|
|
self.config.data.special_dir, transform_dict["test"] |
|
|
) |
|
|
self.special_paths = self.special_dataset.paths |
|
|
|
|
|
def _setup_dataloaders(self, batch_size): |
|
|
self.train_dataloader = InfiniteLoader( |
|
|
self.train_dataset, |
|
|
batch_size=batch_size, |
|
|
shuffle=True, |
|
|
num_workers=self.config.model.workers, |
|
|
drop_last=True, |
|
|
is_infinite=True |
|
|
) |
|
|
self.test_dataloader = InfiniteLoader( |
|
|
self.test_dataset, |
|
|
batch_size=batch_size, |
|
|
shuffle=False, |
|
|
num_workers=self.config.model.workers, |
|
|
is_infinite=False |
|
|
) |
|
|
self.special_dataloader = InfiniteLoader( |
|
|
self.special_dataset, |
|
|
batch_size=batch_size, |
|
|
shuffle=False, |
|
|
num_workers=self.config.model.workers, |
|
|
is_infinite=False |
|
|
) |
|
|
|
|
|
def _setup_optimizers(self): |
|
|
params = list(self.method.encoder.parameters()) |
|
|
|
|
|
optimizer_args = dict( |
|
|
self.config.optimizers[self.config.train.encoder_optimizer] |
|
|
) |
|
|
optimizer_args["params"] = params |
|
|
self.encoder_optimizer = optimizers[self.config.train.encoder_optimizer]( |
|
|
**optimizer_args |
|
|
) |
|
|
|
|
|
if self.config.model.checkpoint_path != "": |
|
|
ckpt = torch.load(self.config.model.checkpoint_path, map_location="cpu") |
|
|
if "encoder_opt" in ckpt.keys(): |
|
|
self.encoder_optimizer.load_state_dict(ckpt["encoder_opt"]) |
|
|
else: |
|
|
print('WARNING, continuing training without loading encoder optimizer state!') |
|
|
|
|
|
if self.config.train.train_dis: |
|
|
params = list(self.method.discriminator.parameters()) |
|
|
optimizer_args = dict( |
|
|
self.config.optimizers[self.config.train.disc_optimizer] |
|
|
) |
|
|
optimizer_args["params"] = params |
|
|
self.disc_optimizer = optimizers[self.config.train.disc_optimizer]( |
|
|
**optimizer_args |
|
|
) |
|
|
|
|
|
if self.config.model.checkpoint_path != "": |
|
|
if "disc_opt" in ckpt.keys(): |
|
|
self.disc_optimizer.load_state_dict(ckpt["disc_opt"]) |
|
|
else: |
|
|
print('WARNING, continuing training without loading disc optimizer state!') |
|
|
|
|
|
def _setup_loss(self): |
|
|
enc_losses_dict = self.config.encoder_losses |
|
|
disc_losses_dict = self.config.disc_losses |
|
|
|
|
|
self.loss_builder = LossBuilder( |
|
|
enc_losses_dict, |
|
|
disc_losses_dict, |
|
|
self.device |
|
|
) |
|
|
|
|
|
def _setup_experiment_dir(self): |
|
|
base_root = Path(__file__).resolve().parent.parent |
|
|
num = 0 |
|
|
exp_dir = self.config.exp.exp_dir |
|
|
exp_dir_name = "{}_{}".format(self.config.exp.name, str(num).zfill(3)) |
|
|
|
|
|
exp_path = base_root / exp_dir / exp_dir_name |
|
|
while True: |
|
|
if exp_path.exists(): |
|
|
num += 1 |
|
|
exp_dir_name = "{}_{}".format(self.config.exp.name, str(num).zfill(3)) |
|
|
print(exp_path, "already exists: move to", exp_dir_name) |
|
|
else: |
|
|
break |
|
|
exp_path = base_root / exp_dir / exp_dir_name |
|
|
self.experiment_dir = str(exp_path) |
|
|
os.makedirs(self.experiment_dir) |
|
|
print("Experiment directory: {self.experiment_dir}") |
|
|
|
|
|
with open(os.path.join(self.experiment_dir, "config.yaml"), "w") as f: |
|
|
omegaconf.OmegaConf.save(config=self.config, f=f.name) |
|
|
|
|
|
with open(os.path.join(self.experiment_dir, "run_command.sh"), "w") as f: |
|
|
f.write(" ".join(sys.argv)) |
|
|
f.write("\n") |
|
|
|
|
|
self.metrics_dir = os.path.join(self.experiment_dir, "metrics") |
|
|
os.mkdir(self.metrics_dir) |
|
|
self.inference_results_dir = os.path.join( |
|
|
self.experiment_dir, "inference_results" |
|
|
) |
|
|
os.mkdir(self.inference_results_dir) |
|
|
|
|
|
def _setup_metrics(self): |
|
|
metrics_names = self.config.train.val_metrics |
|
|
|
|
|
self.metrics = [] |
|
|
for metric_name in metrics_names: |
|
|
metric_args = {} |
|
|
if hasattr(self.config.metrics, metric_name): |
|
|
metric_args = getattr(self.config.metrics, metric_name) |
|
|
self.metrics.append(metrics_registry[metric_name](**metric_args)) |
|
|
|
|
|
|
|
|
def to_train(self): |
|
|
self.method.train() |
|
|
|
|
|
def to_eval(self): |
|
|
self.method.eval() |
|
|
|
|
|
def run(self): |
|
|
iter_info = StreamingMeans() |
|
|
self.to_train() |
|
|
|
|
|
for self.global_step in range(self.start_step, self.config.train.steps + 1): |
|
|
with Timer(iter_info, "iter_train"): |
|
|
loss_dict = self.train_step() |
|
|
iter_info.update({f"iter_train/{k}": v for k, v in loss_dict.items()}) |
|
|
|
|
|
if self.global_step % self.config.train.val_step == 0 : |
|
|
with Timer(iter_info, "iter_val"): |
|
|
val_loss_dict = self.validate() |
|
|
iter_info.update({f"iter_val/{k}": v for k, v in val_loss_dict.items()}) |
|
|
|
|
|
orig_pics, method_pics, captions = self.inference_special() |
|
|
self.logger.save_validation_logs( |
|
|
orig_pics, |
|
|
method_pics, |
|
|
captions, |
|
|
special_paths=self.special_paths |
|
|
) |
|
|
|
|
|
if self.global_step % self.config.train.log_step == 0: |
|
|
self.logger.save_train_logs(iter_info, self.global_step) |
|
|
iter_info.clear() |
|
|
|
|
|
if self.global_step % self.config.train.checkpoint_step == 0: |
|
|
self.save_checkpoint() |
|
|
|
|
|
def train_step(self): |
|
|
x = next(self.train_dataloader) |
|
|
x = x.to(self.device).float() |
|
|
output = self.forward(x) |
|
|
|
|
|
enc_loss, loss_dict = self.loss_builder.encoder_loss(output["encoder"]) |
|
|
|
|
|
self.encoder_optimizer.zero_grad() |
|
|
enc_loss.backward() |
|
|
self.encoder_optimizer.step() |
|
|
loss_dict["enc_loss"] = float(enc_loss) |
|
|
|
|
|
if ( |
|
|
self.config.train.train_dis |
|
|
and self.global_step >= self.config.train.dis_train_start_step |
|
|
): |
|
|
if self.global_step == self.config.train.dis_train_start_step: |
|
|
print("Start training with discriminator") |
|
|
if self.train_dataloader.batch_size != self.config.model.batch_size: |
|
|
print(f"Changing batch size from {self.train_dataloader.batch_size} to {self.config.model.batch_size}") |
|
|
self.setup_dataloaders(self.config.model.batch_size) |
|
|
|
|
|
toogle_grad(self.method.discriminator, True) |
|
|
self.method.discriminator.train() |
|
|
|
|
|
disc_loss, disc_losses_dict = self.loss_builder.disc_loss( |
|
|
self.method.discriminator, |
|
|
output["to_disc"] |
|
|
) |
|
|
loss_dict.update(disc_losses_dict) |
|
|
|
|
|
self.disc_optimizer.zero_grad() |
|
|
disc_loss.backward() |
|
|
self.disc_optimizer.step() |
|
|
|
|
|
toogle_grad(self.method.discriminator, False) |
|
|
self.method.discriminator.eval() |
|
|
|
|
|
self.method.latent_avg = self.method.latent_avg.detach() |
|
|
|
|
|
return loss_dict |
|
|
|
|
|
def save_checkpoint(self): |
|
|
save_name = f"iteration_{self.global_step}.pt" |
|
|
checkpoint_path = os.path.join(self.experiment_dir, save_name) |
|
|
save_dict = self.get_save_dict() |
|
|
print(f"Saving checkpoint to {checkpoint_path}") |
|
|
torch.save(save_dict, checkpoint_path) |
|
|
|
|
|
options_path = os.path.join(self.experiment_dir, "save_options.json") |
|
|
save_options = {"start_step": self.global_step + 1} |
|
|
|
|
|
if self.config.exp.wandb: |
|
|
save_options.update(self.logger.wandb_logger.wandb_args) |
|
|
|
|
|
with open(options_path, "w") as f: |
|
|
json.dump(save_options, f) |
|
|
|
|
|
def get_save_dict(self): |
|
|
save_dict = { |
|
|
"state_dict": self.method.state_dict(), |
|
|
"encoder_opt": self.encoder_optimizer.state_dict(), |
|
|
"latent_avg": self.method.latent_avg |
|
|
} |
|
|
|
|
|
if self.config.train.train_dis: |
|
|
save_dict["disc_opt"] = self.disc_optimizer.state_dict() |
|
|
return save_dict |
|
|
|
|
|
@torch.inference_mode() |
|
|
def inference_special(self): |
|
|
print("Runing inversion for special") |
|
|
self.validate(special=True) |
|
|
|
|
|
captions = defaultdict(str) |
|
|
for metric in self.metrics: |
|
|
if metric.get_name() == "FID": |
|
|
continue |
|
|
|
|
|
from_data_arg = { |
|
|
"fake_data": self.val_pics_res, |
|
|
"inp_data": self.val_pics_orig, |
|
|
"paths": self.special_paths, |
|
|
} |
|
|
metric_data, _, _ = metric( |
|
|
None, None, out_path=None, from_data=from_data_arg |
|
|
) |
|
|
for path in self.special_paths: |
|
|
metric_value = metric_data[os.path.basename(path)] |
|
|
captions[path] += f"{metric.get_name()}: {metric_value:.3}\n" |
|
|
|
|
|
return self.val_pics_orig, self.val_pics_res, captions |
|
|
|
|
|
@torch.inference_mode() |
|
|
def validate(self, special=False): |
|
|
if not special: |
|
|
print("Start validating") |
|
|
|
|
|
self.to_eval() |
|
|
self.val_pics_res = [] |
|
|
self.val_pics_orig = [] |
|
|
|
|
|
if not special: |
|
|
dataloader = self.test_dataloader |
|
|
paths = self.paths |
|
|
else: |
|
|
dataloader = self.special_dataloader |
|
|
paths = self.special_paths |
|
|
|
|
|
global_i = 0 |
|
|
for input_batch in tqdm(dataloader): |
|
|
input_batch = input_batch.to(self.device).float() |
|
|
result_batch = self._run_on_batch(input_batch) |
|
|
|
|
|
for i in range(result_batch.shape[0]): |
|
|
result = tensor2im(result_batch[i]) |
|
|
img = Image.fromarray(np.array(result)).convert("RGB") |
|
|
|
|
|
memory_tmp = BytesIO() |
|
|
img.save(memory_tmp, format="jpeg") |
|
|
img = Image.open(memory_tmp).convert("RGB") |
|
|
memory_tmp.close() |
|
|
|
|
|
self.val_pics_res.append(img) |
|
|
self.val_pics_orig.append( |
|
|
Image.open(paths[global_i]).convert("RGB") |
|
|
) |
|
|
|
|
|
global_i += 1 |
|
|
|
|
|
metrics_dict = {} |
|
|
if not special: |
|
|
for metric in self.metrics: |
|
|
from_data_arg = { |
|
|
"fake_data": self.val_pics_res, |
|
|
"inp_data": self.val_pics_orig, |
|
|
"paths": paths, |
|
|
} |
|
|
_, metric_mean, _ = metric( |
|
|
None, None, out_path=None, from_data=from_data_arg |
|
|
) |
|
|
metrics_dict[metric.get_name()] = metric_mean |
|
|
|
|
|
self.to_train() |
|
|
return metrics_dict |
|
|
|
|
|
@abstractmethod |
|
|
def _run_on_batch(self, inputs): |
|
|
raise NotImplementedError() |
|
|
|
|
|
@abstractmethod |
|
|
def forward(self, x): |
|
|
raise NotImplementedError() |
|
|
|
|
|
|
|
|
@training_runners.add_to_registry(name="fse_inverter") |
|
|
class FSEInverterTrainingRunner(BaseTrainingRunner): |
|
|
def forward(self, x): |
|
|
y_hat_inv, w_inv, fused_feat, w_feat = self.method( |
|
|
x, |
|
|
return_latents=True, |
|
|
n_iter=self.global_step |
|
|
) |
|
|
|
|
|
y_hat_inv_w, _ = self.method.decoder( |
|
|
[w_inv], |
|
|
input_is_latent=True, |
|
|
is_stylespace=False, |
|
|
randomize_noise=False |
|
|
) |
|
|
|
|
|
y_hat = torch.cat([y_hat_inv, y_hat_inv_w], dim=0) |
|
|
|
|
|
output = {"encoder": {}, "to_disc": {}} |
|
|
use_adv_loss = ( |
|
|
self.config.train.train_dis |
|
|
and self.global_step >= self.config.train.dis_train_start_step |
|
|
) |
|
|
output["encoder"]["use_adv_loss"] = use_adv_loss |
|
|
if use_adv_loss: |
|
|
output["encoder"]["fake_preds"] = self.method.discriminator(y_hat, None) |
|
|
output["to_disc"]["y_hat"] = y_hat |
|
|
output["to_disc"]["x"] = x |
|
|
output["to_disc"]["step"] = self.global_step |
|
|
|
|
|
y_hat = self.method.pool(y_hat) |
|
|
x = self.method.pool(x) |
|
|
x = torch.cat([x, x], dim=0) |
|
|
|
|
|
output["encoder"]["x"] = x |
|
|
output["encoder"]["y_hat"] = y_hat |
|
|
output["encoder"]["feat_recon"] = fused_feat |
|
|
output["encoder"]["feat_real"] = w_feat |
|
|
|
|
|
return output |
|
|
|
|
|
def _run_on_batch(self, inputs): |
|
|
result_batch = self.method(inputs) |
|
|
return result_batch |
|
|
|
|
|
|
|
|
@training_runners.add_to_registry(name="fse_editor") |
|
|
class FSEEditorTrainingRunner(BaseTrainingRunner): |
|
|
def forward(self, x): |
|
|
|
|
|
y_hat_inv, w, fused_feat, w_feat = self.method(x, return_latents=True) |
|
|
|
|
|
|
|
|
with torch.no_grad(): |
|
|
|
|
|
d, strenght = get_random_edit() |
|
|
|
|
|
x_resh = F.interpolate(x, size=(256, 256), mode="bilinear", align_corners=False) |
|
|
|
|
|
w_e4e = self.method.e4e_encoder(x_resh) |
|
|
w_e4e = w_e4e + self.method.latent_avg |
|
|
x_E, fx_e4e = self.method.decoder( |
|
|
[w_e4e], |
|
|
input_is_latent=True, |
|
|
randomize_noise=False, |
|
|
return_latents=False, |
|
|
return_features=True |
|
|
) |
|
|
|
|
|
edited_w_e4e = self.get_edited_latent(w_e4e, d, [strenght]) |
|
|
if isinstance(edited_w_e4e, tuple): |
|
|
|
|
|
y_E, fy_e4e = self.method.decoder( |
|
|
edited_w_e4e, |
|
|
is_stylespace=True, |
|
|
input_is_latent=True, |
|
|
randomize_noise=False, |
|
|
return_features=True |
|
|
) |
|
|
else: |
|
|
edited_w_e4e = torch.cat(edited_w_e4e, dim=0) |
|
|
y_E, fy_e4e = self.method.decoder( |
|
|
[edited_w_e4e], |
|
|
is_stylespace=False, |
|
|
input_is_latent=True, |
|
|
randomize_noise=False, |
|
|
return_features=True |
|
|
) |
|
|
|
|
|
y_E_256 = F.interpolate(y_E, size=(256, 256), mode="bilinear", align_corners=False) |
|
|
x_E_256 = F.interpolate(x_E, size=(256, 256), mode="bilinear", align_corners=False) |
|
|
delta = fx_e4e[9] - fy_e4e[9] |
|
|
|
|
|
if d in self.config.train.disc_edits: |
|
|
x_E_256 = torch.cat([x_E_256, x_resh], dim=0) |
|
|
delta = torch.cat([delta, delta], dim=0) |
|
|
|
|
|
|
|
|
w_x_E, x_E_predicted_feats = self.method.inverter.fs_backbone(x_E_256) |
|
|
w_x_E = w_x_E + self.method.latent_avg |
|
|
|
|
|
w_x_E_edited = self.get_edited_latent(w_x_E, d, [strenght]) |
|
|
is_stylespace = isinstance(w_x_E_edited, tuple) |
|
|
if not is_stylespace: |
|
|
w_x_E_edited = [torch.cat(w_x_E_edited, dim=0)] |
|
|
|
|
|
_, x_E_w_feats = self.method.decoder( |
|
|
[w_x_E], |
|
|
input_is_latent=True, |
|
|
return_features=True, |
|
|
is_stylespace=False, |
|
|
randomize_noise=False, |
|
|
early_stop=64 |
|
|
) |
|
|
x_E_w_feat = x_E_w_feats[9] |
|
|
to_fuser = torch.cat([x_E_predicted_feats, x_E_w_feat], dim=1) |
|
|
x_E_fused_feat = self.method.inverter.fuser(to_fuser) |
|
|
|
|
|
|
|
|
to_feature_editor = torch.cat([x_E_fused_feat, delta], dim=1) |
|
|
x_E_edited_feat = self.method.encoder(to_feature_editor) |
|
|
x_E_edited_feats = [None] * 9 + [x_E_edited_feat] + [None] * (17 - 9) |
|
|
|
|
|
|
|
|
y_hat_edit, _ = self.method.decoder( |
|
|
w_x_E_edited, |
|
|
input_is_latent=True, |
|
|
new_features=x_E_edited_feats, |
|
|
feature_scale=1.0, |
|
|
is_stylespace=is_stylespace, |
|
|
randomize_noise=False |
|
|
) |
|
|
|
|
|
bs = x_resh.size(0) |
|
|
output = {"encoder": {}, "to_disc": {}} |
|
|
use_adv_loss = ( |
|
|
self.config.train.train_dis |
|
|
and self.global_step >= self.config.train.dis_train_start_step |
|
|
) |
|
|
output["encoder"]["use_adv_loss"] = use_adv_loss |
|
|
if use_adv_loss: |
|
|
if x_E_256.size(0) > x_resh.size(0): |
|
|
assert y_hat_edit.size(0) == bs * 2 |
|
|
output["encoder"]["fake_preds"] = self.method.discriminator( |
|
|
torch.cat([y_hat_inv, y_hat_edit[bs:]], dim=0), |
|
|
None |
|
|
) |
|
|
else: |
|
|
output["encoder"]["fake_preds"] = self.method.discriminator(y_hat_inv, None) |
|
|
output["to_disc"]["y_hat"] = y_hat_inv |
|
|
output["to_disc"]["x"] = x |
|
|
output["to_disc"]["step"] = self.global_step |
|
|
|
|
|
if x_E_256.size(0) > x_resh.size(0): |
|
|
assert y_hat_edit.size(0) == bs * 2 |
|
|
y_hat_edit = y_hat_edit[:bs] |
|
|
|
|
|
x = torch.cat([x, y_E], dim=0) |
|
|
y_hat = torch.cat([y_hat_inv, y_hat_edit]) |
|
|
|
|
|
y_hat = self.method.pool(y_hat) |
|
|
x = self.method.pool(x) |
|
|
output["encoder"]["x"] = x |
|
|
output["encoder"]["y_hat"] = y_hat |
|
|
|
|
|
return output |
|
|
|
|
|
def _run_on_batch(self, inputs): |
|
|
result_batch = self.method(inputs) |
|
|
return result_batch |
|
|
|