Spaces:
Sleeping
Sleeping
| import sys | |
| import torch | |
| sys.path = ['.'] + sys.path | |
| from argparse import ArgumentParser | |
| from pathlib import Path | |
| from metrics.metrics import metrics_registry | |
| from datasets.transforms import transforms_registry | |
| from datasets.datasets import CelebaAttributeDataset | |
| from utils.common_utils import tensor2im, setup_seed | |
| setup_seed(777) | |
| def inferece_fid_editing(opts): | |
| fid_metric = metrics_registry["fid"]() | |
| transform = transforms_registry[opts.transforms]().get_transforms()["test"] | |
| attr_name = opts.attr_name | |
| attr_dataset = CelebaAttributeDataset( | |
| opts.orig_path, | |
| attr_name, | |
| transform, | |
| opts.celeba_attr_table_pth, | |
| use_attr=not opts.attr_is_reversed | |
| ) | |
| not_attr_dataset = CelebaAttributeDataset( | |
| opts.synt_path, | |
| attr_name, | |
| transform, | |
| opts.celeba_attr_table_pth, | |
| use_attr=opts.attr_is_reversed | |
| ) | |
| print(f"Percent of Images of attribute {opts.attr_name} is " | |
| f"{len(attr_dataset) / (len(attr_dataset) + len(not_attr_dataset))}") | |
| attr_images = [] | |
| for attr_image in attr_dataset: | |
| img = tensor2im(attr_image).convert("RGB") | |
| attr_images.append(img) | |
| edited_images = [] | |
| for not_attr_image in not_attr_dataset: | |
| img = tensor2im(not_attr_image).convert("RGB") | |
| edited_images.append(img) | |
| from_data_arg = { | |
| "inp_data": attr_images, | |
| "fake_data": edited_images, | |
| "paths": [], | |
| } | |
| _, fid_value, _ = fid_metric("", "", out_path="", from_data=from_data_arg) | |
| print(f"FID for {opts.attr_name} is {fid_value:.4f}") | |
| if __name__ == "__main__": | |
| parser = ArgumentParser() | |
| parser.add_argument( | |
| "--orig_path", type=str, help="Path to directory of original Celeba images " | |
| ) | |
| parser.add_argument( | |
| "--synt_path", | |
| type=str, | |
| help="Path to synthesized edited images", | |
| ) | |
| parser.add_argument( | |
| "--attr_name", | |
| type=str, | |
| help="Name of Celeba attribute that is added during editing.", | |
| ) | |
| parser.add_argument( | |
| "--attr_is_reversed", | |
| action='store_true', | |
| help="Means that attribute was not added but removed during editing", | |
| ) | |
| parser.add_argument( | |
| "--celeba_attr_table_pth", | |
| default="CelebAMask-HQ-attribute-anno.txt", | |
| type=str, | |
| help="Path to celeba attributes .txt", | |
| ) | |
| parser.add_argument( | |
| "--transforms", | |
| default="face_1024", | |
| type=str, | |
| help="Which transforms from datasets.transforms.transforms_registry should be used", | |
| ) | |
| opts = parser.parse_args() | |
| inferece_fid_editing(opts) | |