| | import datetime
|
| | import numpy as np
|
| | import os
|
| | from PIL import Image
|
| | import pytest
|
| | from pytest import fixture
|
| | from typing import Tuple, List
|
| |
|
| | from cv2 import imread, cvtColor, COLOR_BGR2RGB
|
| | from skimage.metrics import structural_similarity as ssim
|
| |
|
| |
|
| | """
|
| | This test suite compares images in 2 directories by file name
|
| | The directories are specified by the command line arguments --baseline_dir and --test_dir
|
| |
|
| | """
|
| |
|
| |
|
| | def ssim_score(img0: np.ndarray, img1: np.ndarray) -> Tuple[float, np.ndarray]:
|
| | score, diff = ssim(img0, img1, channel_axis=-1, full=True)
|
| |
|
| | diff = (diff * 255).astype("uint8")
|
| | return score, diff
|
| |
|
| |
|
| | METRICS = {"ssim": ssim_score}
|
| | METRICS_PASS_THRESHOLD = {"ssim": 0.95}
|
| |
|
| |
|
| | class TestCompareImageMetrics:
|
| | @fixture(scope="class")
|
| | def test_file_names(self, args_pytest):
|
| | test_dir = args_pytest['test_dir']
|
| | fnames = self.gather_file_basenames(test_dir)
|
| | yield fnames
|
| | del fnames
|
| |
|
| | @fixture(scope="class", autouse=True)
|
| | def teardown(self, args_pytest):
|
| | yield
|
| |
|
| |
|
| | baseline_dir = args_pytest['baseline_dir']
|
| | test_dir = args_pytest['test_dir']
|
| | img_output_dir = args_pytest['img_output_dir']
|
| | metrics_file = args_pytest['metrics_file']
|
| |
|
| | grid_dir = os.path.join(img_output_dir, "grid")
|
| | os.makedirs(grid_dir, exist_ok=True)
|
| |
|
| | for metric_dir in METRICS.keys():
|
| | metric_path = os.path.join(img_output_dir, metric_dir)
|
| | for file in os.listdir(metric_path):
|
| | if file.endswith(".png"):
|
| | score = self.lookup_score_from_fname(file, metrics_file)
|
| | image_file_list = []
|
| | image_file_list.append([
|
| | os.path.join(baseline_dir, file),
|
| | os.path.join(test_dir, file),
|
| | os.path.join(metric_path, file)
|
| | ])
|
| |
|
| | image_list = [[Image.open(file) for file in files] for files in image_file_list]
|
| | grid = self.image_grid(image_list)
|
| | grid.save(os.path.join(grid_dir, f"{metric_dir}_{score:.3f}_{file}"))
|
| |
|
| |
|
| | @fixture()
|
| | def fname(self, baseline_fname):
|
| | yield baseline_fname
|
| | del baseline_fname
|
| |
|
| | def test_directories_not_empty(self, args_pytest):
|
| | baseline_dir = args_pytest['baseline_dir']
|
| | test_dir = args_pytest['test_dir']
|
| | assert len(os.listdir(baseline_dir)) != 0, f"Baseline directory {baseline_dir} is empty"
|
| | assert len(os.listdir(test_dir)) != 0, f"Test directory {test_dir} is empty"
|
| |
|
| | def test_dir_has_all_matching_metadata(self, fname, test_file_names, args_pytest):
|
| |
|
| | baseline_file_path = os.path.join(args_pytest['baseline_dir'], fname)
|
| | file_paths = [os.path.join(args_pytest['test_dir'], f) for f in test_file_names]
|
| | file_match = self.find_file_match(baseline_file_path, file_paths)
|
| | assert file_match is not None, f"Could not find a file in {args_pytest['test_dir']} with matching metadata to {baseline_file_path}"
|
| |
|
| |
|
| |
|
| | @pytest.mark.parametrize("metric", METRICS.keys())
|
| | def test_pipeline_compare(
|
| | self,
|
| | args_pytest,
|
| | fname,
|
| | test_file_names,
|
| | metric,
|
| | ):
|
| | baseline_dir = args_pytest['baseline_dir']
|
| | test_dir = args_pytest['test_dir']
|
| | metrics_output_file = args_pytest['metrics_file']
|
| | img_output_dir = args_pytest['img_output_dir']
|
| |
|
| | baseline_file_path = os.path.join(baseline_dir, fname)
|
| |
|
| |
|
| | file_paths = [os.path.join(test_dir, f) for f in test_file_names]
|
| | test_file = self.find_file_match(baseline_file_path, file_paths)
|
| |
|
| |
|
| | sample_baseline = self.read_img(baseline_file_path)
|
| | sample_secondary = self.read_img(test_file)
|
| |
|
| | score, metric_img = METRICS[metric](sample_baseline, sample_secondary)
|
| | metric_status = score > METRICS_PASS_THRESHOLD[metric]
|
| |
|
| |
|
| | with open(metrics_output_file, 'a') as f:
|
| | run_info = os.path.splitext(fname)[0]
|
| | metric_status_str = "PASS ✅" if metric_status else "FAIL ❌"
|
| | date_str = datetime.datetime.now().strftime("%Y-%m-%d %H:%M:%S")
|
| | f.write(f"| {date_str} | {run_info} | {metric} | {metric_status_str} | {score} | \n")
|
| |
|
| |
|
| | metric_img_dir = os.path.join(img_output_dir, metric)
|
| | os.makedirs(metric_img_dir, exist_ok=True)
|
| | output_filename = f'{fname}'
|
| | Image.fromarray(metric_img).save(os.path.join(metric_img_dir, output_filename))
|
| |
|
| | assert score > METRICS_PASS_THRESHOLD[metric]
|
| |
|
| | def read_img(self, filename: str) -> np.ndarray:
|
| | cvImg = imread(filename)
|
| | cvImg = cvtColor(cvImg, COLOR_BGR2RGB)
|
| | return cvImg
|
| |
|
| | def image_grid(self, img_list: list[list[Image.Image]]):
|
| |
|
| |
|
| | rows = len(img_list)
|
| | cols = len(img_list[0])
|
| |
|
| | w, h = img_list[0][0].size
|
| | grid = Image.new('RGB', size=(cols*w, rows*h))
|
| |
|
| | for i, row in enumerate(img_list):
|
| | for j, img in enumerate(row):
|
| | grid.paste(img, box=(j*w, i*h))
|
| | return grid
|
| |
|
| | def lookup_score_from_fname(self,
|
| | fname: str,
|
| | metrics_output_file: str
|
| | ) -> float:
|
| | fname_basestr = os.path.splitext(fname)[0]
|
| | with open(metrics_output_file, 'r') as f:
|
| | for line in f:
|
| | if fname_basestr in line:
|
| | score = float(line.split('|')[5])
|
| | return score
|
| | raise ValueError(f"Could not find score for {fname} in {metrics_output_file}")
|
| |
|
| | def gather_file_basenames(self, directory: str):
|
| | files = []
|
| | for file in os.listdir(directory):
|
| | if file.endswith(".png"):
|
| | files.append(file)
|
| | return files
|
| |
|
| | def read_file_prompt(self, fname:str) -> str:
|
| |
|
| | img = Image.open(fname)
|
| | img.load()
|
| | return img.info['prompt']
|
| |
|
| | def find_file_match(self, baseline_file: str, file_paths: List[str]):
|
| |
|
| | baseline_prompt = self.read_file_prompt(baseline_file)
|
| |
|
| |
|
| | if baseline_prompt is None or baseline_prompt == "":
|
| | return None
|
| |
|
| |
|
| |
|
| |
|
| |
|
| | basename = os.path.basename(baseline_file)
|
| | file_path_basenames = [os.path.basename(f) for f in file_paths]
|
| | if basename in file_path_basenames:
|
| | match_index = file_path_basenames.index(basename)
|
| | file_paths.insert(0, file_paths.pop(match_index))
|
| |
|
| | for f in file_paths:
|
| | test_file_prompt = self.read_file_prompt(f)
|
| | if baseline_prompt == test_file_prompt:
|
| | return f
|
| |
|