|
|
import os |
|
|
import cv2 |
|
|
import time |
|
|
import torch |
|
|
import numpy as np |
|
|
import shadow_utils |
|
|
from PIL import Image |
|
|
|
|
|
from datetime import datetime |
|
|
|
|
|
from diffusers.utils import make_image_grid |
|
|
from transformers import DPTFeatureExtractor, DPTForDepthEstimation |
|
|
from diffusers import ControlNetModel, StableDiffusionXLControlNetImg2ImgPipeline, AutoencoderKL |
|
|
|
|
|
|
|
|
class Shadow_Diffusion: |
|
|
def __init__(self, |
|
|
base_model_path="stabilityai/stable-diffusion-xl-base-1.0", |
|
|
vae_path="madebyollin/sdxl-vae-fp16-fix", |
|
|
controlnet_path="./checkpoints/shadow_checkpoint", |
|
|
depth_midas_path="Intel/dpt-hybrid-midas", |
|
|
resolution=1024, |
|
|
precision_type=torch.float16): |
|
|
self.resolution=resolution |
|
|
self.depth_estimator = DPTForDepthEstimation.from_pretrained(depth_midas_path).to("cuda") |
|
|
self.feature_extractor = DPTFeatureExtractor.from_pretrained(depth_midas_path) |
|
|
self.controlnet = ControlNetModel.from_pretrained(controlnet_path, torch_dtype=precision_type).to("cuda") |
|
|
self.vae = AutoencoderKL.from_pretrained(vae_path, torch_dtype=precision_type).to("cuda") |
|
|
self.shadow_pipe = StableDiffusionXLControlNetImg2ImgPipeline.from_pretrained(base_model_path, |
|
|
controlnet=self.controlnet, |
|
|
vae=self.vae, |
|
|
variant="fp16", |
|
|
use_safetensors=True, |
|
|
torch_dtype=precision_type).to("cuda") |
|
|
self.shadow_pipe.enable_model_cpu_offload() |
|
|
|
|
|
def get_depth_map(self, image): |
|
|
image = self.feature_extractor(images=image, return_tensors="pt").pixel_values.to("cuda") |
|
|
with torch.no_grad(), torch.autocast("cuda"): |
|
|
depth_map = self.depth_estimator(image).predicted_depth |
|
|
|
|
|
depth_map = torch.nn.functional.interpolate( |
|
|
depth_map.unsqueeze(1), |
|
|
size=(self.resolution, self.resolution), |
|
|
mode="bicubic", |
|
|
align_corners=False, |
|
|
) |
|
|
depth_min = torch.amin(depth_map, dim=[1, 2, 3], keepdim=True) |
|
|
depth_max = torch.amax(depth_map, dim=[1, 2, 3], keepdim=True) |
|
|
depth_map = (depth_map - depth_min) / (depth_max - depth_min) |
|
|
image = torch.cat([depth_map] * 3, dim=1) |
|
|
image = image.permute(0, 2, 3, 1).cpu().numpy()[0] |
|
|
image = Image.fromarray((image * 255.0).clip(0, 255).astype(np.uint8)) |
|
|
return image |
|
|
|
|
|
def generate_shadow(self, |
|
|
img_pil, |
|
|
mask_pil, |
|
|
prompt="", |
|
|
padding_rate=0.4, |
|
|
denoise_strength=1.0, |
|
|
num_inference_steps=20, |
|
|
controlnet_conditioning_scale=0.5, |
|
|
cfg=5.0, |
|
|
seed=-1): |
|
|
|
|
|
|
|
|
objcut_xmin, objcut_xmax, objcut_ymin, objcut_ymax = shadow_utils.get_mask_bbox(np.array(mask_pil)) |
|
|
img_pil = Image.fromarray(np.array(img_pil)[objcut_ymin:objcut_ymax, objcut_xmin:objcut_xmax]) |
|
|
mask_pil = Image.fromarray(np.array(mask_pil)[objcut_ymin:objcut_ymax, objcut_xmin:objcut_xmax]) |
|
|
|
|
|
|
|
|
img_pil = shadow_utils.padding_image(img_pil, padding_rate, self.resolution, (255, 255, 255)) |
|
|
mask_pil = shadow_utils.padding_image(mask_pil, padding_rate, self.resolution, (0, 0, 0)) |
|
|
|
|
|
mask_np = np.array(mask_pil) |
|
|
depth_pil = self.get_depth_map(img_pil) |
|
|
masked_depth_pil = Image.fromarray(((1 - mask_np / 255.0) * np.zeros_like(mask_np) + |
|
|
mask_np / 255.0 * np.array(depth_pil)).astype(np.uint8)) |
|
|
masked_image_pil = Image.fromarray(((1 - mask_np / 255.0) * np.array([255, 255, 255])[np.newaxis, np.newaxis, :] + |
|
|
mask_np / 255.0 * np.array(img_pil)).astype(np.uint8)) |
|
|
|
|
|
generated_image = self.shadow_pipe(prompt, |
|
|
image=masked_image_pil, |
|
|
control_image=masked_depth_pil, |
|
|
strength=denoise_strength, |
|
|
num_inference_steps=num_inference_steps, |
|
|
generator=None if seed == -1 else torch.manual_seed(seed), |
|
|
controlnet_conditioning_scale=controlnet_conditioning_scale, |
|
|
guidance_scale=cfg |
|
|
).images[0] |
|
|
composed_image = mask_np / 255.0 * np.array(img_pil) + (1 - mask_np / 255.0) * np.array(generated_image) |
|
|
composed_image = Image.fromarray(composed_image.astype(np.uint8)) |
|
|
|
|
|
return masked_image_pil, masked_depth_pil, generated_image, composed_image |
|
|
|
|
|
|
|
|
if __name__ == '__main__': |
|
|
|
|
|
shadowfree_img_dir = "./test_images" |
|
|
|
|
|
save_dir = "./test_results".format(datetime.now().strftime("%Y%m%d%H%M%S")) |
|
|
os.makedirs(save_dir, exist_ok=True) |
|
|
|
|
|
prompts = {"软阴影": "a product with soft natural shadow, white background", |
|
|
"硬阴影": "a product with hard natural shadow, white background", |
|
|
"悬浮阴影": "a product with floating natural shadow, white background"} |
|
|
shadow_diffuser = Shadow_Diffusion() |
|
|
|
|
|
for image_name in os.listdir(shadowfree_img_dir): |
|
|
results = [] |
|
|
for shadow_type, prompt in prompts.items(): |
|
|
org_image = cv2.imread(os.path.join(shadowfree_img_dir, image_name), cv2.IMREAD_UNCHANGED) |
|
|
org_image = cv2.cvtColor(org_image, cv2.COLOR_BGRA2RGBA) |
|
|
img_pil = Image.fromarray(org_image[..., :-1]) |
|
|
mask_pil = Image.fromarray(np.repeat(org_image[..., -1:], 3, -1)) |
|
|
|
|
|
|
|
|
start_time = time.time() |
|
|
masked_img, masked_depth, gen_img, compose_img = shadow_diffuser.generate_shadow(img_pil, |
|
|
mask_pil, |
|
|
prompt, |
|
|
num_inference_steps=50, |
|
|
padding_rate=0.5, |
|
|
seed=42, |
|
|
denoise_strength=1.0, |
|
|
cfg=5.0, |
|
|
controlnet_conditioning_scale=0.5, |
|
|
) |
|
|
results.append(compose_img) |
|
|
make_image_grid(results, rows=1, cols=3).save(os.path.join(save_dir, image_name)) |
|
|
|