import gradio as gr import os os.system('pip3 install torch torchvision')# torchaudio') #pip3 install torch -q') import sys import numpy as np import random import matplotlib.pyplot as plt import tqdm import torch from codes import * #device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu") x_t,y_t = generate_standard_elips(N_samples = 30, a= 0.8,b = 1.5) pointcloud_target = To_pointcloud(x_t,y_t) dim = 1.5 RefRotation, Translation = random_rigid_transformation(dim=dim) pointcloud_source = np.matmul(RefRotation, pointcloud_target.T).T + Translation x_s, y_s, z_s = To_xyz(pointcloud_source) PC1_mean = np.mean(pointcloud_source, axis=0) pointcloud_source_norm = pointcloud_source - PC1_mean PC2_mean = np.mean(pointcloud_target, axis=0) pointcloud_target_norm = pointcloud_target - PC2_mean x_sn, y_sn, z_sn = To_xyz(pointcloud_source_norm) x_tn, y_tn, z_tn = To_xyz(pointcloud_target_norm) pointcloud_source_norm_torch = torch.tensor(pointcloud_source_norm, requires_grad=False).to(torch.float32) pointcloud_target_norm_torch = torch.tensor(pointcloud_target_norm, requires_grad=False).to(torch.float32) def imgnt_reg(img1,img2): fixed_images = preprocess_image(img1, dim = 128) moving_images = preprocess_image(img2, dim = 128) M_i = torch.normal(torch.zeros([1,2,3]), torch.ones([1,2,3])) model_inputs = {'source':fixed_images, 'target':moving_images, 'M_i' :M_i } #[moving_images/255, fixed_images/255] pred = IR_Model_tst(model_inputs) img_out = wrap_imge_cropped(pred['Affine_mtrx'].detach(), fixed_images, dim1=224, dim2=128) registered_img = torchvision.transforms.ToPILImage()(img_out[0]) if with_points: x0_source, y0_source = generate_standard_elips(N_samples= 100) x_source = destandarize_point(x0_source, dim=dim, flip = False) y_source = destandarize_point(y0_source, dim=dim, flip = False) source_im_w_points = wrap_points(fixed_images.detach(), x_source, y_source, l=1) source_im_w_points = torchvision.transforms.ToPILImage()(source_im_w_points[0]) M_Predicted = workaround_matrix(pred['Affine_mtrx'].detach(), acc = 0.5/crop_ratio) x0_transformed, y0_transformed = transform_standard_points(M_Predicted[0], x0_source, y0_source) x_transformed = destandarize_point(x0_transformed, dim=dim, flip = False) y_transformed = destandarize_point(y0_transformed, dim=dim, flip = False) wrapped_img = wrap_points(img_out.detach(), x_transformed, y_transformed, l=1) img_out2 = wrapped_img marked_image = torchvision.transforms.ToPILImage()(img_out2[0]) else: source_im_w_points = torchvision.transforms.ToPILImage()(torch.zeros(3,128,128)) marked_image = torchvision.transforms.ToPILImage()(torch.zeros(3,128,128)) return [registered_img,source_im_w_points, marked_image] with gr.Blocks() as demo: #gr.Markdown("