Spaces:
Runtime error
Runtime error
| import gradio as gr | |
| import os | |
| os.system('pip3 install torch torchvision')# torchaudio') | |
| #pip3 install torch -q') | |
| import sys | |
| import numpy as np | |
| import random | |
| import matplotlib.pyplot as plt | |
| import tqdm | |
| import torch | |
| from codes import * | |
| #device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu") | |
| x_t,y_t = generate_standard_elips(N_samples = 30, a= 0.8,b = 1.5) | |
| pointcloud_target = To_pointcloud(x_t,y_t) | |
| dim = 1.5 | |
| RefRotation, Translation = random_rigid_transformation(dim=dim) | |
| pointcloud_source = np.matmul(RefRotation, pointcloud_target.T).T + Translation | |
| x_s, y_s, z_s = To_xyz(pointcloud_source) | |
| PC1_mean = np.mean(pointcloud_source, axis=0) | |
| pointcloud_source_norm = pointcloud_source - PC1_mean | |
| PC2_mean = np.mean(pointcloud_target, axis=0) | |
| pointcloud_target_norm = pointcloud_target - PC2_mean | |
| x_sn, y_sn, z_sn = To_xyz(pointcloud_source_norm) | |
| x_tn, y_tn, z_tn = To_xyz(pointcloud_target_norm) | |
| pointcloud_source_norm_torch = torch.tensor(pointcloud_source_norm, requires_grad=False).to(torch.float32) | |
| pointcloud_target_norm_torch = torch.tensor(pointcloud_target_norm, requires_grad=False).to(torch.float32) | |
| def imgnt_reg(img1,img2): | |
| fixed_images = preprocess_image(img1, dim = 128) | |
| moving_images = preprocess_image(img2, dim = 128) | |
| M_i = torch.normal(torch.zeros([1,2,3]), torch.ones([1,2,3])) | |
| model_inputs = {'source':fixed_images, | |
| 'target':moving_images, | |
| 'M_i' :M_i } | |
| #[moving_images/255, fixed_images/255] | |
| pred = IR_Model_tst(model_inputs) | |
| img_out = wrap_imge_cropped(pred['Affine_mtrx'].detach(), fixed_images, dim1=224, dim2=128) | |
| registered_img = torchvision.transforms.ToPILImage()(img_out[0]) | |
| if with_points: | |
| x0_source, y0_source = generate_standard_elips(N_samples= 100) | |
| x_source = destandarize_point(x0_source, dim=dim, flip = False) | |
| y_source = destandarize_point(y0_source, dim=dim, flip = False) | |
| source_im_w_points = wrap_points(fixed_images.detach(), x_source, y_source, l=1) | |
| source_im_w_points = torchvision.transforms.ToPILImage()(source_im_w_points[0]) | |
| M_Predicted = workaround_matrix(pred['Affine_mtrx'].detach(), acc = 0.5/crop_ratio) | |
| x0_transformed, y0_transformed = transform_standard_points(M_Predicted[0], x0_source, y0_source) | |
| x_transformed = destandarize_point(x0_transformed, dim=dim, flip = False) | |
| y_transformed = destandarize_point(y0_transformed, dim=dim, flip = False) | |
| wrapped_img = wrap_points(img_out.detach(), x_transformed, y_transformed, l=1) | |
| img_out2 = wrapped_img | |
| marked_image = torchvision.transforms.ToPILImage()(img_out2[0]) | |
| else: | |
| source_im_w_points = torchvision.transforms.ToPILImage()(torch.zeros(3,128,128)) | |
| marked_image = torchvision.transforms.ToPILImage()(torch.zeros(3,128,128)) | |
| return [registered_img,source_im_w_points, marked_image] | |
| with gr.Blocks() as demo: | |
| #gr.Markdown("<h1><center> Message Encryption</center></h1>") | |
| #gr.Markdown("<center> Encrypt your message and let your friends decrypt it on the same day.</center>") | |
| image_1 = gr.Image( | |
| label = "Fixed Image",height = 224, width=180, | |
| #source = "upload", | |
| type = "filepath", | |
| elem_id = "image-in", | |
| ) | |
| image_2 = gr.Image( | |
| label = "Moving Image",height = 224, width=180, | |
| #source = "upload", | |
| type = "filepath", | |
| elem_id = "image-in", | |
| ) | |
| ''' | |
| model_list = gr.Dropdown( | |
| ["Additive_Recurence", "Rawblock"], label="Model", info="select a model" | |
| )''' | |
| out_image1 = gr.Image(label = "ٌRegistered image", height = 224, width=224, | |
| #source = "upload", | |
| #type = "filepath", | |
| elem_id = "image-out" | |
| ) | |
| out_image2 = gr.Image(label = "ٌMarked source image",height = 224, width=224, | |
| elem_id = "image-out2" | |
| ) | |
| out_image3 = gr.Image(label = "ٌMarked wrapped image",height = 224, width=224, | |
| elem_id = "image-out3" | |
| ) | |
| inputs = [image_1, image_2]#, model_list] | |
| out_image = [out_image1, out_image2, out_image3] | |
| iface = gr.Interface(fn=imgnt_reg, inputs=inputs,outputs=out_image, | |
| title="Imagenet registration V2", | |
| description="Upload 2 images to generate a registered one:", | |
| examples=[["./examples/ex5.png","./examples/ex6.png"],["./examples/ex1.jpg","./examples/ex2.jpg"]], | |
| ) | |
| demo.queue().launch(share=True) | |