RyanHangZhou commited on
Commit
35c4eee
·
verified ·
1 Parent(s): ff8dd04

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +62 -46
app.py CHANGED
@@ -140,67 +140,83 @@ def process_composition(item, obj_thr):
140
  def pics_pairwise_inference(background, img_a, mask_a, img_b, mask_b):
141
  device = "cuda"
142
  model.to(device)
143
- ddim_sampler = DDIMSampler(model)
144
-
145
- # 将 Gradio 的 PIL 转为 OpenCV 格式
 
146
  back_image = cv2.cvtColor(np.array(background), cv2.COLOR_RGB2BGR)
147
-
148
- # 模拟你的 run_inference 循环逻辑组装 item_with_collage
149
  item_with_collage = {}
150
-
151
- # 处理 Object 0 和 Object 1
152
  objs = [(img_a, mask_a), (img_b, mask_b)]
153
  for j, (img, mask) in enumerate(objs):
154
- # 将 PIL 转存为临时文件以适配你的 process_pairs_multiple (或者直接改写该函数接受 numpy)
155
  temp_patch = f"temp_obj_{j}.png"
156
  cv2.imwrite(temp_patch, cv2.cvtColor(np.array(img), cv2.COLOR_RGB2BGR))
157
-
158
  tar_mask = (np.array(mask)[:, :, 0] > 128).astype(np.uint8)
159
 
160
- # 调用你的原生函数
161
  item = process_pairs_multiple(tar_mask, back_image, temp_patch, counter=j)
162
  item_with_collage.update(item)
163
-
164
- # 调用你的合成函数
165
  item_with_collage = process_composition(item_with_collage, obj_thr=2)
166
-
167
- # 执行 Sampling (对应你的 inference 函数内容)
 
 
 
168
  H, W = 512, 512
169
- with torch.no_grad():
170
- xc = [get_input_tensor(item_with_collage, f"view{i}", device) for i in range(2)]
171
- xc_mask = [get_input_tensor(item_with_collage, f"mask{i}", device) for i in range(2)]
172
-
173
- c_list = [model.get_learned_conditioning(xc_i) for xc_i in xc]
174
- cond_cross = {"pch_code": torch.stack(c_list).permute(1, 2, 3, 0)}
175
- c_mask = torch.stack(xc_mask).permute(1, 2, 3, 4, 0).to(device) # 这里保证是 5 维
176
-
177
- hint = item_with_collage['hint']
178
- control = torch.from_numpy(hint.copy()).float().to(device).unsqueeze(0)
179
- control = einops.rearrange(control, 'b h w c -> b c h w').clone()
180
-
181
- cond = {"c_concat": [control], "c_crossattn": [cond_cross], "c_mask": [c_mask]}
182
- uc_pch = get_unconditional_conditioning(1, 2, device)
183
- un_cond = {"c_concat": [control], "c_crossattn": [uc_pch], "c_mask": [c_mask]}
184
-
185
- shape = (4, H // 8, W // 8)
186
- model.control_scales = [1.0] * 13
187
-
188
- samples, _ = ddim_sampler.sample(50, 1, shape, cond, verbose=False,
189
- unconditional_guidance_scale=5.0, unconditional_conditioning=un_cond)
190
-
191
- x_samples = model.decode_first_stage(samples)
192
- x_samples = (einops.rearrange(x_samples, 'b c h w -> b h w c') * 127.5 + 127.5).cpu().numpy()
193
- pred_rgb = np.clip(x_samples[0], 0, 255).astype(np.uint8)
194
 
195
- pred_bgr = cv2.cvtColor(pred_rgb, cv2.COLOR_RGB2BGR)
196
-
197
- # 结果后处理
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
198
  side = max(back_image.shape[0], back_image.shape[1])
199
- pred = cv2.resize(pred_bgr, (side, side))
200
- pred = crop_back(pred, back_image, item_with_collage['extra_sizes'],
201
- item_with_collage['hint_sizes0'], item_with_collage['hint_sizes1'], is_masked=True)
 
 
202
 
203
- return cv2.cvtColor(pred, cv2.COLOR_BGR2RGB)
 
204
 
205
  with gr.Blocks(title="PICS: Pairwise Spatial Compositing") as demo:
206
  gr.Markdown("# 🚀 PICS: Pairwise Image Compositing (5-Input Framework)")
 
140
  def pics_pairwise_inference(background, img_a, mask_a, img_b, mask_b):
141
  device = "cuda"
142
  model.to(device)
143
+ # 必须在函数内部重新定义一次,确保它拿到的是 cuda 上的 model
144
+ ddim_sampler = DDIMSampler(model)
145
+
146
+ # 1. 转换 Gradio 输入为 OpenCV BGR 格式 (因为你 process_pairs 内部用的是 BGR 背景)
147
  back_image = cv2.cvtColor(np.array(background), cv2.COLOR_RGB2BGR)
148
+
149
+ # 2. 模拟 run_inference 循环,构造 item
150
  item_with_collage = {}
 
 
151
  objs = [(img_a, mask_a), (img_b, mask_b)]
152
  for j, (img, mask) in enumerate(objs):
 
153
  temp_patch = f"temp_obj_{j}.png"
154
  cv2.imwrite(temp_patch, cv2.cvtColor(np.array(img), cv2.COLOR_RGB2BGR))
 
155
  tar_mask = (np.array(mask)[:, :, 0] > 128).astype(np.uint8)
156
 
157
+ # 调用你那段“正确” process 函数
158
  item = process_pairs_multiple(tar_mask, back_image, temp_patch, counter=j)
159
  item_with_collage.update(item)
160
+
161
+ # 3. 合成 Hint
162
  item_with_collage = process_composition(item_with_collage, obj_thr=2)
163
+
164
+ # 4. 执行你那段“正确”的 inference 函数逻辑
165
+ # --- START 原装逻辑 ---
166
+ obj_thr = 2
167
+ num_samples = 1
168
  H, W = 512, 512
169
+ guidance_scale = 5.0
170
+
171
+ xc = []
172
+ xc_mask = []
173
+ for i in range(obj_thr):
174
+ xc.append(get_input(item_with_collage, f"view{i}").to(device))
175
+ xc_mask.append(get_input(item_with_collage, f"mask{i}"))
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
176
 
177
+ c_list = [model.get_learned_conditioning(xc_i) for xc_i in xc]
178
+ c_tensor = torch.stack(c_list).permute(1, 2, 3, 0)
179
+ cond_cross = {"pch_code": c_tensor}
180
+ c_mask = torch.stack(xc_mask).permute(1, 2, 3, 4, 0).to(device)
181
+
182
+ hint = item_with_collage['hint']
183
+ control = torch.from_numpy(hint.copy()).float().to(device)
184
+ control = torch.stack([control] * num_samples, dim=0)
185
+ control = einops.rearrange(control, 'b h w c -> b c h w').clone()
186
+
187
+ cond = {"c_concat": [control], "c_crossattn": [cond_cross], "c_mask": [c_mask]}
188
+
189
+ # 这里的 UC 逻辑极其关键,决定了 BBox 会不会变蓝
190
+ uc_pch = get_unconditional_conditioning(num_samples, obj_thr)
191
+ # 这里的 get_unconditional_conditioning 内部会用到 model.device,确保它已经是 cuda
192
+ un_cond = {"c_concat": [control], "c_crossattn": [uc_pch], "c_mask": [c_mask]}
193
+
194
+ shape = (4, H // 8, W // 8)
195
+ model.control_scales = [1.0] * 13
196
+
197
+ samples, _ = ddim_sampler.sample(
198
+ 50, num_samples, shape, cond,
199
+ verbose=False, eta=0.0,
200
+ unconditional_guidance_scale=guidance_scale,
201
+ unconditional_conditioning=un_cond
202
+ )
203
+
204
+ x_samples = model.decode_first_stage(samples)
205
+ x_samples = (einops.rearrange(x_samples, 'b c h w -> b h w c') * 127.5 + 127.5).cpu().numpy()
206
+ pred_rgb = np.clip(x_samples[0], 0, 255).astype(np.uint8)
207
+ # --- END 原装逻辑 ---
208
+
209
+ # 5. 后处理 (注意 RGB/BGR 转换以适配 crop_back)
210
+ pred_bgr = cv2.cvtColor(pred_rgb, cv2.COLOR_RGB2BGR)
211
  side = max(back_image.shape[0], back_image.shape[1])
212
+ pred_res = cv2.resize(pred_bgr, (side, side))
213
+
214
+ # 这里的 crop_back 依赖 BGR 格式
215
+ pred_final = crop_back(pred_res, back_image, item_with_collage['extra_sizes'],
216
+ item_with_collage['hint_sizes0'], item_with_collage['hint_sizes1'], is_masked=True)
217
 
218
+ # 最后转回 RGB 给 Gradio
219
+ return cv2.cvtColor(pred_final, cv2.COLOR_BGR2RGB)
220
 
221
  with gr.Blocks(title="PICS: Pairwise Spatial Compositing") as demo:
222
  gr.Markdown("# 🚀 PICS: Pairwise Image Compositing (5-Input Framework)")