Upload layer_diff_dataset/test_inp_2.py with huggingface_hub
Browse files
layer_diff_dataset/test_inp_2.py
ADDED
|
@@ -0,0 +1,72 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import cv2
|
| 2 |
+
import torch
|
| 3 |
+
from modelscope.outputs import OutputKeys
|
| 4 |
+
from modelscope.pipelines import pipeline
|
| 5 |
+
from modelscope.utils.constant import Tasks
|
| 6 |
+
import PIL
|
| 7 |
+
import numpy as np
|
| 8 |
+
import os
|
| 9 |
+
from tqdm import tqdm
|
| 10 |
+
import random
|
| 11 |
+
|
| 12 |
+
# input_location = 'https://modelscope.oss-cn-beijing.aliyuncs.com/test/images/image_inpainting/image_inpainting_1.png'
|
| 13 |
+
# input_mask_location = 'https://modelscope.oss-cn-beijing.aliyuncs.com/test/images/image_inpainting/image_inpainting_mask_1.png'
|
| 14 |
+
prompt = 'hazy background with nothing on'#'background'
|
| 15 |
+
data_dir = '/mnt/workspace/workgroup/sihui.jsh/layer_diff_dataset'
|
| 16 |
+
split = 'try'
|
| 17 |
+
input_dir = f'{data_dir}/{split}/im'
|
| 18 |
+
input_mask_dir = f'{data_dir}/{split}/gt'
|
| 19 |
+
output_image_path = f'{data_dir}/{split}/bg_inpaint/'
|
| 20 |
+
if not os.path.exists(output_image_path):
|
| 21 |
+
os.mkdir(output_image_path)
|
| 22 |
+
image_inpainting = pipeline(
|
| 23 |
+
Tasks.image_inpainting,
|
| 24 |
+
model='/mnt/workspace/workgroup/sihui.jsh/alpha_work/diffusers/iic/cv_stable-diffusion-v2_image-inpainting_base',
|
| 25 |
+
device='gpu',
|
| 26 |
+
torch_dtype=torch.float16,
|
| 27 |
+
enable_attention_slicing=True,
|
| 28 |
+
)
|
| 29 |
+
|
| 30 |
+
all_images = os.listdir(input_dir)
|
| 31 |
+
all_images.sort()
|
| 32 |
+
# random.shuffle(all_images)
|
| 33 |
+
for img in tqdm(all_images):
|
| 34 |
+
img_name = img[:-4]
|
| 35 |
+
input_location = os.path.join(input_dir,img_name+'.jpg')
|
| 36 |
+
input_mask_location = os.path.join(input_mask_dir,img_name+'.png')
|
| 37 |
+
input_image = PIL.Image.open(input_location).convert('RGB')
|
| 38 |
+
mask_image = PIL.Image.open(input_mask_location).convert('L')
|
| 39 |
+
h,w = input_image.size
|
| 40 |
+
# print(h,w)
|
| 41 |
+
target_size = 1024
|
| 42 |
+
if h>target_size:
|
| 43 |
+
input_image = input_image.resize((target_size,int(input_image.size[-1]*(target_size/input_image.size[0]))))
|
| 44 |
+
mask_image = mask_image.resize((target_size,int(input_image.size[-1]*(target_size/input_image.size[0]))))
|
| 45 |
+
input_image.save('input_image.png')
|
| 46 |
+
mask_image = np.array(mask_image)
|
| 47 |
+
|
| 48 |
+
kernel = np.ones((40, 40), np.uint8)
|
| 49 |
+
mask_image = cv2.dilate(mask_image,kernel,20)
|
| 50 |
+
# mask_indice = np.argwhere(mask_image>0)
|
| 51 |
+
# x0,x1 = mask_indice[:,0].min(),mask_indice[:,0].max(),
|
| 52 |
+
# y0,y1 = mask_indice[:,1].min(),mask_indice[:,1].max(),
|
| 53 |
+
|
| 54 |
+
# mask_new = np.zeros_like(mask_image)
|
| 55 |
+
# mask_new[x0:x1,y0:y1] = 1
|
| 56 |
+
|
| 57 |
+
mask_image = mask_image / 255.0
|
| 58 |
+
mask_image = (mask_image>0)
|
| 59 |
+
PIL.Image.fromarray((mask_image*255.0).astype(np.uint8)).save('mask_dilate.png')
|
| 60 |
+
|
| 61 |
+
input = {
|
| 62 |
+
'image': 'input_image.png',
|
| 63 |
+
'mask': 'mask_dilate.png',
|
| 64 |
+
'prompt': prompt
|
| 65 |
+
}
|
| 66 |
+
|
| 67 |
+
output = image_inpainting(input)[OutputKeys.OUTPUT_IMG]
|
| 68 |
+
# print(output.shape)
|
| 69 |
+
input_image = input_image.resize((output.shape[1],output.shape[0]))
|
| 70 |
+
res = np.concatenate([np.array(input_image)[:,:,::-1],output],axis=1)
|
| 71 |
+
cv2.imwrite(os.path.join(output_image_path,img_name+'.png'),res)
|
| 72 |
+
# print('pipeline: the output image path is {}'.format(output_image_path))
|