Diffusers
Safetensors
icip_source_2 / scripts /mvadapter_ig2mv.py
hansQAQ's picture
Upload folder using huggingface_hub
278bf35 verified
# copied from https://github.com/huanngzh/MV-Adapter/blob/main/scripts/inference_ig2mv_partial_sdxl.py
import argparse
import json
import numpy as np
import torch
from diffusers import AutoencoderKL, DDPMScheduler, LCMScheduler, UNet2DConditionModel
from mvadapter.models.attention_processor import DecoupledMVRowColSelfAttnProcessor2_0
from mvadapter.pipelines.pipeline_mvadapter_i2mv_sdxl import MVAdapterI2MVSDXLPipeline
from mvadapter.schedulers.scheduling_shift_snr import ShiftSNRScheduler
from mvadapter.utils import make_image_grid, tensor_to_image
from mvadapter.utils.mesh_utils import (
NVDiffRastContextWrapper,
get_orthogonal_camera,
load_mesh,
render,
)
from PIL import Image
from torchvision import transforms
from tqdm import tqdm
from transformers import AutoModelForImageSegmentation
def prepare_pipeline(
base_model,
vae_model,
unet_model,
lora_model,
adapter_path,
scheduler,
num_views,
device,
dtype,
):
# Load vae and unet if provided
pipe_kwargs = {}
if vae_model is not None:
pipe_kwargs["vae"] = AutoencoderKL.from_pretrained(vae_model)
if unet_model is not None:
pipe_kwargs["unet"] = UNet2DConditionModel.from_pretrained(unet_model)
# Prepare pipeline
pipe: MVAdapterI2MVSDXLPipeline
pipe = MVAdapterI2MVSDXLPipeline.from_pretrained(base_model, **pipe_kwargs)
# Load scheduler if provided
scheduler_class = None
if scheduler == "ddpm":
scheduler_class = DDPMScheduler
elif scheduler == "lcm":
scheduler_class = LCMScheduler
pipe.scheduler = ShiftSNRScheduler.from_scheduler(
pipe.scheduler,
shift_mode="interpolated",
shift_scale=8.0,
scheduler_class=scheduler_class,
)
pipe.init_custom_adapter(
num_views=num_views, self_attn_processor=DecoupledMVRowColSelfAttnProcessor2_0
)
pipe.load_custom_adapter(
adapter_path, weight_name="mvadapter_ig2mv_partial_sdxl.safetensors"
)
pipe.to(device=device, dtype=dtype)
pipe.cond_encoder.to(device=device, dtype=dtype)
# load lora if provided
if lora_model is not None:
model_, name_ = lora_model.rsplit("/", 1)
pipe.load_lora_weights(model_, weight_name=name_)
pipe.enable_vae_slicing()
return pipe
def remove_bg(image, net, transform, device):
image_size = image.size
input_images = transform(image).unsqueeze(0).to(device)
with torch.no_grad():
preds = net(input_images)[-1].sigmoid().cpu()
pred = preds[0].squeeze()
pred_pil = transforms.ToPILImage()(pred)
mask = pred_pil.resize(image_size)
image.putalpha(mask)
return image
def preprocess_image(image: Image.Image, height, width):
image = np.array(image)
alpha = image[..., 3] > 0
H, W = alpha.shape
# get the bounding box of alpha
y, x = np.where(alpha)
y0, y1 = max(y.min() - 1, 0), min(y.max() + 1, H)
x0, x1 = max(x.min() - 1, 0), min(x.max() + 1, W)
image_center = image[y0:y1, x0:x1]
# resize the longer side to H * 0.9
H, W, _ = image_center.shape
if H > W:
W = int(W * (height * 0.9) / H)
H = int(height * 0.9)
else:
H = int(H * (width * 0.9) / W)
W = int(width * 0.9)
image_center = np.array(Image.fromarray(image_center).resize((W, H)))
# pad to H, W
start_h = (height - H) // 2
start_w = (width - W) // 2
image = np.zeros((height, width, 4), dtype=np.uint8)
image[start_h : start_h + H, start_w : start_w + W] = image_center
image = image.astype(np.float32) / 255.0
image = image[:, :, :3] * image[:, :, 3:4] + (1 - image[:, :, 3:4]) * 0.5
image = (image * 255).clip(0, 255).astype(np.uint8)
image = Image.fromarray(image)
return image
def run_pipeline(
pipe,
mesh_path,
num_views,
text,
image,
height,
width,
num_inference_steps,
guidance_scale,
seed,
remove_bg_fn=None,
reference_conditioning_scale=1.0,
negative_prompt="watermark, ugly, deformed, noisy, blurry, low contrast",
lora_scale=1.0,
device="cuda",
):
# Prepare cameras
cameras = get_orthogonal_camera(
elevation_deg=[0, 0, 0, 0, 89.99, -89.99],
distance=[1.8] * num_views,
left=-0.55,
right=0.55,
bottom=-0.55,
top=0.55,
azimuth_deg=[x - 90 for x in [0, 90, 180, 270, 180, 180]],
device=device,
)
ctx = NVDiffRastContextWrapper(device=device)
mesh, offset, scale = load_mesh(
mesh_path,
rescale=True,
move_to_center=True,
device=device,
return_transform=True,
)
transform_dict = {"offset": offset.tolist(), "scale": scale.tolist()}
render_out = render(
ctx,
mesh,
cameras,
height=height,
width=width,
render_attr=False,
normal_background=0.0,
)
pos_images = tensor_to_image((render_out.pos + 0.5).clamp(0, 1), batched=True)
normal_images = tensor_to_image(
(render_out.normal / 2 + 0.5).clamp(0, 1), batched=True
)
control_images = (
torch.cat(
[
(render_out.pos + 0.5).clamp(0, 1),
(render_out.normal / 2 + 0.5).clamp(0, 1),
],
dim=-1,
)
.permute(0, 3, 1, 2)
.to(device)
)
# Prepare image
reference_image = Image.open(image) if isinstance(image, str) else image
if remove_bg_fn is not None:
reference_image = remove_bg_fn(reference_image)
reference_image = preprocess_image(reference_image, height, width)
elif reference_image.mode == "RGBA":
reference_image = preprocess_image(reference_image, height, width)
pipe_kwargs = {}
if seed != -1 and isinstance(seed, int):
pipe_kwargs["generator"] = torch.Generator(device=device).manual_seed(seed)
images = pipe(
text,
height=height,
width=width,
num_inference_steps=num_inference_steps,
guidance_scale=guidance_scale,
num_images_per_prompt=num_views,
control_image=control_images,
control_conditioning_scale=1.0,
reference_image=reference_image,
reference_conditioning_scale=reference_conditioning_scale,
negative_prompt=negative_prompt,
cross_attention_kwargs={"scale": lora_scale},
**pipe_kwargs,
).images
return images, pos_images, normal_images, reference_image, transform_dict
if __name__ == "__main__":
parser = argparse.ArgumentParser()
# Models
parser.add_argument(
"--base_model", type=str, default="stabilityai/stable-diffusion-xl-base-1.0"
)
parser.add_argument(
"--vae_model", type=str, default="madebyollin/sdxl-vae-fp16-fix"
)
parser.add_argument("--unet_model", type=str, default=None)
parser.add_argument("--scheduler", type=str, default=None)
parser.add_argument("--lora_model", type=str, default=None)
parser.add_argument("--adapter_path", type=str, default="huanngzh/mv-adapter")
parser.add_argument("--num_views", type=int, default=6)
# Device
parser.add_argument("--device", type=str, default="cuda")
# Inference
parser.add_argument("--mesh", type=str, required=True)
parser.add_argument("--image", type=str, required=True)
parser.add_argument("--text", type=str, required=False, default="high quality")
parser.add_argument("--num_inference_steps", type=int, default=50)
parser.add_argument("--guidance_scale", type=float, default=3.0)
parser.add_argument("--seed", type=int, default=-1)
parser.add_argument("--lora_scale", type=float, default=1.0)
parser.add_argument("--reference_conditioning_scale", type=float, default=1.0)
parser.add_argument(
"--negative_prompt",
type=str,
default="watermark, ugly, deformed, noisy, blurry, low contrast",
)
parser.add_argument("--output", type=str, default="output.png")
# Extra
parser.add_argument("--remove_bg", action="store_true", help="Remove background")
args = parser.parse_args()
pipe = prepare_pipeline(
base_model=args.base_model,
vae_model=args.vae_model,
unet_model=args.unet_model,
lora_model=args.lora_model,
adapter_path=args.adapter_path,
scheduler=args.scheduler,
num_views=args.num_views,
device=args.device,
dtype=torch.float16,
)
if args.remove_bg:
birefnet = AutoModelForImageSegmentation.from_pretrained(
"ZhengPeng7/BiRefNet", trust_remote_code=True
)
birefnet.to(args.device)
transform_image = transforms.Compose(
[
transforms.Resize((1024, 1024)),
transforms.ToTensor(),
transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]),
]
)
remove_bg_fn = lambda x: remove_bg(x, birefnet, transform_image, args.device)
else:
remove_bg_fn = None
images, pos_images, normal_images, reference_image, transform_dict = run_pipeline(
pipe,
mesh_path=args.mesh,
num_views=args.num_views,
text=args.text,
image=args.image,
height=768,
width=768,
num_inference_steps=args.num_inference_steps,
guidance_scale=args.guidance_scale,
seed=args.seed,
lora_scale=args.lora_scale,
reference_conditioning_scale=args.reference_conditioning_scale,
negative_prompt=args.negative_prompt,
device=args.device,
remove_bg_fn=remove_bg_fn,
)
make_image_grid(images, rows=1).save(args.output)
make_image_grid(pos_images, rows=1).save(args.output.rsplit(".", 1)[0] + "_pos.png")
make_image_grid(normal_images, rows=1).save(
args.output.rsplit(".", 1)[0] + "_nor.png"
)
reference_image.save(args.output.rsplit(".", 1)[0] + "_reference.png")
with open(args.output.rsplit(".", 1)[0] + "_transform.json", "w") as f:
json.dump(transform_dict, f, indent=4)