File size: 9,095 Bytes
fac3291
4d9ba5b
 
 
4d307ef
 
 
 
4d9ba5b
 
dba7566
4d9ba5b
44efe5b
c92a66b
bf9cd8f
 
1090c15
bf9cd8f
4d307ef
cc28535
c92a66b
 
 
6f9b747
2ad8b09
7bd1eec
 
 
 
 
 
 
 
 
64b9e67
 
4d307ef
 
 
f28925f
d4c5a1e
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
64b9e67
 
 
 
d4c5a1e
 
 
 
 
 
 
 
 
 
 
 
 
64b9e67
d4c5a1e
 
 
 
 
dba7566
 
64b9e67
f28925f
3d7998e
4d307ef
 
64b9e67
c92a66b
64b9e67
ca2fd62
fd13683
 
 
 
 
 
35c4eee
4d307ef
35c4eee
 
 
4d307ef
35c4eee
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
64b9e67
35c4eee
ff8dd04
35c4eee
 
 
64b9e67
 
35c4eee
 
 
64b9e67
35c4eee
64b9e67
 
cc28535
f576d1d
 
 
 
fac3291
3c748e1
e2919fd
3d7998e
fac3291
 
 
 
 
 
 
3d7998e
 
 
 
fac3291
3d7998e
 
 
fac3291
3d7998e
2c54e51
 
9d8211a
2c54e51
cc28535
 
 
2c54e51
 
8298ad4
 
fd13683
8298ad4
dba7566
 
 
 
 
eaa6991
9d8211a
 
 
 
 
 
 
8298ad4
 
dba7566
8298ad4
 
2c54e51
 
3d7998e
2c54e51
 
 
 
c92a66b
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
import gradio as gr
import os
import sys
import torch
import cv2
import einops
import numpy as np
import spaces
from omegaconf import OmegaConf
from huggingface_hub import snapshot_download
from PIL import Image

REPO_DIR = snapshot_download(repo_id="Hang2991/PICS")
os.chdir(REPO_DIR)
sys.path.insert(0, REPO_DIR)
sys.path.insert(0, os.path.join(REPO_DIR, "dinov2"))

from cldm.model import create_model, load_state_dict
from cldm.ddim_hacked import DDIMSampler
from datasets.data_utils import *
config = OmegaConf.load('configs/inference.yaml')
model = create_model(config.config_file).cpu()
model.load_state_dict(load_state_dict(config.pretrained_model, location='cpu'))
model.eval()

def get_input(batch, k):
    x = batch[k]
    if len(x.shape) == 3:
        x = x[None, ...]
    x = torch.tensor(x)
    x = einops.rearrange(x, 'b h w c -> b c h w')
    x = x.to(memory_format=torch.contiguous_format).float()
    return x

def get_unconditional_conditioning(N, obj_thr):
    x = [torch.zeros((1, 3, 224, 224)).to(model.device)] * N
    single_uc = model.get_learned_conditioning(x)
    uc = single_uc.unsqueeze(-1).repeat(1, 1, 1, obj_thr)
    return {"pch_code": uc}

def process_pairs_multiple(mask, tar_image, patch_dir, counter=0, max_ratio=0.8):
    view = cv2.imread(patch_dir)
    view = cv2.cvtColor(view, cv2.COLOR_BGR2RGB)
    view = pad_to_square(view, pad_value=255, random=False)
    view = cv2.resize(view.astype(np.uint8), (224, 224))
    view = view.astype(np.float32) / 255.0
    box_yyxx = get_bbox_from_mask(mask)
    H1, W1 = tar_image.shape[0], tar_image.shape[1]
    box_yyxx_crop = [0, H1, 0, W1]
    y1, y2, x1, x2 = box_in_box(box_yyxx, box_yyxx_crop)
    collage = tar_image.copy()
    source_collage = collage.copy()
    collage[y1:y2, x1:x2, :] = 0
    collage_mask = np.zeros_like(tar_image, dtype=np.float32)
    collage_mask[y1:y2, x1:x2, :] = 1.0
    tar_square = pad_to_square(tar_image, pad_value=0, random=False)
    collage_square = pad_to_square(collage, pad_value=0, random=False)
    mask_square = pad_to_square(collage_mask, pad_value=2, random=False)
    H2, W2 = collage_square.shape[0], collage_square.shape[1]
    tar_res = cv2.resize(tar_square, (512, 512)).astype(np.float32)
    col_res = cv2.resize(collage_square, (512, 512)).astype(np.float32)
    mask_res = cv2.resize(mask_square, (512, 512), interpolation=cv2.INTER_NEAREST).astype(np.float32)
    mask_res[mask_res == 2] = -1
    c_mask = np.where(mask_res[..., 0:1] == 1, 1.0, 0.0).astype(np.float32)
    tar_res = tar_res / 127.5 - 1.0
    col_res = col_res / 127.5 - 1.0
    hint_final = np.concatenate([col_res, mask_res[..., :1]], axis=-1)
    return {
        f'view{counter}': view, f'hint{counter}': hint_final, f'mask{counter}': c_mask,
        f'hint_sizes{counter}': np.array([y1, x1, y2, x2]), 'jpg': tar_res,
        'collage': source_collage, 'extra_sizes': np.array([H1, W1, H2, W2])
    }

def process_composition(item, obj_thr):
    collage = item['collage'].copy()
    collage_mask = np.zeros((collage.shape[0], collage.shape[1], 1), dtype=np.float32)
    for i in reversed(range(obj_thr)):
        y1, x1, y2, x2 = item['hint_sizes'+str(i)]
        collage[y1:y2, x1:x2, :] = 0
        collage_mask[y1:y2,x1:x2,:] = 1.0
    collage = pad_to_square(collage, pad_value = 0, random = False).astype(np.uint8)
    collage_mask = pad_to_square(collage_mask, pad_value = 2, random = False).astype(np.float32)
    collage = cv2.resize(collage.astype(np.uint8), (512, 512)).astype(np.float32) / 127.5 - 1.0 
    collage_mask = cv2.resize(collage_mask, (512, 512), interpolation=cv2.INTER_NEAREST).astype(np.float32)
    if len(collage_mask.shape) == 2: collage_mask = collage_mask[..., None]
    collage_mask[collage_mask == 2] = -1.0
    collage_final = np.concatenate([collage, collage_mask[:,:,:1]] , -1)
    item.update({'hint': collage_final.copy()})
    return item

def load_example_pil(path):
    return Image.open(path).convert("RGB")

@spaces.GPU(duration=120)
def pics_pairwise_inference(background, img_a, mask_a, img_b, mask_b):
    device = "cuda"
    model.to(device)
    ddim_sampler = DDIMSampler(model)
    back_image = np.array(background)
    
    item_with_collage = {}
    objs = [(img_a, mask_a), (img_b, mask_b)]
    for j, (img, mask) in enumerate(objs):
        temp_patch = f"temp_obj_{j}.png"
        cv2.imwrite(temp_patch, cv2.cvtColor(np.array(img), cv2.COLOR_RGB2BGR))
        tar_mask = (np.array(mask)[:, :, 0] > 128).astype(np.uint8)
        item_with_collage.update(process_pairs_multiple(tar_mask, back_image, temp_patch, counter=j))

    item_with_collage = process_composition(item_with_collage, obj_thr=2)

    obj_thr = 2
    num_samples = 1
    H, W = 512, 512
    guidance_scale = 5.0
    
    xc = []
    xc_mask = []
    for i in range(obj_thr):
        xc.append(get_input(item_with_collage, f"view{i}").to(device))
        xc_mask.append(get_input(item_with_collage, f"mask{i}"))

    c_list = [model.get_learned_conditioning(xc_i) for xc_i in xc]
    c_tensor = torch.stack(c_list).permute(1, 2, 3, 0)
    cond_cross = {"pch_code": c_tensor}
    c_mask = torch.stack(xc_mask).permute(1, 2, 3, 4, 0).to(device)

    hint = item_with_collage['hint']
    control = torch.from_numpy(hint.copy()).float().to(device)
    control = torch.stack([control] * num_samples, dim=0)
    control = einops.rearrange(control, 'b h w c -> b c h w').clone()

    cond = {"c_concat": [control], "c_crossattn": [cond_cross], "c_mask": [c_mask]}
    uc_pch = get_unconditional_conditioning(num_samples, obj_thr)
    un_cond = {"c_concat": [control], "c_crossattn": [uc_pch], "c_mask": [c_mask]}

    shape = (4, H // 8, W // 8)
    model.control_scales = [1.0] * 13
    
    samples, _ = ddim_sampler.sample(50, num_samples, shape, cond, verbose=False, eta=0.0,
                                     unconditional_guidance_scale=guidance_scale, unconditional_conditioning=un_cond)

    x_samples = model.decode_first_stage(samples)
    x_samples = (einops.rearrange(x_samples, 'b c h w -> b h w c') * 127.5 + 127.5).cpu().numpy()
    pred = np.clip(x_samples[0], 0, 255).astype(np.uint8)
    
    side = max(back_image.shape[0], back_image.shape[1])
    pred_res = cv2.resize(pred, (side, side))

    final_image = crop_back(pred_res, back_image, item_with_collage['extra_sizes'], 
                            item_with_collage['hint_sizes0'], item_with_collage['hint_sizes1'], is_masked=True)
    
    return final_image

with gr.Blocks(title="PICS: Pairwise Spatial Compositing with Spatial Interactions") as demo:
    gr.Markdown("# 🚀 PICS: Pairwise Image Compositing with Spatial Interactions")
    gr.Markdown("Submit **Background**, **Two Objects**, and their **Two Masks** to reason about spatial interactions.")

    with gr.Row():
        with gr.Column(scale=2):
            bg_input = gr.Image(label="1. Scene Background", type="pil")
            
            with gr.Row():
                with gr.Column():
                    gr.Markdown("### Object A")
                    obj_a_img = gr.Image(label="Image A", type="pil")
                    obj_a_mask = gr.Image(label="Mask A", type="pil")
                
                with gr.Column():
                    gr.Markdown("### Object B")
                    obj_b_img = gr.Image(label="Image B", type="pil")
                    obj_b_mask = gr.Image(label="Mask B", type="pil")
            
            run_btn = gr.Button("Execute PICS Inference ✨", variant="primary")

        with gr.Column(scale=1):
            output_img = gr.Image(label="PICS Composite Result")
            gr.Markdown("""
            ### 🎨 PICS Reasoning Logic
            * **Pairwise Interaction**: Model reasons about spatial relations between Object A and B.
            * **Composition**: It intelligently composites objects into the provided scene background.
            """)

    gr.Markdown("### 💡 Quick Examples")
    gr.Examples(
        examples=[
            [
                load_example_pil("sample/bread_basket/image.jpg"), 
                load_example_pil("sample/bread_basket/object_0.png"), 
                load_example_pil("sample/bread_basket/object_0_mask.png"), 
                load_example_pil("sample/bread_basket/object_1.png"), 
                load_example_pil("sample/bread_basket/object_1_mask.png")
            ], 
            [
                load_example_pil("sample/pen_penholder/image.jpg"), 
                load_example_pil("sample/pen_penholder/object_0.png"), 
                load_example_pil("sample/pen_penholder/object_0_mask.png"), 
                load_example_pil("sample/pen_penholder/object_1.png"), 
                load_example_pil("sample/pen_penholder/object_1_mask.png")
            ]
        ],
        inputs=[bg_input, obj_a_img, obj_a_mask, obj_b_img, obj_b_mask],
        cache_examples=False, 
    )

    run_btn.click(
        fn=pics_pairwise_inference,
        inputs=[bg_input, obj_a_img, obj_a_mask, obj_b_img, obj_b_mask],
        outputs=output_img
    )

if __name__ == "__main__":
    demo.launch(allowed_paths=[REPO_DIR])