RyanHangZhou commited on
Commit
c92a66b
·
verified ·
1 Parent(s): ad982db

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +8 -21
app.py CHANGED
@@ -10,22 +10,16 @@ from omegaconf import OmegaConf
10
  from huggingface_hub import snapshot_download
11
 
12
  REPO_DIR = snapshot_download(repo_id="Hang2991/PICS")
13
- # os.chdir(REPO_DIR)
14
  sys.path.insert(0, REPO_DIR)
15
  sys.path.insert(0, os.path.join(REPO_DIR, "dinov2"))
16
 
17
  from cldm.model import create_model, load_state_dict
18
  from cldm.ddim_hacked import DDIMSampler
19
  from datasets.data_utils import *
20
- # config = OmegaConf.load('configs/inference.yaml')
21
- config_path = os.path.join(REPO_DIR, 'configs/inference.yaml')
22
- config = OmegaConf.load(config_path)
23
- # pretrained_path = os.path.join(REPO_DIR, config.pretrained_model)
24
-
25
- full_config_file = os.path.join(REPO_DIR, config.config_file)
26
- full_pretrained_path = os.path.join(REPO_DIR, config.pretrained_model)
27
- model = create_model(full_config_file).cpu()
28
- model.load_state_dict(load_state_dict(full_pretrained_path, location='cpu'))
29
  model.eval()
30
 
31
  def get_input(batch, k):
@@ -96,24 +90,17 @@ def process_composition(item, obj_thr):
96
 
97
  @spaces.GPU(duration=120)
98
  def pics_pairwise_inference(background, img_a, mask_a, img_b, mask_b):
99
- def to_numpy(img):
100
- if img is None: return None
101
- # 如果 Gradio 传过来的是路径字符串,手动读取它
102
- if isinstance(img, str):
103
- return cv2.cvtColor(cv2.imread(img), cv2.COLOR_BGR2RGB)
104
- return np.array(img)
105
-
106
  device = "cuda"
107
  model.to(device)
108
  ddim_sampler = DDIMSampler(model)
109
- back_image = to_numpy(background)
110
 
111
  item_with_collage = {}
112
  objs = [(img_a, mask_a), (img_b, mask_b)]
113
  for j, (img, mask) in enumerate(objs):
114
  temp_patch = f"temp_obj_{j}.png"
115
- cv2.imwrite(temp_patch, cv2.cvtColor(to_numpy(img), cv2.COLOR_RGB2BGR))
116
- tar_mask = (to_numpy(mask)[:, :, 0] > 128).astype(np.uint8)
117
  item_with_collage.update(process_pairs_multiple(tar_mask, back_image, temp_patch, counter=j))
118
 
119
  item_with_collage = process_composition(item_with_collage, obj_thr=2)
@@ -217,4 +204,4 @@ with gr.Blocks(title="PICS: Pairwise Spatial Compositing with Spatial Interactio
217
  )
218
 
219
  if __name__ == "__main__":
220
- demo.launch()
 
10
  from huggingface_hub import snapshot_download
11
 
12
  REPO_DIR = snapshot_download(repo_id="Hang2991/PICS")
13
+ os.chdir(REPO_DIR)
14
  sys.path.insert(0, REPO_DIR)
15
  sys.path.insert(0, os.path.join(REPO_DIR, "dinov2"))
16
 
17
  from cldm.model import create_model, load_state_dict
18
  from cldm.ddim_hacked import DDIMSampler
19
  from datasets.data_utils import *
20
+ config = OmegaConf.load('configs/inference.yaml')
21
+ model = create_model(config.config_file).cpu()
22
+ model.load_state_dict(load_state_dict(config.pretrained_model, location='cpu'))
 
 
 
 
 
 
23
  model.eval()
24
 
25
  def get_input(batch, k):
 
90
 
91
  @spaces.GPU(duration=120)
92
  def pics_pairwise_inference(background, img_a, mask_a, img_b, mask_b):
 
 
 
 
 
 
 
93
  device = "cuda"
94
  model.to(device)
95
  ddim_sampler = DDIMSampler(model)
96
+ back_image = np.array(background)
97
 
98
  item_with_collage = {}
99
  objs = [(img_a, mask_a), (img_b, mask_b)]
100
  for j, (img, mask) in enumerate(objs):
101
  temp_patch = f"temp_obj_{j}.png"
102
+ cv2.imwrite(temp_patch, cv2.cvtColor(np.array(img), cv2.COLOR_RGB2BGR))
103
+ tar_mask = (np.array(mask)[:, :, 0] > 128).astype(np.uint8)
104
  item_with_collage.update(process_pairs_multiple(tar_mask, back_image, temp_patch, counter=j))
105
 
106
  item_with_collage = process_composition(item_with_collage, obj_thr=2)
 
204
  )
205
 
206
  if __name__ == "__main__":
207
+ demo.launch(allowed_paths=[REPO_DIR])