File size: 7,420 Bytes
d048c6f
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
de27f8f
d048c6f
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
import os
import cv2
import time
import torch
import numpy as np
import shadow_utils
from PIL import Image

from datetime import datetime

from diffusers.utils import make_image_grid
from transformers import DPTFeatureExtractor, DPTForDepthEstimation
from diffusers import ControlNetModel, StableDiffusionXLControlNetImg2ImgPipeline, AutoencoderKL


class Shadow_Diffusion:
    def __init__(self,
                 base_model_path="stabilityai/stable-diffusion-xl-base-1.0",
                 vae_path="madebyollin/sdxl-vae-fp16-fix",
                 controlnet_path="./checkpoints/shadow_checkpoint",
                 depth_midas_path="Intel/dpt-hybrid-midas",
                 resolution=1024,
                 precision_type=torch.float16):
        self.resolution=resolution
        self.depth_estimator = DPTForDepthEstimation.from_pretrained(depth_midas_path).to("cuda")
        self.feature_extractor = DPTFeatureExtractor.from_pretrained(depth_midas_path)
        self.controlnet = ControlNetModel.from_pretrained(controlnet_path, torch_dtype=precision_type).to("cuda")
        self.vae = AutoencoderKL.from_pretrained(vae_path, torch_dtype=precision_type).to("cuda")
        self.shadow_pipe = StableDiffusionXLControlNetImg2ImgPipeline.from_pretrained(base_model_path,
                                                                               controlnet=self.controlnet,
                                                                               vae=self.vae,
                                                                               variant="fp16",
                                                                               use_safetensors=True,
                                                                               torch_dtype=precision_type).to("cuda")
        self.shadow_pipe.enable_model_cpu_offload()

    def get_depth_map(self, image):
        image = self.feature_extractor(images=image, return_tensors="pt").pixel_values.to("cuda")
        with torch.no_grad(), torch.autocast("cuda"):
            depth_map = self.depth_estimator(image).predicted_depth

        depth_map = torch.nn.functional.interpolate(
            depth_map.unsqueeze(1),
            size=(self.resolution, self.resolution),
            mode="bicubic",
            align_corners=False,
        )
        depth_min = torch.amin(depth_map, dim=[1, 2, 3], keepdim=True)
        depth_max = torch.amax(depth_map, dim=[1, 2, 3], keepdim=True)
        depth_map = (depth_map - depth_min) / (depth_max - depth_min)
        image = torch.cat([depth_map] * 3, dim=1)
        image = image.permute(0, 2, 3, 1).cpu().numpy()[0]
        image = Image.fromarray((image * 255.0).clip(0, 255).astype(np.uint8))
        return image

    def generate_shadow(self,
                        img_pil,
                        mask_pil,
                        prompt="",
                        padding_rate=0.4,
                        denoise_strength=1.0,
                        num_inference_steps=20,
                        controlnet_conditioning_scale=0.5,
                        cfg=5.0,
                        seed=-1):

        # 1.Extract the foreground area according to the minimum bounding box
        objcut_xmin, objcut_xmax, objcut_ymin, objcut_ymax = shadow_utils.get_mask_bbox(np.array(mask_pil))
        img_pil = Image.fromarray(np.array(img_pil)[objcut_ymin:objcut_ymax, objcut_xmin:objcut_xmax])
        mask_pil = Image.fromarray(np.array(mask_pil)[objcut_ymin:objcut_ymax, objcut_xmin:objcut_xmax])

        # 2.Fill the smallest foreground area with white edges
        img_pil = shadow_utils.padding_image(img_pil, padding_rate, self.resolution, (255, 255, 255))
        mask_pil = shadow_utils.padding_image(mask_pil, padding_rate, self.resolution, (0, 0, 0))

        mask_np = np.array(mask_pil)
        depth_pil = self.get_depth_map(img_pil)
        masked_depth_pil = Image.fromarray(((1 - mask_np / 255.0) * np.zeros_like(mask_np) +
                                                 mask_np / 255.0 * np.array(depth_pil)).astype(np.uint8))
        masked_image_pil = Image.fromarray(((1 - mask_np / 255.0) * np.array([255, 255, 255])[np.newaxis, np.newaxis, :] +
                                                 mask_np / 255.0 * np.array(img_pil)).astype(np.uint8))

        generated_image = self.shadow_pipe(prompt,
                                           image=masked_image_pil,
                                           control_image=masked_depth_pil,
                                           strength=denoise_strength,
                                           num_inference_steps=num_inference_steps,
                                           generator=None if seed == -1 else torch.manual_seed(seed),
                                           controlnet_conditioning_scale=controlnet_conditioning_scale,
                                           guidance_scale=cfg
                                           ).images[0]
        composed_image = mask_np / 255.0 * np.array(img_pil) + (1 - mask_np / 255.0) * np.array(generated_image)
        composed_image = Image.fromarray(composed_image.astype(np.uint8))

        return masked_image_pil, masked_depth_pil, generated_image, composed_image


if __name__ == '__main__':

    shadowfree_img_dir = "./test_images"

    save_dir = "./test_results".format(datetime.now().strftime("%Y%m%d%H%M%S"))
    os.makedirs(save_dir, exist_ok=True)

    prompts = {"软阴影": "a product with soft natural shadow, white background",
               "硬阴影": "a product with hard natural shadow, white background",
               "悬浮阴影": "a product with floating natural shadow, white background"}
    shadow_diffuser = Shadow_Diffusion()

    for image_name in os.listdir(shadowfree_img_dir):
        results = []
        for shadow_type, prompt in prompts.items():
            org_image = cv2.imread(os.path.join(shadowfree_img_dir, image_name), cv2.IMREAD_UNCHANGED)
            org_image = cv2.cvtColor(org_image, cv2.COLOR_BGRA2RGBA)
            img_pil = Image.fromarray(org_image[..., :-1])
            mask_pil = Image.fromarray(np.repeat(org_image[..., -1:], 3, -1))


            start_time = time.time()
            masked_img, masked_depth, gen_img, compose_img = shadow_diffuser.generate_shadow(img_pil,
                                                                                             mask_pil,
                                                                                             prompt,
                                                                                             num_inference_steps=50,
                                                                                             padding_rate=0.5,
                                                                                             seed=42,
                                                                                             denoise_strength=1.0,
                                                                                             cfg=5.0,
                                                                                             controlnet_conditioning_scale=0.5,
                                                                                             )
            results.append(compose_img)
        make_image_grid(results, rows=1, cols=3).save(os.path.join(save_dir, image_name))