CUHKWilliam commited on
Commit
6e64d71
·
verified ·
1 Parent(s): bdf0ffa

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +2 -1
app.py CHANGED
@@ -26,6 +26,7 @@ def inference_mask1(
26
  query_img,
27
  *prompt,
28
  ):
 
29
  query_img = Image.fromarray(query_img)
30
  query_img_np = np.asarray(query_img)
31
  query_img = transformation(query_img)
@@ -73,10 +74,10 @@ def inference_mask1(
73
  "support_masks": support_masks,
74
  "support_imgs": support_img,
75
  "query_img": query_img,
 
76
  }
77
  nshot = support_masks.size(1)
78
  pred_mask, simi, simi_map = model.predict_mask_nshot(batch, nshot=nshot)
79
- pred_mask = F.interpolate(pred_mask, shape)
80
  pred_mask = pred.detach().cpu().numpy()
81
  output_img = query_img_np * 0.5 + 0.5 * np.array([1, 0, 0]) * np.expand_dims(pred_mask, axis=0)
82
  output_img = (output_img * 255).astype(np.uint8)
 
26
  query_img,
27
  *prompt,
28
  ):
29
+ org_qry_imsize = query_img.size
30
  query_img = Image.fromarray(query_img)
31
  query_img_np = np.asarray(query_img)
32
  query_img = transformation(query_img)
 
74
  "support_masks": support_masks,
75
  "support_imgs": support_img,
76
  "query_img": query_img,
77
+ "org_query_imsize": [torch.tensor([org_qry_imsize[0]]), torch.tensor([org_qry_imsize[1]])],
78
  }
79
  nshot = support_masks.size(1)
80
  pred_mask, simi, simi_map = model.predict_mask_nshot(batch, nshot=nshot)
 
81
  pred_mask = pred.detach().cpu().numpy()
82
  output_img = query_img_np * 0.5 + 0.5 * np.array([1, 0, 0]) * np.expand_dims(pred_mask, axis=0)
83
  output_img = (output_img * 255).astype(np.uint8)