| | from oldm.hack import disable_verbosity |
| | disable_verbosity() |
| |
|
| | import sys |
| | import os |
| | import cv2 |
| | import einops |
| | import gradio as gr |
| | import numpy as np |
| | import torch |
| | import random |
| | import json |
| | import argparse |
| |
|
| | file_path = os.path.abspath(__file__) |
| | parent_dir = os.path.abspath(os.path.dirname(file_path) + '/..') |
| | if parent_dir not in sys.path: |
| | sys.path.append(parent_dir) |
| |
|
| | from PIL import Image |
| | from pytorch_lightning import seed_everything |
| | from oldm.model import create_model, load_state_dict |
| | from oldm.ddim_hacked import DDIMSampler |
| | from oft import inject_trainable_oft, inject_trainable_oft_conv, inject_trainable_oft_extended, inject_trainable_oft_with_norm |
| | from hra import inject_trainable_hra |
| | from lora import inject_trainable_lora |
| |
|
| | from dataset.utils import return_dataset |
| |
|
| |
|
| | def process(input_image, prompt, hint_image, a_prompt, n_prompt, num_samples, image_resolution, ddim_steps, guess_mode, strength, scale, seed, eta, low_threshold, high_threshold): |
| | with torch.no_grad(): |
| | |
| | H, W, C = input_image.shape |
| |
|
| | |
| | |
| |
|
| | |
| | control = torch.from_numpy(hint_image.copy()).float().cuda() |
| | control = torch.stack([control for _ in range(num_samples)], dim=0) |
| | control = einops.rearrange(control, 'b h w c -> b c h w').clone() |
| |
|
| | if seed == -1: |
| | seed = random.randint(0, 65535) |
| | seed_everything(seed) |
| |
|
| | |
| | |
| |
|
| | cond = {"c_concat": [control], "c_crossattn": [model.get_learned_conditioning([prompt + ', ' + a_prompt] * num_samples)]} |
| | un_cond = {"c_concat": None if guess_mode else [control], "c_crossattn": [model.get_learned_conditioning([n_prompt] * num_samples)]} |
| | shape = (4, H // 8, W // 8) |
| |
|
| | |
| | |
| |
|
| | model.control_scales = [strength * (0.825 ** float(12 - i)) for i in range(13)] if guess_mode else ([strength] * 13) |
| | samples, intermediates = ddim_sampler.sample(ddim_steps, num_samples, |
| | shape, cond, verbose=False, eta=eta, |
| | unconditional_guidance_scale=scale, |
| | unconditional_conditioning=un_cond) |
| |
|
| | |
| | |
| |
|
| | x_samples = model.decode_first_stage(samples) |
| | x_samples = (einops.rearrange(x_samples, 'b c h w -> b h w c') * 127.5 + 127.5).cpu().numpy().clip(0, 255).astype(np.uint8) |
| |
|
| | results = [x_samples[i] for i in range(num_samples)] |
| | |
| | return [input_image] + [hint_image] + results |
| |
|
| | parser = argparse.ArgumentParser() |
| |
|
| | parser.add_argument('--d', type=int, help='the index of GPU', default=0) |
| |
|
| | |
| | parser.add_argument('--hra_r', type=int, default=8) |
| | parser.add_argument('--hra_apply_GS', action="store_true", default=False) |
| |
|
| | |
| | parser.add_argument('--oft_r', type=int, default=4) |
| | parser.add_argument('--oft_eps', |
| | type=float, |
| | choices=[1e-3, 2e-5, 7e-6], |
| | default=7e-6, |
| | ) |
| | parser.add_argument('--oft_coft', action="store_true", default=True) |
| | parser.add_argument('--oft_block_share', action="store_true", default=False) |
| | parser.add_argument('--img_ID', type=int, default=1) |
| | parser.add_argument('--num_samples', type=int, default=1) |
| | parser.add_argument('--batch', type=int, default=20) |
| | parser.add_argument('--sd_locked', action="store_true", default=True) |
| | parser.add_argument('--only_mid_control', action="store_true", default=False) |
| | parser.add_argument('--num_gpus', type=int, default=8) |
| | |
| | parser.add_argument('--time_str', type=str, default='2024-03-18-10-55-21-089985') |
| | parser.add_argument('--epoch', type=int, default=19) |
| | parser.add_argument('--control', |
| | type=str, |
| | help='control signal. Options are [segm, sketch, densepose, depth, canny, landmark]', |
| | default="segm") |
| |
|
| | args = parser.parse_args() |
| |
|
| | if __name__ == '__main__': |
| | |
| | epoch = args.epoch |
| | control = args.control |
| | _, dataset, data_name, logger_freq, max_epochs = return_dataset(control, full=True) |
| |
|
| | |
| | |
| | |
| | num_gpus = args.num_gpus |
| | time_str = args.time_str |
| | |
| | experiment = './log/image_log_hra_0.0_ADE20K_segm_pe_diff_mlp_r_8_8gpu_2024-06-27-19-57-34-979197' |
| | |
| | |
| | assert args.control in experiment |
| | |
| | if 'train_with_norm' in experiment: |
| | epoch = 4 |
| | else: |
| | if 'COCO' in experiment: |
| | epoch = 9 |
| | else: |
| | epoch = 19 |
| | |
| | resume_path = os.path.join(experiment, f'model-epoch={epoch:02d}.ckpt') |
| | sd_locked = args.sd_locked |
| | only_mid_control = args.only_mid_control |
| | device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu') |
| |
|
| | |
| | result_dir = os.path.join(experiment, 'results', str(epoch)) |
| | os.makedirs(result_dir, exist_ok=True) |
| | source_dir = os.path.join(experiment, 'source', str(epoch)) |
| | os.makedirs(source_dir, exist_ok=True) |
| | hint_dir = os.path.join(experiment, 'hints', str(epoch)) |
| | os.makedirs(hint_dir, exist_ok=True) |
| |
|
| | model = create_model('./configs/oft_ldm_v15.yaml').cpu() |
| | model.model.requires_grad_(False) |
| |
|
| | if 'hra' in experiment: |
| | unet_lora_params, train_names = inject_trainable_hra(model.model, r=args.hra_r, apply_GS=args.hra_apply_GS) |
| | elif 'lora' in experiment: |
| | unet_lora_params, train_names = inject_trainable_lora(model.model, rank=args.r, network_alpha=None) |
| | else: |
| | if 'train_with_norm' in experiment: |
| | unet_opt_params, train_names = inject_trainable_oft_with_norm(model.model, r=args.oft_r, eps=args.oft_eps, is_coft=args.oft_coft, block_share=args.oft_block_share) |
| | else: |
| | unet_lora_params, train_names = inject_trainable_oft(model.model, r=args.oft_r, eps=args.oft_eps, is_coft=args.oft_coft, block_share=args.oft_block_share) |
| | |
| | |
| | |
| | |
| | model.load_state_dict(load_state_dict(resume_path, location='cuda')) |
| | model = model.cuda() |
| | ddim_sampler = DDIMSampler(model) |
| |
|
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | num_pack = len(dataset) // args.num_gpus |
| | start_idx = args.d * num_pack |
| | end_idx = (args.d + 1) * num_pack if args.d < args.num_gpus - 1 else len(dataset) |
| | |
| | for idx in range(start_idx, end_idx): |
| | |
| | data = dataset[idx] |
| | input_image, prompt, hint = data['jpg'], data['txt'], data['hint'] |
| | |
| | |
| | if not os.path.exists(os.path.join(result_dir, f'result_{idx}_0.png')): |
| | result_images = process( |
| | input_image=input_image, |
| | prompt=prompt, |
| | hint_image=hint, |
| | a_prompt="", |
| | n_prompt="", |
| | num_samples=args.num_samples, |
| | image_resolution=512, |
| | ddim_steps=50, |
| | guess_mode=False, |
| | strength=1, |
| | scale=9.0, |
| | seed=-1, |
| | eta=0.0, |
| | low_threshold=100, |
| | high_threshold=200, |
| | ) |
| | for i, image in enumerate(result_images): |
| | if i == 0: |
| | image = ((image + 1) * 127.5).clip(0, 255).astype(np.uint8) |
| | pil_image = Image.fromarray(image) |
| | output_path = os.path.join(source_dir, f'image_{idx}.png') |
| | pil_image.save(output_path) |
| | elif i == 1: |
| | image = (image * 255).clip(0, 255).astype(np.uint8) |
| | |
| | pil_image = Image.fromarray(image) |
| | |
| | output_path = os.path.join(hint_dir, f'hint_{idx}.png') |
| | pil_image.save(output_path) |
| | else: |
| | n = i - 2 |
| | |
| | pil_image = Image.fromarray(image) |
| | |
| | output_path = os.path.join(result_dir, f'result_{idx}_{n}.png') |
| | pil_image.save(output_path) |
| |
|
| | |
| | |
| | |
| | |
| |
|
| | |
| | |
| |
|
| | |
| |
|