Hang Zhou commited on
Commit
5e662f5
·
verified ·
1 Parent(s): b230236

Upload run_test.py with huggingface_hub

Browse files
Files changed (1) hide show
  1. run_test.py +271 -0
run_test.py ADDED
@@ -0,0 +1,271 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import cv2
2
+ import os
3
+ import einops
4
+ import numpy as np
5
+ import torch
6
+ import argparse
7
+ from cldm.model import create_model, load_state_dict
8
+ from cldm.ddim_hacked import DDIMSampler
9
+ from cldm.hack import disable_verbosity, enable_sliced_attention
10
+ from datasets.data_utils import *
11
+ from omegaconf import OmegaConf
12
+ from tqdm import tqdm
13
+ import albumentations as A
14
+
15
+ save_memory = False
16
+ disable_verbosity()
17
+ if save_memory:
18
+ enable_sliced_attention()
19
+
20
+ config = OmegaConf.load('./configs/inference.yaml')
21
+ model_ckpt = config.pretrained_model
22
+ model_config = config.config_file
23
+
24
+ model = create_model(model_config).cpu()
25
+ model.load_state_dict(load_state_dict(model_ckpt, location='cuda'))
26
+ model = model.cuda()
27
+ ddim_sampler = DDIMSampler(model)
28
+
29
+
30
+ def get_input(batch, k):
31
+ x = batch[k]
32
+ if len(x.shape) == 3:
33
+ x = x[None, ...]
34
+
35
+ x = torch.tensor(x)
36
+ x = einops.rearrange(x, 'b h w c -> b c h w')
37
+ x = x.to(memory_format=torch.contiguous_format).float()
38
+ return x
39
+
40
+ def get_unconditional_conditioning(N, obj_thr):
41
+ x = [torch.zeros((1, 3, 224, 224)).to(model.device)] * N
42
+ single_uc = model.get_learned_conditioning(x)
43
+ uc = single_uc.unsqueeze(-1).repeat(1, 1, 1, obj_thr)
44
+ return {"pch_code": uc}
45
+
46
+ def inference(item, back_image):
47
+ obj_thr = 2
48
+ num_samples = 1
49
+ H, W = 512, 512
50
+ guidance_scale = 5.0
51
+
52
+ # 1. Condition & Mask Extraction
53
+ xc = []
54
+ xc_mask = []
55
+ for i in range(obj_thr):
56
+ xc.append(get_input(item, f"view{i}").cuda())
57
+ xc_mask.append(get_input(item, f"mask{i}"))
58
+
59
+ # 2. Cross-Attention Condition (pch_code)
60
+ c_list = [model.get_learned_conditioning(xc_i) for xc_i in xc]
61
+ c_tensor = torch.stack(c_list).permute(1, 2, 3, 0) # [B, Tokens, Dim, Obj]
62
+ cond_cross = {"pch_code": c_tensor}
63
+
64
+ # 3. Mask Condition
65
+ c_mask = torch.stack(xc_mask).permute(1, 2, 3, 4, 0) # Align with BasicTransformerBlock
66
+
67
+ # 4. ControlNet / Concat Condition
68
+ hint = item['hint']
69
+ control = torch.from_numpy(hint.copy()).float().cuda()
70
+ control = torch.stack([control] * num_samples, dim=0)
71
+ control = einops.rearrange(control, 'b h w c -> b c h w').clone()
72
+
73
+ # 5. Build Final Condition Dictionaries
74
+ cond = {
75
+ "c_concat": [control],
76
+ "c_crossattn": [cond_cross],
77
+ "c_mask": [c_mask]
78
+ }
79
+
80
+ # Correctly unwrap the UC dictionary
81
+ uc_pch = get_unconditional_conditioning(num_samples, obj_thr)
82
+ un_cond = {
83
+ "c_concat": [control],
84
+ "c_crossattn": [uc_pch],
85
+ "c_mask": [c_mask]
86
+ }
87
+
88
+ # 6. Sampling
89
+ if save_memory:
90
+ model.low_vram_shift(is_diffusing=True)
91
+
92
+ shape = (4, H // 8, W // 8)
93
+ model.control_scales = [1.0] * 13
94
+
95
+ samples, _ = ddim_sampler.sample(
96
+ 50, num_samples, shape, cond,
97
+ verbose=False, eta=0.0,
98
+ unconditional_guidance_scale=guidance_scale,
99
+ unconditional_conditioning=un_cond
100
+ )
101
+
102
+ if save_memory:
103
+ model.low_vram_shift(is_diffusing=False)
104
+
105
+ # 7. Post-processing
106
+ x_samples = model.decode_first_stage(samples)
107
+ x_samples = (einops.rearrange(x_samples, 'b c h w -> b h w c') * 127.5 + 127.5).cpu().numpy()
108
+
109
+ pred = np.clip(x_samples[0], 0, 255).astype(np.uint8)
110
+
111
+ # Resize and crop
112
+ side = max(back_image.shape[0], back_image.shape[1])
113
+ pred = cv2.resize(pred, (side, side))
114
+ pred = crop_back(pred, back_image, item['extra_sizes'], item['hint_sizes0'], item['hint_sizes1'], is_masked=True)
115
+
116
+ return pred
117
+
118
+
119
+ def process_pairs_multiple(mask, tar_image, patch_dir, counter=0, max_ratio=0.8):
120
+ # 1. Process Reference Object (View)
121
+ view = cv2.imread(patch_dir)
122
+ view = cv2.cvtColor(view, cv2.COLOR_BGR2RGB)
123
+ view = pad_to_square(view, pad_value=255, random=False)
124
+ view = cv2.resize(view.astype(np.uint8), (224, 224))
125
+ view = view.astype(np.float32) / 255.0
126
+
127
+ # 2. BBox and Mask Logic
128
+ box_yyxx = get_bbox_from_mask(mask)
129
+
130
+ # Define crop area (using full image here)
131
+ H1, W1 = tar_image.shape[0], tar_image.shape[1]
132
+ box_yyxx_crop = [0, H1, 0, W1]
133
+
134
+ # Handle box within crop
135
+ y1, y2, x1, x2 = box_in_box(box_yyxx, box_yyxx_crop)
136
+
137
+ # 3. Create Collage (Input Hint)
138
+ # Background with hole (zeroed out at object position)
139
+ collage = tar_image.copy()
140
+ source_collage = collage.copy()
141
+ collage[y1:y2, x1:x2, :] = 0
142
+
143
+ # Binary mask for the current object hole
144
+ collage_mask = np.zeros_like(tar_image, dtype=np.float32)
145
+ collage_mask[y1:y2, x1:x2, :] = 1.0
146
+
147
+ # 4. Square Padding & Resizing
148
+ # Pad all to square (pad_value 2 for mask indicates padding area)
149
+ tar_square = pad_to_square(tar_image, pad_value=0, random=False)
150
+ collage_square = pad_to_square(collage, pad_value=0, random=False)
151
+ mask_square = pad_to_square(collage_mask, pad_value=2, random=False)
152
+
153
+ H2, W2 = collage_square.shape[0], collage_square.shape[1]
154
+
155
+ # Resize to model input size
156
+ tar_res = cv2.resize(tar_square, (512, 512)).astype(np.float32)
157
+ col_res = cv2.resize(collage_square, (512, 512)).astype(np.float32)
158
+ mask_res = cv2.resize(mask_square, (512, 512), interpolation=cv2.INTER_NEAREST).astype(np.float32)
159
+
160
+ # 5. Mask Value Normalization
161
+ # Original logic: mask=1 for object, 0 for background, -1 for padding
162
+ mask_res[mask_res == 2] = -1
163
+
164
+ # For conditioning: keep a 0/1 version for cross-attn mask
165
+ c_mask = np.where(mask_res[..., 0:1] == 1, 1.0, 0.0).astype(np.float32)
166
+
167
+ # 6. Final Item Assembly
168
+ # Normalize images to [-1, 1]
169
+ tar_res = tar_res / 127.5 - 1.0
170
+ col_res = col_res / 127.5 - 1.0
171
+
172
+ # Hint: Concatenate background with the (-1, 0, 1) mask
173
+ hint_final = np.concatenate([col_res, mask_res[..., :1]], axis=-1)
174
+
175
+ item = {
176
+ f'view{counter}': view,
177
+ f'hint{counter}': hint_final,
178
+ f'mask{counter}': c_mask,
179
+ f'hint_sizes{counter}': np.array([y1, x1, y2, x2]),
180
+ 'jpg': tar_res, # Targets are same for all counters in a pair
181
+ 'collage': source_collage,
182
+ 'extra_sizes': np.array([H1, W1, H2, W2])
183
+ }
184
+
185
+ return item
186
+
187
+
188
+ def process_composition(item, obj_thr):
189
+ collage = item['collage'].copy()
190
+ collage_mask = np.zeros((collage.shape[0], collage.shape[1], 1), dtype=np.float32)
191
+
192
+ for i in reversed(range(obj_thr)):
193
+ y1, x1, y2, x2 = item['hint_sizes'+str(i)]
194
+ collage[y1:y2, x1:x2, :] = 0
195
+ collage_mask[y1:y2,x1:x2,:] = 1.0
196
+
197
+ collage = pad_to_square(collage, pad_value = 0, random = False).astype(np.uint8)
198
+
199
+ collage_mask = pad_to_square(collage_mask, pad_value = 2, random = False).astype(np.float32)
200
+
201
+ collage = cv2.resize(collage.astype(np.uint8), (512, 512)).astype(np.float32) / 127.5 - 1.0
202
+ collage_mask = cv2.resize(collage_mask, (512, 512), interpolation=cv2.INTER_NEAREST).astype(np.float32)
203
+
204
+ if len(collage_mask.shape) == 2:
205
+ collage_mask = collage_mask[..., None]
206
+
207
+ collage_mask[collage_mask == 2] = -1.0
208
+
209
+ collage_final = np.concatenate([collage, collage_mask[:,:,:1]] , -1)
210
+
211
+ item.update({'hint': collage_final.copy()})
212
+ return item
213
+
214
+ def run_inference(input_dir, output_dir, sample_num=31, obj_thr=2):
215
+ """
216
+ Core inference loop for multi-object composition.
217
+ """
218
+ os.makedirs(output_dir, exist_ok=True)
219
+ comp_image_dir = os.path.join(output_dir, 'composed')
220
+ os.makedirs(comp_image_dir, exist_ok=True)
221
+
222
+ img_ids = sorted([d for d in os.listdir(input_dir) if os.path.isdir(os.path.join(input_dir, d))])
223
+
224
+ for img_id in tqdm(img_ids, desc="Processing images"):
225
+ img_folder = os.path.join(input_dir, img_id)
226
+ img_path = os.path.join(img_folder, 'image.jpg')
227
+
228
+ if not os.path.exists(img_path):
229
+ continue
230
+
231
+ # 1. Load background image
232
+ back_image = cv2.imread(img_path)
233
+ back_image = cv2.cvtColor(back_image, cv2.COLOR_BGR2RGB)
234
+
235
+ # 2. Iteratively process multiple objects
236
+ item_with_collage = {}
237
+ for j in range(obj_thr):
238
+ # for j in reversed(range(obj_thr)):
239
+ patch_path = os.path.join(img_folder, f"object_{j}.png")
240
+ mask_path = os.path.join(img_folder, f"object_{j}_mask.png")
241
+
242
+ if not (os.path.exists(patch_path) and os.path.exists(mask_path)):
243
+ print(f"Warning: Object {j} missing in {img_id}")
244
+ continue
245
+
246
+ tar_mask = (cv2.imread(mask_path)[:, :, 0] > 128).astype(np.uint8)
247
+
248
+ # Pass counter=j to ensure keys like 'view0', 'view1' are unique
249
+ item = process_pairs_multiple(tar_mask, back_image, patch_path, counter=j)
250
+ item_with_collage.update(item)
251
+
252
+ # 3. Composition & Model Prediction
253
+ # Ensure process_composition merges 'hint0', 'hint1' into a single 'hint'
254
+ item_with_collage = process_composition(item_with_collage, obj_thr)
255
+
256
+ # Using inference_single_image_multi as defined previously
257
+ gen_image = inference(item_with_collage, back_image)
258
+
259
+ # 4. Save result
260
+ save_name = f'composed_{img_id}.png'
261
+ cv2.imwrite(os.path.join(comp_image_dir, save_name), gen_image[:, :, ::-1])
262
+
263
+
264
+ if __name__ == '__main__':
265
+ parser = argparse.ArgumentParser()
266
+ parser.add_argument('--input', type=str, help='Input data directory')
267
+ parser.add_argument('--output', type=str, help='Output save directory')
268
+ parser.add_argument('--obj_thr', type=int, default=2, help='Number of objects to compose')
269
+ args = parser.parse_args()
270
+
271
+ run_inference(args.input, args.output, obj_thr=args.obj_thr)