| | import argparse |
| |
|
| | import numpy as np |
| | import torch |
| | from diffusers import AutoencoderKL, DDPMScheduler, LCMScheduler, UNet2DConditionModel |
| | from PIL import Image |
| | from torchvision import transforms |
| | from tqdm import tqdm |
| | from transformers import AutoModelForImageSegmentation |
| |
|
| | from mvadapter.pipelines.pipeline_mvadapter_i2mv_sdxl import MVAdapterI2MVSDXLPipeline |
| | from mvadapter.schedulers.scheduling_shift_snr import ShiftSNRScheduler |
| | from mvadapter.utils import ( |
| | get_orthogonal_camera, |
| | get_plucker_embeds_from_cameras_ortho, |
| | make_image_grid, |
| | ) |
| |
|
| |
|
| | def prepare_pipeline( |
| | base_model, |
| | vae_model, |
| | unet_model, |
| | lora_model, |
| | adapter_path, |
| | scheduler, |
| | num_views, |
| | device, |
| | dtype, |
| | ): |
| | |
| | pipe_kwargs = {} |
| | if vae_model is not None: |
| | pipe_kwargs["vae"] = AutoencoderKL.from_pretrained(vae_model) |
| | if unet_model is not None: |
| | pipe_kwargs["unet"] = UNet2DConditionModel.from_pretrained(unet_model) |
| |
|
| | |
| | pipe: MVAdapterI2MVSDXLPipeline |
| | pipe = MVAdapterI2MVSDXLPipeline.from_pretrained(base_model, **pipe_kwargs) |
| |
|
| | |
| | scheduler_class = None |
| | if scheduler == "ddpm": |
| | scheduler_class = DDPMScheduler |
| | elif scheduler == "lcm": |
| | scheduler_class = LCMScheduler |
| |
|
| | pipe.scheduler = ShiftSNRScheduler.from_scheduler( |
| | pipe.scheduler, |
| | shift_mode="interpolated", |
| | shift_scale=8.0, |
| | scheduler_class=scheduler_class, |
| | ) |
| | pipe.init_custom_adapter(num_views=num_views) |
| | pipe.load_custom_adapter( |
| | adapter_path, weight_name="mvadapter_i2mv_sdxl.safetensors" |
| | ) |
| |
|
| | pipe.to(device=device, dtype=dtype) |
| | pipe.cond_encoder.to(device=device, dtype=dtype) |
| |
|
| | |
| | if lora_model is not None: |
| | model_, name_ = lora_model.rsplit("/", 1) |
| | pipe.load_lora_weights(model_, weight_name=name_) |
| |
|
| | return pipe |
| |
|
| |
|
| | def remove_bg(image, net, transform, device): |
| | image_size = image.size |
| | input_images = transform(image).unsqueeze(0).to(device) |
| | with torch.no_grad(): |
| | preds = net(input_images)[-1].sigmoid().cpu() |
| | pred = preds[0].squeeze() |
| | pred_pil = transforms.ToPILImage()(pred) |
| | mask = pred_pil.resize(image_size) |
| | image.putalpha(mask) |
| | return image |
| |
|
| |
|
| | def preprocess_image(image: Image.Image, height, width): |
| | image = np.array(image) |
| | alpha = image[..., 3] > 0 |
| | H, W = alpha.shape |
| | |
| | y, x = np.where(alpha) |
| | y0, y1 = max(y.min() - 1, 0), min(y.max() + 1, H) |
| | x0, x1 = max(x.min() - 1, 0), min(x.max() + 1, W) |
| | image_center = image[y0:y1, x0:x1] |
| | |
| | H, W, _ = image_center.shape |
| | if H > W: |
| | W = int(W * (height * 0.9) / H) |
| | H = int(height * 0.9) |
| | else: |
| | H = int(H * (width * 0.9) / W) |
| | W = int(width * 0.9) |
| | image_center = np.array(Image.fromarray(image_center).resize((W, H))) |
| | |
| | start_h = (height - H) // 2 |
| | start_w = (width - W) // 2 |
| | image = np.zeros((height, width, 4), dtype=np.uint8) |
| | image[start_h : start_h + H, start_w : start_w + W] = image_center |
| | image = image.astype(np.float32) / 255.0 |
| | image = image[:, :, :3] * image[:, :, 3:4] + (1 - image[:, :, 3:4]) * 0.5 |
| | image = (image * 255).clip(0, 255).astype(np.uint8) |
| | image = Image.fromarray(image) |
| |
|
| | return image |
| |
|
| |
|
| | def run_pipeline( |
| | pipe, |
| | num_views, |
| | text, |
| | image, |
| | height, |
| | width, |
| | num_inference_steps, |
| | guidance_scale, |
| | seed, |
| | remove_bg_fn=None, |
| | reference_conditioning_scale=1.0, |
| | negative_prompt="watermark, ugly, deformed, noisy, blurry, low contrast", |
| | lora_scale=1.0, |
| | device="cuda", |
| | ): |
| | |
| | cameras = get_orthogonal_camera( |
| | elevation_deg=[0, 0, 0, 0, 0, 0], |
| | distance=[1.8] * num_views, |
| | left=-0.55, |
| | right=0.55, |
| | bottom=-0.55, |
| | top=0.55, |
| | azimuth_deg=[x - 90 for x in [0, 45, 90, 180, 270, 315]], |
| | device=device, |
| | ) |
| |
|
| | plucker_embeds = get_plucker_embeds_from_cameras_ortho( |
| | cameras.c2w, [1.1] * num_views, width |
| | ) |
| | control_images = ((plucker_embeds + 1.0) / 2.0).clamp(0, 1) |
| |
|
| | |
| | reference_image = Image.open(image) if isinstance(image, str) else image |
| | if remove_bg_fn is not None: |
| | reference_image = remove_bg_fn(reference_image) |
| | reference_image = preprocess_image(reference_image, height, width) |
| | elif reference_image.mode == "RGBA": |
| | reference_image = preprocess_image(reference_image, height, width) |
| |
|
| | pipe_kwargs = {} |
| | if seed != -1 and isinstance(seed, int): |
| | pipe_kwargs["generator"] = torch.Generator(device=device).manual_seed(seed) |
| |
|
| | images = pipe( |
| | text, |
| | height=height, |
| | width=width, |
| | num_inference_steps=num_inference_steps, |
| | guidance_scale=guidance_scale, |
| | num_images_per_prompt=num_views, |
| | control_image=control_images, |
| | control_conditioning_scale=1.0, |
| | reference_image=reference_image, |
| | reference_conditioning_scale=reference_conditioning_scale, |
| | negative_prompt=negative_prompt, |
| | cross_attention_kwargs={"scale": lora_scale}, |
| | **pipe_kwargs, |
| | ).images |
| |
|
| | return images, reference_image |
| |
|
| |
|
| | if __name__ == "__main__": |
| | parser = argparse.ArgumentParser() |
| | |
| | parser.add_argument( |
| | "--base_model", type=str, default="stabilityai/stable-diffusion-xl-base-1.0" |
| | ) |
| | parser.add_argument( |
| | "--vae_model", type=str, default="madebyollin/sdxl-vae-fp16-fix" |
| | ) |
| | parser.add_argument("--unet_model", type=str, default=None) |
| | parser.add_argument("--scheduler", type=str, default=None) |
| | parser.add_argument("--lora_model", type=str, default=None) |
| | parser.add_argument("--adapter_path", type=str, default="huanngzh/mv-adapter") |
| | parser.add_argument("--num_views", type=int, default=6) |
| | |
| | parser.add_argument("--device", type=str, default="cuda") |
| | |
| | parser.add_argument("--image", type=str, required=True) |
| | parser.add_argument("--text", type=str, default="high quality") |
| | parser.add_argument("--num_inference_steps", type=int, default=50) |
| | parser.add_argument("--guidance_scale", type=float, default=3.0) |
| | parser.add_argument("--seed", type=int, default=-1) |
| | parser.add_argument("--lora_scale", type=float, default=1.0) |
| | parser.add_argument("--reference_conditioning_scale", type=float, default=1.0) |
| | parser.add_argument( |
| | "--negative_prompt", |
| | type=str, |
| | default="watermark, ugly, deformed, noisy, blurry, low contrast", |
| | ) |
| | parser.add_argument("--output", type=str, default="output.png") |
| | |
| | parser.add_argument("--remove_bg", action="store_true", help="Remove background") |
| | args = parser.parse_args() |
| |
|
| | pipe = prepare_pipeline( |
| | base_model=args.base_model, |
| | vae_model=args.vae_model, |
| | unet_model=args.unet_model, |
| | lora_model=args.lora_model, |
| | adapter_path=args.adapter_path, |
| | scheduler=args.scheduler, |
| | num_views=args.num_views, |
| | device=args.device, |
| | dtype=torch.float16, |
| | ) |
| |
|
| | if args.remove_bg: |
| | birefnet = AutoModelForImageSegmentation.from_pretrained( |
| | "ZhengPeng7/BiRefNet", trust_remote_code=True |
| | ) |
| | birefnet.to(args.device) |
| | transform_image = transforms.Compose( |
| | [ |
| | transforms.Resize((1024, 1024)), |
| | transforms.ToTensor(), |
| | transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]), |
| | ] |
| | ) |
| | remove_bg_fn = lambda x: remove_bg(x, birefnet, transform_image, args.device) |
| | else: |
| | remove_bg_fn = None |
| |
|
| | images, reference_image = run_pipeline( |
| | pipe, |
| | num_views=args.num_views, |
| | text=args.text, |
| | image=args.image, |
| | height=768, |
| | width=768, |
| | num_inference_steps=args.num_inference_steps, |
| | guidance_scale=args.guidance_scale, |
| | seed=args.seed, |
| | lora_scale=args.lora_scale, |
| | reference_conditioning_scale=args.reference_conditioning_scale, |
| | negative_prompt=args.negative_prompt, |
| | device=args.device, |
| | remove_bg_fn=remove_bg_fn, |
| | ) |
| | make_image_grid(images, rows=1).save(args.output) |
| | reference_image.save(args.output.rsplit(".", 1)[0] + "_reference.png") |