| import cv2 |
| import einops |
| import numpy as np |
| import torch |
| import random |
| from pytorch_lightning import seed_everything |
| from cldm.model import create_model, load_state_dict |
| from cldm.ddim_hacked import DDIMSampler |
| from cldm.hack import disable_verbosity, enable_sliced_attention |
| from datasets.data_utils import * |
| cv2.setNumThreads(0) |
| cv2.ocl.setUseOpenCL(False) |
| import albumentations as A |
| from omegaconf import OmegaConf |
| from PIL import Image |
|
|
|
|
| save_memory = False |
| disable_verbosity() |
| if save_memory: |
| enable_sliced_attention() |
|
|
|
|
| config = OmegaConf.load('./configs/inference.yaml') |
| model_ckpt = config.pretrained_model |
| model_config = config.config_file |
|
|
| model = create_model(model_config ).cpu() |
| model.load_state_dict(load_state_dict(model_ckpt, location='cuda')) |
| model = model.cuda() |
| ddim_sampler = DDIMSampler(model) |
|
|
|
|
|
|
| def aug_data_mask(image, mask): |
| transform = A.Compose([ |
| A.HorizontalFlip(p=0.5), |
| A.RandomBrightnessContrast(p=0.5), |
| ]) |
| transformed = transform(image=image.astype(np.uint8), mask = mask) |
| transformed_image = transformed["image"] |
| transformed_mask = transformed["mask"] |
| return transformed_image, transformed_mask |
|
|
|
|
| def process_pairs(ref_image, ref_mask, tar_image, tar_mask): |
| |
| |
| ref_box_yyxx = get_bbox_from_mask(ref_mask) |
|
|
| |
| ref_mask_3 = np.stack([ref_mask,ref_mask,ref_mask],-1) |
| masked_ref_image = ref_image * ref_mask_3 + np.ones_like(ref_image) * 255 * (1-ref_mask_3) |
|
|
| y1,y2,x1,x2 = ref_box_yyxx |
| masked_ref_image = masked_ref_image[y1:y2,x1:x2,:] |
| ref_mask = ref_mask[y1:y2,x1:x2] |
|
|
|
|
| ratio = np.random.randint(12, 13) / 10 |
| masked_ref_image, ref_mask = expand_image_mask(masked_ref_image, ref_mask, ratio=ratio) |
| ref_mask_3 = np.stack([ref_mask,ref_mask,ref_mask],-1) |
|
|
| |
| masked_ref_image = pad_to_square(masked_ref_image, pad_value = 255, random = False) |
| masked_ref_image = cv2.resize(masked_ref_image, (224,224) ).astype(np.uint8) |
|
|
| ref_mask_3 = pad_to_square(ref_mask_3 * 255, pad_value = 0, random = False) |
| ref_mask_3 = cv2.resize(ref_mask_3, (224,224) ).astype(np.uint8) |
| ref_mask = ref_mask_3[:,:,0] |
|
|
| |
| masked_ref_image_aug = masked_ref_image |
|
|
| |
| masked_ref_image_compose, ref_mask_compose = masked_ref_image, ref_mask |
| masked_ref_image_aug = masked_ref_image_compose.copy() |
| ref_mask_3 = np.stack([ref_mask_compose,ref_mask_compose,ref_mask_compose],-1) |
| ref_image_collage = sobel(masked_ref_image_compose, ref_mask_compose/255) |
|
|
| |
| tar_box_yyxx = get_bbox_from_mask(tar_mask) |
| tar_box_yyxx = expand_bbox(tar_mask, tar_box_yyxx, ratio=[1.1,1.2]) |
|
|
| |
| tar_box_yyxx_crop = expand_bbox(tar_image, tar_box_yyxx, ratio=[1.5, 3]) |
| tar_box_yyxx_crop = box2squre(tar_image, tar_box_yyxx_crop) |
| y1,y2,x1,x2 = tar_box_yyxx_crop |
|
|
| cropped_target_image = tar_image[y1:y2,x1:x2,:] |
| tar_box_yyxx = box_in_box(tar_box_yyxx, tar_box_yyxx_crop) |
| y1,y2,x1,x2 = tar_box_yyxx |
|
|
| |
| ref_image_collage = cv2.resize(ref_image_collage, (x2-x1, y2-y1)) |
| ref_mask_compose = cv2.resize(ref_mask_compose.astype(np.uint8), (x2-x1, y2-y1)) |
| ref_mask_compose = (ref_mask_compose > 128).astype(np.uint8) |
|
|
| collage = cropped_target_image.copy() |
| collage[y1:y2,x1:x2,:] = ref_image_collage |
|
|
| collage_mask = cropped_target_image.copy() * 0.0 |
| collage_mask[y1:y2,x1:x2,:] = 1.0 |
|
|
| |
| H1, W1 = collage.shape[0], collage.shape[1] |
| cropped_target_image = pad_to_square(cropped_target_image, pad_value = 0, random = False).astype(np.uint8) |
| collage = pad_to_square(collage, pad_value = 0, random = False).astype(np.uint8) |
| collage_mask = pad_to_square(collage_mask, pad_value = -1, random = False).astype(np.uint8) |
|
|
| |
| H2, W2 = collage.shape[0], collage.shape[1] |
| cropped_target_image = cv2.resize(cropped_target_image, (512,512)).astype(np.float32) |
| collage = cv2.resize(collage, (512,512)).astype(np.float32) |
| collage_mask = (cv2.resize(collage_mask, (512,512)).astype(np.float32) > 0.5).astype(np.float32) |
|
|
| masked_ref_image_aug = masked_ref_image_aug / 255 |
| cropped_target_image = cropped_target_image / 127.5 - 1.0 |
| collage = collage / 127.5 - 1.0 |
| collage = np.concatenate([collage, collage_mask[:,:,:1] ] , -1) |
|
|
| item = dict(ref=masked_ref_image_aug.copy(), jpg=cropped_target_image.copy(), hint=collage.copy(), extra_sizes=np.array([H1, W1, H2, W2]), tar_box_yyxx_crop=np.array( tar_box_yyxx_crop ) ) |
| return item |
|
|
|
|
| def crop_back( pred, tar_image, extra_sizes, tar_box_yyxx_crop): |
| H1, W1, H2, W2 = extra_sizes |
| y1,y2,x1,x2 = tar_box_yyxx_crop |
| pred = cv2.resize(pred, (W2, H2)) |
| m = 5 |
|
|
| if W1 == H1: |
| tar_image[y1+m :y2-m, x1+m:x2-m, :] = pred[m:-m, m:-m] |
| return tar_image |
|
|
| if W1 < W2: |
| pad1 = int((W2 - W1) / 2) |
| pad2 = W2 - W1 - pad1 |
| pred = pred[:,pad1: -pad2, :] |
| else: |
| pad1 = int((H2 - H1) / 2) |
| pad2 = H2 - H1 - pad1 |
| pred = pred[pad1: -pad2, :, :] |
|
|
| gen_image = tar_image.copy() |
| gen_image[y1+m :y2-m, x1+m:x2-m, :] = pred[m:-m, m:-m] |
| return gen_image |
|
|
|
|
| def inference_single_image(ref_image, ref_mask, tar_image, tar_mask, guidance_scale = 5.0): |
| item = process_pairs(ref_image, ref_mask, tar_image, tar_mask) |
| ref = item['ref'] * 255 |
| tar = item['jpg'] * 127.5 + 127.5 |
| hint = item['hint'] * 127.5 + 127.5 |
|
|
| hint_image = hint[:,:,:-1] |
| hint_mask = item['hint'][:,:,-1] * 255 |
| hint_mask = np.stack([hint_mask,hint_mask,hint_mask],-1) |
| ref = cv2.resize(ref.astype(np.uint8), (512,512)) |
|
|
| seed = random.randint(0, 65535) |
| if save_memory: |
| model.low_vram_shift(is_diffusing=False) |
|
|
| ref = item['ref'] |
| tar = item['jpg'] |
| hint = item['hint'] |
| num_samples = 1 |
|
|
| control = torch.from_numpy(hint.copy()).float().cuda() |
| control = torch.stack([control for _ in range(num_samples)], dim=0) |
| control = einops.rearrange(control, 'b h w c -> b c h w').clone() |
|
|
|
|
| clip_input = torch.from_numpy(ref.copy()).float().cuda() |
| clip_input = torch.stack([clip_input for _ in range(num_samples)], dim=0) |
| clip_input = einops.rearrange(clip_input, 'b h w c -> b c h w').clone() |
|
|
| guess_mode = False |
| H,W = 512,512 |
|
|
| cond = {"c_concat": [control], "c_crossattn": [model.get_learned_conditioning( clip_input )]} |
| un_cond = {"c_concat": None if guess_mode else [control], "c_crossattn": [model.get_learned_conditioning([torch.zeros((1,3,224,224))] * num_samples)]} |
| shape = (4, H // 8, W // 8) |
|
|
| if save_memory: |
| model.low_vram_shift(is_diffusing=True) |
|
|
| |
| num_samples = 1 |
| image_resolution = 512 |
| strength = 1 |
| guess_mode = False |
| |
| ddim_steps = 50 |
| scale = guidance_scale |
| seed = -1 |
| eta = 0.0 |
|
|
| model.control_scales = [strength * (0.825 ** float(12 - i)) for i in range(13)] if guess_mode else ([strength] * 13) |
| samples, intermediates = ddim_sampler.sample(ddim_steps, num_samples, |
| shape, cond, verbose=False, eta=eta, |
| unconditional_guidance_scale=scale, |
| unconditional_conditioning=un_cond) |
| if save_memory: |
| model.low_vram_shift(is_diffusing=False) |
|
|
| x_samples = model.decode_first_stage(samples) |
| x_samples = (einops.rearrange(x_samples, 'b c h w -> b h w c') * 127.5 + 127.5).cpu().numpy() |
|
|
| result = x_samples[0][:,:,::-1] |
| result = np.clip(result,0,255) |
|
|
| pred = x_samples[0] |
| pred = np.clip(pred,0,255)[1:,:,:] |
| sizes = item['extra_sizes'] |
| tar_box_yyxx_crop = item['tar_box_yyxx_crop'] |
| gen_image = crop_back(pred, tar_image, sizes, tar_box_yyxx_crop) |
| return gen_image |
|
|
|
|
| if __name__ == '__main__': |
| |
| import os |
| import itertools |
|
|
| save_dir = '/work/we_select/out' |
| cloth_dir = '/work/we_select/cloth' |
| cloth_mask_dir = '/work/we_select/cloth-mask' |
| image_dir = '/work/we_select/image' |
| image_parse_v3_dir = '/work/we_select/image-parse-v3' |
|
|
| if not os.path.exists(save_dir): |
| os.makedirs(save_dir) |
|
|
| cloth_image_names = os.listdir(cloth_dir) |
| ref_image_names = os.listdir(image_dir) |
|
|
| assert len(ref_image_names) > 0, "No reference images found" |
|
|
| ref_images_cycle = itertools.cycle(ref_image_names) |
|
|
| for cloth_image_name in cloth_image_names: |
| ref_image_name = next(ref_images_cycle) |
| |
| |
| cloth_image_path = os.path.join(cloth_dir, cloth_image_name) |
| cloth_mask_path = os.path.join(cloth_mask_dir, cloth_image_name) |
| |
| |
| ref_image_path = os.path.join(image_dir, ref_image_name) |
| ref_mask_path = os.path.join(image_parse_v3_dir, ref_image_name.replace('.jpg', '.png')) |
|
|
| |
| cloth_image = cv2.imread(cloth_image_path) |
| cloth_image = cv2.cvtColor(cloth_image, cv2.COLOR_BGR2RGB) |
| cloth_mask = (cv2.imread(cloth_mask_path) > 128).astype(np.uint8)[:, :, 0] |
|
|
| |
| ref_image = cv2.imread(ref_image_path) |
| ref_image = cv2.cvtColor(ref_image, cv2.COLOR_BGR2RGB) |
| ref_mask = Image.open(ref_mask_path).convert('P') |
| ref_mask = np.array(ref_mask) == 5 |
|
|
| |
| gen_image = inference_single_image(cloth_image, cloth_mask, ref_image, ref_mask) |
| gen_path = os.path.join(save_dir, str('5_' + cloth_image_name)) |
|
|
| |
| vis_image = cv2.hconcat([cloth_image, ref_image, gen_image]) |
| cv2.imwrite(gen_path, vis_image[:, :, ::-1]) |
|
|
| |
| |
| |
| cloth_image = cv2.imread(cloth_image_path) |
| cloth_image = cv2.cvtColor(cloth_image, cv2.COLOR_BGR2RGB) |
| cloth_mask = (cv2.imread(cloth_mask_path) > 128).astype(np.uint8)[:, :, 0] |
|
|
| |
| ref_image = cv2.imread(ref_image_path) |
| ref_image = cv2.cvtColor(ref_image, cv2.COLOR_BGR2RGB) |
| ref_mask = Image.open(ref_mask_path).convert('P') |
| ref_mask = np.array(ref_mask) == 9 |
|
|
| |
| gen_image = inference_single_image(cloth_image, cloth_mask, ref_image, ref_mask) |
| gen_path = os.path.join(save_dir, str('9_' + cloth_image_name)) |
|
|
| |
| vis_image = cv2.hconcat([cloth_image, ref_image, gen_image]) |
| cv2.imwrite(gen_path, vis_image[:, :, ::-1]) |
|
|
|
|