File size: 2,676 Bytes
95b1715
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
import sys
import torch

sys.path =  ['.'] + sys.path

from argparse import ArgumentParser
from pathlib import Path
from metrics.metrics import metrics_registry
from datasets.transforms import transforms_registry
from datasets.datasets import CelebaAttributeDataset
from utils.common_utils import tensor2im, setup_seed


setup_seed(777)


def inferece_fid_editing(opts):
    fid_metric = metrics_registry["fid"]()
    transform = transforms_registry[opts.transforms]().get_transforms()["test"]

    attr_name = opts.attr_name

    attr_dataset = CelebaAttributeDataset(
        opts.orig_path, 
        attr_name,
        transform, 
        opts.celeba_attr_table_pth, 
        use_attr=not opts.attr_is_reversed
    )

    not_attr_dataset = CelebaAttributeDataset(
        opts.synt_path, 
        attr_name,
        transform, 
        opts.celeba_attr_table_pth, 
        use_attr=opts.attr_is_reversed
        )

    print(f"Percent of Images of attribute {opts.attr_name} is "
                  f"{len(attr_dataset) / (len(attr_dataset) + len(not_attr_dataset))}")

    attr_images = []
    for attr_image in attr_dataset:
        img = tensor2im(attr_image).convert("RGB")
        attr_images.append(img)

    edited_images = []
    for not_attr_image in not_attr_dataset:
        img = tensor2im(not_attr_image).convert("RGB")
        edited_images.append(img)

    from_data_arg = {
        "inp_data": attr_images,
        "fake_data": edited_images,
        "paths": [],
    }

    _, fid_value, _ = fid_metric("", "", out_path="", from_data=from_data_arg)
    print(f"FID for {opts.attr_name} is {fid_value:.4f}")


if __name__ == "__main__":
    parser = ArgumentParser()
    parser.add_argument(
        "--orig_path", type=str, help="Path to directory of original Celeba images "
    )
    parser.add_argument(
        "--synt_path",
        type=str,
        help="Path to synthesized edited images",
    )
    parser.add_argument(
        "--attr_name",
        type=str,
        help="Name of Celeba attribute that is added during editing.",
    )
    parser.add_argument(
        "--attr_is_reversed",
        action='store_true',
        help="Means that attribute was not added but removed during editing",
    )
    parser.add_argument(
        "--celeba_attr_table_pth",
        default="CelebAMask-HQ-attribute-anno.txt",
        type=str,
        help="Path to celeba attributes .txt",
    )
    parser.add_argument(
        "--transforms",
        default="face_1024",
        type=str,
        help="Which transforms from datasets.transforms.transforms_registry should be used",
    )

    opts = parser.parse_args()
    inferece_fid_editing(opts)