KurtLin commited on
Commit
070d37c
·
1 Parent(s): 46a64cf

Add rescale function

Browse files
Files changed (1) hide show
  1. app.py +24 -9
app.py CHANGED
@@ -21,7 +21,7 @@ def filesort(img, model):
21
  img = cv2.cvtColor(img, cv2.COLOR_BGR2GRAY)
22
  h, w = img.shape
23
  img_out = preprocessing(img, model)
24
- return img_out, h, w, img, ori
25
 
26
  def preprocessing(img, model='SE-RegUNet 4GF'):
27
  # print(img.shape, img.dtype)
@@ -46,7 +46,14 @@ def preprocessing(img, model='SE-RegUNet 4GF'):
46
  img_out = np.stack((image1,)*3, axis=0)
47
  return img_out
48
 
49
- def process_input_image(img, model, rescale):
 
 
 
 
 
 
 
50
  ori_img = img.copy()
51
  h, w, _ = ori_img.shape
52
  pad_h = h % 32
@@ -72,13 +79,21 @@ def process_input_image(img, model, rescale):
72
  pipe = torch.jit.load('./model/UNet3plus.pt')
73
  pipe = pipe.to(device).eval()
74
 
 
 
75
  start = time.time()
76
- img, h, w, ori_gray, ori = filesort(img, model)
77
- img = torch.FloatTensor(img).unsqueeze(0).to(device)
78
- with torch.no_grad():
79
- if model == 'AngioNet':
80
- img = torch.cat([img, img], dim=0)
81
- logit = np.round(torch.softmax(pipe.forward(img), dim=1).detach().cpu().numpy()[0, 0]).astype(np.uint8)
 
 
 
 
 
 
82
  spent = time.time() - start
83
  spent = f"{spent:.3f} seconds"
84
 
@@ -108,7 +123,7 @@ with my_app:
108
  img_source = gr.Image(label="Please select angiogram.", value='./example/angio.png', shape=(512, 512))
109
  model_choice = gr.Dropdown(['SE-RegUNet 4GF', 'SE-RegUNet 16GF', 'AngioNet', 'EffUNet++ B5',
110
  'Reg-SA-UNet++', 'UNet3+'], label='Model', info='Which model to infer?')
111
- model_rescale = gr.Dropdown(['2x2', '4x4', '8x8', '16x16'], label='Rescale', info='How many batches?')
112
  source_image_loader = gr.Button("Vessel Segment")
113
  with gr.Column():
114
  time_spent = gr.Label(label="Time Spent (Preprocessing + Inference)")
 
21
  img = cv2.cvtColor(img, cv2.COLOR_BGR2GRAY)
22
  h, w = img.shape
23
  img_out = preprocessing(img, model)
24
+ return img_out, h, w, ori
25
 
26
  def preprocessing(img, model='SE-RegUNet 4GF'):
27
  # print(img.shape, img.dtype)
 
46
  img_out = np.stack((image1,)*3, axis=0)
47
  return img_out
48
 
49
+ def inference(pipe, img, model):
50
+ with torch.no_grad():
51
+ if model == 'AngioNet':
52
+ img = torch.cat([img, img], dim=0)
53
+ logit = np.round(torch.softmax(pipe.forward(img), dim=1).detach().cpu().numpy()[0, 0]).astype(np.uint8)
54
+ return logit
55
+
56
+ def process_input_image(img, model, scale):
57
  ori_img = img.copy()
58
  h, w, _ = ori_img.shape
59
  pad_h = h % 32
 
79
  pipe = torch.jit.load('./model/UNet3plus.pt')
80
  pipe = pipe.to(device).eval()
81
 
82
+ scale = int(rescale.split('x')[0])
83
+
84
  start = time.time()
85
+ if scale == 1:
86
+ img, h, w, ori = filesort(img, model)
87
+ img = torch.FloatTensor(img).unsqueeze(0).to(device)
88
+ logit = inference(pipe, img, model)
89
+ else:
90
+ len_h, len_w = img.shape[0] // scale, img.shape[1] // scale
91
+ logit = np.zeros(img.shape, np.float32)
92
+ for x in range(scale):
93
+ for y in range(scale):
94
+ temp_img, _, _, _ = filesort(img[len_h * x : len_h * (x + 1), len_w * y : len_w * (y + 1)])
95
+ temp_img = torch.FloatTensor(temp_img).unsqueeze(0).to(device)
96
+ logit[len_h * x : len_h * (x + 1), len_w * y : len_w * (y + 1)] = inference(pipe, temp_img, model)
97
  spent = time.time() - start
98
  spent = f"{spent:.3f} seconds"
99
 
 
123
  img_source = gr.Image(label="Please select angiogram.", value='./example/angio.png', shape=(512, 512))
124
  model_choice = gr.Dropdown(['SE-RegUNet 4GF', 'SE-RegUNet 16GF', 'AngioNet', 'EffUNet++ B5',
125
  'Reg-SA-UNet++', 'UNet3+'], label='Model', info='Which model to infer?')
126
+ model_rescale = gr.Dropdown(['1x1', '2x2', '4x4', '8x8', '16x16'], label='Rescale', info='How many batches?')
127
  source_image_loader = gr.Button("Vessel Segment")
128
  with gr.Column():
129
  time_spent = gr.Label(label="Time Spent (Preprocessing + Inference)")