GeorgeQi commited on
Commit
d048c6f
·
verified ·
1 Parent(s): 5c65631

Upload 3 files

Browse files
app_shadow_generation_gradio.py ADDED
@@ -0,0 +1,56 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import gradio as gr
3
+ import numpy as np
4
+ from PIL import Image
5
+ from generate_shadow_main import Shadow_Diffusion
6
+ shadow_diffuser = Shadow_Diffusion()
7
+ shadow_prompt_dict = {"软阴影": "a product with soft natural shadow, white background",
8
+ "硬阴影": "a product with hard natural shadow, white background",
9
+ "悬浮阴影": "a product with floating natural shadow, white background",}
10
+ def generate_shadow(image_rgba, shadow_type, padding_rate, denoise_strength, num_inference_steps,
11
+ controlnet_conditioning_scale, cfg, seed):
12
+ img_pil = Image.fromarray(image_rgba[..., :-1])
13
+ mask_pil = Image.fromarray(np.repeat(image_rgba[..., -1:], 3, -1))
14
+ shadow_prompt = shadow_prompt_dict[shadow_type]
15
+ print(padding_rate, denoise_strength, num_inference_steps, controlnet_conditioning_scale, cfg, seed)
16
+ _, _, _, shadow_result_pil = shadow_diffuser.generate_shadow(img_pil,
17
+ mask_pil,
18
+ shadow_prompt,
19
+ padding_rate=float(padding_rate),
20
+ denoise_strength=float(denoise_strength),
21
+ num_inference_steps=int(num_inference_steps),
22
+ controlnet_conditioning_scale=float(controlnet_conditioning_scale),
23
+ cfg=float(cfg),
24
+ seed=int(seed))
25
+ return np.array(shadow_result_pil)
26
+
27
+
28
+ with gr.Blocks() as demo:
29
+ gr.Markdown("# 💡基于Diffusion的白底阴影生成 \n"
30
+ "请确保上传带有透明通道的RGBA图像作为输入")
31
+ with gr.Row():
32
+ with gr.Column():
33
+ rgba = gr.Image(image_mode="RGBA", label="输入商品图(RGBA)")
34
+ gr.Examples(label="示例图片", inputs=[rgba],
35
+ examples=[os.path.join("./test_images", n) for n in os.listdir("./test_images")])
36
+ with gr.Column():
37
+ shadow_output = gr.Image(image_mode="RGB", label="阴影生成结果")
38
+
39
+ with gr.Row():
40
+ shadow_type = gr.Radio(["软阴影", "硬阴影", "悬浮阴影"], value="硬阴影", label="阴影类型")
41
+ generate_btn = gr.Button(value="生成阴影")
42
+
43
+ with gr.Accordion(label="其他参数>>", open=False) as sku_accordion:
44
+ padding_rate = gr.Slider(0, 0.99, value=0.4, step=0.1, label="白边填充比例")
45
+ denoise_strength = gr.Slider(0, 0.99, value=1.0, step=0.01, label="去噪程度")
46
+ num_inference_steps = gr.Slider(10, 50, value=20, step=1, label="推理步数")
47
+ controlnet_conditioning_scale = gr.Slider(0, 1.0, value=0.5, step=0.01, label="控制强度(ControlNet)")
48
+ step_num = gr.Slider(1, 50, value=20, step=1, label="推理步数")
49
+ cfg = gr.Slider(0, 20, value=5, step=0.5, label="CFG")
50
+ seed = gr.Slider(-1, 99999999999999, value=42, step=0.01, label="随机种子")
51
+
52
+ generate_btn.click(generate_shadow, inputs=[rgba, shadow_type, padding_rate, denoise_strength, num_inference_steps,
53
+ controlnet_conditioning_scale, cfg, seed], outputs=shadow_output)
54
+
55
+ demo.queue().launch(server_name='[::]', share=True)
56
+
generate_shadow_main.py ADDED
@@ -0,0 +1,131 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import cv2
3
+ import time
4
+ import torch
5
+ import numpy as np
6
+ import shadow_utils
7
+ from PIL import Image
8
+
9
+ from datetime import datetime
10
+
11
+ from diffusers.utils import make_image_grid
12
+ from transformers import DPTFeatureExtractor, DPTForDepthEstimation
13
+ from diffusers import ControlNetModel, StableDiffusionXLControlNetImg2ImgPipeline, AutoencoderKL
14
+
15
+
16
+ class Shadow_Diffusion:
17
+ def __init__(self,
18
+ base_model_path="stabilityai/stable-diffusion-xl-base-1.0",
19
+ vae_path="madebyollin/sdxl-vae-fp16-fix",
20
+ controlnet_path="./checkpoints/shadow",
21
+ depth_midas_path="Intel/dpt-hybrid-midas",
22
+ resolution=1024,
23
+ precision_type=torch.float16):
24
+ self.resolution=resolution
25
+ self.depth_estimator = DPTForDepthEstimation.from_pretrained(depth_midas_path).to("cuda")
26
+ self.feature_extractor = DPTFeatureExtractor.from_pretrained(depth_midas_path)
27
+ self.controlnet = ControlNetModel.from_pretrained(controlnet_path, torch_dtype=precision_type).to("cuda")
28
+ self.vae = AutoencoderKL.from_pretrained(vae_path, torch_dtype=precision_type).to("cuda")
29
+ self.shadow_pipe = StableDiffusionXLControlNetImg2ImgPipeline.from_pretrained(base_model_path,
30
+ controlnet=self.controlnet,
31
+ vae=self.vae,
32
+ variant="fp16",
33
+ use_safetensors=True,
34
+ torch_dtype=precision_type).to("cuda")
35
+ self.shadow_pipe.enable_model_cpu_offload()
36
+
37
+ def get_depth_map(self, image):
38
+ image = self.feature_extractor(images=image, return_tensors="pt").pixel_values.to("cuda")
39
+ with torch.no_grad(), torch.autocast("cuda"):
40
+ depth_map = self.depth_estimator(image).predicted_depth
41
+
42
+ depth_map = torch.nn.functional.interpolate(
43
+ depth_map.unsqueeze(1),
44
+ size=(self.resolution, self.resolution),
45
+ mode="bicubic",
46
+ align_corners=False,
47
+ )
48
+ depth_min = torch.amin(depth_map, dim=[1, 2, 3], keepdim=True)
49
+ depth_max = torch.amax(depth_map, dim=[1, 2, 3], keepdim=True)
50
+ depth_map = (depth_map - depth_min) / (depth_max - depth_min)
51
+ image = torch.cat([depth_map] * 3, dim=1)
52
+ image = image.permute(0, 2, 3, 1).cpu().numpy()[0]
53
+ image = Image.fromarray((image * 255.0).clip(0, 255).astype(np.uint8))
54
+ return image
55
+
56
+ def generate_shadow(self,
57
+ img_pil,
58
+ mask_pil,
59
+ prompt="",
60
+ padding_rate=0.4,
61
+ denoise_strength=1.0,
62
+ num_inference_steps=20,
63
+ controlnet_conditioning_scale=0.5,
64
+ cfg=5.0,
65
+ seed=-1):
66
+
67
+ # 1.Extract the foreground area according to the minimum bounding box
68
+ objcut_xmin, objcut_xmax, objcut_ymin, objcut_ymax = shadow_utils.get_mask_bbox(np.array(mask_pil))
69
+ img_pil = Image.fromarray(np.array(img_pil)[objcut_ymin:objcut_ymax, objcut_xmin:objcut_xmax])
70
+ mask_pil = Image.fromarray(np.array(mask_pil)[objcut_ymin:objcut_ymax, objcut_xmin:objcut_xmax])
71
+
72
+ # 2.Fill the smallest foreground area with white edges
73
+ img_pil = shadow_utils.padding_image(img_pil, padding_rate, self.resolution, (255, 255, 255))
74
+ mask_pil = shadow_utils.padding_image(mask_pil, padding_rate, self.resolution, (0, 0, 0))
75
+
76
+ mask_np = np.array(mask_pil)
77
+ depth_pil = self.get_depth_map(img_pil)
78
+ masked_depth_pil = Image.fromarray(((1 - mask_np / 255.0) * np.zeros_like(mask_np) +
79
+ mask_np / 255.0 * np.array(depth_pil)).astype(np.uint8))
80
+ masked_image_pil = Image.fromarray(((1 - mask_np / 255.0) * np.array([255, 255, 255])[np.newaxis, np.newaxis, :] +
81
+ mask_np / 255.0 * np.array(img_pil)).astype(np.uint8))
82
+
83
+ generated_image = self.shadow_pipe(prompt,
84
+ image=masked_image_pil,
85
+ control_image=masked_depth_pil,
86
+ strength=denoise_strength,
87
+ num_inference_steps=num_inference_steps,
88
+ generator=None if seed == -1 else torch.manual_seed(seed),
89
+ controlnet_conditioning_scale=controlnet_conditioning_scale,
90
+ guidance_scale=cfg
91
+ ).images[0]
92
+ composed_image = mask_np / 255.0 * np.array(img_pil) + (1 - mask_np / 255.0) * np.array(generated_image)
93
+ composed_image = Image.fromarray(composed_image.astype(np.uint8))
94
+
95
+ return masked_image_pil, masked_depth_pil, generated_image, composed_image
96
+
97
+
98
+ if __name__ == '__main__':
99
+
100
+ shadowfree_img_dir = "./test_images"
101
+
102
+ save_dir = "./test_results".format(datetime.now().strftime("%Y%m%d%H%M%S"))
103
+ os.makedirs(save_dir, exist_ok=True)
104
+
105
+ prompts = {"软阴影": "a product with soft natural shadow, white background",
106
+ "硬阴影": "a product with hard natural shadow, white background",
107
+ "悬浮阴影": "a product with floating natural shadow, white background"}
108
+ shadow_diffuser = Shadow_Diffusion()
109
+
110
+ for image_name in os.listdir(shadowfree_img_dir):
111
+ results = []
112
+ for shadow_type, prompt in prompts.items():
113
+ org_image = cv2.imread(os.path.join(shadowfree_img_dir, image_name), cv2.IMREAD_UNCHANGED)
114
+ org_image = cv2.cvtColor(org_image, cv2.COLOR_BGRA2RGBA)
115
+ img_pil = Image.fromarray(org_image[..., :-1])
116
+ mask_pil = Image.fromarray(np.repeat(org_image[..., -1:], 3, -1))
117
+
118
+
119
+ start_time = time.time()
120
+ masked_img, masked_depth, gen_img, compose_img = shadow_diffuser.generate_shadow(img_pil,
121
+ mask_pil,
122
+ prompt,
123
+ num_inference_steps=50,
124
+ padding_rate=0.5,
125
+ seed=42,
126
+ denoise_strength=1.0,
127
+ cfg=5.0,
128
+ controlnet_conditioning_scale=0.5,
129
+ )
130
+ results.append(compose_img)
131
+ make_image_grid(results, rows=1, cols=3).save(os.path.join(save_dir, image_name))
shadow_utils.py ADDED
@@ -0,0 +1,54 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import cv2
2
+ from PIL import Image
3
+ import numpy as np
4
+
5
+
6
+ def get_mask_bbox(mask_org, threshold=0):
7
+ '''
8
+ Returns the minimum bounding box of the mask. If the shadow area is all 0, a single-pixel image is symbolically returned.
9
+ '''
10
+ mask = mask_org.copy()
11
+ mask[mask > threshold] = 255
12
+ mask[mask <= threshold] = 0
13
+ h, w = mask.shape[:2]
14
+ if len(mask.shape) == 3:
15
+ coords = np.where(np.mean(mask, axis=-1, keepdims=False) > threshold)
16
+ else:
17
+ coords = np.where(mask > 20)
18
+
19
+ if len(coords[0]) > 0:
20
+ ymin, ymax = coords[0].min(), coords[0].max()
21
+ else:
22
+ ymin, ymax = h // 2, h // 2 + 1
23
+ if len(coords[1]) > 0:
24
+ xmin, xmax = coords[1].min(), coords[1].max()
25
+ else:
26
+ xmin, xmax = w // 2, w // 2 + 1
27
+ return (xmin, xmax, ymin, ymax)
28
+
29
+ def padding_image(img, padding_rate, canve_reso=1024, padding_color=(0, 0, 0)):
30
+ '''
31
+ img:PIL.Image
32
+ Scale the image proportionally and place it on a white square canvas,
33
+ making sure the long side of the original image is filled with padding_rate on the edge of the canvas.
34
+ '''
35
+ long_size = int(canve_reso*(1-padding_rate))
36
+ img = np.array(img)
37
+ h, w = img.shape[:2]
38
+
39
+ if h > w:
40
+ new_h = long_size
41
+ new_w = int(w / h * long_size)
42
+ else:
43
+ new_w = long_size
44
+ new_h = int(h / w * long_size)
45
+
46
+ resized_img = cv2.resize(img, (new_w, new_h), interpolation=cv2.INTER_AREA)
47
+
48
+ padding_h = (canve_reso - new_h) // 2
49
+ padding_w = (canve_reso - new_w) // 2
50
+
51
+ padding_output = np.ones((canve_reso, canve_reso, resized_img.shape[-1])) * \
52
+ np.array(padding_color)[np.newaxis, np.newaxis, :]
53
+ padding_output[padding_h:padding_h + new_h, padding_w:padding_w + new_w, :] = resized_img
54
+ return Image.fromarray(padding_output.astype(np.uint8))