""" Lingteng Qiu Baseline I2Normal to show the difference demo. """ import sys sys.path.append("./") import cv2 import einops import numpy as np import torch from tqdm import tqdm import matplotlib.pyplot as plt import os import torch.nn.functional as F import matplotlib.pyplot as plt import json import pdb import argparse import tqlt import random from tqlt import utils as tu from tqlt import op as tqlo from human_generate_system.engineer.NormalEstimator.data_utils import ( HWC3, resize_image, norm_normalize, center_crop, flip_x, ) from PIL import Image from os.path import join ABS_PATH = join(os.path.dirname(os.path.abspath(__file__)), "DSINE") if __name__ == "__main__": parser = argparse.ArgumentParser(description="") parser.add_argument("--num_samples", default=1, type=int) parser.add_argument("--image_resolution", default=768, type=int) parser.add_argument("--strength", default=1.0, type=float) parser.add_argument("--ng_scale", default=1.0, type=float) parser.add_argument("--ddim_steps", default=10, type=int) parser.add_argument("--seed", default=23012, type=int) parser.add_argument("--eta", default=0.0, type=float) parser.add_argument("--temperature", default=0.0, type=float) parser.add_argument("--save_memory", action="store_true") parser.add_argument( "--negative_prompt", default="longbody, lowres, bad anatomy, bad hands, missing fingers, extra digit, fewer digits, cropped, worst quality, low quality", type=str, ) parser.add_argument("--input", "-i", default=None, type=str) parser.add_argument( "--prior", default="DSINE", type=str, choices=["DSINE", "geowizard"] ) parser.add_argument( "--flip", action="store_true", help="flip init normal and out normal" ) parser.add_argument("--wo_center", action="store_true", help="without center crop") parser.add_argument("--num_gpu", default=1, type=int) parser.add_argument("--rank", default=0, type=int) opt = parser.parse_args() assert opt.input is not None and os.path.exists(opt.input) if os.path.isfile(opt.input): input_list = [opt.input] else: input_name_list = sorted(os.listdir(opt.input)) input_list = [os.path.join(opt.input, name) for name in input_name_list] all_data = tu.is_img(input_list) num_samples = opt.num_samples image_resolution = opt.image_resolution strength = opt.strength neg_scale = opt.ng_scale guess_mode = False ddim_steps = opt.ddim_steps seed = opt.seed eta = opt.eta temperature = opt.temperature save_memory = opt.save_memory tu.seed_everything(seed, verbose=True) # all_data = sorted(all_data, key=lambda x: x['image']) if opt.num_gpu > 1: bucket = len(all_data) // opt.num_gpu rank = opt.rank if rank == opt.num_gpu - 1: all_data = all_data[bucket * rank :] else: all_data = all_data[bucket * rank : bucket * (rank + 1)] if opt.prior == "DSINE": normal_predictor = torch.hub.load( ABS_PATH, "DSINE", local_file_path="./pretrained_models/dsine.pt", source="local", ) else: raise NotImplementedError if torch.cuda.is_available(): current_device_id = torch.cuda.current_device() device = f"cuda:{current_device_id}" else: device = "cpu" output_dir = os.path.join(opt.input, "normal") os.makedirs(output_dir, exist_ok=True) for item in tqdm(all_data): input_image_path = item basename = os.path.basename(item) if opt.wo_center: input_image = cv2.imread(input_image_path) else: input_image = center_crop(cv2.imread(input_image_path)) height, width = input_image.shape[:2] with torch.no_grad(): raw_input_image = HWC3(input_image) ori_H, ori_W, _ = raw_input_image.shape img = resize_image(raw_input_image, image_resolution) H, W, C = img.shape if opt.prior == "DSINE": pred_normal = normal_predictor.infer_cv2(img)[0] # (3, H, W) pred_normal = (pred_normal + 1) / 2 * 255 pred_normal = pred_normal.cpu().numpy().transpose(1, 2, 0) pred_normal = cv2.cvtColor( pred_normal.astype(np.uint8), cv2.COLOR_RGB2BGR ) elif opt.prior == "geowizard": pred_normal = normal_predictor(img, image_resolution) pred_normal = (pred_normal + 1) / 2 * 255 pred_normal = cv2.cvtColor( pred_normal.astype(np.uint8), cv2.COLOR_RGB2BGR ) pred_normal = cv2.resize(pred_normal, (ori_W, ori_H)) basename = os.path.splitext(basename)[0] cv2.imwrite(f"{output_dir}/normal_{basename}.png", pred_normal)