Amould commited on
Commit
7222107
·
verified ·
1 Parent(s): 80dd9dd

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +69 -8
app.py CHANGED
@@ -24,7 +24,7 @@ from torchvision.models import resnet18
24
 
25
 
26
  #device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")
27
-
28
  file_savingfolder = './modelbest/'
29
  ext = '_bestVal'
30
  core_model_tst = resnet18(pretrained=True)
@@ -37,6 +37,49 @@ IR_Model_tst.load_state_dict(torch.load(file_savingfolder+'IR_Model'+ext+'.pth',
37
  IR_Model_tst.eval()
38
 
39
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
40
  def wrap_imge_cropped(Affine_mtrx, source_img, dim1=224, dim2=128):
41
  source_img224 = torch.nn.ZeroPad2d(int((dim1-dim2)/2))(source_img)
42
  grd = torch.nn.functional.affine_grid(Affine_mtrx, size=source_img224.shape,align_corners=False)
@@ -45,7 +88,7 @@ def wrap_imge_cropped(Affine_mtrx, source_img, dim1=224, dim2=128):
45
  wrapped_img = torchvision.transforms.CenterCrop((dim2, dim2))(wrapped_img)
46
  return wrapped_img
47
 
48
- def imgnt_reg(img1,img2, model_selected):
49
  #fixed_images = np.empty((1, 128, 128, 3))
50
  #moving_images = np.empty((1, 128, 128, 3))
51
  # prepare inputs:
@@ -59,7 +102,21 @@ def imgnt_reg(img1,img2, model_selected):
59
  pred = IR_Model_tst(model_inputs)
60
  img_out = wrap_imge_cropped(pred['Affine_mtrx'], fixed_images, dim1=224, dim2=128)
61
  registered_img = torchvision.transforms.ToPILImage()(img_out[0])
62
- return registered_img
 
 
 
 
 
 
 
 
 
 
 
 
 
 
63
 
64
 
65
 
@@ -79,21 +136,25 @@ with gr.Blocks() as demo:
79
  type = "filepath",
80
  elem_id = "image-in",
81
  )
82
-
83
  model_list = gr.Dropdown(
84
  ["Additive_Recurence", "Rawblock"], label="Model", info="select a model"
85
- )
86
 
87
- out_image = gr.Image(label = "ٌRegistered image",
88
  #source = "upload",
89
  #type = "filepath",
90
  elem_id = "image-out"
91
  )
92
- inputs = [image_1, image_2, model_list]
 
 
 
 
93
  iface = gr.Interface(fn=imgnt_reg, inputs=inputs,outputs=out_image,
94
  title="Imagenet registration V2",
95
  description="Upload 2 images to generate a registered one:",
96
  examples=[["./examples/ex1.jpg","./examples/ex2.jpg"]],
97
  )
98
 
99
- demo.queue().launch()
 
24
 
25
 
26
  #device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")
27
+ with_points = True
28
  file_savingfolder = './modelbest/'
29
  ext = '_bestVal'
30
  core_model_tst = resnet18(pretrained=True)
 
37
  IR_Model_tst.eval()
38
 
39
 
40
+
41
+ def standarize_point(d, dim=128, flip = False):
42
+ if flip:
43
+ d = -d
44
+ return d/dim - 0.5
45
+
46
+ def destandarize_point(d, dim=128, flip = False):
47
+ if flip:
48
+ d = -d
49
+ return dim*(d + 0.5)
50
+
51
+ def generate_standard_elips(N_samples = 100, a= 1,b = 1):
52
+ radius = 0.25
53
+ center = 0
54
+ N_samples1 = int(N_samples/2 - 1)
55
+ N_samples2 = N_samples - N_samples1
56
+ x1 = torch.distributions.uniform.Uniform(center-radius,center + radius).sample([N_samples1])
57
+ x1_ordered = torch.sort(x1).values
58
+ y1 = center + b*torch.sqrt(radius**2 - ((x1_ordered-center)/a)**2)
59
+ x2 = torch.distributions.uniform.Uniform(center-radius,center + radius).sample([N_samples2])
60
+ x2_ordered = torch.sort(x2, descending=True).values
61
+ y2 = center - b*torch.sqrt(radius**2 - ((x2_ordered-center)/a)**2)
62
+ x = torch.cat([x1_ordered, x2_ordered])
63
+ y = torch.cat([y1, y2])
64
+ return x, y
65
+
66
+ def transform_standard_points(Affine_mat, x,y):
67
+ XY = torch.ones([3,x.shape[0]])
68
+ XY[0,:]= x
69
+ XY[1,:]= y
70
+ XYt = torch.matmul(Affine_mat.to('cpu').detach(),XY)
71
+ xt0 = XYt[0]
72
+ yt0 = XYt[1]
73
+ return xt0, yt0
74
+
75
+ def wrap_points(img, x_source, y_source, l=1):
76
+ for i in range(len(y_source)):
77
+ x0 = x_source[i].int()
78
+ y0 = y_source[i].int()
79
+ img[:,:,x0-l:x0+l,y0-l:y0+l] = 0
80
+ return img
81
+
82
+
83
  def wrap_imge_cropped(Affine_mtrx, source_img, dim1=224, dim2=128):
84
  source_img224 = torch.nn.ZeroPad2d(int((dim1-dim2)/2))(source_img)
85
  grd = torch.nn.functional.affine_grid(Affine_mtrx, size=source_img224.shape,align_corners=False)
 
88
  wrapped_img = torchvision.transforms.CenterCrop((dim2, dim2))(wrapped_img)
89
  return wrapped_img
90
 
91
+ def imgnt_reg(img1,img2):#, model_selected):
92
  #fixed_images = np.empty((1, 128, 128, 3))
93
  #moving_images = np.empty((1, 128, 128, 3))
94
  # prepare inputs:
 
102
  pred = IR_Model_tst(model_inputs)
103
  img_out = wrap_imge_cropped(pred['Affine_mtrx'], fixed_images, dim1=224, dim2=128)
104
  registered_img = torchvision.transforms.ToPILImage()(img_out[0])
105
+ if with_points:
106
+ x0_source, y0_source = generate_standard_elips(N_samples= 100)
107
+ x_source = destandarize_point(x0_source, dim=dim, flip = False)
108
+ y_source = destandarize_point(y0_source, dim=dim, flip = False)
109
+ source_im_w_points = wrap_points(fixed_images, x_source, y_source, l=1)
110
+ M_Predicted = workaround_matrix(pred['Affine_mtrx'].detach(), acc = 0.5/crop_ratio)
111
+ x0_transformed, y0_transformed = transform_standard_points(M_Predicted[0], x0_source, y0_source)
112
+ x_transformed = destandarize_point(x0_transformed, dim=dim, flip = False)
113
+ y_transformed = destandarize_point(y0_transformed, dim=dim, flip = False)
114
+ wrapped_img = wrap_points(img_out, x_transformed, y_transformed, l=1)
115
+ img_out2 = wrapped_img
116
+ marked_image = torchvision.transforms.ToPILImage()(img_out2[0])
117
+ else:
118
+ marked_image = torchvision.transforms.ToPILImage()(torch.zeros(3,10,10))
119
+ return [registered_img, marked_image]
120
 
121
 
122
 
 
136
  type = "filepath",
137
  elem_id = "image-in",
138
  )
139
+ '''
140
  model_list = gr.Dropdown(
141
  ["Additive_Recurence", "Rawblock"], label="Model", info="select a model"
142
+ )'''
143
 
144
+ out_image1 = gr.Image(label = "ٌRegistered image",
145
  #source = "upload",
146
  #type = "filepath",
147
  elem_id = "image-out"
148
  )
149
+ out_image2 = gr.Image(label = "ٌMarked image",
150
+ elem_id = "image-out"
151
+ )
152
+ inputs = [image_1, image_2]#, model_list]
153
+ out_image = [out_image1, out_image2]
154
  iface = gr.Interface(fn=imgnt_reg, inputs=inputs,outputs=out_image,
155
  title="Imagenet registration V2",
156
  description="Upload 2 images to generate a registered one:",
157
  examples=[["./examples/ex1.jpg","./examples/ex2.jpg"]],
158
  )
159
 
160
+ demo.queue().launch(share=True)