Amould's picture
Update app.py
4066298 verified
import gradio as gr
import os
os.system('pip3 install torch torchvision')# torchaudio')
#pip3 install torch -q')
import torch
import sys
import numpy as np
from PIL import Image
import itertools
import glob
import random
import torch
import torchvision
import torchvision.transforms as transforms
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
from codes import *
## Print samples
from torchvision.models import resnet18
#device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")
with_points = True
dim = 128
dim0 =224
crop_ratio = dim/dim0
file_savingfolder = './modelbest/'
ext = '_bestVal'
core_model_tst = resnet18(pretrained=True)
core_model_tst.fc = Identity()
core_model_tst.load_state_dict(torch.load(file_savingfolder+'core_model'+ext+'.pth', map_location=torch.device('cpu') ))
#core_model_tst.to(device)
IR_Model_tst = Build_IRmodel_Resnet(core_model_tst, registration_method)
IR_Model_tst.load_state_dict(torch.load(file_savingfolder+'IR_Model'+ext+'.pth', map_location=torch.device('cpu')))
#IR_Model_tst.to(device)
IR_Model_tst.eval()
def imgnt_reg(img1,img2):#, model_selected):
#fixed_images = np.empty((1, 128, 128, 3))
#moving_images = np.empty((1, 128, 128, 3))
# prepare inputs:
fixed_images = preprocess_image(img1, dim = 128)
moving_images = preprocess_image(img2, dim = 128)
M_i = torch.normal(torch.zeros([1,2,3]), torch.ones([1,2,3]))
model_inputs = {'source':fixed_images,
'target':moving_images,
'M_i' :M_i }
#[moving_images/255, fixed_images/255]
pred = IR_Model_tst(model_inputs)
img_out = wrap_imge_cropped(pred['Affine_mtrx'].detach(), fixed_images, dim1=224, dim2=128)
registered_img = torchvision.transforms.ToPILImage()(img_out[0])
if with_points:
x0_source, y0_source = generate_standard_elips(N_samples= 100)
x_source = destandarize_point(x0_source, dim=dim, flip = False)
y_source = destandarize_point(y0_source, dim=dim, flip = False)
source_im_w_points = wrap_points(fixed_images.detach(), x_source, y_source, l=1)
source_im_w_points = torchvision.transforms.ToPILImage()(source_im_w_points[0])
M_Predicted = workaround_matrix(pred['Affine_mtrx'].detach(), acc = 0.5/crop_ratio)
x0_transformed, y0_transformed = transform_standard_points(M_Predicted[0], x0_source, y0_source)
x_transformed = destandarize_point(x0_transformed, dim=dim, flip = False)
y_transformed = destandarize_point(y0_transformed, dim=dim, flip = False)
wrapped_img = wrap_points(img_out.detach(), x_transformed, y_transformed, l=1)
img_out2 = wrapped_img
marked_image = torchvision.transforms.ToPILImage()(img_out2[0])
else:
source_im_w_points = torchvision.transforms.ToPILImage()(torch.zeros(3,128,128))
marked_image = torchvision.transforms.ToPILImage()(torch.zeros(3,128,128))
return [registered_img,source_im_w_points, marked_image]
with gr.Blocks() as demo:
#gr.Markdown("<h1><center> Message Encryption</center></h1>")
#gr.Markdown("<center> Encrypt your message and let your friends decrypt it on the same day.</center>")
image_1 = gr.Image(
label = "Fixed Image",height = 224, width=180,
#source = "upload",
type = "filepath",
elem_id = "image-in",
)
image_2 = gr.Image(
label = "Moving Image",height = 224, width=180,
#source = "upload",
type = "filepath",
elem_id = "image-in",
)
'''
model_list = gr.Dropdown(
["Additive_Recurence", "Rawblock"], label="Model", info="select a model"
)'''
out_image1 = gr.Image(label = "ٌRegistered image", height = 224, width=224,
#source = "upload",
#type = "filepath",
elem_id = "image-out"
)
out_image2 = gr.Image(label = "ٌMarked source image",height = 224, width=224,
elem_id = "image-out2"
)
out_image3 = gr.Image(label = "ٌMarked wrapped image",height = 224, width=224,
elem_id = "image-out3"
)
inputs = [image_1, image_2]#, model_list]
out_image = [out_image1, out_image2, out_image3]
iface = gr.Interface(fn=imgnt_reg, inputs=inputs,outputs=out_image,
title="Imagenet registration V2",
description="Upload 2 images to generate a registered one:",
examples=[["./examples/ex5.png","./examples/ex6.png"],["./examples/ex1.jpg","./examples/ex2.jpg"]],
)
demo.queue().launch(share=True)