| | """ |
| | Lingteng Qiu |
| | Baseline I2Normal to show the difference demo. |
| | """ |
| |
|
| | import sys |
| |
|
| | sys.path.append("./") |
| |
|
| | import cv2 |
| | import einops |
| | import numpy as np |
| | import torch |
| | from tqdm import tqdm |
| | import matplotlib.pyplot as plt |
| | import os |
| | import torch.nn.functional as F |
| | import matplotlib.pyplot as plt |
| | import json |
| | import pdb |
| | import argparse |
| | import tqlt |
| | import random |
| | from tqlt import utils as tu |
| | from tqlt import op as tqlo |
| | from human_generate_system.engineer.NormalEstimator.data_utils import ( |
| | HWC3, |
| | resize_image, |
| | norm_normalize, |
| | center_crop, |
| | flip_x, |
| | ) |
| | from PIL import Image |
| | from os.path import join |
| |
|
| | ABS_PATH = join(os.path.dirname(os.path.abspath(__file__)), "DSINE") |
| |
|
| |
|
| | if __name__ == "__main__": |
| |
|
| | parser = argparse.ArgumentParser(description="") |
| | parser.add_argument("--num_samples", default=1, type=int) |
| | parser.add_argument("--image_resolution", default=768, type=int) |
| | parser.add_argument("--strength", default=1.0, type=float) |
| | parser.add_argument("--ng_scale", default=1.0, type=float) |
| | parser.add_argument("--ddim_steps", default=10, type=int) |
| | parser.add_argument("--seed", default=23012, type=int) |
| | parser.add_argument("--eta", default=0.0, type=float) |
| | parser.add_argument("--temperature", default=0.0, type=float) |
| | parser.add_argument("--save_memory", action="store_true") |
| | parser.add_argument( |
| | "--negative_prompt", |
| | default="longbody, lowres, bad anatomy, bad hands, missing fingers, extra digit, fewer digits, cropped, worst quality, low quality", |
| | type=str, |
| | ) |
| | parser.add_argument("--input", "-i", default=None, type=str) |
| | parser.add_argument( |
| | "--prior", default="DSINE", type=str, choices=["DSINE", "geowizard"] |
| | ) |
| | parser.add_argument( |
| | "--flip", action="store_true", help="flip init normal and out normal" |
| | ) |
| | parser.add_argument("--wo_center", action="store_true", help="without center crop") |
| | parser.add_argument("--num_gpu", default=1, type=int) |
| | parser.add_argument("--rank", default=0, type=int) |
| |
|
| | opt = parser.parse_args() |
| |
|
| | assert opt.input is not None and os.path.exists(opt.input) |
| |
|
| | if os.path.isfile(opt.input): |
| | input_list = [opt.input] |
| | else: |
| | input_name_list = sorted(os.listdir(opt.input)) |
| | input_list = [os.path.join(opt.input, name) for name in input_name_list] |
| |
|
| | all_data = tu.is_img(input_list) |
| |
|
| | num_samples = opt.num_samples |
| | image_resolution = opt.image_resolution |
| | strength = opt.strength |
| | neg_scale = opt.ng_scale |
| | guess_mode = False |
| | ddim_steps = opt.ddim_steps |
| | seed = opt.seed |
| | eta = opt.eta |
| | temperature = opt.temperature |
| | save_memory = opt.save_memory |
| |
|
| | tu.seed_everything(seed, verbose=True) |
| | |
| |
|
| | if opt.num_gpu > 1: |
| | bucket = len(all_data) // opt.num_gpu |
| | rank = opt.rank |
| | if rank == opt.num_gpu - 1: |
| | all_data = all_data[bucket * rank :] |
| | else: |
| | all_data = all_data[bucket * rank : bucket * (rank + 1)] |
| |
|
| | if opt.prior == "DSINE": |
| | normal_predictor = torch.hub.load( |
| | ABS_PATH, |
| | "DSINE", |
| | local_file_path="./pretrained_models/dsine.pt", |
| | source="local", |
| | ) |
| | else: |
| | raise NotImplementedError |
| |
|
| | if torch.cuda.is_available(): |
| | current_device_id = torch.cuda.current_device() |
| | device = f"cuda:{current_device_id}" |
| | else: |
| | device = "cpu" |
| |
|
| | output_dir = os.path.join(opt.input, "normal") |
| |
|
| | os.makedirs(output_dir, exist_ok=True) |
| |
|
| | for item in tqdm(all_data): |
| |
|
| | input_image_path = item |
| | basename = os.path.basename(item) |
| |
|
| | if opt.wo_center: |
| | input_image = cv2.imread(input_image_path) |
| | else: |
| | input_image = center_crop(cv2.imread(input_image_path)) |
| |
|
| | height, width = input_image.shape[:2] |
| |
|
| | with torch.no_grad(): |
| | raw_input_image = HWC3(input_image) |
| | ori_H, ori_W, _ = raw_input_image.shape |
| |
|
| | img = resize_image(raw_input_image, image_resolution) |
| |
|
| | H, W, C = img.shape |
| | if opt.prior == "DSINE": |
| | pred_normal = normal_predictor.infer_cv2(img)[0] |
| | pred_normal = (pred_normal + 1) / 2 * 255 |
| | pred_normal = pred_normal.cpu().numpy().transpose(1, 2, 0) |
| |
|
| | pred_normal = cv2.cvtColor( |
| | pred_normal.astype(np.uint8), cv2.COLOR_RGB2BGR |
| | ) |
| | elif opt.prior == "geowizard": |
| | pred_normal = normal_predictor(img, image_resolution) |
| | pred_normal = (pred_normal + 1) / 2 * 255 |
| | pred_normal = cv2.cvtColor( |
| | pred_normal.astype(np.uint8), cv2.COLOR_RGB2BGR |
| | ) |
| |
|
| | pred_normal = cv2.resize(pred_normal, (ori_W, ori_H)) |
| |
|
| | basename = os.path.splitext(basename)[0] |
| |
|
| | cv2.imwrite(f"{output_dir}/normal_{basename}.png", pred_normal) |
| |
|