Spaces:
Running on Zero
Running on Zero
Update app.py
Browse files
app.py
CHANGED
|
@@ -8,21 +8,11 @@ import numpy as np
|
|
| 8 |
import spaces
|
| 9 |
from omegaconf import OmegaConf
|
| 10 |
from huggingface_hub import snapshot_download
|
| 11 |
-
from PIL import Image
|
| 12 |
|
| 13 |
REPO_DIR = snapshot_download(repo_id="Hang2991/PICS")
|
| 14 |
os.chdir(REPO_DIR)
|
| 15 |
sys.path.insert(0, REPO_DIR)
|
| 16 |
sys.path.insert(0, os.path.join(REPO_DIR, "dinov2"))
|
| 17 |
-
sys.path.insert(0, os.path.join(REPO_DIR, "sample"))
|
| 18 |
-
|
| 19 |
-
|
| 20 |
-
import shutil
|
| 21 |
-
# 1. 强制把 sample 文件夹拷贝到 Gradio 运行的根目录 (/home/user/app)
|
| 22 |
-
# 不管你 chdir 去了哪,这一步保证了 /home/user/app/sample 真实存在
|
| 23 |
-
if not os.path.exists("/home/user/app/sample"):
|
| 24 |
-
# REPO_DIR 是你下载的那个缓存路径
|
| 25 |
-
shutil.copytree(os.path.join(REPO_DIR, "sample"), "/home/user/app/sample")
|
| 26 |
|
| 27 |
from cldm.model import create_model, load_state_dict
|
| 28 |
from cldm.ddim_hacked import DDIMSampler
|
|
@@ -32,7 +22,6 @@ model = create_model(config.config_file).cpu()
|
|
| 32 |
model.load_state_dict(load_state_dict(config.pretrained_model, location='cpu'))
|
| 33 |
model.eval()
|
| 34 |
|
| 35 |
-
|
| 36 |
def get_input(batch, k):
|
| 37 |
x = batch[k]
|
| 38 |
if len(x.shape) == 3:
|
|
@@ -104,31 +93,15 @@ def pics_pairwise_inference(background, img_a, mask_a, img_b, mask_b):
|
|
| 104 |
device = "cuda"
|
| 105 |
model.to(device)
|
| 106 |
ddim_sampler = DDIMSampler(model)
|
| 107 |
-
|
| 108 |
-
def force_to_pil(data):
|
| 109 |
-
if data is None: return None
|
| 110 |
-
if isinstance(data, str): # 如果是路径字符串
|
| 111 |
-
return Image.open(data).convert("RGB")
|
| 112 |
-
if isinstance(data, np.ndarray): # 如果已经是 numpy
|
| 113 |
-
return Image.fromarray(data).convert("RGB")
|
| 114 |
-
return data.convert("RGB") # 假设是 PIL
|
| 115 |
-
|
| 116 |
-
background = force_to_pil(background)
|
| 117 |
-
img_a = force_to_pil(img_a)
|
| 118 |
-
mask_a = force_to_pil(mask_a)
|
| 119 |
-
img_b = force_to_pil(img_b)
|
| 120 |
-
mask_b = force_to_pil(mask_b)
|
| 121 |
-
|
| 122 |
back_image = np.array(background)
|
| 123 |
-
# back_image = np.array(background)
|
| 124 |
|
| 125 |
item_with_collage = {}
|
| 126 |
-
|
| 127 |
-
|
| 128 |
-
|
| 129 |
-
|
| 130 |
-
|
| 131 |
-
|
| 132 |
|
| 133 |
item_with_collage = process_composition(item_with_collage, obj_thr=2)
|
| 134 |
|
|
@@ -210,19 +183,17 @@ with gr.Blocks(title="PICS: Pairwise Spatial Compositing with Spatial Interactio
|
|
| 210 |
# --- 核心修改:把 Examples 放在这里 ---
|
| 211 |
gr.Markdown("### 💡 Quick Examples")
|
| 212 |
gr.Examples(
|
| 213 |
-
|
| 214 |
[
|
| 215 |
-
"
|
| 216 |
-
"/
|
| 217 |
-
"
|
| 218 |
-
"/home/user/app/sample/bread_basket/object_1.png",
|
| 219 |
-
"/home/user/app/sample/bread_basket/object_1_mask.png"
|
| 220 |
]
|
| 221 |
],
|
| 222 |
inputs=[bg_input, obj_a_img, obj_a_mask, obj_b_img, obj_b_mask],
|
| 223 |
outputs=output_img, # 现在它认识这个变量了
|
| 224 |
fn=pics_pairwise_inference,
|
| 225 |
-
cache_examples=
|
| 226 |
)
|
| 227 |
|
| 228 |
# 按钮绑定
|
|
|
|
| 8 |
import spaces
|
| 9 |
from omegaconf import OmegaConf
|
| 10 |
from huggingface_hub import snapshot_download
|
|
|
|
| 11 |
|
| 12 |
REPO_DIR = snapshot_download(repo_id="Hang2991/PICS")
|
| 13 |
os.chdir(REPO_DIR)
|
| 14 |
sys.path.insert(0, REPO_DIR)
|
| 15 |
sys.path.insert(0, os.path.join(REPO_DIR, "dinov2"))
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 16 |
|
| 17 |
from cldm.model import create_model, load_state_dict
|
| 18 |
from cldm.ddim_hacked import DDIMSampler
|
|
|
|
| 22 |
model.load_state_dict(load_state_dict(config.pretrained_model, location='cpu'))
|
| 23 |
model.eval()
|
| 24 |
|
|
|
|
| 25 |
def get_input(batch, k):
|
| 26 |
x = batch[k]
|
| 27 |
if len(x.shape) == 3:
|
|
|
|
| 93 |
device = "cuda"
|
| 94 |
model.to(device)
|
| 95 |
ddim_sampler = DDIMSampler(model)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 96 |
back_image = np.array(background)
|
|
|
|
| 97 |
|
| 98 |
item_with_collage = {}
|
| 99 |
+
objs = [(img_a, mask_a), (img_b, mask_b)]
|
| 100 |
+
for j, (img, mask) in enumerate(objs):
|
| 101 |
+
temp_patch = f"temp_obj_{j}.png"
|
| 102 |
+
cv2.imwrite(temp_patch, cv2.cvtColor(np.array(img), cv2.COLOR_RGB2BGR))
|
| 103 |
+
tar_mask = (np.array(mask)[:, :, 0] > 128).astype(np.uint8)
|
| 104 |
+
item_with_collage.update(process_pairs_multiple(tar_mask, back_image, temp_patch, counter=j))
|
| 105 |
|
| 106 |
item_with_collage = process_composition(item_with_collage, obj_thr=2)
|
| 107 |
|
|
|
|
| 183 |
# --- 核心修改:把 Examples 放在这里 ---
|
| 184 |
gr.Markdown("### 💡 Quick Examples")
|
| 185 |
gr.Examples(
|
| 186 |
+
examples=[
|
| 187 |
[
|
| 188 |
+
"sample/bread_basket/image.jpg",
|
| 189 |
+
"sample/bread_basket/object_0.png", "sample/bread_basket/object_0_mask.png",
|
| 190 |
+
"sample/bread_basket/object_1.png", "sample/bread_basket/object_1_mask.png"
|
|
|
|
|
|
|
| 191 |
]
|
| 192 |
],
|
| 193 |
inputs=[bg_input, obj_a_img, obj_a_mask, obj_b_img, obj_b_mask],
|
| 194 |
outputs=output_img, # 现在它认识这个变量了
|
| 195 |
fn=pics_pairwise_inference,
|
| 196 |
+
cache_examples=True,
|
| 197 |
)
|
| 198 |
|
| 199 |
# 按钮绑定
|