SDXL-Shadow_Generation / generate_shadow_main.py
GeorgeQi's picture
Update generate_shadow_main.py
de27f8f verified
import os
import cv2
import time
import torch
import numpy as np
import shadow_utils
from PIL import Image
from datetime import datetime
from diffusers.utils import make_image_grid
from transformers import DPTFeatureExtractor, DPTForDepthEstimation
from diffusers import ControlNetModel, StableDiffusionXLControlNetImg2ImgPipeline, AutoencoderKL
class Shadow_Diffusion:
def __init__(self,
base_model_path="stabilityai/stable-diffusion-xl-base-1.0",
vae_path="madebyollin/sdxl-vae-fp16-fix",
controlnet_path="./checkpoints/shadow_checkpoint",
depth_midas_path="Intel/dpt-hybrid-midas",
resolution=1024,
precision_type=torch.float16):
self.resolution=resolution
self.depth_estimator = DPTForDepthEstimation.from_pretrained(depth_midas_path).to("cuda")
self.feature_extractor = DPTFeatureExtractor.from_pretrained(depth_midas_path)
self.controlnet = ControlNetModel.from_pretrained(controlnet_path, torch_dtype=precision_type).to("cuda")
self.vae = AutoencoderKL.from_pretrained(vae_path, torch_dtype=precision_type).to("cuda")
self.shadow_pipe = StableDiffusionXLControlNetImg2ImgPipeline.from_pretrained(base_model_path,
controlnet=self.controlnet,
vae=self.vae,
variant="fp16",
use_safetensors=True,
torch_dtype=precision_type).to("cuda")
self.shadow_pipe.enable_model_cpu_offload()
def get_depth_map(self, image):
image = self.feature_extractor(images=image, return_tensors="pt").pixel_values.to("cuda")
with torch.no_grad(), torch.autocast("cuda"):
depth_map = self.depth_estimator(image).predicted_depth
depth_map = torch.nn.functional.interpolate(
depth_map.unsqueeze(1),
size=(self.resolution, self.resolution),
mode="bicubic",
align_corners=False,
)
depth_min = torch.amin(depth_map, dim=[1, 2, 3], keepdim=True)
depth_max = torch.amax(depth_map, dim=[1, 2, 3], keepdim=True)
depth_map = (depth_map - depth_min) / (depth_max - depth_min)
image = torch.cat([depth_map] * 3, dim=1)
image = image.permute(0, 2, 3, 1).cpu().numpy()[0]
image = Image.fromarray((image * 255.0).clip(0, 255).astype(np.uint8))
return image
def generate_shadow(self,
img_pil,
mask_pil,
prompt="",
padding_rate=0.4,
denoise_strength=1.0,
num_inference_steps=20,
controlnet_conditioning_scale=0.5,
cfg=5.0,
seed=-1):
# 1.Extract the foreground area according to the minimum bounding box
objcut_xmin, objcut_xmax, objcut_ymin, objcut_ymax = shadow_utils.get_mask_bbox(np.array(mask_pil))
img_pil = Image.fromarray(np.array(img_pil)[objcut_ymin:objcut_ymax, objcut_xmin:objcut_xmax])
mask_pil = Image.fromarray(np.array(mask_pil)[objcut_ymin:objcut_ymax, objcut_xmin:objcut_xmax])
# 2.Fill the smallest foreground area with white edges
img_pil = shadow_utils.padding_image(img_pil, padding_rate, self.resolution, (255, 255, 255))
mask_pil = shadow_utils.padding_image(mask_pil, padding_rate, self.resolution, (0, 0, 0))
mask_np = np.array(mask_pil)
depth_pil = self.get_depth_map(img_pil)
masked_depth_pil = Image.fromarray(((1 - mask_np / 255.0) * np.zeros_like(mask_np) +
mask_np / 255.0 * np.array(depth_pil)).astype(np.uint8))
masked_image_pil = Image.fromarray(((1 - mask_np / 255.0) * np.array([255, 255, 255])[np.newaxis, np.newaxis, :] +
mask_np / 255.0 * np.array(img_pil)).astype(np.uint8))
generated_image = self.shadow_pipe(prompt,
image=masked_image_pil,
control_image=masked_depth_pil,
strength=denoise_strength,
num_inference_steps=num_inference_steps,
generator=None if seed == -1 else torch.manual_seed(seed),
controlnet_conditioning_scale=controlnet_conditioning_scale,
guidance_scale=cfg
).images[0]
composed_image = mask_np / 255.0 * np.array(img_pil) + (1 - mask_np / 255.0) * np.array(generated_image)
composed_image = Image.fromarray(composed_image.astype(np.uint8))
return masked_image_pil, masked_depth_pil, generated_image, composed_image
if __name__ == '__main__':
shadowfree_img_dir = "./test_images"
save_dir = "./test_results".format(datetime.now().strftime("%Y%m%d%H%M%S"))
os.makedirs(save_dir, exist_ok=True)
prompts = {"软阴影": "a product with soft natural shadow, white background",
"硬阴影": "a product with hard natural shadow, white background",
"悬浮阴影": "a product with floating natural shadow, white background"}
shadow_diffuser = Shadow_Diffusion()
for image_name in os.listdir(shadowfree_img_dir):
results = []
for shadow_type, prompt in prompts.items():
org_image = cv2.imread(os.path.join(shadowfree_img_dir, image_name), cv2.IMREAD_UNCHANGED)
org_image = cv2.cvtColor(org_image, cv2.COLOR_BGRA2RGBA)
img_pil = Image.fromarray(org_image[..., :-1])
mask_pil = Image.fromarray(np.repeat(org_image[..., -1:], 3, -1))
start_time = time.time()
masked_img, masked_depth, gen_img, compose_img = shadow_diffuser.generate_shadow(img_pil,
mask_pil,
prompt,
num_inference_steps=50,
padding_rate=0.5,
seed=42,
denoise_strength=1.0,
cfg=5.0,
controlnet_conditioning_scale=0.5,
)
results.append(compose_img)
make_image_grid(results, rows=1, cols=3).save(os.path.join(save_dir, image_name))