Spaces:
Sleeping
Sleeping
Update app.py
Browse files
app.py
CHANGED
|
@@ -140,67 +140,83 @@ def process_composition(item, obj_thr):
|
|
| 140 |
def pics_pairwise_inference(background, img_a, mask_a, img_b, mask_b):
|
| 141 |
device = "cuda"
|
| 142 |
model.to(device)
|
| 143 |
-
|
| 144 |
-
|
| 145 |
-
|
|
|
|
| 146 |
back_image = cv2.cvtColor(np.array(background), cv2.COLOR_RGB2BGR)
|
| 147 |
-
|
| 148 |
-
# 模拟
|
| 149 |
item_with_collage = {}
|
| 150 |
-
|
| 151 |
-
# 处理 Object 0 和 Object 1
|
| 152 |
objs = [(img_a, mask_a), (img_b, mask_b)]
|
| 153 |
for j, (img, mask) in enumerate(objs):
|
| 154 |
-
# 将 PIL 转存为临时文件以适配你的 process_pairs_multiple (或者直接改写该函数接受 numpy)
|
| 155 |
temp_patch = f"temp_obj_{j}.png"
|
| 156 |
cv2.imwrite(temp_patch, cv2.cvtColor(np.array(img), cv2.COLOR_RGB2BGR))
|
| 157 |
-
|
| 158 |
tar_mask = (np.array(mask)[:, :, 0] > 128).astype(np.uint8)
|
| 159 |
|
| 160 |
-
# 调用你的
|
| 161 |
item = process_pairs_multiple(tar_mask, back_image, temp_patch, counter=j)
|
| 162 |
item_with_collage.update(item)
|
| 163 |
-
|
| 164 |
-
#
|
| 165 |
item_with_collage = process_composition(item_with_collage, obj_thr=2)
|
| 166 |
-
|
| 167 |
-
# 执行
|
|
|
|
|
|
|
|
|
|
| 168 |
H, W = 512, 512
|
| 169 |
-
|
| 170 |
-
|
| 171 |
-
|
| 172 |
-
|
| 173 |
-
|
| 174 |
-
|
| 175 |
-
|
| 176 |
-
|
| 177 |
-
hint = item_with_collage['hint']
|
| 178 |
-
control = torch.from_numpy(hint.copy()).float().to(device).unsqueeze(0)
|
| 179 |
-
control = einops.rearrange(control, 'b h w c -> b c h w').clone()
|
| 180 |
-
|
| 181 |
-
cond = {"c_concat": [control], "c_crossattn": [cond_cross], "c_mask": [c_mask]}
|
| 182 |
-
uc_pch = get_unconditional_conditioning(1, 2, device)
|
| 183 |
-
un_cond = {"c_concat": [control], "c_crossattn": [uc_pch], "c_mask": [c_mask]}
|
| 184 |
-
|
| 185 |
-
shape = (4, H // 8, W // 8)
|
| 186 |
-
model.control_scales = [1.0] * 13
|
| 187 |
-
|
| 188 |
-
samples, _ = ddim_sampler.sample(50, 1, shape, cond, verbose=False,
|
| 189 |
-
unconditional_guidance_scale=5.0, unconditional_conditioning=un_cond)
|
| 190 |
-
|
| 191 |
-
x_samples = model.decode_first_stage(samples)
|
| 192 |
-
x_samples = (einops.rearrange(x_samples, 'b c h w -> b h w c') * 127.5 + 127.5).cpu().numpy()
|
| 193 |
-
pred_rgb = np.clip(x_samples[0], 0, 255).astype(np.uint8)
|
| 194 |
|
| 195 |
-
|
| 196 |
-
|
| 197 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 198 |
side = max(back_image.shape[0], back_image.shape[1])
|
| 199 |
-
|
| 200 |
-
|
| 201 |
-
|
|
|
|
|
|
|
| 202 |
|
| 203 |
-
|
|
|
|
| 204 |
|
| 205 |
with gr.Blocks(title="PICS: Pairwise Spatial Compositing") as demo:
|
| 206 |
gr.Markdown("# 🚀 PICS: Pairwise Image Compositing (5-Input Framework)")
|
|
|
|
| 140 |
def pics_pairwise_inference(background, img_a, mask_a, img_b, mask_b):
|
| 141 |
device = "cuda"
|
| 142 |
model.to(device)
|
| 143 |
+
# 必须在函数内部重新定义一次,确保它拿到的是 cuda 上的 model
|
| 144 |
+
ddim_sampler = DDIMSampler(model)
|
| 145 |
+
|
| 146 |
+
# 1. 转换 Gradio 输入为 OpenCV BGR 格式 (因为你 process_pairs 内部用的是 BGR 背景)
|
| 147 |
back_image = cv2.cvtColor(np.array(background), cv2.COLOR_RGB2BGR)
|
| 148 |
+
|
| 149 |
+
# 2. 模拟 run_inference 的循环,构造 item
|
| 150 |
item_with_collage = {}
|
|
|
|
|
|
|
| 151 |
objs = [(img_a, mask_a), (img_b, mask_b)]
|
| 152 |
for j, (img, mask) in enumerate(objs):
|
|
|
|
| 153 |
temp_patch = f"temp_obj_{j}.png"
|
| 154 |
cv2.imwrite(temp_patch, cv2.cvtColor(np.array(img), cv2.COLOR_RGB2BGR))
|
|
|
|
| 155 |
tar_mask = (np.array(mask)[:, :, 0] > 128).astype(np.uint8)
|
| 156 |
|
| 157 |
+
# 调用你那段“正确”的 process 函数
|
| 158 |
item = process_pairs_multiple(tar_mask, back_image, temp_patch, counter=j)
|
| 159 |
item_with_collage.update(item)
|
| 160 |
+
|
| 161 |
+
# 3. 合成 Hint
|
| 162 |
item_with_collage = process_composition(item_with_collage, obj_thr=2)
|
| 163 |
+
|
| 164 |
+
# 4. 执行你那段“正确”的 inference 函数逻辑
|
| 165 |
+
# --- START 原装逻辑 ---
|
| 166 |
+
obj_thr = 2
|
| 167 |
+
num_samples = 1
|
| 168 |
H, W = 512, 512
|
| 169 |
+
guidance_scale = 5.0
|
| 170 |
+
|
| 171 |
+
xc = []
|
| 172 |
+
xc_mask = []
|
| 173 |
+
for i in range(obj_thr):
|
| 174 |
+
xc.append(get_input(item_with_collage, f"view{i}").to(device))
|
| 175 |
+
xc_mask.append(get_input(item_with_collage, f"mask{i}"))
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 176 |
|
| 177 |
+
c_list = [model.get_learned_conditioning(xc_i) for xc_i in xc]
|
| 178 |
+
c_tensor = torch.stack(c_list).permute(1, 2, 3, 0)
|
| 179 |
+
cond_cross = {"pch_code": c_tensor}
|
| 180 |
+
c_mask = torch.stack(xc_mask).permute(1, 2, 3, 4, 0).to(device)
|
| 181 |
+
|
| 182 |
+
hint = item_with_collage['hint']
|
| 183 |
+
control = torch.from_numpy(hint.copy()).float().to(device)
|
| 184 |
+
control = torch.stack([control] * num_samples, dim=0)
|
| 185 |
+
control = einops.rearrange(control, 'b h w c -> b c h w').clone()
|
| 186 |
+
|
| 187 |
+
cond = {"c_concat": [control], "c_crossattn": [cond_cross], "c_mask": [c_mask]}
|
| 188 |
+
|
| 189 |
+
# 这里的 UC 逻辑极其关键,决定了 BBox 会不会变蓝
|
| 190 |
+
uc_pch = get_unconditional_conditioning(num_samples, obj_thr)
|
| 191 |
+
# 这里的 get_unconditional_conditioning 内部会用到 model.device,确保它已经是 cuda
|
| 192 |
+
un_cond = {"c_concat": [control], "c_crossattn": [uc_pch], "c_mask": [c_mask]}
|
| 193 |
+
|
| 194 |
+
shape = (4, H // 8, W // 8)
|
| 195 |
+
model.control_scales = [1.0] * 13
|
| 196 |
+
|
| 197 |
+
samples, _ = ddim_sampler.sample(
|
| 198 |
+
50, num_samples, shape, cond,
|
| 199 |
+
verbose=False, eta=0.0,
|
| 200 |
+
unconditional_guidance_scale=guidance_scale,
|
| 201 |
+
unconditional_conditioning=un_cond
|
| 202 |
+
)
|
| 203 |
+
|
| 204 |
+
x_samples = model.decode_first_stage(samples)
|
| 205 |
+
x_samples = (einops.rearrange(x_samples, 'b c h w -> b h w c') * 127.5 + 127.5).cpu().numpy()
|
| 206 |
+
pred_rgb = np.clip(x_samples[0], 0, 255).astype(np.uint8)
|
| 207 |
+
# --- END 原装逻辑 ---
|
| 208 |
+
|
| 209 |
+
# 5. 后处理 (注意 RGB/BGR 转换以适配 crop_back)
|
| 210 |
+
pred_bgr = cv2.cvtColor(pred_rgb, cv2.COLOR_RGB2BGR)
|
| 211 |
side = max(back_image.shape[0], back_image.shape[1])
|
| 212 |
+
pred_res = cv2.resize(pred_bgr, (side, side))
|
| 213 |
+
|
| 214 |
+
# 这里的 crop_back 依赖 BGR 格式
|
| 215 |
+
pred_final = crop_back(pred_res, back_image, item_with_collage['extra_sizes'],
|
| 216 |
+
item_with_collage['hint_sizes0'], item_with_collage['hint_sizes1'], is_masked=True)
|
| 217 |
|
| 218 |
+
# 最后转回 RGB 给 Gradio
|
| 219 |
+
return cv2.cvtColor(pred_final, cv2.COLOR_BGR2RGB)
|
| 220 |
|
| 221 |
with gr.Blocks(title="PICS: Pairwise Spatial Compositing") as demo:
|
| 222 |
gr.Markdown("# 🚀 PICS: Pairwise Image Compositing (5-Input Framework)")
|