from oldm.hack import disable_verbosity disable_verbosity() import sys import os import cv2 import einops import gradio as gr import numpy as np import torch import random import json import argparse file_path = os.path.abspath(__file__) parent_dir = os.path.abspath(os.path.dirname(file_path) + '/..') if parent_dir not in sys.path: sys.path.append(parent_dir) from PIL import Image from pytorch_lightning import seed_everything from oldm.model import create_model, load_state_dict from oldm.ddim_hacked import DDIMSampler from oft import inject_trainable_oft, inject_trainable_oft_conv, inject_trainable_oft_extended, inject_trainable_oft_with_norm from hra import inject_trainable_hra from lora import inject_trainable_lora from dataset.utils import return_dataset def process(input_image, prompt, hint_image, a_prompt, n_prompt, num_samples, image_resolution, ddim_steps, guess_mode, strength, scale, seed, eta, low_threshold, high_threshold): with torch.no_grad(): # img = resize_image(HWC3(input_image), image_resolution) H, W, C = input_image.shape #detected_map = apply_canny(input_image, low_threshold, high_threshold) #detected_map = HWC3(detected_map) # control = torch.from_numpy(hint_image.copy()).float().cuda() / 255.0 control = torch.from_numpy(hint_image.copy()).float().cuda() control = torch.stack([control for _ in range(num_samples)], dim=0) control = einops.rearrange(control, 'b h w c -> b c h w').clone() if seed == -1: seed = random.randint(0, 65535) seed_everything(seed) # if config.save_memory: # model.low_vram_shift(is_diffusing=False) cond = {"c_concat": [control], "c_crossattn": [model.get_learned_conditioning([prompt + ', ' + a_prompt] * num_samples)]} un_cond = {"c_concat": None if guess_mode else [control], "c_crossattn": [model.get_learned_conditioning([n_prompt] * num_samples)]} shape = (4, H // 8, W // 8) # if config.save_memory: # model.low_vram_shift(is_diffusing=True) model.control_scales = [strength * (0.825 ** float(12 - i)) for i in range(13)] if guess_mode else ([strength] * 13) # Magic number. IDK why. Perhaps because 0.825**12<0.01 but 0.826**12>0.01 samples, intermediates = ddim_sampler.sample(ddim_steps, num_samples, shape, cond, verbose=False, eta=eta, unconditional_guidance_scale=scale, unconditional_conditioning=un_cond) # if config.save_memory: # model.low_vram_shift(is_diffusing=False) x_samples = model.decode_first_stage(samples) x_samples = (einops.rearrange(x_samples, 'b c h w -> b h w c') * 127.5 + 127.5).cpu().numpy().clip(0, 255).astype(np.uint8) results = [x_samples[i] for i in range(num_samples)] # return [255 - hint_image] + results return [input_image] + [hint_image] + results parser = argparse.ArgumentParser() parser.add_argument('--d', type=int, help='the index of GPU', default=0) # HRA parser.add_argument('--hra_r', type=int, default=8) parser.add_argument('--hra_apply_GS', action="store_true", default=False) # OFT parser.add_argument('--oft_r', type=int, default=4) parser.add_argument('--oft_eps', type=float, choices=[1e-3, 2e-5, 7e-6], default=7e-6, ) parser.add_argument('--oft_coft', action="store_true", default=True) parser.add_argument('--oft_block_share', action="store_true", default=False) parser.add_argument('--img_ID', type=int, default=1) parser.add_argument('--num_samples', type=int, default=1) parser.add_argument('--batch', type=int, default=20) parser.add_argument('--sd_locked', action="store_true", default=True) parser.add_argument('--only_mid_control', action="store_true", default=False) parser.add_argument('--num_gpus', type=int, default=8) # parser.add_argument('--time_str', type=str, default=datetime.now().strftime("%Y-%m-%d-%H-%M-%S-%f")) parser.add_argument('--time_str', type=str, default='2024-03-18-10-55-21-089985') parser.add_argument('--epoch', type=int, default=19) parser.add_argument('--control', type=str, help='control signal. Options are [segm, sketch, densepose, depth, canny, landmark]', default="segm") args = parser.parse_args() if __name__ == '__main__': # Configs epoch = args.epoch control = args.control _, dataset, data_name, logger_freq, max_epochs = return_dataset(control, full=True) # specify the experiment name # experiment = './log/image_log_oft_{}_{}_eps_{}_pe_diff_mlp_r_{}_{}gpu'.format(data_name, control, args.eps, args.r, args.num_gpus) num_gpus = args.num_gpus time_str = args.time_str # experiment = 'log/image_log_oft_{}_{}_eps_{}_pe_diff_mlp_r_{}_{}gpu_{}'.format(data_name, control, args.eps, args.r, num_gpus, time_str) experiment = './log/image_log_hra_0.0_ADE20K_segm_pe_diff_mlp_r_8_8gpu_2024-06-27-19-57-34-979197' # experiment = './log/image_log_oft_ADE20K_segm_eps_0.001_pe_diff_mlp_r_4_8gpu_2024-03-25-21-04-17-549433/train_with_norm' assert args.control in experiment if 'train_with_norm' in experiment: epoch = 4 else: if 'COCO' in experiment: epoch = 9 else: epoch = 19 resume_path = os.path.join(experiment, f'model-epoch={epoch:02d}.ckpt') sd_locked = args.sd_locked only_mid_control = args.only_mid_control device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu') # Result directory result_dir = os.path.join(experiment, 'results', str(epoch)) os.makedirs(result_dir, exist_ok=True) source_dir = os.path.join(experiment, 'source', str(epoch)) os.makedirs(source_dir, exist_ok=True) hint_dir = os.path.join(experiment, 'hints', str(epoch)) os.makedirs(hint_dir, exist_ok=True) model = create_model('./configs/oft_ldm_v15.yaml').cpu() model.model.requires_grad_(False) if 'hra' in experiment: unet_lora_params, train_names = inject_trainable_hra(model.model, r=args.hra_r, apply_GS=args.hra_apply_GS) elif 'lora' in experiment: unet_lora_params, train_names = inject_trainable_lora(model.model, rank=args.r, network_alpha=None) else: if 'train_with_norm' in experiment: unet_opt_params, train_names = inject_trainable_oft_with_norm(model.model, r=args.oft_r, eps=args.oft_eps, is_coft=args.oft_coft, block_share=args.oft_block_share) else: unet_lora_params, train_names = inject_trainable_oft(model.model, r=args.oft_r, eps=args.oft_eps, is_coft=args.oft_coft, block_share=args.oft_block_share) # unet_lora_params, train_names = inject_trainable_oft_conv(model.model, r=args.r, eps=args.eps, is_coft=args.coft, block_share=args.block_share) # unet_lora_params, train_names = inject_trainable_oft_extended(model.model, r=args.r, eps=args.eps, is_coft=args.coft, block_share=args.block_share) model.load_state_dict(load_state_dict(resume_path, location='cuda')) model = model.cuda() ddim_sampler = DDIMSampler(model) # pack = range(0, len(dataset), args.batch) # formatted_data = {} # for index in range(args.batch): # # import ipdb; ipdb.set_trace() # start_point = pack[args.img_ID] # idx = start_point + index # canny # img_list = [378, 441, 0, 31, 115, 182, 59, 60, 66, 269, ] # landmark # img_list = [139, 179, 197, 144, 54, 71, 76, 98, 100, 277, ] # segm # img_list = [14, 667, 576, 1387, 1603, 1697, 987, 1830, 1232, 1881, ] # for idx in img_list: num_pack = len(dataset) // args.num_gpus start_idx = args.d * num_pack end_idx = (args.d + 1) * num_pack if args.d < args.num_gpus - 1 else len(dataset) for idx in range(start_idx, end_idx): data = dataset[idx] input_image, prompt, hint = data['jpg'], data['txt'], data['hint'] # input_image, hint = input_image.to(device), hint.to(device) if not os.path.exists(os.path.join(result_dir, f'result_{idx}_0.png')): result_images = process( input_image=input_image, prompt=prompt, hint_image=hint, a_prompt="", n_prompt="", num_samples=args.num_samples, image_resolution=512, ddim_steps=50, guess_mode=False, strength=1, scale=9.0, seed=-1, eta=0.0, low_threshold=100, high_threshold=200, ) for i, image in enumerate(result_images): if i == 0: image = ((image + 1) * 127.5).clip(0, 255).astype(np.uint8) pil_image = Image.fromarray(image) output_path = os.path.join(source_dir, f'image_{idx}.png') pil_image.save(output_path) elif i == 1: image = (image * 255).clip(0, 255).astype(np.uint8) # Convert numpy array to PIL Image pil_image = Image.fromarray(image) # Save PIL Image to file output_path = os.path.join(hint_dir, f'hint_{idx}.png') pil_image.save(output_path) else: n = i - 2 # Convert numpy array to PIL Image pil_image = Image.fromarray(image) # Save PIL Image to file output_path = os.path.join(result_dir, f'result_{idx}_{n}.png') pil_image.save(output_path) # formatted_data[f"item{idx}"] = { # "image_name": f'result_{idx}.png', # "prompt": prompt # } # with open(os.path.join(experiment, 'results_{}.json'.format(img_ID)), 'w') as f: # json.dump(formatted_data, f)