Upload 3 files
Browse files- app_shadow_generation_gradio.py +56 -0
- generate_shadow_main.py +131 -0
- shadow_utils.py +54 -0
app_shadow_generation_gradio.py
ADDED
|
@@ -0,0 +1,56 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import os
|
| 2 |
+
import gradio as gr
|
| 3 |
+
import numpy as np
|
| 4 |
+
from PIL import Image
|
| 5 |
+
from generate_shadow_main import Shadow_Diffusion
|
| 6 |
+
shadow_diffuser = Shadow_Diffusion()
|
| 7 |
+
shadow_prompt_dict = {"软阴影": "a product with soft natural shadow, white background",
|
| 8 |
+
"硬阴影": "a product with hard natural shadow, white background",
|
| 9 |
+
"悬浮阴影": "a product with floating natural shadow, white background",}
|
| 10 |
+
def generate_shadow(image_rgba, shadow_type, padding_rate, denoise_strength, num_inference_steps,
|
| 11 |
+
controlnet_conditioning_scale, cfg, seed):
|
| 12 |
+
img_pil = Image.fromarray(image_rgba[..., :-1])
|
| 13 |
+
mask_pil = Image.fromarray(np.repeat(image_rgba[..., -1:], 3, -1))
|
| 14 |
+
shadow_prompt = shadow_prompt_dict[shadow_type]
|
| 15 |
+
print(padding_rate, denoise_strength, num_inference_steps, controlnet_conditioning_scale, cfg, seed)
|
| 16 |
+
_, _, _, shadow_result_pil = shadow_diffuser.generate_shadow(img_pil,
|
| 17 |
+
mask_pil,
|
| 18 |
+
shadow_prompt,
|
| 19 |
+
padding_rate=float(padding_rate),
|
| 20 |
+
denoise_strength=float(denoise_strength),
|
| 21 |
+
num_inference_steps=int(num_inference_steps),
|
| 22 |
+
controlnet_conditioning_scale=float(controlnet_conditioning_scale),
|
| 23 |
+
cfg=float(cfg),
|
| 24 |
+
seed=int(seed))
|
| 25 |
+
return np.array(shadow_result_pil)
|
| 26 |
+
|
| 27 |
+
|
| 28 |
+
with gr.Blocks() as demo:
|
| 29 |
+
gr.Markdown("# 💡基于Diffusion的白底阴影生成 \n"
|
| 30 |
+
"请确保上传带有透明通道的RGBA图像作为输入")
|
| 31 |
+
with gr.Row():
|
| 32 |
+
with gr.Column():
|
| 33 |
+
rgba = gr.Image(image_mode="RGBA", label="输入商品图(RGBA)")
|
| 34 |
+
gr.Examples(label="示例图片", inputs=[rgba],
|
| 35 |
+
examples=[os.path.join("./test_images", n) for n in os.listdir("./test_images")])
|
| 36 |
+
with gr.Column():
|
| 37 |
+
shadow_output = gr.Image(image_mode="RGB", label="阴影生成结果")
|
| 38 |
+
|
| 39 |
+
with gr.Row():
|
| 40 |
+
shadow_type = gr.Radio(["软阴影", "硬阴影", "悬浮阴影"], value="硬阴影", label="阴影类型")
|
| 41 |
+
generate_btn = gr.Button(value="生成阴影")
|
| 42 |
+
|
| 43 |
+
with gr.Accordion(label="其他参数>>", open=False) as sku_accordion:
|
| 44 |
+
padding_rate = gr.Slider(0, 0.99, value=0.4, step=0.1, label="白边填充比例")
|
| 45 |
+
denoise_strength = gr.Slider(0, 0.99, value=1.0, step=0.01, label="去噪程度")
|
| 46 |
+
num_inference_steps = gr.Slider(10, 50, value=20, step=1, label="推理步数")
|
| 47 |
+
controlnet_conditioning_scale = gr.Slider(0, 1.0, value=0.5, step=0.01, label="控制强度(ControlNet)")
|
| 48 |
+
step_num = gr.Slider(1, 50, value=20, step=1, label="推理步数")
|
| 49 |
+
cfg = gr.Slider(0, 20, value=5, step=0.5, label="CFG")
|
| 50 |
+
seed = gr.Slider(-1, 99999999999999, value=42, step=0.01, label="随机种子")
|
| 51 |
+
|
| 52 |
+
generate_btn.click(generate_shadow, inputs=[rgba, shadow_type, padding_rate, denoise_strength, num_inference_steps,
|
| 53 |
+
controlnet_conditioning_scale, cfg, seed], outputs=shadow_output)
|
| 54 |
+
|
| 55 |
+
demo.queue().launch(server_name='[::]', share=True)
|
| 56 |
+
|
generate_shadow_main.py
ADDED
|
@@ -0,0 +1,131 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import os
|
| 2 |
+
import cv2
|
| 3 |
+
import time
|
| 4 |
+
import torch
|
| 5 |
+
import numpy as np
|
| 6 |
+
import shadow_utils
|
| 7 |
+
from PIL import Image
|
| 8 |
+
|
| 9 |
+
from datetime import datetime
|
| 10 |
+
|
| 11 |
+
from diffusers.utils import make_image_grid
|
| 12 |
+
from transformers import DPTFeatureExtractor, DPTForDepthEstimation
|
| 13 |
+
from diffusers import ControlNetModel, StableDiffusionXLControlNetImg2ImgPipeline, AutoencoderKL
|
| 14 |
+
|
| 15 |
+
|
| 16 |
+
class Shadow_Diffusion:
|
| 17 |
+
def __init__(self,
|
| 18 |
+
base_model_path="stabilityai/stable-diffusion-xl-base-1.0",
|
| 19 |
+
vae_path="madebyollin/sdxl-vae-fp16-fix",
|
| 20 |
+
controlnet_path="./checkpoints/shadow",
|
| 21 |
+
depth_midas_path="Intel/dpt-hybrid-midas",
|
| 22 |
+
resolution=1024,
|
| 23 |
+
precision_type=torch.float16):
|
| 24 |
+
self.resolution=resolution
|
| 25 |
+
self.depth_estimator = DPTForDepthEstimation.from_pretrained(depth_midas_path).to("cuda")
|
| 26 |
+
self.feature_extractor = DPTFeatureExtractor.from_pretrained(depth_midas_path)
|
| 27 |
+
self.controlnet = ControlNetModel.from_pretrained(controlnet_path, torch_dtype=precision_type).to("cuda")
|
| 28 |
+
self.vae = AutoencoderKL.from_pretrained(vae_path, torch_dtype=precision_type).to("cuda")
|
| 29 |
+
self.shadow_pipe = StableDiffusionXLControlNetImg2ImgPipeline.from_pretrained(base_model_path,
|
| 30 |
+
controlnet=self.controlnet,
|
| 31 |
+
vae=self.vae,
|
| 32 |
+
variant="fp16",
|
| 33 |
+
use_safetensors=True,
|
| 34 |
+
torch_dtype=precision_type).to("cuda")
|
| 35 |
+
self.shadow_pipe.enable_model_cpu_offload()
|
| 36 |
+
|
| 37 |
+
def get_depth_map(self, image):
|
| 38 |
+
image = self.feature_extractor(images=image, return_tensors="pt").pixel_values.to("cuda")
|
| 39 |
+
with torch.no_grad(), torch.autocast("cuda"):
|
| 40 |
+
depth_map = self.depth_estimator(image).predicted_depth
|
| 41 |
+
|
| 42 |
+
depth_map = torch.nn.functional.interpolate(
|
| 43 |
+
depth_map.unsqueeze(1),
|
| 44 |
+
size=(self.resolution, self.resolution),
|
| 45 |
+
mode="bicubic",
|
| 46 |
+
align_corners=False,
|
| 47 |
+
)
|
| 48 |
+
depth_min = torch.amin(depth_map, dim=[1, 2, 3], keepdim=True)
|
| 49 |
+
depth_max = torch.amax(depth_map, dim=[1, 2, 3], keepdim=True)
|
| 50 |
+
depth_map = (depth_map - depth_min) / (depth_max - depth_min)
|
| 51 |
+
image = torch.cat([depth_map] * 3, dim=1)
|
| 52 |
+
image = image.permute(0, 2, 3, 1).cpu().numpy()[0]
|
| 53 |
+
image = Image.fromarray((image * 255.0).clip(0, 255).astype(np.uint8))
|
| 54 |
+
return image
|
| 55 |
+
|
| 56 |
+
def generate_shadow(self,
|
| 57 |
+
img_pil,
|
| 58 |
+
mask_pil,
|
| 59 |
+
prompt="",
|
| 60 |
+
padding_rate=0.4,
|
| 61 |
+
denoise_strength=1.0,
|
| 62 |
+
num_inference_steps=20,
|
| 63 |
+
controlnet_conditioning_scale=0.5,
|
| 64 |
+
cfg=5.0,
|
| 65 |
+
seed=-1):
|
| 66 |
+
|
| 67 |
+
# 1.Extract the foreground area according to the minimum bounding box
|
| 68 |
+
objcut_xmin, objcut_xmax, objcut_ymin, objcut_ymax = shadow_utils.get_mask_bbox(np.array(mask_pil))
|
| 69 |
+
img_pil = Image.fromarray(np.array(img_pil)[objcut_ymin:objcut_ymax, objcut_xmin:objcut_xmax])
|
| 70 |
+
mask_pil = Image.fromarray(np.array(mask_pil)[objcut_ymin:objcut_ymax, objcut_xmin:objcut_xmax])
|
| 71 |
+
|
| 72 |
+
# 2.Fill the smallest foreground area with white edges
|
| 73 |
+
img_pil = shadow_utils.padding_image(img_pil, padding_rate, self.resolution, (255, 255, 255))
|
| 74 |
+
mask_pil = shadow_utils.padding_image(mask_pil, padding_rate, self.resolution, (0, 0, 0))
|
| 75 |
+
|
| 76 |
+
mask_np = np.array(mask_pil)
|
| 77 |
+
depth_pil = self.get_depth_map(img_pil)
|
| 78 |
+
masked_depth_pil = Image.fromarray(((1 - mask_np / 255.0) * np.zeros_like(mask_np) +
|
| 79 |
+
mask_np / 255.0 * np.array(depth_pil)).astype(np.uint8))
|
| 80 |
+
masked_image_pil = Image.fromarray(((1 - mask_np / 255.0) * np.array([255, 255, 255])[np.newaxis, np.newaxis, :] +
|
| 81 |
+
mask_np / 255.0 * np.array(img_pil)).astype(np.uint8))
|
| 82 |
+
|
| 83 |
+
generated_image = self.shadow_pipe(prompt,
|
| 84 |
+
image=masked_image_pil,
|
| 85 |
+
control_image=masked_depth_pil,
|
| 86 |
+
strength=denoise_strength,
|
| 87 |
+
num_inference_steps=num_inference_steps,
|
| 88 |
+
generator=None if seed == -1 else torch.manual_seed(seed),
|
| 89 |
+
controlnet_conditioning_scale=controlnet_conditioning_scale,
|
| 90 |
+
guidance_scale=cfg
|
| 91 |
+
).images[0]
|
| 92 |
+
composed_image = mask_np / 255.0 * np.array(img_pil) + (1 - mask_np / 255.0) * np.array(generated_image)
|
| 93 |
+
composed_image = Image.fromarray(composed_image.astype(np.uint8))
|
| 94 |
+
|
| 95 |
+
return masked_image_pil, masked_depth_pil, generated_image, composed_image
|
| 96 |
+
|
| 97 |
+
|
| 98 |
+
if __name__ == '__main__':
|
| 99 |
+
|
| 100 |
+
shadowfree_img_dir = "./test_images"
|
| 101 |
+
|
| 102 |
+
save_dir = "./test_results".format(datetime.now().strftime("%Y%m%d%H%M%S"))
|
| 103 |
+
os.makedirs(save_dir, exist_ok=True)
|
| 104 |
+
|
| 105 |
+
prompts = {"软阴影": "a product with soft natural shadow, white background",
|
| 106 |
+
"硬阴影": "a product with hard natural shadow, white background",
|
| 107 |
+
"悬浮阴影": "a product with floating natural shadow, white background"}
|
| 108 |
+
shadow_diffuser = Shadow_Diffusion()
|
| 109 |
+
|
| 110 |
+
for image_name in os.listdir(shadowfree_img_dir):
|
| 111 |
+
results = []
|
| 112 |
+
for shadow_type, prompt in prompts.items():
|
| 113 |
+
org_image = cv2.imread(os.path.join(shadowfree_img_dir, image_name), cv2.IMREAD_UNCHANGED)
|
| 114 |
+
org_image = cv2.cvtColor(org_image, cv2.COLOR_BGRA2RGBA)
|
| 115 |
+
img_pil = Image.fromarray(org_image[..., :-1])
|
| 116 |
+
mask_pil = Image.fromarray(np.repeat(org_image[..., -1:], 3, -1))
|
| 117 |
+
|
| 118 |
+
|
| 119 |
+
start_time = time.time()
|
| 120 |
+
masked_img, masked_depth, gen_img, compose_img = shadow_diffuser.generate_shadow(img_pil,
|
| 121 |
+
mask_pil,
|
| 122 |
+
prompt,
|
| 123 |
+
num_inference_steps=50,
|
| 124 |
+
padding_rate=0.5,
|
| 125 |
+
seed=42,
|
| 126 |
+
denoise_strength=1.0,
|
| 127 |
+
cfg=5.0,
|
| 128 |
+
controlnet_conditioning_scale=0.5,
|
| 129 |
+
)
|
| 130 |
+
results.append(compose_img)
|
| 131 |
+
make_image_grid(results, rows=1, cols=3).save(os.path.join(save_dir, image_name))
|
shadow_utils.py
ADDED
|
@@ -0,0 +1,54 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import cv2
|
| 2 |
+
from PIL import Image
|
| 3 |
+
import numpy as np
|
| 4 |
+
|
| 5 |
+
|
| 6 |
+
def get_mask_bbox(mask_org, threshold=0):
|
| 7 |
+
'''
|
| 8 |
+
Returns the minimum bounding box of the mask. If the shadow area is all 0, a single-pixel image is symbolically returned.
|
| 9 |
+
'''
|
| 10 |
+
mask = mask_org.copy()
|
| 11 |
+
mask[mask > threshold] = 255
|
| 12 |
+
mask[mask <= threshold] = 0
|
| 13 |
+
h, w = mask.shape[:2]
|
| 14 |
+
if len(mask.shape) == 3:
|
| 15 |
+
coords = np.where(np.mean(mask, axis=-1, keepdims=False) > threshold)
|
| 16 |
+
else:
|
| 17 |
+
coords = np.where(mask > 20)
|
| 18 |
+
|
| 19 |
+
if len(coords[0]) > 0:
|
| 20 |
+
ymin, ymax = coords[0].min(), coords[0].max()
|
| 21 |
+
else:
|
| 22 |
+
ymin, ymax = h // 2, h // 2 + 1
|
| 23 |
+
if len(coords[1]) > 0:
|
| 24 |
+
xmin, xmax = coords[1].min(), coords[1].max()
|
| 25 |
+
else:
|
| 26 |
+
xmin, xmax = w // 2, w // 2 + 1
|
| 27 |
+
return (xmin, xmax, ymin, ymax)
|
| 28 |
+
|
| 29 |
+
def padding_image(img, padding_rate, canve_reso=1024, padding_color=(0, 0, 0)):
|
| 30 |
+
'''
|
| 31 |
+
img:PIL.Image
|
| 32 |
+
Scale the image proportionally and place it on a white square canvas,
|
| 33 |
+
making sure the long side of the original image is filled with padding_rate on the edge of the canvas.
|
| 34 |
+
'''
|
| 35 |
+
long_size = int(canve_reso*(1-padding_rate))
|
| 36 |
+
img = np.array(img)
|
| 37 |
+
h, w = img.shape[:2]
|
| 38 |
+
|
| 39 |
+
if h > w:
|
| 40 |
+
new_h = long_size
|
| 41 |
+
new_w = int(w / h * long_size)
|
| 42 |
+
else:
|
| 43 |
+
new_w = long_size
|
| 44 |
+
new_h = int(h / w * long_size)
|
| 45 |
+
|
| 46 |
+
resized_img = cv2.resize(img, (new_w, new_h), interpolation=cv2.INTER_AREA)
|
| 47 |
+
|
| 48 |
+
padding_h = (canve_reso - new_h) // 2
|
| 49 |
+
padding_w = (canve_reso - new_w) // 2
|
| 50 |
+
|
| 51 |
+
padding_output = np.ones((canve_reso, canve_reso, resized_img.shape[-1])) * \
|
| 52 |
+
np.array(padding_color)[np.newaxis, np.newaxis, :]
|
| 53 |
+
padding_output[padding_h:padding_h + new_h, padding_w:padding_w + new_w, :] = resized_img
|
| 54 |
+
return Image.fromarray(padding_output.astype(np.uint8))
|