djl234 commited on
Commit
09d07ec
·
verified ·
1 Parent(s): f698025

Update app.py

Browse files

[fix] dont cat when single

Files changed (1) hide show
  1. app.py +7 -4
app.py CHANGED
@@ -124,7 +124,7 @@ for k in weight.keys():
124
  net.load_state_dict(new_dict)
125
  net.eval()
126
  net = net.to(device)
127
- def test(gpu_id, net, img_list, group_size, img_size):
128
  print('test')
129
  #device=device
130
  hl,wl=[_.shape[0] for _ in img_list],[_.shape[1] for _ in img_list]
@@ -143,6 +143,8 @@ def test(gpu_id, net, img_list, group_size, img_size):
143
  img_resize=[((group_img[i]-group_img[i].min())/(group_img[i].max()-group_img[i].min())*255).permute(1,2,0).contiguous().numpy().astype(np.uint8)
144
  for i in range(5)]
145
  pred_mask=[(pred_mask[i].numpy().astype(np.uint8)) for i in range(5)]#[(img_resize[i],pred_mask[i].numpy().astype(np.uint8)) for i in range(5)]
 
 
146
  #for i in range(5):
147
  # print(img_list[i].shape,pred_mask[i].shape)
148
  #pred_mask=[crf_refine(img_list[i],pred_mask[i]) for i in range(5)]
@@ -177,12 +179,13 @@ def sepia(img1,img2,img3,img4,img5,stack_image=True):
177
  h_list,w_list=[_.shape[0] for _ in img_list],[_.shape[1] for _ in img_list]
178
  #print(type(img1))
179
  #print(img1.shape)
180
- result_list=test(device,net,img_list,5,224)
 
 
181
  #result_list=[result_list[i].resize((w_list[i], h_list[i]), Image.BILINEAR) for i in range(5)]
182
  img1,img2,img3,img4,img5=result_list#test('cpu',net,img_list,5,224)
183
  white=(torch.ones(img1.shape[0],2,3)*255).numpy().astype(np.uint8)
184
- if not stack_image:
185
- return img1
186
  return np.concatenate([img1,white,img2,white,img3,white,img4,white,img5],axis=1)
187
 
188
  #gr.Image(shape=(224, 2))
 
124
  net.load_state_dict(new_dict)
125
  net.eval()
126
  net = net.to(device)
127
+ def test(gpu_id, net, img_list, group_size, img_size,stack_image=True):
128
  print('test')
129
  #device=device
130
  hl,wl=[_.shape[0] for _ in img_list],[_.shape[1] for _ in img_list]
 
143
  img_resize=[((group_img[i]-group_img[i].min())/(group_img[i].max()-group_img[i].min())*255).permute(1,2,0).contiguous().numpy().astype(np.uint8)
144
  for i in range(5)]
145
  pred_mask=[(pred_mask[i].numpy().astype(np.uint8)) for i in range(5)]#[(img_resize[i],pred_mask[i].numpy().astype(np.uint8)) for i in range(5)]
146
+ if not stack_image:
147
+ return pred_mask[0]
148
  #for i in range(5):
149
  # print(img_list[i].shape,pred_mask[i].shape)
150
  #pred_mask=[crf_refine(img_list[i],pred_mask[i]) for i in range(5)]
 
179
  h_list,w_list=[_.shape[0] for _ in img_list],[_.shape[1] for _ in img_list]
180
  #print(type(img1))
181
  #print(img1.shape)
182
+ result_list=test(device,net,img_list,5,224,stack_image)
183
+ if not stack_image:
184
+ return result_list
185
  #result_list=[result_list[i].resize((w_list[i], h_list[i]), Image.BILINEAR) for i in range(5)]
186
  img1,img2,img3,img4,img5=result_list#test('cpu',net,img_list,5,224)
187
  white=(torch.ones(img1.shape[0],2,3)*255).numpy().astype(np.uint8)
188
+
 
189
  return np.concatenate([img1,white,img2,white,img3,white,img4,white,img5],axis=1)
190
 
191
  #gr.Image(shape=(224, 2))