| import os |
| import shutil |
| from io import BytesIO |
|
|
| import numpy as np |
| import pytest |
| import requests |
| from PIL import Image |
|
|
| from controlnet_aux import (CannyDetector, ContentShuffleDetector, HEDdetector, |
| LeresDetector, LineartAnimeDetector, |
| LineartDetector, MediapipeFaceDetector, |
| MidasDetector, MLSDdetector, NormalBaeDetector, |
| OpenposeDetector, PidiNetDetector, SamDetector, |
| ZoeDetector, TileDetector) |
|
|
| OUTPUT_DIR = "tests/outputs" |
|
|
| def output(name, img): |
| img.save(os.path.join(OUTPUT_DIR, "{:s}.png".format(name))) |
|
|
| def common(name, processor, img): |
| output(name, processor(img)) |
| output(name + "_pil_np", Image.fromarray(processor(img, output_type="np"))) |
| output(name + "_np_np", Image.fromarray(processor(np.array(img, dtype=np.uint8), output_type="np"))) |
| output(name + "_np_pil", processor(np.array(img, dtype=np.uint8), output_type="pil")) |
| output(name + "_scaled", processor(img, detect_resolution=640, image_resolution=768)) |
|
|
| def return_pil(name, processor, img): |
| output(name + "_pil_false", Image.fromarray(processor(img, return_pil=False))) |
| output(name + "_pil_true", processor(img, return_pil=True)) |
|
|
| @pytest.fixture(scope="module") |
| def img(): |
| if os.path.exists(OUTPUT_DIR): |
| shutil.rmtree(OUTPUT_DIR) |
| os.mkdir(OUTPUT_DIR) |
| url = "https://huggingface.co/lllyasviel/sd-controlnet-openpose/resolve/main/images/pose.png" |
| response = requests.get(url) |
| img = Image.open(BytesIO(response.content)).convert("RGB").resize((512, 512)) |
| return img |
|
|
| def test_canny(img): |
| canny = CannyDetector() |
| common("canny", canny, img) |
| output("canny_img", canny(img=img)) |
|
|
| def test_hed(img): |
| hed = HEDdetector.from_pretrained("lllyasviel/Annotators") |
| common("hed", hed, img) |
| return_pil("hed", hed, img) |
| output("hed_safe", hed(img, safe=True)) |
| output("hed_scribble", hed(img, scribble=True)) |
|
|
| def test_leres(img): |
| leres = LeresDetector.from_pretrained("lllyasviel/Annotators") |
| common("leres", leres, img) |
| output("leres_boost", leres(img, boost=True)) |
|
|
| def test_lineart(img): |
| lineart = LineartDetector.from_pretrained("lllyasviel/Annotators") |
| common("lineart", lineart, img) |
| return_pil("lineart", lineart, img) |
| output("lineart_coarse", lineart(img, coarse=True)) |
|
|
| def test_lineart_anime(img): |
| lineart_anime = LineartAnimeDetector.from_pretrained("lllyasviel/Annotators") |
| common("lineart_anime", lineart_anime, img) |
| return_pil("lineart_anime", lineart_anime, img) |
|
|
| def test_mediapipe_face(img): |
| mediapipe = MediapipeFaceDetector() |
| common("mediapipe", mediapipe, img) |
| output("mediapipe_image", mediapipe(image=img)) |
|
|
| def test_midas(img): |
| midas = MidasDetector.from_pretrained("lllyasviel/Annotators") |
| common("midas", midas, img) |
| output("midas_normal", midas(img, depth_and_normal=True)[1]) |
|
|
| def test_mlsd(img): |
| mlsd = MLSDdetector.from_pretrained("lllyasviel/Annotators") |
| common("mlsd", mlsd, img) |
| return_pil("mlsd", mlsd, img) |
|
|
| def test_normalbae(img): |
| normal_bae = NormalBaeDetector.from_pretrained("lllyasviel/Annotators") |
| common("normal_bae", normal_bae, img) |
| return_pil("normal_bae", normal_bae, img) |
|
|
| def test_openpose(img): |
| openpose = OpenposeDetector.from_pretrained("lllyasviel/Annotators") |
| common("openpose", openpose, img) |
| return_pil("openpose", openpose, img) |
| output("openpose_hand_and_face_false", openpose(img, hand_and_face=False)) |
| output("openpose_hand_and_face_true", openpose(img, hand_and_face=True)) |
| output("openpose_face", openpose(img, include_body=True, include_hand=False, include_face=True)) |
| output("openpose_faceonly", openpose(img, include_body=False, include_hand=False, include_face=True)) |
| output("openpose_full", openpose(img, include_body=True, include_hand=True, include_face=True)) |
| output("openpose_hand", openpose(img, include_body=True, include_hand=True, include_face=False)) |
|
|
| def test_pidi(img): |
| pidi = PidiNetDetector.from_pretrained("lllyasviel/Annotators") |
| common("pidi", pidi, img) |
| return_pil("pidi", pidi, img) |
| output("pidi_safe", pidi(img, safe=True)) |
| output("pidi_scribble", pidi(img, scribble=True)) |
|
|
| def test_sam(img): |
| sam = SamDetector.from_pretrained("ybelkada/segment-anything", subfolder="checkpoints") |
| common("sam", sam, img) |
| output("sam_image", sam(image=img)) |
|
|
| def test_shuffle(img): |
| shuffle = ContentShuffleDetector() |
| common("shuffle", shuffle, img) |
| return_pil("shuffle", shuffle, img) |
|
|
| def test_zoe(img): |
| zoe = ZoeDetector.from_pretrained("lllyasviel/Annotators") |
| common("zoe", zoe, img) |
|
|
| def test_tile(img): |
| tile = TileDetector() |
| common("tile", tile, img) |
| output("tile_img", tile(img)) |