zjuJish commited on
Commit
1d0a838
·
verified ·
1 Parent(s): f49c3a1

Upload layer_diff_dataset/test_inp_sd_3.py with huggingface_hub

Browse files
Files changed (1) hide show
  1. layer_diff_dataset/test_inp_sd_3.py +102 -0
layer_diff_dataset/test_inp_sd_3.py ADDED
@@ -0,0 +1,102 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from diffusers import AutoPipelineForInpainting
2
+ from diffusers.utils import load_image
3
+ import torch
4
+ import os
5
+ from tqdm import tqdm
6
+ import cv2
7
+ from PIL import Image
8
+ import imageio
9
+ import numpy as np
10
+
11
+ pipe = AutoPipelineForInpainting.from_pretrained("../alpha_work/diffusers/stable-diffusion-xl-1.0-inpainting_", torch_dtype=torch.float16, variant="fp16").to("cuda")
12
+ # print('pipe',pipe)
13
+ # StableDiffusionXLInpaintPipeline
14
+
15
+ preprocessed_root_path = '../data/video_dataset/YoutubeVOS/train/impainting_256'
16
+ # folder_path_0 = '../codes/Inpaint-Anything/results/0b6f9105fc'
17
+ mask_root_path = '../data/video_dataset/YoutubeVOS/train/mask'
18
+ output_root_path = '../data/video_dataset/YoutubeVOS/train/inp_preprocess_sd_0.9_base_ppt_w'
19
+ os.makedirs(output_root_path,exist_ok=True)
20
+ # img_url = "https://raw.githubusercontent.com/CompVis/latent-diffusion/main/data/inpainting_examples/overture-creations-5sI6fQgYIuo.png"
21
+ # mask_url = "https://raw.githubusercontent.com/CompVis/latent-diffusion/main/data/inpainting_examples/overture-creations-5sI6fQgYIuo_mask.png"
22
+
23
+
24
+ prompt = "hazy background with nothing on"
25
+ generator = torch.Generator(device="cuda").manual_seed(22)
26
+ # base_image = Image.open(base_image_path).resize((1024, 1024))
27
+
28
+ vid_list = os.listdir(preprocessed_root_path)
29
+ pbar = tqdm(enumerate(vid_list),total=len(vid_list))
30
+ for i, vid_name in pbar:
31
+ # if not vid_name=='7d18074fef':
32
+ # continue
33
+ # if i<=800:
34
+ # continue
35
+ if i>550:
36
+ break
37
+
38
+ output_folder = os.path.join(output_root_path, vid_name)
39
+ output_gif = os.path.join(output_folder, f'{vid_name}.gif') # 输出GIF的文件名
40
+ if os.path.exists(output_gif):
41
+ continue
42
+ os.makedirs(output_folder,exist_ok=True)
43
+
44
+ img_folder = os.path.join(preprocessed_root_path, vid_name)
45
+ img_list = os.listdir(img_folder)
46
+ img_list = [i for i in img_list if i.endswith('.jpg')]
47
+ img_list.sort()
48
+ # print(len(img_list))
49
+ mask_folder = os.path.join(mask_root_path, vid_name)
50
+
51
+
52
+ frames = []
53
+ for i,image_name in enumerate(img_list):
54
+
55
+ # print('i',i)
56
+ # if not i%2 == 0:
57
+ # continue
58
+ image_path = os.path.join(img_folder,image_name)
59
+ mask_path = os.path.join(mask_folder,image_name.split('.')[0]+'.png')
60
+ image = Image.open(image_path).resize((1024, 1024))
61
+ mask_image = Image.open(mask_path).resize((1024, 1024))
62
+ # image = cv2.resize(cv2.imread(image_path),(1024,1024))
63
+ # mask_image = cv2.resize(cv2.imread(mask_path,cv2.IMREAD_GRAYSCALE),(1024,1024))
64
+ # image = load_image(img_url).resize((1024, 1024))
65
+ # mask_image = load_image(mask_url).resize((1024, 1024))
66
+ if i==0:
67
+ base_image_1 = image
68
+ base_image = base_image_1
69
+ strength = 0.99
70
+ else:
71
+ strength = 0.5
72
+ if i==1:
73
+ base_image_2 = base_image_1
74
+ base_image_1 = image_out
75
+ image_array_1 = np.array(base_image_1.convert('RGB'))
76
+ image_array_2 = np.array(base_image_2.convert('RGB'))
77
+ base_image = (image_array_1*0.8+image_array_2*0.2).astype(np.uint8)
78
+ base_image = Image.fromarray(base_image)
79
+ else:
80
+ base_image_3 = base_image_2
81
+ base_image_2 = base_image_1
82
+ base_image_1 = image_out
83
+ image_array_1 = np.array(base_image_1.convert('RGB'))
84
+ image_array_2 = np.array(base_image_2.convert('RGB'))
85
+ image_array_3 = np.array(base_image_3.convert('RGB'))
86
+ base_image = (image_array_1*0.7+image_array_2*0.2+image_array_3*0.1).astype(np.uint8)
87
+ base_image = Image.fromarray(base_image)
88
+
89
+ image_out = pipe(
90
+ prompt=prompt,
91
+ image=image,
92
+ base_image=base_image,
93
+ mask_image=mask_image,
94
+ guidance_scale=8.0,
95
+ num_inference_steps=20, # steps between 15 and 30 work well for us
96
+ strength=strength, # make sure to use `strength` below 1.0
97
+ generator=generator,
98
+ ).images[0]
99
+ image_out.save(os.path.join(output_folder,image_name))
100
+ frames.append(imageio.imread(os.path.join(output_folder,image_name)))
101
+ imageio.mimsave(output_gif, frames, fps=8, loop=0)
102
+ # exit(0)