Spaces:
Running on Zero
Running on Zero
File size: 9,095 Bytes
fac3291 4d9ba5b 4d307ef 4d9ba5b dba7566 4d9ba5b 44efe5b c92a66b bf9cd8f 1090c15 bf9cd8f 4d307ef cc28535 c92a66b 6f9b747 2ad8b09 7bd1eec 64b9e67 4d307ef f28925f d4c5a1e 64b9e67 d4c5a1e 64b9e67 d4c5a1e dba7566 64b9e67 f28925f 3d7998e 4d307ef 64b9e67 c92a66b 64b9e67 ca2fd62 fd13683 35c4eee 4d307ef 35c4eee 4d307ef 35c4eee 64b9e67 35c4eee ff8dd04 35c4eee 64b9e67 35c4eee 64b9e67 35c4eee 64b9e67 cc28535 f576d1d fac3291 3c748e1 e2919fd 3d7998e fac3291 3d7998e fac3291 3d7998e fac3291 3d7998e 2c54e51 9d8211a 2c54e51 cc28535 2c54e51 8298ad4 fd13683 8298ad4 dba7566 eaa6991 9d8211a 8298ad4 dba7566 8298ad4 2c54e51 3d7998e 2c54e51 c92a66b | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 | import gradio as gr
import os
import sys
import torch
import cv2
import einops
import numpy as np
import spaces
from omegaconf import OmegaConf
from huggingface_hub import snapshot_download
from PIL import Image
REPO_DIR = snapshot_download(repo_id="Hang2991/PICS")
os.chdir(REPO_DIR)
sys.path.insert(0, REPO_DIR)
sys.path.insert(0, os.path.join(REPO_DIR, "dinov2"))
from cldm.model import create_model, load_state_dict
from cldm.ddim_hacked import DDIMSampler
from datasets.data_utils import *
config = OmegaConf.load('configs/inference.yaml')
model = create_model(config.config_file).cpu()
model.load_state_dict(load_state_dict(config.pretrained_model, location='cpu'))
model.eval()
def get_input(batch, k):
x = batch[k]
if len(x.shape) == 3:
x = x[None, ...]
x = torch.tensor(x)
x = einops.rearrange(x, 'b h w c -> b c h w')
x = x.to(memory_format=torch.contiguous_format).float()
return x
def get_unconditional_conditioning(N, obj_thr):
x = [torch.zeros((1, 3, 224, 224)).to(model.device)] * N
single_uc = model.get_learned_conditioning(x)
uc = single_uc.unsqueeze(-1).repeat(1, 1, 1, obj_thr)
return {"pch_code": uc}
def process_pairs_multiple(mask, tar_image, patch_dir, counter=0, max_ratio=0.8):
view = cv2.imread(patch_dir)
view = cv2.cvtColor(view, cv2.COLOR_BGR2RGB)
view = pad_to_square(view, pad_value=255, random=False)
view = cv2.resize(view.astype(np.uint8), (224, 224))
view = view.astype(np.float32) / 255.0
box_yyxx = get_bbox_from_mask(mask)
H1, W1 = tar_image.shape[0], tar_image.shape[1]
box_yyxx_crop = [0, H1, 0, W1]
y1, y2, x1, x2 = box_in_box(box_yyxx, box_yyxx_crop)
collage = tar_image.copy()
source_collage = collage.copy()
collage[y1:y2, x1:x2, :] = 0
collage_mask = np.zeros_like(tar_image, dtype=np.float32)
collage_mask[y1:y2, x1:x2, :] = 1.0
tar_square = pad_to_square(tar_image, pad_value=0, random=False)
collage_square = pad_to_square(collage, pad_value=0, random=False)
mask_square = pad_to_square(collage_mask, pad_value=2, random=False)
H2, W2 = collage_square.shape[0], collage_square.shape[1]
tar_res = cv2.resize(tar_square, (512, 512)).astype(np.float32)
col_res = cv2.resize(collage_square, (512, 512)).astype(np.float32)
mask_res = cv2.resize(mask_square, (512, 512), interpolation=cv2.INTER_NEAREST).astype(np.float32)
mask_res[mask_res == 2] = -1
c_mask = np.where(mask_res[..., 0:1] == 1, 1.0, 0.0).astype(np.float32)
tar_res = tar_res / 127.5 - 1.0
col_res = col_res / 127.5 - 1.0
hint_final = np.concatenate([col_res, mask_res[..., :1]], axis=-1)
return {
f'view{counter}': view, f'hint{counter}': hint_final, f'mask{counter}': c_mask,
f'hint_sizes{counter}': np.array([y1, x1, y2, x2]), 'jpg': tar_res,
'collage': source_collage, 'extra_sizes': np.array([H1, W1, H2, W2])
}
def process_composition(item, obj_thr):
collage = item['collage'].copy()
collage_mask = np.zeros((collage.shape[0], collage.shape[1], 1), dtype=np.float32)
for i in reversed(range(obj_thr)):
y1, x1, y2, x2 = item['hint_sizes'+str(i)]
collage[y1:y2, x1:x2, :] = 0
collage_mask[y1:y2,x1:x2,:] = 1.0
collage = pad_to_square(collage, pad_value = 0, random = False).astype(np.uint8)
collage_mask = pad_to_square(collage_mask, pad_value = 2, random = False).astype(np.float32)
collage = cv2.resize(collage.astype(np.uint8), (512, 512)).astype(np.float32) / 127.5 - 1.0
collage_mask = cv2.resize(collage_mask, (512, 512), interpolation=cv2.INTER_NEAREST).astype(np.float32)
if len(collage_mask.shape) == 2: collage_mask = collage_mask[..., None]
collage_mask[collage_mask == 2] = -1.0
collage_final = np.concatenate([collage, collage_mask[:,:,:1]] , -1)
item.update({'hint': collage_final.copy()})
return item
def load_example_pil(path):
return Image.open(path).convert("RGB")
@spaces.GPU(duration=120)
def pics_pairwise_inference(background, img_a, mask_a, img_b, mask_b):
device = "cuda"
model.to(device)
ddim_sampler = DDIMSampler(model)
back_image = np.array(background)
item_with_collage = {}
objs = [(img_a, mask_a), (img_b, mask_b)]
for j, (img, mask) in enumerate(objs):
temp_patch = f"temp_obj_{j}.png"
cv2.imwrite(temp_patch, cv2.cvtColor(np.array(img), cv2.COLOR_RGB2BGR))
tar_mask = (np.array(mask)[:, :, 0] > 128).astype(np.uint8)
item_with_collage.update(process_pairs_multiple(tar_mask, back_image, temp_patch, counter=j))
item_with_collage = process_composition(item_with_collage, obj_thr=2)
obj_thr = 2
num_samples = 1
H, W = 512, 512
guidance_scale = 5.0
xc = []
xc_mask = []
for i in range(obj_thr):
xc.append(get_input(item_with_collage, f"view{i}").to(device))
xc_mask.append(get_input(item_with_collage, f"mask{i}"))
c_list = [model.get_learned_conditioning(xc_i) for xc_i in xc]
c_tensor = torch.stack(c_list).permute(1, 2, 3, 0)
cond_cross = {"pch_code": c_tensor}
c_mask = torch.stack(xc_mask).permute(1, 2, 3, 4, 0).to(device)
hint = item_with_collage['hint']
control = torch.from_numpy(hint.copy()).float().to(device)
control = torch.stack([control] * num_samples, dim=0)
control = einops.rearrange(control, 'b h w c -> b c h w').clone()
cond = {"c_concat": [control], "c_crossattn": [cond_cross], "c_mask": [c_mask]}
uc_pch = get_unconditional_conditioning(num_samples, obj_thr)
un_cond = {"c_concat": [control], "c_crossattn": [uc_pch], "c_mask": [c_mask]}
shape = (4, H // 8, W // 8)
model.control_scales = [1.0] * 13
samples, _ = ddim_sampler.sample(50, num_samples, shape, cond, verbose=False, eta=0.0,
unconditional_guidance_scale=guidance_scale, unconditional_conditioning=un_cond)
x_samples = model.decode_first_stage(samples)
x_samples = (einops.rearrange(x_samples, 'b c h w -> b h w c') * 127.5 + 127.5).cpu().numpy()
pred = np.clip(x_samples[0], 0, 255).astype(np.uint8)
side = max(back_image.shape[0], back_image.shape[1])
pred_res = cv2.resize(pred, (side, side))
final_image = crop_back(pred_res, back_image, item_with_collage['extra_sizes'],
item_with_collage['hint_sizes0'], item_with_collage['hint_sizes1'], is_masked=True)
return final_image
with gr.Blocks(title="PICS: Pairwise Spatial Compositing with Spatial Interactions") as demo:
gr.Markdown("# 🚀 PICS: Pairwise Image Compositing with Spatial Interactions")
gr.Markdown("Submit **Background**, **Two Objects**, and their **Two Masks** to reason about spatial interactions.")
with gr.Row():
with gr.Column(scale=2):
bg_input = gr.Image(label="1. Scene Background", type="pil")
with gr.Row():
with gr.Column():
gr.Markdown("### Object A")
obj_a_img = gr.Image(label="Image A", type="pil")
obj_a_mask = gr.Image(label="Mask A", type="pil")
with gr.Column():
gr.Markdown("### Object B")
obj_b_img = gr.Image(label="Image B", type="pil")
obj_b_mask = gr.Image(label="Mask B", type="pil")
run_btn = gr.Button("Execute PICS Inference ✨", variant="primary")
with gr.Column(scale=1):
output_img = gr.Image(label="PICS Composite Result")
gr.Markdown("""
### 🎨 PICS Reasoning Logic
* **Pairwise Interaction**: Model reasons about spatial relations between Object A and B.
* **Composition**: It intelligently composites objects into the provided scene background.
""")
gr.Markdown("### 💡 Quick Examples")
gr.Examples(
examples=[
[
load_example_pil("sample/bread_basket/image.jpg"),
load_example_pil("sample/bread_basket/object_0.png"),
load_example_pil("sample/bread_basket/object_0_mask.png"),
load_example_pil("sample/bread_basket/object_1.png"),
load_example_pil("sample/bread_basket/object_1_mask.png")
],
[
load_example_pil("sample/pen_penholder/image.jpg"),
load_example_pil("sample/pen_penholder/object_0.png"),
load_example_pil("sample/pen_penholder/object_0_mask.png"),
load_example_pil("sample/pen_penholder/object_1.png"),
load_example_pil("sample/pen_penholder/object_1_mask.png")
]
],
inputs=[bg_input, obj_a_img, obj_a_mask, obj_b_img, obj_b_mask],
cache_examples=False,
)
run_btn.click(
fn=pics_pairwise_inference,
inputs=[bg_input, obj_a_img, obj_a_mask, obj_b_img, obj_b_mask],
outputs=output_img
)
if __name__ == "__main__":
demo.launch(allowed_paths=[REPO_DIR]) |