import os import cv2 import time import torch import numpy as np import shadow_utils from PIL import Image from datetime import datetime from diffusers.utils import make_image_grid from transformers import DPTFeatureExtractor, DPTForDepthEstimation from diffusers import ControlNetModel, StableDiffusionXLControlNetImg2ImgPipeline, AutoencoderKL class Shadow_Diffusion: def __init__(self, base_model_path="stabilityai/stable-diffusion-xl-base-1.0", vae_path="madebyollin/sdxl-vae-fp16-fix", controlnet_path="./checkpoints/shadow_checkpoint", depth_midas_path="Intel/dpt-hybrid-midas", resolution=1024, precision_type=torch.float16): self.resolution=resolution self.depth_estimator = DPTForDepthEstimation.from_pretrained(depth_midas_path).to("cuda") self.feature_extractor = DPTFeatureExtractor.from_pretrained(depth_midas_path) self.controlnet = ControlNetModel.from_pretrained(controlnet_path, torch_dtype=precision_type).to("cuda") self.vae = AutoencoderKL.from_pretrained(vae_path, torch_dtype=precision_type).to("cuda") self.shadow_pipe = StableDiffusionXLControlNetImg2ImgPipeline.from_pretrained(base_model_path, controlnet=self.controlnet, vae=self.vae, variant="fp16", use_safetensors=True, torch_dtype=precision_type).to("cuda") self.shadow_pipe.enable_model_cpu_offload() def get_depth_map(self, image): image = self.feature_extractor(images=image, return_tensors="pt").pixel_values.to("cuda") with torch.no_grad(), torch.autocast("cuda"): depth_map = self.depth_estimator(image).predicted_depth depth_map = torch.nn.functional.interpolate( depth_map.unsqueeze(1), size=(self.resolution, self.resolution), mode="bicubic", align_corners=False, ) depth_min = torch.amin(depth_map, dim=[1, 2, 3], keepdim=True) depth_max = torch.amax(depth_map, dim=[1, 2, 3], keepdim=True) depth_map = (depth_map - depth_min) / (depth_max - depth_min) image = torch.cat([depth_map] * 3, dim=1) image = image.permute(0, 2, 3, 1).cpu().numpy()[0] image = Image.fromarray((image * 255.0).clip(0, 255).astype(np.uint8)) return image def generate_shadow(self, img_pil, mask_pil, prompt="", padding_rate=0.4, denoise_strength=1.0, num_inference_steps=20, controlnet_conditioning_scale=0.5, cfg=5.0, seed=-1): # 1.Extract the foreground area according to the minimum bounding box objcut_xmin, objcut_xmax, objcut_ymin, objcut_ymax = shadow_utils.get_mask_bbox(np.array(mask_pil)) img_pil = Image.fromarray(np.array(img_pil)[objcut_ymin:objcut_ymax, objcut_xmin:objcut_xmax]) mask_pil = Image.fromarray(np.array(mask_pil)[objcut_ymin:objcut_ymax, objcut_xmin:objcut_xmax]) # 2.Fill the smallest foreground area with white edges img_pil = shadow_utils.padding_image(img_pil, padding_rate, self.resolution, (255, 255, 255)) mask_pil = shadow_utils.padding_image(mask_pil, padding_rate, self.resolution, (0, 0, 0)) mask_np = np.array(mask_pil) depth_pil = self.get_depth_map(img_pil) masked_depth_pil = Image.fromarray(((1 - mask_np / 255.0) * np.zeros_like(mask_np) + mask_np / 255.0 * np.array(depth_pil)).astype(np.uint8)) masked_image_pil = Image.fromarray(((1 - mask_np / 255.0) * np.array([255, 255, 255])[np.newaxis, np.newaxis, :] + mask_np / 255.0 * np.array(img_pil)).astype(np.uint8)) generated_image = self.shadow_pipe(prompt, image=masked_image_pil, control_image=masked_depth_pil, strength=denoise_strength, num_inference_steps=num_inference_steps, generator=None if seed == -1 else torch.manual_seed(seed), controlnet_conditioning_scale=controlnet_conditioning_scale, guidance_scale=cfg ).images[0] composed_image = mask_np / 255.0 * np.array(img_pil) + (1 - mask_np / 255.0) * np.array(generated_image) composed_image = Image.fromarray(composed_image.astype(np.uint8)) return masked_image_pil, masked_depth_pil, generated_image, composed_image if __name__ == '__main__': shadowfree_img_dir = "./test_images" save_dir = "./test_results".format(datetime.now().strftime("%Y%m%d%H%M%S")) os.makedirs(save_dir, exist_ok=True) prompts = {"软阴影": "a product with soft natural shadow, white background", "硬阴影": "a product with hard natural shadow, white background", "悬浮阴影": "a product with floating natural shadow, white background"} shadow_diffuser = Shadow_Diffusion() for image_name in os.listdir(shadowfree_img_dir): results = [] for shadow_type, prompt in prompts.items(): org_image = cv2.imread(os.path.join(shadowfree_img_dir, image_name), cv2.IMREAD_UNCHANGED) org_image = cv2.cvtColor(org_image, cv2.COLOR_BGRA2RGBA) img_pil = Image.fromarray(org_image[..., :-1]) mask_pil = Image.fromarray(np.repeat(org_image[..., -1:], 3, -1)) start_time = time.time() masked_img, masked_depth, gen_img, compose_img = shadow_diffuser.generate_shadow(img_pil, mask_pil, prompt, num_inference_steps=50, padding_rate=0.5, seed=42, denoise_strength=1.0, cfg=5.0, controlnet_conditioning_scale=0.5, ) results.append(compose_img) make_image_grid(results, rows=1, cols=3).save(os.path.join(save_dir, image_name))