Smile_Changer / scripts /fid_calculation.py
LogicGoInfotechSpaces's picture
Bundle StyleFeatureEditor code packages in Space to fix ModuleNotFoundError
95b1715
import sys
import torch
sys.path = ['.'] + sys.path
from argparse import ArgumentParser
from pathlib import Path
from metrics.metrics import metrics_registry
from datasets.transforms import transforms_registry
from datasets.datasets import CelebaAttributeDataset
from utils.common_utils import tensor2im, setup_seed
setup_seed(777)
def inferece_fid_editing(opts):
fid_metric = metrics_registry["fid"]()
transform = transforms_registry[opts.transforms]().get_transforms()["test"]
attr_name = opts.attr_name
attr_dataset = CelebaAttributeDataset(
opts.orig_path,
attr_name,
transform,
opts.celeba_attr_table_pth,
use_attr=not opts.attr_is_reversed
)
not_attr_dataset = CelebaAttributeDataset(
opts.synt_path,
attr_name,
transform,
opts.celeba_attr_table_pth,
use_attr=opts.attr_is_reversed
)
print(f"Percent of Images of attribute {opts.attr_name} is "
f"{len(attr_dataset) / (len(attr_dataset) + len(not_attr_dataset))}")
attr_images = []
for attr_image in attr_dataset:
img = tensor2im(attr_image).convert("RGB")
attr_images.append(img)
edited_images = []
for not_attr_image in not_attr_dataset:
img = tensor2im(not_attr_image).convert("RGB")
edited_images.append(img)
from_data_arg = {
"inp_data": attr_images,
"fake_data": edited_images,
"paths": [],
}
_, fid_value, _ = fid_metric("", "", out_path="", from_data=from_data_arg)
print(f"FID for {opts.attr_name} is {fid_value:.4f}")
if __name__ == "__main__":
parser = ArgumentParser()
parser.add_argument(
"--orig_path", type=str, help="Path to directory of original Celeba images "
)
parser.add_argument(
"--synt_path",
type=str,
help="Path to synthesized edited images",
)
parser.add_argument(
"--attr_name",
type=str,
help="Name of Celeba attribute that is added during editing.",
)
parser.add_argument(
"--attr_is_reversed",
action='store_true',
help="Means that attribute was not added but removed during editing",
)
parser.add_argument(
"--celeba_attr_table_pth",
default="CelebAMask-HQ-attribute-anno.txt",
type=str,
help="Path to celeba attributes .txt",
)
parser.add_argument(
"--transforms",
default="face_1024",
type=str,
help="Which transforms from datasets.transforms.transforms_registry should be used",
)
opts = parser.parse_args()
inferece_fid_editing(opts)