import gradio as gr import os os.system('pip3 install torch torchvision')# torchaudio') #pip3 install torch -q') import torch import sys import numpy as np from PIL import Image import itertools import glob import random import torch import torchvision import torchvision.transforms as transforms import torch.nn as nn import torch.optim as optim import torch.nn.functional as F from codes import * ## Print samples from torchvision.models import resnet18 #device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu") with_points = True dim = 128 dim0 =224 crop_ratio = dim/dim0 file_savingfolder = './modelbest/' ext = '_bestVal' core_model_tst = resnet18(pretrained=True) core_model_tst.fc = Identity() core_model_tst.load_state_dict(torch.load(file_savingfolder+'core_model'+ext+'.pth', map_location=torch.device('cpu') )) #core_model_tst.to(device) IR_Model_tst = Build_IRmodel_Resnet(core_model_tst, registration_method) IR_Model_tst.load_state_dict(torch.load(file_savingfolder+'IR_Model'+ext+'.pth', map_location=torch.device('cpu'))) #IR_Model_tst.to(device) IR_Model_tst.eval() def imgnt_reg(img1,img2):#, model_selected): #fixed_images = np.empty((1, 128, 128, 3)) #moving_images = np.empty((1, 128, 128, 3)) # prepare inputs: fixed_images = preprocess_image(img1, dim = 128) moving_images = preprocess_image(img2, dim = 128) M_i = torch.normal(torch.zeros([1,2,3]), torch.ones([1,2,3])) model_inputs = {'source':fixed_images, 'target':moving_images, 'M_i' :M_i } #[moving_images/255, fixed_images/255] pred = IR_Model_tst(model_inputs) img_out = wrap_imge_cropped(pred['Affine_mtrx'].detach(), fixed_images, dim1=224, dim2=128) registered_img = torchvision.transforms.ToPILImage()(img_out[0]) if with_points: x0_source, y0_source = generate_standard_elips(N_samples= 100) x_source = destandarize_point(x0_source, dim=dim, flip = False) y_source = destandarize_point(y0_source, dim=dim, flip = False) source_im_w_points = wrap_points(fixed_images.detach(), x_source, y_source, l=1) source_im_w_points = torchvision.transforms.ToPILImage()(source_im_w_points[0]) M_Predicted = workaround_matrix(pred['Affine_mtrx'].detach(), acc = 0.5/crop_ratio) x0_transformed, y0_transformed = transform_standard_points(M_Predicted[0], x0_source, y0_source) x_transformed = destandarize_point(x0_transformed, dim=dim, flip = False) y_transformed = destandarize_point(y0_transformed, dim=dim, flip = False) wrapped_img = wrap_points(img_out.detach(), x_transformed, y_transformed, l=1) img_out2 = wrapped_img marked_image = torchvision.transforms.ToPILImage()(img_out2[0]) else: source_im_w_points = torchvision.transforms.ToPILImage()(torch.zeros(3,128,128)) marked_image = torchvision.transforms.ToPILImage()(torch.zeros(3,128,128)) return [registered_img,source_im_w_points, marked_image] with gr.Blocks() as demo: #gr.Markdown("

Message Encryption

") #gr.Markdown("
Encrypt your message and let your friends decrypt it on the same day.
") image_1 = gr.Image( label = "Fixed Image",height = 224, width=180, #source = "upload", type = "filepath", elem_id = "image-in", ) image_2 = gr.Image( label = "Moving Image",height = 224, width=180, #source = "upload", type = "filepath", elem_id = "image-in", ) ''' model_list = gr.Dropdown( ["Additive_Recurence", "Rawblock"], label="Model", info="select a model" )''' out_image1 = gr.Image(label = "ٌRegistered image", height = 224, width=224, #source = "upload", #type = "filepath", elem_id = "image-out" ) out_image2 = gr.Image(label = "ٌMarked source image",height = 224, width=224, elem_id = "image-out2" ) out_image3 = gr.Image(label = "ٌMarked wrapped image",height = 224, width=224, elem_id = "image-out3" ) inputs = [image_1, image_2]#, model_list] out_image = [out_image1, out_image2, out_image3] iface = gr.Interface(fn=imgnt_reg, inputs=inputs,outputs=out_image, title="Imagenet registration V2", description="Upload 2 images to generate a registered one:", examples=[["./examples/ex5.png","./examples/ex6.png"],["./examples/ex1.jpg","./examples/ex2.jpg"]], ) demo.queue().launch(share=True)