Spaces:
Sleeping
Sleeping
| import gradio as gr | |
| import os | |
| os.system('pip3 install torch torchvision')# torchaudio') | |
| #pip3 install torch -q') | |
| import torch | |
| import sys | |
| import numpy as np | |
| from PIL import Image | |
| import itertools | |
| import glob | |
| import random | |
| import torch | |
| import torchvision | |
| import torchvision.transforms as transforms | |
| import torch.nn as nn | |
| import torch.optim as optim | |
| import torch.nn.functional as F | |
| from codes import * | |
| ## Print samples | |
| from torchvision.models import resnet18 | |
| #device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu") | |
| with_points = True | |
| dim = 128 | |
| dim0 =224 | |
| crop_ratio = dim/dim0 | |
| file_savingfolder = './modelbest/' | |
| ext = '_bestVal' | |
| core_model_tst = resnet18(pretrained=True) | |
| core_model_tst.fc = Identity() | |
| core_model_tst.load_state_dict(torch.load(file_savingfolder+'core_model'+ext+'.pth', map_location=torch.device('cpu') )) | |
| #core_model_tst.to(device) | |
| IR_Model_tst = Build_IRmodel_Resnet(core_model_tst, registration_method) | |
| IR_Model_tst.load_state_dict(torch.load(file_savingfolder+'IR_Model'+ext+'.pth', map_location=torch.device('cpu'))) | |
| #IR_Model_tst.to(device) | |
| IR_Model_tst.eval() | |
| def imgnt_reg(img1,img2):#, model_selected): | |
| #fixed_images = np.empty((1, 128, 128, 3)) | |
| #moving_images = np.empty((1, 128, 128, 3)) | |
| # prepare inputs: | |
| fixed_images = preprocess_image(img1, dim = 128) | |
| moving_images = preprocess_image(img2, dim = 128) | |
| M_i = torch.normal(torch.zeros([1,2,3]), torch.ones([1,2,3])) | |
| model_inputs = {'source':fixed_images, | |
| 'target':moving_images, | |
| 'M_i' :M_i } | |
| #[moving_images/255, fixed_images/255] | |
| pred = IR_Model_tst(model_inputs) | |
| img_out = wrap_imge_cropped(pred['Affine_mtrx'].detach(), fixed_images, dim1=224, dim2=128) | |
| registered_img = torchvision.transforms.ToPILImage()(img_out[0]) | |
| if with_points: | |
| x0_source, y0_source = generate_standard_elips(N_samples= 100) | |
| x_source = destandarize_point(x0_source, dim=dim, flip = False) | |
| y_source = destandarize_point(y0_source, dim=dim, flip = False) | |
| source_im_w_points = wrap_points(fixed_images.detach(), x_source, y_source, l=1) | |
| source_im_w_points = torchvision.transforms.ToPILImage()(source_im_w_points[0]) | |
| M_Predicted = workaround_matrix(pred['Affine_mtrx'].detach(), acc = 0.5/crop_ratio) | |
| x0_transformed, y0_transformed = transform_standard_points(M_Predicted[0], x0_source, y0_source) | |
| x_transformed = destandarize_point(x0_transformed, dim=dim, flip = False) | |
| y_transformed = destandarize_point(y0_transformed, dim=dim, flip = False) | |
| wrapped_img = wrap_points(img_out.detach(), x_transformed, y_transformed, l=1) | |
| img_out2 = wrapped_img | |
| marked_image = torchvision.transforms.ToPILImage()(img_out2[0]) | |
| else: | |
| source_im_w_points = torchvision.transforms.ToPILImage()(torch.zeros(3,128,128)) | |
| marked_image = torchvision.transforms.ToPILImage()(torch.zeros(3,128,128)) | |
| return [registered_img,source_im_w_points, marked_image] | |
| with gr.Blocks() as demo: | |
| #gr.Markdown("<h1><center> Message Encryption</center></h1>") | |
| #gr.Markdown("<center> Encrypt your message and let your friends decrypt it on the same day.</center>") | |
| image_1 = gr.Image( | |
| label = "Fixed Image",height = 224, width=180, | |
| #source = "upload", | |
| type = "filepath", | |
| elem_id = "image-in", | |
| ) | |
| image_2 = gr.Image( | |
| label = "Moving Image",height = 224, width=180, | |
| #source = "upload", | |
| type = "filepath", | |
| elem_id = "image-in", | |
| ) | |
| ''' | |
| model_list = gr.Dropdown( | |
| ["Additive_Recurence", "Rawblock"], label="Model", info="select a model" | |
| )''' | |
| out_image1 = gr.Image(label = "ٌRegistered image", height = 224, width=224, | |
| #source = "upload", | |
| #type = "filepath", | |
| elem_id = "image-out" | |
| ) | |
| out_image2 = gr.Image(label = "ٌMarked source image",height = 224, width=224, | |
| elem_id = "image-out2" | |
| ) | |
| out_image3 = gr.Image(label = "ٌMarked wrapped image",height = 224, width=224, | |
| elem_id = "image-out3" | |
| ) | |
| inputs = [image_1, image_2]#, model_list] | |
| out_image = [out_image1, out_image2, out_image3] | |
| iface = gr.Interface(fn=imgnt_reg, inputs=inputs,outputs=out_image, | |
| title="Imagenet registration V2", | |
| description="Upload 2 images to generate a registered one:", | |
| examples=[["./examples/ex5.png","./examples/ex6.png"],["./examples/ex1.jpg","./examples/ex2.jpg"]], | |
| ) | |
| demo.queue().launch(share=True) | |