File size: 4,633 Bytes
b75b26b
 
e03c91e
e103b0c
 
079c64d
b75b26b
 
 
545e895
 
 
 
 
 
 
 
 
b75b26b
 
 
 
545e895
b75b26b
 
d9c618d
7222107
1c553d4
 
 
 
545e895
 
 
 
6224444
d9c618d
545e895
6224444
d9c618d
545e895
 
7222107
545e895
 
b75b26b
545e895
 
b45b4ad
545e895
 
b45b4ad
545e895
 
4205531
545e895
7222107
 
 
 
4205531
4718ba8
7222107
 
 
 
4205531
047575b
7222107
 
047575b
4205531
 
b75b26b
 
 
 
 
 
 
6c73196
cdf9ddd
b75b26b
 
 
 
6c73196
cdf9ddd
b75b26b
 
 
7222107
b75b26b
dd61989
7222107
dd61989
6c73196
b75b26b
 
 
 
6c73196
4205531
 
6c73196
4205531
7222107
 
4205531
b75b26b
545e895
b75b26b
4066298
b75b26b
 
7222107
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
import gradio as gr
import os
os.system('pip3 install torch torchvision')# torchaudio')

#pip3 install torch  -q')
import torch
import sys
import numpy as np
from PIL import Image
import itertools
import glob
import random
import torch
import torchvision
import torchvision.transforms as transforms
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
from codes import *

## Print samples

from torchvision.models import resnet18


#device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")
with_points = True
dim = 128
dim0 =224
crop_ratio = dim/dim0

file_savingfolder = './modelbest/'
ext = '_bestVal'
core_model_tst = resnet18(pretrained=True)
core_model_tst.fc = Identity()
core_model_tst.load_state_dict(torch.load(file_savingfolder+'core_model'+ext+'.pth', map_location=torch.device('cpu') ))
#core_model_tst.to(device)
IR_Model_tst = Build_IRmodel_Resnet(core_model_tst, registration_method)
IR_Model_tst.load_state_dict(torch.load(file_savingfolder+'IR_Model'+ext+'.pth', map_location=torch.device('cpu')))
#IR_Model_tst.to(device)
IR_Model_tst.eval()

def imgnt_reg(img1,img2):#, model_selected):
    #fixed_images = np.empty((1, 128, 128, 3))
    #moving_images = np.empty((1, 128, 128, 3))
    # prepare inputs:
    fixed_images = preprocess_image(img1, dim = 128)
    moving_images = preprocess_image(img2, dim = 128)
    M_i = torch.normal(torch.zeros([1,2,3]), torch.ones([1,2,3]))
    model_inputs = {'source':fixed_images,
                'target':moving_images,
                'M_i' :M_i }
    #[moving_images/255, fixed_images/255]
    pred = IR_Model_tst(model_inputs)
    img_out = wrap_imge_cropped(pred['Affine_mtrx'].detach(), fixed_images, dim1=224, dim2=128)
    registered_img = torchvision.transforms.ToPILImage()(img_out[0])
    if with_points:
        x0_source, y0_source = generate_standard_elips(N_samples= 100)
        x_source = destandarize_point(x0_source, dim=dim, flip = False)
        y_source = destandarize_point(y0_source, dim=dim, flip = False)
        source_im_w_points = wrap_points(fixed_images.detach(), x_source, y_source, l=1)
        source_im_w_points = torchvision.transforms.ToPILImage()(source_im_w_points[0])
        M_Predicted = workaround_matrix(pred['Affine_mtrx'].detach(), acc = 0.5/crop_ratio)
        x0_transformed, y0_transformed = transform_standard_points(M_Predicted[0], x0_source, y0_source)
        x_transformed = destandarize_point(x0_transformed, dim=dim, flip = False)
        y_transformed = destandarize_point(y0_transformed, dim=dim, flip = False)
        wrapped_img = wrap_points(img_out.detach(), x_transformed, y_transformed, l=1)
        img_out2 = wrapped_img
        marked_image = torchvision.transforms.ToPILImage()(img_out2[0])
    else:
        source_im_w_points = torchvision.transforms.ToPILImage()(torch.zeros(3,128,128))
        marked_image = torchvision.transforms.ToPILImage()(torch.zeros(3,128,128))
    return [registered_img,source_im_w_points, marked_image]


with gr.Blocks() as demo:
    #gr.Markdown("<h1><center> Message Encryption</center></h1>")
    #gr.Markdown("<center> Encrypt your message and let your friends decrypt it on the same day.</center>")
    
    image_1 = gr.Image(
            label = "Fixed Image",height = 224, width=180,
            #source = "upload", 
            type = "filepath",
            elem_id = "image-in",
        )
    image_2 = gr.Image(
            label = "Moving Image",height = 224, width=180,
            #source = "upload", 
            type = "filepath",
            elem_id = "image-in",
        )
    '''
    model_list = gr.Dropdown(
            ["Additive_Recurence", "Rawblock"], label="Model", info="select a model"
        )'''
    
    out_image1 = gr.Image(label = "ٌRegistered image", height = 224, width=224,
            #source = "upload", 
            #type = "filepath",
            elem_id = "image-out"
        )
    out_image2 = gr.Image(label = "ٌMarked source image",height = 224, width=224,
            elem_id = "image-out2"
        )
    out_image3 = gr.Image(label = "ٌMarked wrapped image",height = 224, width=224,
            elem_id = "image-out3"
        )
    inputs = [image_1, image_2]#, model_list]
    out_image = [out_image1, out_image2, out_image3]
    iface = gr.Interface(fn=imgnt_reg, inputs=inputs,outputs=out_image,
    title="Imagenet registration V2",
    description="Upload 2 images to generate a registered one:",
    examples=[["./examples/ex5.png","./examples/ex6.png"],["./examples/ex1.jpg","./examples/ex2.jpg"]],
                         )

demo.queue().launch(share=True)