Spaces:
Runtime error
Runtime error
Update app.py
Browse files
app.py
CHANGED
|
@@ -24,7 +24,7 @@ from torchvision.models import resnet18
|
|
| 24 |
|
| 25 |
|
| 26 |
#device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")
|
| 27 |
-
|
| 28 |
file_savingfolder = './modelbest/'
|
| 29 |
ext = '_bestVal'
|
| 30 |
core_model_tst = resnet18(pretrained=True)
|
|
@@ -37,6 +37,49 @@ IR_Model_tst.load_state_dict(torch.load(file_savingfolder+'IR_Model'+ext+'.pth',
|
|
| 37 |
IR_Model_tst.eval()
|
| 38 |
|
| 39 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 40 |
def wrap_imge_cropped(Affine_mtrx, source_img, dim1=224, dim2=128):
|
| 41 |
source_img224 = torch.nn.ZeroPad2d(int((dim1-dim2)/2))(source_img)
|
| 42 |
grd = torch.nn.functional.affine_grid(Affine_mtrx, size=source_img224.shape,align_corners=False)
|
|
@@ -45,7 +88,7 @@ def wrap_imge_cropped(Affine_mtrx, source_img, dim1=224, dim2=128):
|
|
| 45 |
wrapped_img = torchvision.transforms.CenterCrop((dim2, dim2))(wrapped_img)
|
| 46 |
return wrapped_img
|
| 47 |
|
| 48 |
-
def imgnt_reg(img1,img2
|
| 49 |
#fixed_images = np.empty((1, 128, 128, 3))
|
| 50 |
#moving_images = np.empty((1, 128, 128, 3))
|
| 51 |
# prepare inputs:
|
|
@@ -59,7 +102,21 @@ def imgnt_reg(img1,img2, model_selected):
|
|
| 59 |
pred = IR_Model_tst(model_inputs)
|
| 60 |
img_out = wrap_imge_cropped(pred['Affine_mtrx'], fixed_images, dim1=224, dim2=128)
|
| 61 |
registered_img = torchvision.transforms.ToPILImage()(img_out[0])
|
| 62 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 63 |
|
| 64 |
|
| 65 |
|
|
@@ -79,21 +136,25 @@ with gr.Blocks() as demo:
|
|
| 79 |
type = "filepath",
|
| 80 |
elem_id = "image-in",
|
| 81 |
)
|
| 82 |
-
|
| 83 |
model_list = gr.Dropdown(
|
| 84 |
["Additive_Recurence", "Rawblock"], label="Model", info="select a model"
|
| 85 |
-
)
|
| 86 |
|
| 87 |
-
|
| 88 |
#source = "upload",
|
| 89 |
#type = "filepath",
|
| 90 |
elem_id = "image-out"
|
| 91 |
)
|
| 92 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
| 93 |
iface = gr.Interface(fn=imgnt_reg, inputs=inputs,outputs=out_image,
|
| 94 |
title="Imagenet registration V2",
|
| 95 |
description="Upload 2 images to generate a registered one:",
|
| 96 |
examples=[["./examples/ex1.jpg","./examples/ex2.jpg"]],
|
| 97 |
)
|
| 98 |
|
| 99 |
-
demo.queue().launch()
|
|
|
|
| 24 |
|
| 25 |
|
| 26 |
#device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")
|
| 27 |
+
with_points = True
|
| 28 |
file_savingfolder = './modelbest/'
|
| 29 |
ext = '_bestVal'
|
| 30 |
core_model_tst = resnet18(pretrained=True)
|
|
|
|
| 37 |
IR_Model_tst.eval()
|
| 38 |
|
| 39 |
|
| 40 |
+
|
| 41 |
+
def standarize_point(d, dim=128, flip = False):
|
| 42 |
+
if flip:
|
| 43 |
+
d = -d
|
| 44 |
+
return d/dim - 0.5
|
| 45 |
+
|
| 46 |
+
def destandarize_point(d, dim=128, flip = False):
|
| 47 |
+
if flip:
|
| 48 |
+
d = -d
|
| 49 |
+
return dim*(d + 0.5)
|
| 50 |
+
|
| 51 |
+
def generate_standard_elips(N_samples = 100, a= 1,b = 1):
|
| 52 |
+
radius = 0.25
|
| 53 |
+
center = 0
|
| 54 |
+
N_samples1 = int(N_samples/2 - 1)
|
| 55 |
+
N_samples2 = N_samples - N_samples1
|
| 56 |
+
x1 = torch.distributions.uniform.Uniform(center-radius,center + radius).sample([N_samples1])
|
| 57 |
+
x1_ordered = torch.sort(x1).values
|
| 58 |
+
y1 = center + b*torch.sqrt(radius**2 - ((x1_ordered-center)/a)**2)
|
| 59 |
+
x2 = torch.distributions.uniform.Uniform(center-radius,center + radius).sample([N_samples2])
|
| 60 |
+
x2_ordered = torch.sort(x2, descending=True).values
|
| 61 |
+
y2 = center - b*torch.sqrt(radius**2 - ((x2_ordered-center)/a)**2)
|
| 62 |
+
x = torch.cat([x1_ordered, x2_ordered])
|
| 63 |
+
y = torch.cat([y1, y2])
|
| 64 |
+
return x, y
|
| 65 |
+
|
| 66 |
+
def transform_standard_points(Affine_mat, x,y):
|
| 67 |
+
XY = torch.ones([3,x.shape[0]])
|
| 68 |
+
XY[0,:]= x
|
| 69 |
+
XY[1,:]= y
|
| 70 |
+
XYt = torch.matmul(Affine_mat.to('cpu').detach(),XY)
|
| 71 |
+
xt0 = XYt[0]
|
| 72 |
+
yt0 = XYt[1]
|
| 73 |
+
return xt0, yt0
|
| 74 |
+
|
| 75 |
+
def wrap_points(img, x_source, y_source, l=1):
|
| 76 |
+
for i in range(len(y_source)):
|
| 77 |
+
x0 = x_source[i].int()
|
| 78 |
+
y0 = y_source[i].int()
|
| 79 |
+
img[:,:,x0-l:x0+l,y0-l:y0+l] = 0
|
| 80 |
+
return img
|
| 81 |
+
|
| 82 |
+
|
| 83 |
def wrap_imge_cropped(Affine_mtrx, source_img, dim1=224, dim2=128):
|
| 84 |
source_img224 = torch.nn.ZeroPad2d(int((dim1-dim2)/2))(source_img)
|
| 85 |
grd = torch.nn.functional.affine_grid(Affine_mtrx, size=source_img224.shape,align_corners=False)
|
|
|
|
| 88 |
wrapped_img = torchvision.transforms.CenterCrop((dim2, dim2))(wrapped_img)
|
| 89 |
return wrapped_img
|
| 90 |
|
| 91 |
+
def imgnt_reg(img1,img2):#, model_selected):
|
| 92 |
#fixed_images = np.empty((1, 128, 128, 3))
|
| 93 |
#moving_images = np.empty((1, 128, 128, 3))
|
| 94 |
# prepare inputs:
|
|
|
|
| 102 |
pred = IR_Model_tst(model_inputs)
|
| 103 |
img_out = wrap_imge_cropped(pred['Affine_mtrx'], fixed_images, dim1=224, dim2=128)
|
| 104 |
registered_img = torchvision.transforms.ToPILImage()(img_out[0])
|
| 105 |
+
if with_points:
|
| 106 |
+
x0_source, y0_source = generate_standard_elips(N_samples= 100)
|
| 107 |
+
x_source = destandarize_point(x0_source, dim=dim, flip = False)
|
| 108 |
+
y_source = destandarize_point(y0_source, dim=dim, flip = False)
|
| 109 |
+
source_im_w_points = wrap_points(fixed_images, x_source, y_source, l=1)
|
| 110 |
+
M_Predicted = workaround_matrix(pred['Affine_mtrx'].detach(), acc = 0.5/crop_ratio)
|
| 111 |
+
x0_transformed, y0_transformed = transform_standard_points(M_Predicted[0], x0_source, y0_source)
|
| 112 |
+
x_transformed = destandarize_point(x0_transformed, dim=dim, flip = False)
|
| 113 |
+
y_transformed = destandarize_point(y0_transformed, dim=dim, flip = False)
|
| 114 |
+
wrapped_img = wrap_points(img_out, x_transformed, y_transformed, l=1)
|
| 115 |
+
img_out2 = wrapped_img
|
| 116 |
+
marked_image = torchvision.transforms.ToPILImage()(img_out2[0])
|
| 117 |
+
else:
|
| 118 |
+
marked_image = torchvision.transforms.ToPILImage()(torch.zeros(3,10,10))
|
| 119 |
+
return [registered_img, marked_image]
|
| 120 |
|
| 121 |
|
| 122 |
|
|
|
|
| 136 |
type = "filepath",
|
| 137 |
elem_id = "image-in",
|
| 138 |
)
|
| 139 |
+
'''
|
| 140 |
model_list = gr.Dropdown(
|
| 141 |
["Additive_Recurence", "Rawblock"], label="Model", info="select a model"
|
| 142 |
+
)'''
|
| 143 |
|
| 144 |
+
out_image1 = gr.Image(label = "ٌRegistered image",
|
| 145 |
#source = "upload",
|
| 146 |
#type = "filepath",
|
| 147 |
elem_id = "image-out"
|
| 148 |
)
|
| 149 |
+
out_image2 = gr.Image(label = "ٌMarked image",
|
| 150 |
+
elem_id = "image-out"
|
| 151 |
+
)
|
| 152 |
+
inputs = [image_1, image_2]#, model_list]
|
| 153 |
+
out_image = [out_image1, out_image2]
|
| 154 |
iface = gr.Interface(fn=imgnt_reg, inputs=inputs,outputs=out_image,
|
| 155 |
title="Imagenet registration V2",
|
| 156 |
description="Upload 2 images to generate a registered one:",
|
| 157 |
examples=[["./examples/ex1.jpg","./examples/ex2.jpg"]],
|
| 158 |
)
|
| 159 |
|
| 160 |
+
demo.queue().launch(share=True)
|