Spaces:
Sleeping
Sleeping
File size: 4,633 Bytes
b75b26b e03c91e e103b0c 079c64d b75b26b 545e895 b75b26b 545e895 b75b26b d9c618d 7222107 1c553d4 545e895 6224444 d9c618d 545e895 6224444 d9c618d 545e895 7222107 545e895 b75b26b 545e895 b45b4ad 545e895 b45b4ad 545e895 4205531 545e895 7222107 4205531 4718ba8 7222107 4205531 047575b 7222107 047575b 4205531 b75b26b 6c73196 cdf9ddd b75b26b 6c73196 cdf9ddd b75b26b 7222107 b75b26b dd61989 7222107 dd61989 6c73196 b75b26b 6c73196 4205531 6c73196 4205531 7222107 4205531 b75b26b 545e895 b75b26b 4066298 b75b26b 7222107 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 |
import gradio as gr
import os
os.system('pip3 install torch torchvision')# torchaudio')
#pip3 install torch -q')
import torch
import sys
import numpy as np
from PIL import Image
import itertools
import glob
import random
import torch
import torchvision
import torchvision.transforms as transforms
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
from codes import *
## Print samples
from torchvision.models import resnet18
#device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")
with_points = True
dim = 128
dim0 =224
crop_ratio = dim/dim0
file_savingfolder = './modelbest/'
ext = '_bestVal'
core_model_tst = resnet18(pretrained=True)
core_model_tst.fc = Identity()
core_model_tst.load_state_dict(torch.load(file_savingfolder+'core_model'+ext+'.pth', map_location=torch.device('cpu') ))
#core_model_tst.to(device)
IR_Model_tst = Build_IRmodel_Resnet(core_model_tst, registration_method)
IR_Model_tst.load_state_dict(torch.load(file_savingfolder+'IR_Model'+ext+'.pth', map_location=torch.device('cpu')))
#IR_Model_tst.to(device)
IR_Model_tst.eval()
def imgnt_reg(img1,img2):#, model_selected):
#fixed_images = np.empty((1, 128, 128, 3))
#moving_images = np.empty((1, 128, 128, 3))
# prepare inputs:
fixed_images = preprocess_image(img1, dim = 128)
moving_images = preprocess_image(img2, dim = 128)
M_i = torch.normal(torch.zeros([1,2,3]), torch.ones([1,2,3]))
model_inputs = {'source':fixed_images,
'target':moving_images,
'M_i' :M_i }
#[moving_images/255, fixed_images/255]
pred = IR_Model_tst(model_inputs)
img_out = wrap_imge_cropped(pred['Affine_mtrx'].detach(), fixed_images, dim1=224, dim2=128)
registered_img = torchvision.transforms.ToPILImage()(img_out[0])
if with_points:
x0_source, y0_source = generate_standard_elips(N_samples= 100)
x_source = destandarize_point(x0_source, dim=dim, flip = False)
y_source = destandarize_point(y0_source, dim=dim, flip = False)
source_im_w_points = wrap_points(fixed_images.detach(), x_source, y_source, l=1)
source_im_w_points = torchvision.transforms.ToPILImage()(source_im_w_points[0])
M_Predicted = workaround_matrix(pred['Affine_mtrx'].detach(), acc = 0.5/crop_ratio)
x0_transformed, y0_transformed = transform_standard_points(M_Predicted[0], x0_source, y0_source)
x_transformed = destandarize_point(x0_transformed, dim=dim, flip = False)
y_transformed = destandarize_point(y0_transformed, dim=dim, flip = False)
wrapped_img = wrap_points(img_out.detach(), x_transformed, y_transformed, l=1)
img_out2 = wrapped_img
marked_image = torchvision.transforms.ToPILImage()(img_out2[0])
else:
source_im_w_points = torchvision.transforms.ToPILImage()(torch.zeros(3,128,128))
marked_image = torchvision.transforms.ToPILImage()(torch.zeros(3,128,128))
return [registered_img,source_im_w_points, marked_image]
with gr.Blocks() as demo:
#gr.Markdown("<h1><center> Message Encryption</center></h1>")
#gr.Markdown("<center> Encrypt your message and let your friends decrypt it on the same day.</center>")
image_1 = gr.Image(
label = "Fixed Image",height = 224, width=180,
#source = "upload",
type = "filepath",
elem_id = "image-in",
)
image_2 = gr.Image(
label = "Moving Image",height = 224, width=180,
#source = "upload",
type = "filepath",
elem_id = "image-in",
)
'''
model_list = gr.Dropdown(
["Additive_Recurence", "Rawblock"], label="Model", info="select a model"
)'''
out_image1 = gr.Image(label = "ٌRegistered image", height = 224, width=224,
#source = "upload",
#type = "filepath",
elem_id = "image-out"
)
out_image2 = gr.Image(label = "ٌMarked source image",height = 224, width=224,
elem_id = "image-out2"
)
out_image3 = gr.Image(label = "ٌMarked wrapped image",height = 224, width=224,
elem_id = "image-out3"
)
inputs = [image_1, image_2]#, model_list]
out_image = [out_image1, out_image2, out_image3]
iface = gr.Interface(fn=imgnt_reg, inputs=inputs,outputs=out_image,
title="Imagenet registration V2",
description="Upload 2 images to generate a registered one:",
examples=[["./examples/ex5.png","./examples/ex6.png"],["./examples/ex1.jpg","./examples/ex2.jpg"]],
)
demo.queue().launch(share=True)
|