PICS / app.py
RyanHangZhou's picture
Update app.py
eaa6991 verified
import gradio as gr
import os
import sys
import torch
import cv2
import einops
import numpy as np
import spaces
from omegaconf import OmegaConf
from huggingface_hub import snapshot_download
from PIL import Image
REPO_DIR = snapshot_download(repo_id="Hang2991/PICS")
os.chdir(REPO_DIR)
sys.path.insert(0, REPO_DIR)
sys.path.insert(0, os.path.join(REPO_DIR, "dinov2"))
from cldm.model import create_model, load_state_dict
from cldm.ddim_hacked import DDIMSampler
from datasets.data_utils import *
config = OmegaConf.load('configs/inference.yaml')
model = create_model(config.config_file).cpu()
model.load_state_dict(load_state_dict(config.pretrained_model, location='cpu'))
model.eval()
def get_input(batch, k):
x = batch[k]
if len(x.shape) == 3:
x = x[None, ...]
x = torch.tensor(x)
x = einops.rearrange(x, 'b h w c -> b c h w')
x = x.to(memory_format=torch.contiguous_format).float()
return x
def get_unconditional_conditioning(N, obj_thr):
x = [torch.zeros((1, 3, 224, 224)).to(model.device)] * N
single_uc = model.get_learned_conditioning(x)
uc = single_uc.unsqueeze(-1).repeat(1, 1, 1, obj_thr)
return {"pch_code": uc}
def process_pairs_multiple(mask, tar_image, patch_dir, counter=0, max_ratio=0.8):
view = cv2.imread(patch_dir)
view = cv2.cvtColor(view, cv2.COLOR_BGR2RGB)
view = pad_to_square(view, pad_value=255, random=False)
view = cv2.resize(view.astype(np.uint8), (224, 224))
view = view.astype(np.float32) / 255.0
box_yyxx = get_bbox_from_mask(mask)
H1, W1 = tar_image.shape[0], tar_image.shape[1]
box_yyxx_crop = [0, H1, 0, W1]
y1, y2, x1, x2 = box_in_box(box_yyxx, box_yyxx_crop)
collage = tar_image.copy()
source_collage = collage.copy()
collage[y1:y2, x1:x2, :] = 0
collage_mask = np.zeros_like(tar_image, dtype=np.float32)
collage_mask[y1:y2, x1:x2, :] = 1.0
tar_square = pad_to_square(tar_image, pad_value=0, random=False)
collage_square = pad_to_square(collage, pad_value=0, random=False)
mask_square = pad_to_square(collage_mask, pad_value=2, random=False)
H2, W2 = collage_square.shape[0], collage_square.shape[1]
tar_res = cv2.resize(tar_square, (512, 512)).astype(np.float32)
col_res = cv2.resize(collage_square, (512, 512)).astype(np.float32)
mask_res = cv2.resize(mask_square, (512, 512), interpolation=cv2.INTER_NEAREST).astype(np.float32)
mask_res[mask_res == 2] = -1
c_mask = np.where(mask_res[..., 0:1] == 1, 1.0, 0.0).astype(np.float32)
tar_res = tar_res / 127.5 - 1.0
col_res = col_res / 127.5 - 1.0
hint_final = np.concatenate([col_res, mask_res[..., :1]], axis=-1)
return {
f'view{counter}': view, f'hint{counter}': hint_final, f'mask{counter}': c_mask,
f'hint_sizes{counter}': np.array([y1, x1, y2, x2]), 'jpg': tar_res,
'collage': source_collage, 'extra_sizes': np.array([H1, W1, H2, W2])
}
def process_composition(item, obj_thr):
collage = item['collage'].copy()
collage_mask = np.zeros((collage.shape[0], collage.shape[1], 1), dtype=np.float32)
for i in reversed(range(obj_thr)):
y1, x1, y2, x2 = item['hint_sizes'+str(i)]
collage[y1:y2, x1:x2, :] = 0
collage_mask[y1:y2,x1:x2,:] = 1.0
collage = pad_to_square(collage, pad_value = 0, random = False).astype(np.uint8)
collage_mask = pad_to_square(collage_mask, pad_value = 2, random = False).astype(np.float32)
collage = cv2.resize(collage.astype(np.uint8), (512, 512)).astype(np.float32) / 127.5 - 1.0
collage_mask = cv2.resize(collage_mask, (512, 512), interpolation=cv2.INTER_NEAREST).astype(np.float32)
if len(collage_mask.shape) == 2: collage_mask = collage_mask[..., None]
collage_mask[collage_mask == 2] = -1.0
collage_final = np.concatenate([collage, collage_mask[:,:,:1]] , -1)
item.update({'hint': collage_final.copy()})
return item
def load_example_pil(path):
return Image.open(path).convert("RGB")
@spaces.GPU(duration=120)
def pics_pairwise_inference(background, img_a, mask_a, img_b, mask_b):
device = "cuda"
model.to(device)
ddim_sampler = DDIMSampler(model)
back_image = np.array(background)
item_with_collage = {}
objs = [(img_a, mask_a), (img_b, mask_b)]
for j, (img, mask) in enumerate(objs):
temp_patch = f"temp_obj_{j}.png"
cv2.imwrite(temp_patch, cv2.cvtColor(np.array(img), cv2.COLOR_RGB2BGR))
tar_mask = (np.array(mask)[:, :, 0] > 128).astype(np.uint8)
item_with_collage.update(process_pairs_multiple(tar_mask, back_image, temp_patch, counter=j))
item_with_collage = process_composition(item_with_collage, obj_thr=2)
obj_thr = 2
num_samples = 1
H, W = 512, 512
guidance_scale = 5.0
xc = []
xc_mask = []
for i in range(obj_thr):
xc.append(get_input(item_with_collage, f"view{i}").to(device))
xc_mask.append(get_input(item_with_collage, f"mask{i}"))
c_list = [model.get_learned_conditioning(xc_i) for xc_i in xc]
c_tensor = torch.stack(c_list).permute(1, 2, 3, 0)
cond_cross = {"pch_code": c_tensor}
c_mask = torch.stack(xc_mask).permute(1, 2, 3, 4, 0).to(device)
hint = item_with_collage['hint']
control = torch.from_numpy(hint.copy()).float().to(device)
control = torch.stack([control] * num_samples, dim=0)
control = einops.rearrange(control, 'b h w c -> b c h w').clone()
cond = {"c_concat": [control], "c_crossattn": [cond_cross], "c_mask": [c_mask]}
uc_pch = get_unconditional_conditioning(num_samples, obj_thr)
un_cond = {"c_concat": [control], "c_crossattn": [uc_pch], "c_mask": [c_mask]}
shape = (4, H // 8, W // 8)
model.control_scales = [1.0] * 13
samples, _ = ddim_sampler.sample(50, num_samples, shape, cond, verbose=False, eta=0.0,
unconditional_guidance_scale=guidance_scale, unconditional_conditioning=un_cond)
x_samples = model.decode_first_stage(samples)
x_samples = (einops.rearrange(x_samples, 'b c h w -> b h w c') * 127.5 + 127.5).cpu().numpy()
pred = np.clip(x_samples[0], 0, 255).astype(np.uint8)
side = max(back_image.shape[0], back_image.shape[1])
pred_res = cv2.resize(pred, (side, side))
final_image = crop_back(pred_res, back_image, item_with_collage['extra_sizes'],
item_with_collage['hint_sizes0'], item_with_collage['hint_sizes1'], is_masked=True)
return final_image
with gr.Blocks(title="PICS: Pairwise Spatial Compositing with Spatial Interactions") as demo:
gr.Markdown("# 🚀 PICS: Pairwise Image Compositing with Spatial Interactions")
gr.Markdown("Submit **Background**, **Two Objects**, and their **Two Masks** to reason about spatial interactions.")
with gr.Row():
with gr.Column(scale=2):
bg_input = gr.Image(label="1. Scene Background", type="pil")
with gr.Row():
with gr.Column():
gr.Markdown("### Object A")
obj_a_img = gr.Image(label="Image A", type="pil")
obj_a_mask = gr.Image(label="Mask A", type="pil")
with gr.Column():
gr.Markdown("### Object B")
obj_b_img = gr.Image(label="Image B", type="pil")
obj_b_mask = gr.Image(label="Mask B", type="pil")
run_btn = gr.Button("Execute PICS Inference ✨", variant="primary")
with gr.Column(scale=1):
output_img = gr.Image(label="PICS Composite Result")
gr.Markdown("""
### 🎨 PICS Reasoning Logic
* **Pairwise Interaction**: Model reasons about spatial relations between Object A and B.
* **Composition**: It intelligently composites objects into the provided scene background.
""")
gr.Markdown("### 💡 Quick Examples")
gr.Examples(
examples=[
[
load_example_pil("sample/bread_basket/image.jpg"),
load_example_pil("sample/bread_basket/object_0.png"),
load_example_pil("sample/bread_basket/object_0_mask.png"),
load_example_pil("sample/bread_basket/object_1.png"),
load_example_pil("sample/bread_basket/object_1_mask.png")
],
[
load_example_pil("sample/pen_penholder/image.jpg"),
load_example_pil("sample/pen_penholder/object_0.png"),
load_example_pil("sample/pen_penholder/object_0_mask.png"),
load_example_pil("sample/pen_penholder/object_1.png"),
load_example_pil("sample/pen_penholder/object_1_mask.png")
]
],
inputs=[bg_input, obj_a_img, obj_a_mask, obj_b_img, obj_b_mask],
cache_examples=False,
)
run_btn.click(
fn=pics_pairwise_inference,
inputs=[bg_input, obj_a_img, obj_a_mask, obj_b_img, obj_b_mask],
outputs=output_img
)
if __name__ == "__main__":
demo.launch(allowed_paths=[REPO_DIR])