# copied from https://github.com/huanngzh/MV-Adapter/blob/main/scripts/inference_ig2mv_partial_sdxl.py import argparse import json import numpy as np import torch from diffusers import AutoencoderKL, DDPMScheduler, LCMScheduler, UNet2DConditionModel from mvadapter.models.attention_processor import DecoupledMVRowColSelfAttnProcessor2_0 from mvadapter.pipelines.pipeline_mvadapter_i2mv_sdxl import MVAdapterI2MVSDXLPipeline from mvadapter.schedulers.scheduling_shift_snr import ShiftSNRScheduler from mvadapter.utils import make_image_grid, tensor_to_image from mvadapter.utils.mesh_utils import ( NVDiffRastContextWrapper, get_orthogonal_camera, load_mesh, render, ) from PIL import Image from torchvision import transforms from tqdm import tqdm from transformers import AutoModelForImageSegmentation def prepare_pipeline( base_model, vae_model, unet_model, lora_model, adapter_path, scheduler, num_views, device, dtype, ): # Load vae and unet if provided pipe_kwargs = {} if vae_model is not None: pipe_kwargs["vae"] = AutoencoderKL.from_pretrained(vae_model) if unet_model is not None: pipe_kwargs["unet"] = UNet2DConditionModel.from_pretrained(unet_model) # Prepare pipeline pipe: MVAdapterI2MVSDXLPipeline pipe = MVAdapterI2MVSDXLPipeline.from_pretrained(base_model, **pipe_kwargs) # Load scheduler if provided scheduler_class = None if scheduler == "ddpm": scheduler_class = DDPMScheduler elif scheduler == "lcm": scheduler_class = LCMScheduler pipe.scheduler = ShiftSNRScheduler.from_scheduler( pipe.scheduler, shift_mode="interpolated", shift_scale=8.0, scheduler_class=scheduler_class, ) pipe.init_custom_adapter( num_views=num_views, self_attn_processor=DecoupledMVRowColSelfAttnProcessor2_0 ) pipe.load_custom_adapter( adapter_path, weight_name="mvadapter_ig2mv_partial_sdxl.safetensors" ) pipe.to(device=device, dtype=dtype) pipe.cond_encoder.to(device=device, dtype=dtype) # load lora if provided if lora_model is not None: model_, name_ = lora_model.rsplit("/", 1) pipe.load_lora_weights(model_, weight_name=name_) pipe.enable_vae_slicing() return pipe def remove_bg(image, net, transform, device): image_size = image.size input_images = transform(image).unsqueeze(0).to(device) with torch.no_grad(): preds = net(input_images)[-1].sigmoid().cpu() pred = preds[0].squeeze() pred_pil = transforms.ToPILImage()(pred) mask = pred_pil.resize(image_size) image.putalpha(mask) return image def preprocess_image(image: Image.Image, height, width): image = np.array(image) alpha = image[..., 3] > 0 H, W = alpha.shape # get the bounding box of alpha y, x = np.where(alpha) y0, y1 = max(y.min() - 1, 0), min(y.max() + 1, H) x0, x1 = max(x.min() - 1, 0), min(x.max() + 1, W) image_center = image[y0:y1, x0:x1] # resize the longer side to H * 0.9 H, W, _ = image_center.shape if H > W: W = int(W * (height * 0.9) / H) H = int(height * 0.9) else: H = int(H * (width * 0.9) / W) W = int(width * 0.9) image_center = np.array(Image.fromarray(image_center).resize((W, H))) # pad to H, W start_h = (height - H) // 2 start_w = (width - W) // 2 image = np.zeros((height, width, 4), dtype=np.uint8) image[start_h : start_h + H, start_w : start_w + W] = image_center image = image.astype(np.float32) / 255.0 image = image[:, :, :3] * image[:, :, 3:4] + (1 - image[:, :, 3:4]) * 0.5 image = (image * 255).clip(0, 255).astype(np.uint8) image = Image.fromarray(image) return image def run_pipeline( pipe, mesh_path, num_views, text, image, height, width, num_inference_steps, guidance_scale, seed, remove_bg_fn=None, reference_conditioning_scale=1.0, negative_prompt="watermark, ugly, deformed, noisy, blurry, low contrast", lora_scale=1.0, device="cuda", ): # Prepare cameras cameras = get_orthogonal_camera( elevation_deg=[0, 0, 0, 0, 89.99, -89.99], distance=[1.8] * num_views, left=-0.55, right=0.55, bottom=-0.55, top=0.55, azimuth_deg=[x - 90 for x in [0, 90, 180, 270, 180, 180]], device=device, ) ctx = NVDiffRastContextWrapper(device=device) mesh, offset, scale = load_mesh( mesh_path, rescale=True, move_to_center=True, device=device, return_transform=True, ) transform_dict = {"offset": offset.tolist(), "scale": scale.tolist()} render_out = render( ctx, mesh, cameras, height=height, width=width, render_attr=False, normal_background=0.0, ) pos_images = tensor_to_image((render_out.pos + 0.5).clamp(0, 1), batched=True) normal_images = tensor_to_image( (render_out.normal / 2 + 0.5).clamp(0, 1), batched=True ) control_images = ( torch.cat( [ (render_out.pos + 0.5).clamp(0, 1), (render_out.normal / 2 + 0.5).clamp(0, 1), ], dim=-1, ) .permute(0, 3, 1, 2) .to(device) ) # Prepare image reference_image = Image.open(image) if isinstance(image, str) else image if remove_bg_fn is not None: reference_image = remove_bg_fn(reference_image) reference_image = preprocess_image(reference_image, height, width) elif reference_image.mode == "RGBA": reference_image = preprocess_image(reference_image, height, width) pipe_kwargs = {} if seed != -1 and isinstance(seed, int): pipe_kwargs["generator"] = torch.Generator(device=device).manual_seed(seed) images = pipe( text, height=height, width=width, num_inference_steps=num_inference_steps, guidance_scale=guidance_scale, num_images_per_prompt=num_views, control_image=control_images, control_conditioning_scale=1.0, reference_image=reference_image, reference_conditioning_scale=reference_conditioning_scale, negative_prompt=negative_prompt, cross_attention_kwargs={"scale": lora_scale}, **pipe_kwargs, ).images return images, pos_images, normal_images, reference_image, transform_dict if __name__ == "__main__": parser = argparse.ArgumentParser() # Models parser.add_argument( "--base_model", type=str, default="stabilityai/stable-diffusion-xl-base-1.0" ) parser.add_argument( "--vae_model", type=str, default="madebyollin/sdxl-vae-fp16-fix" ) parser.add_argument("--unet_model", type=str, default=None) parser.add_argument("--scheduler", type=str, default=None) parser.add_argument("--lora_model", type=str, default=None) parser.add_argument("--adapter_path", type=str, default="huanngzh/mv-adapter") parser.add_argument("--num_views", type=int, default=6) # Device parser.add_argument("--device", type=str, default="cuda") # Inference parser.add_argument("--mesh", type=str, required=True) parser.add_argument("--image", type=str, required=True) parser.add_argument("--text", type=str, required=False, default="high quality") parser.add_argument("--num_inference_steps", type=int, default=50) parser.add_argument("--guidance_scale", type=float, default=3.0) parser.add_argument("--seed", type=int, default=-1) parser.add_argument("--lora_scale", type=float, default=1.0) parser.add_argument("--reference_conditioning_scale", type=float, default=1.0) parser.add_argument( "--negative_prompt", type=str, default="watermark, ugly, deformed, noisy, blurry, low contrast", ) parser.add_argument("--output", type=str, default="output.png") # Extra parser.add_argument("--remove_bg", action="store_true", help="Remove background") args = parser.parse_args() pipe = prepare_pipeline( base_model=args.base_model, vae_model=args.vae_model, unet_model=args.unet_model, lora_model=args.lora_model, adapter_path=args.adapter_path, scheduler=args.scheduler, num_views=args.num_views, device=args.device, dtype=torch.float16, ) if args.remove_bg: birefnet = AutoModelForImageSegmentation.from_pretrained( "ZhengPeng7/BiRefNet", trust_remote_code=True ) birefnet.to(args.device) transform_image = transforms.Compose( [ transforms.Resize((1024, 1024)), transforms.ToTensor(), transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]), ] ) remove_bg_fn = lambda x: remove_bg(x, birefnet, transform_image, args.device) else: remove_bg_fn = None images, pos_images, normal_images, reference_image, transform_dict = run_pipeline( pipe, mesh_path=args.mesh, num_views=args.num_views, text=args.text, image=args.image, height=768, width=768, num_inference_steps=args.num_inference_steps, guidance_scale=args.guidance_scale, seed=args.seed, lora_scale=args.lora_scale, reference_conditioning_scale=args.reference_conditioning_scale, negative_prompt=args.negative_prompt, device=args.device, remove_bg_fn=remove_bg_fn, ) make_image_grid(images, rows=1).save(args.output) make_image_grid(pos_images, rows=1).save(args.output.rsplit(".", 1)[0] + "_pos.png") make_image_grid(normal_images, rows=1).save( args.output.rsplit(".", 1)[0] + "_nor.png" ) reference_image.save(args.output.rsplit(".", 1)[0] + "_reference.png") with open(args.output.rsplit(".", 1)[0] + "_transform.json", "w") as f: json.dump(transform_dict, f, indent=4)