| | import io |
| | import os |
| | import cv2 |
| | import base64 |
| | import functools |
| | from typing import Dict, Any, List, Union, Literal, Optional |
| | from pathlib import Path |
| | import datetime |
| | from enum import Enum |
| | import numpy as np |
| | import pytest |
| |
|
| | import requests |
| | from PIL import Image |
| |
|
| |
|
| | def disable_in_cq(func): |
| | """Skips the decorated test func in CQ run.""" |
| | @functools.wraps(func) |
| | def wrapped_func(*args, **kwargs): |
| | if APITestTemplate.is_cq_run: |
| | pytest.skip() |
| | return func(*args, **kwargs) |
| | return wrapped_func |
| |
|
| |
|
| | PayloadOverrideType = Dict[str, Any] |
| |
|
| | timestamp = datetime.datetime.now().strftime("%Y%m%d-%H%M%S") |
| | test_result_dir = Path(__file__).parent / "results" / f"test_result_{timestamp}" |
| | test_expectation_dir = Path(__file__).parent / "expectations" |
| | os.makedirs(test_expectation_dir, exist_ok=True) |
| | resource_dir = Path(__file__).parents[1] / "images" |
| |
|
| |
|
| | def get_dest_dir(): |
| | if APITestTemplate.is_set_expectation_run: |
| | return test_expectation_dir |
| | else: |
| | return test_result_dir |
| |
|
| |
|
| | def save_base64(base64img: str, dest: Path): |
| | Image.open(io.BytesIO(base64.b64decode(base64img.split(",", 1)[0]))).save(dest) |
| |
|
| |
|
| | def read_image(img_path: Path) -> str: |
| | img = cv2.imread(str(img_path)) |
| | _, bytes = cv2.imencode(".png", img) |
| | encoded_image = base64.b64encode(bytes).decode("utf-8") |
| | return encoded_image |
| |
|
| |
|
| | def read_image_dir(img_dir: Path, suffixes=('.png', '.jpg', '.jpeg', '.webp')) -> List[str]: |
| | """Try read all images in given img_dir.""" |
| | img_dir = str(img_dir) |
| | images = [] |
| | for filename in os.listdir(img_dir): |
| | if filename.endswith(suffixes): |
| | img_path = os.path.join(img_dir, filename) |
| | try: |
| | images.append(read_image(img_path)) |
| | except IOError: |
| | print(f"Error opening {img_path}") |
| | return images |
| |
|
| |
|
| | girl_img = read_image(resource_dir / "1girl.png") |
| | mask_img = read_image(resource_dir / "mask.png") |
| | mask_small_img = read_image(resource_dir / "mask_small.png") |
| | portrait_imgs = read_image_dir(resource_dir / "portrait") |
| | realistic_girl_face_img = portrait_imgs[0] |
| |
|
| |
|
| | general_negative_prompt = """ |
| | (worst quality:2), (low quality:2), (normal quality:2), lowres, normal quality, |
| | ((monochrome)), ((grayscale)), skin spots, acnes, skin blemishes, age spot, |
| | backlight,(ugly:1.331), (duplicate:1.331), (morbid:1.21), (mutilated:1.21), |
| | (tranny:1.331), mutated hands, (poorly drawn hands:1.331), blurry, (bad anatomy:1.21), |
| | (bad proportions:1.331), extra limbs, (missing arms:1.331), (extra legs:1.331), |
| | (fused fingers:1.61051), (too many fingers:1.61051), (unclear eyes:1.331), bad hands, |
| | missing fingers, extra digit, bad body, easynegative, nsfw""" |
| |
|
| | class StableDiffusionVersion(Enum): |
| | """The version family of stable diffusion model.""" |
| |
|
| | UNKNOWN = 0 |
| | SD1x = 1 |
| | SD2x = 2 |
| | SDXL = 3 |
| |
|
| |
|
| | sd_version = StableDiffusionVersion( |
| | int(os.environ.get("CONTROLNET_TEST_SD_VERSION", StableDiffusionVersion.SD1x.value)) |
| | ) |
| |
|
| | is_full_coverage = os.environ.get("CONTROLNET_TEST_FULL_COVERAGE", None) is not None |
| |
|
| |
|
| | class APITestTemplate: |
| | is_set_expectation_run = os.environ.get("CONTROLNET_SET_EXP", "True") == "True" |
| | is_cq_run = os.environ.get("FORGE_CQ_TEST", "False") == "True" |
| | BASE_URL = "http://localhost:7860/" |
| |
|
| | def __init__( |
| | self, |
| | name: str, |
| | gen_type: Union[Literal["img2img"], Literal["txt2img"]], |
| | payload_overrides: PayloadOverrideType, |
| | unit_overrides: Union[PayloadOverrideType, List[PayloadOverrideType]], |
| | input_image: Optional[str] = None, |
| | ): |
| | self.name = name |
| | self.url = APITestTemplate.BASE_URL + "sdapi/v1/" + gen_type |
| | self.payload = { |
| | **(txt2img_payload if gen_type == "txt2img" else img2img_payload), |
| | **payload_overrides, |
| | } |
| | if gen_type == "img2img" and input_image is not None: |
| | self.payload["init_images"] = [input_image] |
| |
|
| | |
| | if "steps" not in payload_overrides and APITestTemplate.is_cq_run: |
| | self.payload["steps"] = 3 |
| |
|
| | unit_overrides = ( |
| | unit_overrides |
| | if isinstance(unit_overrides, (list, tuple)) |
| | else [unit_overrides] |
| | ) |
| | self.payload["alwayson_scripts"]["ControlNet"]["args"] = [ |
| | { |
| | **default_unit, |
| | **unit_override, |
| | **({"image": input_image} if gen_type == "txt2img" and input_image is not None else {}), |
| | } |
| | for unit_override in unit_overrides |
| | ] |
| | self.active_unit_count = len(unit_overrides) |
| |
|
| | def exec(self, *args, **kwargs) -> bool: |
| | if APITestTemplate.is_cq_run: |
| | return self.exec_cq(*args, **kwargs) |
| | else: |
| | return self.exec_local(*args, **kwargs) |
| |
|
| | def exec_cq(self, expected_output_num: Optional[int] = None, *args, **kwargs) -> bool: |
| | """Execute test in CQ environment.""" |
| | res = requests.post(url=self.url, json=self.payload) |
| | if res.status_code != 200: |
| | print(f"Unexpected status code {res.status_code}") |
| | return False |
| |
|
| | response = res.json() |
| | if "images" not in response: |
| | print(response.keys()) |
| | return False |
| |
|
| | if expected_output_num is None: |
| | expected_output_num = self.payload["n_iter"] * self.payload["batch_size"] + self.active_unit_count |
| |
|
| | if len(response["images"]) != expected_output_num: |
| | print(f"{len(response['images'])} != {expected_output_num}") |
| | return False |
| |
|
| | return True |
| |
|
| | def exec_local(self, result_only: bool = True, *args, **kwargs) -> bool: |
| | """Execute test in local environment.""" |
| | if not APITestTemplate.is_set_expectation_run: |
| | os.makedirs(test_result_dir, exist_ok=True) |
| |
|
| | failed = False |
| |
|
| | response = requests.post(url=self.url, json=self.payload).json() |
| | if "images" not in response: |
| | print(response.keys()) |
| | return False |
| |
|
| | dest_dir = get_dest_dir() |
| | results = response["images"][:1] if result_only else response["images"] |
| | for i, base64image in enumerate(results): |
| | img_file_name = f"{self.name}_{i}.png" |
| | save_base64(base64image, dest_dir / img_file_name) |
| |
|
| | if not APITestTemplate.is_set_expectation_run: |
| | try: |
| | img1 = cv2.imread(os.path.join(test_expectation_dir, img_file_name)) |
| | img2 = cv2.imread(os.path.join(test_result_dir, img_file_name)) |
| | except Exception as e: |
| | print(f"Get exception reading imgs: {e}") |
| | failed = True |
| | continue |
| |
|
| | if img1 is None: |
| | print(f"Warn: No expectation file found {img_file_name}.") |
| | continue |
| |
|
| | if not expect_same_image( |
| | img1, |
| | img2, |
| | diff_img_path=str(test_result_dir |
| | / img_file_name.replace(".png", "_diff.png")), |
| | ): |
| | failed = True |
| | return not failed |
| |
|
| |
|
| | def expect_same_image(img1, img2, diff_img_path: str) -> bool: |
| | |
| | diff = cv2.absdiff(img1, img2) |
| |
|
| | |
| | threshold = 30 |
| | diff_highlighted = np.where(diff > threshold, 255, 0).astype(np.uint8) |
| |
|
| | |
| | similar = np.allclose(img1, img2, rtol=0.5, atol=1) |
| | if not similar: |
| | |
| | cv2.imwrite(diff_img_path, diff_highlighted) |
| |
|
| | matching_pixels = np.isclose(img1, img2, rtol=0.5, atol=1) |
| | similar_in_general = (matching_pixels.sum() / matching_pixels.size) >= 0.95 |
| | return similar_in_general |
| |
|
| |
|
| | def get_model(model_name: str) -> str: |
| | """ Find an available model with specified model name.""" |
| | if model_name.lower() == "none": |
| | return "None" |
| |
|
| | r = requests.get(APITestTemplate.BASE_URL + "controlnet/model_list") |
| | result = r.json() |
| | if "model_list" not in result: |
| | raise ValueError("No model available") |
| |
|
| | candidates = [ |
| | model |
| | for model in result["model_list"] |
| | if model_name.lower() in model.lower() |
| | ] |
| |
|
| | if not candidates: |
| | raise ValueError("No suitable model available") |
| |
|
| | return candidates[0] |
| |
|
| |
|
| | default_unit = { |
| | "control_mode": 0, |
| | "enabled": True, |
| | "guidance_end": 1, |
| | "guidance_start": 0, |
| | "pixel_perfect": True, |
| | "processor_res": 512, |
| | "resize_mode": 1, |
| | "threshold_a": 64, |
| | "threshold_b": 64, |
| | "weight": 1, |
| | "module": "canny", |
| | "model": get_model("sd15_canny"), |
| | } |
| |
|
| | img2img_payload = { |
| | "batch_size": 1, |
| | "cfg_scale": 7, |
| | "height": 768, |
| | "width": 512, |
| | "n_iter": 1, |
| | "steps": 10, |
| | "sampler_name": "Euler a", |
| | "prompt": "(masterpiece: 1.3), (highres: 1.3), best quality,", |
| | "negative_prompt": "", |
| | "seed": 42, |
| | "seed_enable_extras": False, |
| | "seed_resize_from_h": 0, |
| | "seed_resize_from_w": 0, |
| | "subseed": -1, |
| | "subseed_strength": 0, |
| | "override_settings": {}, |
| | "override_settings_restore_afterwards": False, |
| | "do_not_save_grid": False, |
| | "do_not_save_samples": False, |
| | "s_churn": 0, |
| | "s_min_uncond": 0, |
| | "s_noise": 1, |
| | "s_tmax": None, |
| | "s_tmin": 0, |
| | "script_args": [], |
| | "script_name": None, |
| | "styles": [], |
| | "alwayson_scripts": {"ControlNet": {"args": [default_unit]}}, |
| | "denoising_strength": 0.75, |
| | "initial_noise_multiplier": 1, |
| | "inpaint_full_res": 0, |
| | "inpaint_full_res_padding": 32, |
| | "inpainting_fill": 1, |
| | "inpainting_mask_invert": 0, |
| | "mask_blur_x": 4, |
| | "mask_blur_y": 4, |
| | "mask_blur": 4, |
| | "resize_mode": 0, |
| | } |
| |
|
| | txt2img_payload = { |
| | "alwayson_scripts": {"ControlNet": {"args": [default_unit]}}, |
| | "batch_size": 1, |
| | "cfg_scale": 7, |
| | "comments": {}, |
| | "disable_extra_networks": False, |
| | "do_not_save_grid": False, |
| | "do_not_save_samples": False, |
| | "enable_hr": False, |
| | "height": 768, |
| | "hr_negative_prompt": "", |
| | "hr_prompt": "", |
| | "hr_resize_x": 0, |
| | "hr_resize_y": 0, |
| | "hr_scale": 2, |
| | "hr_second_pass_steps": 0, |
| | "hr_upscaler": "Latent", |
| | "n_iter": 1, |
| | "negative_prompt": "", |
| | "override_settings": {}, |
| | "override_settings_restore_afterwards": True, |
| | "prompt": "(masterpiece: 1.3), (highres: 1.3), best quality,", |
| | "restore_faces": False, |
| | "s_churn": 0.0, |
| | "s_min_uncond": 0, |
| | "s_noise": 1.0, |
| | "s_tmax": None, |
| | "s_tmin": 0.0, |
| | "sampler_name": "Euler a", |
| | "script_args": [], |
| | "script_name": None, |
| | "seed": 42, |
| | "seed_enable_extras": True, |
| | "seed_resize_from_h": -1, |
| | "seed_resize_from_w": -1, |
| | "steps": 10, |
| | "styles": [], |
| | "subseed": -1, |
| | "subseed_strength": 0, |
| | "tiling": False, |
| | "width": 512, |
| | } |
| |
|