RyanHangZhou commited on
Commit
fd13683
·
verified ·
1 Parent(s): ca2fd62

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +11 -40
app.py CHANGED
@@ -8,21 +8,11 @@ import numpy as np
8
  import spaces
9
  from omegaconf import OmegaConf
10
  from huggingface_hub import snapshot_download
11
- from PIL import Image
12
 
13
  REPO_DIR = snapshot_download(repo_id="Hang2991/PICS")
14
  os.chdir(REPO_DIR)
15
  sys.path.insert(0, REPO_DIR)
16
  sys.path.insert(0, os.path.join(REPO_DIR, "dinov2"))
17
- sys.path.insert(0, os.path.join(REPO_DIR, "sample"))
18
-
19
-
20
- import shutil
21
- # 1. 强制把 sample 文件夹拷贝到 Gradio 运行的根目录 (/home/user/app)
22
- # 不管你 chdir 去了哪,这一步保证了 /home/user/app/sample 真实存在
23
- if not os.path.exists("/home/user/app/sample"):
24
- # REPO_DIR 是你下载的那个缓存路径
25
- shutil.copytree(os.path.join(REPO_DIR, "sample"), "/home/user/app/sample")
26
 
27
  from cldm.model import create_model, load_state_dict
28
  from cldm.ddim_hacked import DDIMSampler
@@ -32,7 +22,6 @@ model = create_model(config.config_file).cpu()
32
  model.load_state_dict(load_state_dict(config.pretrained_model, location='cpu'))
33
  model.eval()
34
 
35
-
36
  def get_input(batch, k):
37
  x = batch[k]
38
  if len(x.shape) == 3:
@@ -104,31 +93,15 @@ def pics_pairwise_inference(background, img_a, mask_a, img_b, mask_b):
104
  device = "cuda"
105
  model.to(device)
106
  ddim_sampler = DDIMSampler(model)
107
-
108
- def force_to_pil(data):
109
- if data is None: return None
110
- if isinstance(data, str): # 如果是路径字符串
111
- return Image.open(data).convert("RGB")
112
- if isinstance(data, np.ndarray): # 如果已经是 numpy
113
- return Image.fromarray(data).convert("RGB")
114
- return data.convert("RGB") # 假设是 PIL
115
-
116
- background = force_to_pil(background)
117
- img_a = force_to_pil(img_a)
118
- mask_a = force_to_pil(mask_a)
119
- img_b = force_to_pil(img_b)
120
- mask_b = force_to_pil(mask_b)
121
-
122
  back_image = np.array(background)
123
- # back_image = np.array(background)
124
 
125
  item_with_collage = {}
126
- # objs = [(img_a, mask_a), (img_b, mask_b)]
127
- # for j, (img, mask) in enumerate(objs):
128
- # temp_patch = f"temp_obj_{j}.png"
129
- # cv2.imwrite(temp_patch, cv2.cvtColor(np.array(img), cv2.COLOR_RGB2BGR))
130
- # tar_mask = (np.array(mask)[:, :, 0] > 128).astype(np.uint8)
131
- # item_with_collage.update(process_pairs_multiple(tar_mask, back_image, temp_patch, counter=j))
132
 
133
  item_with_collage = process_composition(item_with_collage, obj_thr=2)
134
 
@@ -210,19 +183,17 @@ with gr.Blocks(title="PICS: Pairwise Spatial Compositing with Spatial Interactio
210
  # --- 核心修改:把 Examples 放在这里 ---
211
  gr.Markdown("### 💡 Quick Examples")
212
  gr.Examples(
213
- examples=[
214
  [
215
- "/home/user/app/sample/bread_basket/image.jpg",
216
- "/home/user/app/sample/bread_basket/object_0.png",
217
- "/home/user/app/sample/bread_basket/object_0_mask.png",
218
- "/home/user/app/sample/bread_basket/object_1.png",
219
- "/home/user/app/sample/bread_basket/object_1_mask.png"
220
  ]
221
  ],
222
  inputs=[bg_input, obj_a_img, obj_a_mask, obj_b_img, obj_b_mask],
223
  outputs=output_img, # 现在它认识这个变量了
224
  fn=pics_pairwise_inference,
225
- cache_examples=False,
226
  )
227
 
228
  # 按钮绑定
 
8
  import spaces
9
  from omegaconf import OmegaConf
10
  from huggingface_hub import snapshot_download
 
11
 
12
  REPO_DIR = snapshot_download(repo_id="Hang2991/PICS")
13
  os.chdir(REPO_DIR)
14
  sys.path.insert(0, REPO_DIR)
15
  sys.path.insert(0, os.path.join(REPO_DIR, "dinov2"))
 
 
 
 
 
 
 
 
 
16
 
17
  from cldm.model import create_model, load_state_dict
18
  from cldm.ddim_hacked import DDIMSampler
 
22
  model.load_state_dict(load_state_dict(config.pretrained_model, location='cpu'))
23
  model.eval()
24
 
 
25
  def get_input(batch, k):
26
  x = batch[k]
27
  if len(x.shape) == 3:
 
93
  device = "cuda"
94
  model.to(device)
95
  ddim_sampler = DDIMSampler(model)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
96
  back_image = np.array(background)
 
97
 
98
  item_with_collage = {}
99
+ objs = [(img_a, mask_a), (img_b, mask_b)]
100
+ for j, (img, mask) in enumerate(objs):
101
+ temp_patch = f"temp_obj_{j}.png"
102
+ cv2.imwrite(temp_patch, cv2.cvtColor(np.array(img), cv2.COLOR_RGB2BGR))
103
+ tar_mask = (np.array(mask)[:, :, 0] > 128).astype(np.uint8)
104
+ item_with_collage.update(process_pairs_multiple(tar_mask, back_image, temp_patch, counter=j))
105
 
106
  item_with_collage = process_composition(item_with_collage, obj_thr=2)
107
 
 
183
  # --- 核心修改:把 Examples 放在这里 ---
184
  gr.Markdown("### 💡 Quick Examples")
185
  gr.Examples(
186
+ examples=[
187
  [
188
+ "sample/bread_basket/image.jpg",
189
+ "sample/bread_basket/object_0.png", "sample/bread_basket/object_0_mask.png",
190
+ "sample/bread_basket/object_1.png", "sample/bread_basket/object_1_mask.png"
 
 
191
  ]
192
  ],
193
  inputs=[bg_input, obj_a_img, obj_a_mask, obj_b_img, obj_b_mask],
194
  outputs=output_img, # 现在它认识这个变量了
195
  fn=pics_pairwise_inference,
196
+ cache_examples=True,
197
  )
198
 
199
  # 按钮绑定