wenpeng commited on
Commit
7079dcb
·
1 Parent(s): 2e36fe0

add sod output

Browse files
Files changed (2) hide show
  1. app.py +10 -8
  2. sod/infer_model.py +1 -1
app.py CHANGED
@@ -9,10 +9,10 @@ import cv2
9
  # cmd = 'sh download.sh'
10
  # os.system(cmd)
11
 
12
- device = torch.device("cuda:0") if torch.cuda.is_available() else torch.device("cpu")
13
  print(device)
14
  inpaint_model = inpaint.IVModel(device=device)
15
- sod_model = sod.IVModel(device=torch.device("cpu"))
16
  max_size=512
17
  scale_factor = 8
18
  count = 0
@@ -28,17 +28,19 @@ def sod_inpaint(img):
28
  w = w // scale_factor * scale_factor
29
  img = cv2.resize(img, (w,h))
30
  img = img[:,:,::-1]
31
- res = sod_model.forward(img,None)
32
- res = np.uint8(res)
33
- res = inpaint_model.forward(res,None)
34
- res = np.uint8(res)
 
 
35
  count +=1
36
  print(count, ' images have been processed')
37
- return res[:,:,::-1]
38
 
39
 
40
 
41
  examples = glob.glob('examples/*.*')
42
  inputs = gr.inputs.Image(shape=None, image_mode="RGB", invert_colors=False, source="upload", tool="editor", type="numpy", label=None, optional=False)
43
- iface = gr.Interface(fn=sod_inpaint, inputs=inputs, outputs="image", examples=examples, title='Salient Object Detection + Inpaint', description='Upload an image and you will find something disappears', theme='huggingface')
44
  iface.launch()
 
9
  # cmd = 'sh download.sh'
10
  # os.system(cmd)
11
 
12
+ device = torch.device("cpu")
13
  print(device)
14
  inpaint_model = inpaint.IVModel(device=device)
15
+ sod_model = sod.IVModel(device=device)
16
  max_size=512
17
  scale_factor = 8
18
  count = 0
 
28
  w = w // scale_factor * scale_factor
29
  img = cv2.resize(img, (w,h))
30
  img = img[:,:,::-1]
31
+ sod_res = sod_model.forward(img,None)
32
+ sod_res = np.uint8(sod_res)
33
+ h,w = sod_res.shape[:2]
34
+ so = np.uint8(sod_res[:,:w//2,:] * (sod_res[:,w//2:,:]/255))
35
+ inpaint_res = inpaint_model.forward(sod_res,None)
36
+ inpaint_res = np.uint8(inpaint_res)
37
  count +=1
38
  print(count, ' images have been processed')
39
+ return so[:,:,::-1], inpaint_res[:,:,::-1]
40
 
41
 
42
 
43
  examples = glob.glob('examples/*.*')
44
  inputs = gr.inputs.Image(shape=None, image_mode="RGB", invert_colors=False, source="upload", tool="editor", type="numpy", label=None, optional=False)
45
+ iface = gr.Interface(fn=sod_inpaint, inputs=inputs, outputs=["image", "image"], examples=examples, title='Salient Object Detection + Inpaint', description='Upload an image and you will see the fg and inpainted bg', theme='huggingface')
46
  iface.launch()
sod/infer_model.py CHANGED
@@ -76,7 +76,7 @@ class IVModel():
76
  img_t = self.input_preprocess_tensor(img)
77
  shape = [torch.as_tensor([img_t.shape[2]]), torch.as_tensor([img_t.shape[3]])]
78
  h, w = img_t.shape[2], img_t.shape[3]
79
- img_t_temp = F.interpolate(img_t, (512, 512), mode='area')
80
  with torch.no_grad():
81
  res = self.net(img_t_temp, shape=shape)
82
  res = F.interpolate(res[0],size=shape, mode='bilinear')
 
76
  img_t = self.input_preprocess_tensor(img)
77
  shape = [torch.as_tensor([img_t.shape[2]]), torch.as_tensor([img_t.shape[3]])]
78
  h, w = img_t.shape[2], img_t.shape[3]
79
+ img_t_temp = F.interpolate(img_t, (1024, 1024), mode='area')
80
  with torch.no_grad():
81
  res = self.net(img_t_temp, shape=shape)
82
  res = F.interpolate(res[0],size=shape, mode='bilinear')