zjuJish commited on
Commit
42ee7a0
·
verified ·
1 Parent(s): c771b65

Upload layer_diff_dataset/test_inp_2.py with huggingface_hub

Browse files
Files changed (1) hide show
  1. layer_diff_dataset/test_inp_2.py +72 -0
layer_diff_dataset/test_inp_2.py ADDED
@@ -0,0 +1,72 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import cv2
2
+ import torch
3
+ from modelscope.outputs import OutputKeys
4
+ from modelscope.pipelines import pipeline
5
+ from modelscope.utils.constant import Tasks
6
+ import PIL
7
+ import numpy as np
8
+ import os
9
+ from tqdm import tqdm
10
+ import random
11
+
12
+ # input_location = 'https://modelscope.oss-cn-beijing.aliyuncs.com/test/images/image_inpainting/image_inpainting_1.png'
13
+ # input_mask_location = 'https://modelscope.oss-cn-beijing.aliyuncs.com/test/images/image_inpainting/image_inpainting_mask_1.png'
14
+ prompt = 'hazy background with nothing on'#'background'
15
+ data_dir = '/mnt/workspace/workgroup/sihui.jsh/layer_diff_dataset'
16
+ split = 'try'
17
+ input_dir = f'{data_dir}/{split}/im'
18
+ input_mask_dir = f'{data_dir}/{split}/gt'
19
+ output_image_path = f'{data_dir}/{split}/bg_inpaint/'
20
+ if not os.path.exists(output_image_path):
21
+ os.mkdir(output_image_path)
22
+ image_inpainting = pipeline(
23
+ Tasks.image_inpainting,
24
+ model='/mnt/workspace/workgroup/sihui.jsh/alpha_work/diffusers/iic/cv_stable-diffusion-v2_image-inpainting_base',
25
+ device='gpu',
26
+ torch_dtype=torch.float16,
27
+ enable_attention_slicing=True,
28
+ )
29
+
30
+ all_images = os.listdir(input_dir)
31
+ all_images.sort()
32
+ # random.shuffle(all_images)
33
+ for img in tqdm(all_images):
34
+ img_name = img[:-4]
35
+ input_location = os.path.join(input_dir,img_name+'.jpg')
36
+ input_mask_location = os.path.join(input_mask_dir,img_name+'.png')
37
+ input_image = PIL.Image.open(input_location).convert('RGB')
38
+ mask_image = PIL.Image.open(input_mask_location).convert('L')
39
+ h,w = input_image.size
40
+ # print(h,w)
41
+ target_size = 1024
42
+ if h>target_size:
43
+ input_image = input_image.resize((target_size,int(input_image.size[-1]*(target_size/input_image.size[0]))))
44
+ mask_image = mask_image.resize((target_size,int(input_image.size[-1]*(target_size/input_image.size[0]))))
45
+ input_image.save('input_image.png')
46
+ mask_image = np.array(mask_image)
47
+
48
+ kernel = np.ones((40, 40), np.uint8)
49
+ mask_image = cv2.dilate(mask_image,kernel,20)
50
+ # mask_indice = np.argwhere(mask_image>0)
51
+ # x0,x1 = mask_indice[:,0].min(),mask_indice[:,0].max(),
52
+ # y0,y1 = mask_indice[:,1].min(),mask_indice[:,1].max(),
53
+
54
+ # mask_new = np.zeros_like(mask_image)
55
+ # mask_new[x0:x1,y0:y1] = 1
56
+
57
+ mask_image = mask_image / 255.0
58
+ mask_image = (mask_image>0)
59
+ PIL.Image.fromarray((mask_image*255.0).astype(np.uint8)).save('mask_dilate.png')
60
+
61
+ input = {
62
+ 'image': 'input_image.png',
63
+ 'mask': 'mask_dilate.png',
64
+ 'prompt': prompt
65
+ }
66
+
67
+ output = image_inpainting(input)[OutputKeys.OUTPUT_IMG]
68
+ # print(output.shape)
69
+ input_image = input_image.resize((output.shape[1],output.shape[0]))
70
+ res = np.concatenate([np.array(input_image)[:,:,::-1],output],axis=1)
71
+ cv2.imwrite(os.path.join(output_image_path,img_name+'.png'),res)
72
+ # print('pipeline: the output image path is {}'.format(output_image_path))