Spaces:
Sleeping
Sleeping
File size: 2,676 Bytes
95b1715 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 |
import sys
import torch
sys.path = ['.'] + sys.path
from argparse import ArgumentParser
from pathlib import Path
from metrics.metrics import metrics_registry
from datasets.transforms import transforms_registry
from datasets.datasets import CelebaAttributeDataset
from utils.common_utils import tensor2im, setup_seed
setup_seed(777)
def inferece_fid_editing(opts):
fid_metric = metrics_registry["fid"]()
transform = transforms_registry[opts.transforms]().get_transforms()["test"]
attr_name = opts.attr_name
attr_dataset = CelebaAttributeDataset(
opts.orig_path,
attr_name,
transform,
opts.celeba_attr_table_pth,
use_attr=not opts.attr_is_reversed
)
not_attr_dataset = CelebaAttributeDataset(
opts.synt_path,
attr_name,
transform,
opts.celeba_attr_table_pth,
use_attr=opts.attr_is_reversed
)
print(f"Percent of Images of attribute {opts.attr_name} is "
f"{len(attr_dataset) / (len(attr_dataset) + len(not_attr_dataset))}")
attr_images = []
for attr_image in attr_dataset:
img = tensor2im(attr_image).convert("RGB")
attr_images.append(img)
edited_images = []
for not_attr_image in not_attr_dataset:
img = tensor2im(not_attr_image).convert("RGB")
edited_images.append(img)
from_data_arg = {
"inp_data": attr_images,
"fake_data": edited_images,
"paths": [],
}
_, fid_value, _ = fid_metric("", "", out_path="", from_data=from_data_arg)
print(f"FID for {opts.attr_name} is {fid_value:.4f}")
if __name__ == "__main__":
parser = ArgumentParser()
parser.add_argument(
"--orig_path", type=str, help="Path to directory of original Celeba images "
)
parser.add_argument(
"--synt_path",
type=str,
help="Path to synthesized edited images",
)
parser.add_argument(
"--attr_name",
type=str,
help="Name of Celeba attribute that is added during editing.",
)
parser.add_argument(
"--attr_is_reversed",
action='store_true',
help="Means that attribute was not added but removed during editing",
)
parser.add_argument(
"--celeba_attr_table_pth",
default="CelebAMask-HQ-attribute-anno.txt",
type=str,
help="Path to celeba attributes .txt",
)
parser.add_argument(
"--transforms",
default="face_1024",
type=str,
help="Which transforms from datasets.transforms.transforms_registry should be used",
)
opts = parser.parse_args()
inferece_fid_editing(opts)
|