import gradio as gr import os os.system('pip3 install torch torchvision')# torchaudio') #pip3 install torch -q') import torch import sys import numpy as np from PIL import Image import itertools import glob import random import torch import torchvision import torchvision.transforms as transforms import torch.nn as nn import torch.optim as optim import torch.nn.functional as F from codes import * ## Print samples from torchvision.models import resnet18 #device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu") with_points = True dim = 128 dim0 =224 crop_ratio = dim/dim0 file_savingfolder = './modelbest/' ext = '_bestVal' core_model_tst = resnet18(pretrained=True) core_model_tst.fc = Identity() core_model_tst.load_state_dict(torch.load(file_savingfolder+'core_model'+ext+'.pth', map_location=torch.device('cpu') )) #core_model_tst.to(device) IR_Model_tst = Build_IRmodel_Resnet(core_model_tst, registration_method) IR_Model_tst.load_state_dict(torch.load(file_savingfolder+'IR_Model'+ext+'.pth', map_location=torch.device('cpu'))) #IR_Model_tst.to(device) IR_Model_tst.eval() def imgnt_reg(img1,img2):#, model_selected): #fixed_images = np.empty((1, 128, 128, 3)) #moving_images = np.empty((1, 128, 128, 3)) # prepare inputs: fixed_images = preprocess_image(img1, dim = 128) moving_images = preprocess_image(img2, dim = 128) M_i = torch.normal(torch.zeros([1,2,3]), torch.ones([1,2,3])) model_inputs = {'source':fixed_images, 'target':moving_images, 'M_i' :M_i } #[moving_images/255, fixed_images/255] pred = IR_Model_tst(model_inputs) img_out = wrap_imge_cropped(pred['Affine_mtrx'].detach(), fixed_images, dim1=224, dim2=128) registered_img = torchvision.transforms.ToPILImage()(img_out[0]) if with_points: x0_source, y0_source = generate_standard_elips(N_samples= 100) x_source = destandarize_point(x0_source, dim=dim, flip = False) y_source = destandarize_point(y0_source, dim=dim, flip = False) source_im_w_points = wrap_points(fixed_images.detach(), x_source, y_source, l=1) source_im_w_points = torchvision.transforms.ToPILImage()(source_im_w_points[0]) M_Predicted = workaround_matrix(pred['Affine_mtrx'].detach(), acc = 0.5/crop_ratio) x0_transformed, y0_transformed = transform_standard_points(M_Predicted[0], x0_source, y0_source) x_transformed = destandarize_point(x0_transformed, dim=dim, flip = False) y_transformed = destandarize_point(y0_transformed, dim=dim, flip = False) wrapped_img = wrap_points(img_out.detach(), x_transformed, y_transformed, l=1) img_out2 = wrapped_img marked_image = torchvision.transforms.ToPILImage()(img_out2[0]) else: source_im_w_points = torchvision.transforms.ToPILImage()(torch.zeros(3,128,128)) marked_image = torchvision.transforms.ToPILImage()(torch.zeros(3,128,128)) return [registered_img,source_im_w_points, marked_image] with gr.Blocks() as demo: #gr.Markdown("