| |
| |
| |
| |
| |
|
|
| import os |
| import sys |
|
|
| sys.path.insert(1, os.path.join(sys.path[0], "..")) |
|
|
| import numpy as np |
| import matplotlib.pyplot as plt |
| from argparse import ArgumentParser |
| import torch |
|
|
| import data_utils.utils as data_utils |
| import inference.utils as inference_utils |
| import BigGAN_PyTorch.utils as biggan_utils |
| from data_utils.datasets_common import pil_loader |
| import torchvision.transforms as transforms |
| import time |
|
|
|
|
| def get_data(root_path, model, resolution, which_dataset, visualize_instance_images): |
| data_path = os.path.join(root_path, "stored_instances") |
| if model == "cc_icgan": |
| feature_extractor = "classification" |
| else: |
| feature_extractor = "selfsupervised" |
| filename = "%s_res%i_rn50_%s_kmeans_k1000_instance_features.npy" % ( |
| which_dataset, |
| resolution, |
| feature_extractor, |
| ) |
| |
| data = np.load(os.path.join(data_path, filename), allow_pickle=True).item() |
|
|
| transform_list = None |
| if visualize_instance_images: |
| |
| transform_list = transforms.Compose( |
| [data_utils.CenterCropLongEdge(), transforms.Resize(resolution)] |
| ) |
| return data, transform_list |
|
|
|
|
| def get_model(exp_name, root_path, backbone, device="cuda"): |
| parser = biggan_utils.prepare_parser() |
| parser = biggan_utils.add_sample_parser(parser) |
| parser = inference_utils.add_backbone_parser(parser) |
|
|
| args = ["--experiment_name", exp_name] |
| args += ["--base_root", root_path] |
| args += ["--model_backbone", backbone] |
|
|
| config = vars(parser.parse_args(args=args)) |
|
|
| |
| config = biggan_utils.update_config_roots(config, change_weight_folder=False) |
| generator, config = inference_utils.load_model_inference(config, device=device) |
| biggan_utils.count_parameters(generator) |
| generator.eval() |
|
|
| return generator |
|
|
|
|
| def get_conditionings(test_config, generator, data): |
| |
| z = torch.empty( |
| test_config["num_imgs_gen"] * test_config["num_conditionings_gen"], |
| generator.z_dim if config["model_backbone"] == "stylegan2" else generator.dim_z, |
| ).normal_(mean=0, std=test_config["z_var"]) |
|
|
| |
| if test_config["num_conditionings_gen"] > 1: |
| total_idxs = np.random.choice( |
| range(1000), test_config["num_conditionings_gen"], replace=False |
| ) |
|
|
| |
| all_feats, all_img_paths, all_labels = [], [], [] |
| for counter in range(test_config["num_conditionings_gen"]): |
| |
| if test_config["index"] is not None: |
| idx = test_config["index"] |
| else: |
| idx = total_idxs[counter] |
| |
| if test_config["visualize_instance_images"]: |
| all_img_paths.append(data["image_path"][idx]) |
| |
| all_feats.append( |
| torch.FloatTensor(data["instance_features"][idx : idx + 1]).repeat( |
| test_config["num_imgs_gen"], 1 |
| ) |
| ) |
| |
| if test_config["swap_target"] is not None: |
| |
| label_int = test_config["swap_target"] |
| else: |
| |
| label_int = int(data["labels"][idx]) |
| |
| labels = None |
| if test_config["model_backbone"] == "stylegan2": |
| dim_labels = 1000 |
| labels = torch.eye(dim_labels)[torch.LongTensor([label_int])].repeat( |
| test_config["num_imgs_gen"], 1 |
| ) |
| else: |
| if test_config["model"] == "cc_icgan": |
| labels = torch.LongTensor([label_int]).repeat( |
| test_config["num_imgs_gen"] |
| ) |
| all_labels.append(labels) |
| |
| all_feats = torch.cat(all_feats) |
| if all_labels[0] is not None: |
| all_labels = torch.cat(all_labels) |
| else: |
| all_labels = None |
| return z, all_feats, all_labels, all_img_paths |
|
|
|
|
| def main(test_config): |
| suffix = ( |
| "_nofeataug" |
| if test_config["resolution"] == 256 |
| and test_config["trained_dataset"] == "imagenet" |
| else "" |
| ) |
| exp_name = "%s_%s_%s_res%i%s" % ( |
| test_config["model"], |
| test_config["model_backbone"], |
| test_config["trained_dataset"], |
| test_config["resolution"], |
| suffix, |
| ) |
| device = "cuda" |
| |
| data, transform_list = get_data( |
| test_config["root_path"], |
| test_config["model"], |
| test_config["resolution"], |
| test_config["which_dataset"], |
| test_config["visualize_instance_images"], |
| ) |
|
|
| |
| generator = get_model( |
| exp_name, test_config["root_path"], test_config["model_backbone"], device=device |
| ) |
|
|
| |
| |
| |
| z, all_feats, all_labels, all_img_paths = get_conditionings( |
| test_config, generator, data |
| ) |
|
|
| |
| all_generated_images = [] |
| with torch.no_grad(): |
| num_batches = 1 + (z.shape[0]) // test_config["batch_size"] |
| for i in range(num_batches): |
| start = test_config["batch_size"] * i |
| end = min( |
| test_config["batch_size"] * i + test_config["batch_size"], z.shape[0] |
| ) |
| if all_labels is not None: |
| labels_ = all_labels[start:end].to(device) |
| else: |
| labels_ = None |
| gen_img = generator( |
| z[start:end].to(device), labels_, all_feats[start:end].to(device) |
| ) |
| if test_config["model_backbone"] == "biggan": |
| gen_img = ((gen_img * 0.5 + 0.5) * 255).int() |
| elif test_config["model_backbone"] == "stylegan2": |
| gen_img = torch.clamp((gen_img * 127.5 + 128), 0, 255).int() |
| all_generated_images.append(gen_img.cpu()) |
| all_generated_images = torch.cat(all_generated_images) |
| all_generated_images = all_generated_images.permute(0, 2, 3, 1).numpy() |
|
|
| big_plot = [] |
| for i in range(0, test_config["num_conditionings_gen"]): |
| row = [] |
| for j in range(0, test_config["num_imgs_gen"]): |
| subplot_idx = (i * test_config["num_imgs_gen"]) + j |
| row.append(all_generated_images[subplot_idx]) |
| row = np.concatenate(row, axis=1) |
| big_plot.append(row) |
| big_plot = np.concatenate(big_plot, axis=0) |
|
|
| |
| if test_config["visualize_instance_images"]: |
| all_gt_imgs = [] |
| for i in range(0, len(all_img_paths)): |
| all_gt_imgs.append( |
| np.array( |
| transform_list( |
| pil_loader( |
| os.path.join(test_config["dataset_path"], all_img_paths[i]) |
| ) |
| ) |
| ).astype(np.uint8) |
| ) |
| all_gt_imgs = np.concatenate(all_gt_imgs, axis=0) |
| white_space = ( |
| np.ones((all_gt_imgs.shape[0], 20, all_gt_imgs.shape[2])) * 255 |
| ).astype(np.uint8) |
| big_plot = np.concatenate([all_gt_imgs, white_space, big_plot], axis=1) |
|
|
| plt.figure( |
| figsize=( |
| 5 * test_config["num_imgs_gen"], |
| 5 * test_config["num_conditionings_gen"], |
| ) |
| ) |
| plt.imshow(big_plot) |
| plt.axis("off") |
|
|
| fig_path = "%s_Generations_with_InstanceDataset_%s%s%s_zvar%0.2f.png" % ( |
| exp_name, |
| test_config["which_dataset"], |
| "_index" + str(test_config["index"]) |
| if test_config["index"] is not None |
| else "", |
| "_class_idx" + str(test_config["swap_target"]) |
| if test_config["swap_target"] is not None |
| else "", |
| test_config["z_var"], |
| ) |
| plt.savefig(fig_path, dpi=600, bbox_inches="tight", pad_inches=0) |
|
|
| print("Done! Figure saved as %s" % (fig_path)) |
|
|
|
|
| if __name__ == "__main__": |
| parser = ArgumentParser( |
| description="Generate and save images using pre-trained models" |
| ) |
|
|
| parser.add_argument( |
| "--root_path", |
| type=str, |
| required=True, |
| help="Path where pretrained models + instance features have been downloaded.", |
| ) |
| parser.add_argument( |
| "--which_dataset", |
| type=str, |
| default="imagenet", |
| choices=["imagenet", "coco"], |
| help="Dataset to sample instances from.", |
| ) |
| parser.add_argument( |
| "--trained_dataset", |
| type=str, |
| default="imagenet", |
| choices=["imagenet", "coco"], |
| help="Dataset in which the model has been trained on.", |
| ) |
| parser.add_argument( |
| "--model", |
| type=str, |
| default="icgan", |
| choices=["icgan", "cc_icgan"], |
| help="Model type.", |
| ) |
| parser.add_argument( |
| "--model_backbone", |
| type=str, |
| default="biggan", |
| choices=["biggan", "stylegan2"], |
| help="Model backbone type.", |
| ) |
| parser.add_argument( |
| "--resolution", |
| type=int, |
| default=256, |
| help="Resolution to generate images with " "(default: %(default)s)", |
| ) |
| parser.add_argument( |
| "--z_var", type=float, default=1.0, help="Noise variance: %(default)s)" |
| ) |
| parser.add_argument("--batch_size", type=int, default=16, help="Batch size.") |
| parser.add_argument( |
| "--num_imgs_gen", |
| type=int, |
| default=5, |
| help="Number of images to generate with different noise vectors, " |
| "given an input conditioning.", |
| ) |
| parser.add_argument( |
| "--num_conditionings_gen", |
| type=int, |
| default=5, |
| help="Number of conditionings to generate with." |
| " Use `num_imgs_gen` to control the number of generated samples per conditioning", |
| ) |
| parser.add_argument( |
| "--index", |
| type=int, |
| default=None, |
| help="Index of the stored instance to use as conditioning [0,1000)." |
| " Mutually exclusive with `num_conditionings_gen!=1`", |
| ) |
| parser.add_argument( |
| "--swap_target", |
| type=int, |
| default=None, |
| help="For class-conditional IC-GAN, we can choose to swap the target for a different one." |
| " If swap_target=None, the original label from the instance is used. " |
| "If swap_target is in [0,1000), a specific ImageNet class is used instead.", |
| ) |
|
|
| parser.add_argument( |
| "--visualize_instance_images", |
| action="store_true", |
| default=False, |
| help="Also visualize the ground-truth image corresponding to the instance conditioning " |
| "(requires a path to the ImageNet dataset)", |
| ) |
| parser.add_argument( |
| "--dataset_path", |
| type=str, |
| default="", |
| help="Only needed if visualize_instance_images=True." |
| " Folder where to find the dataset ground-truth images.", |
| ) |
|
|
| config = vars(parser.parse_args()) |
|
|
| if config["index"] is not None and config["num_conditionings_gen"] != 1: |
| raise ValueError( |
| "If a specific feature vector (specificed by --index) " |
| "wants to be used to sample images from, num_conditionings_gen" |
| " needs to be set to 1" |
| ) |
| if config["swap_target"] is not None and config["model"] == "icgan": |
| raise ValueError( |
| 'Cannot specify a class label for IC-GAN! Only use "swap_target" with --model=cc_igan. ' |
| ) |
| main(config) |
|
|