Amould commited on
Commit
047575b
·
verified ·
1 Parent(s): 4205531

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +3 -3
app.py CHANGED
@@ -59,16 +59,16 @@ def imgnt_reg(img1,img2):#, model_selected):
59
  x_source = destandarize_point(x0_source, dim=dim, flip = False)
60
  y_source = destandarize_point(y0_source, dim=dim, flip = False)
61
  source_im_w_points = wrap_points(fixed_images.detach(), x_source, y_source, l=1)
62
-
63
  M_Predicted = workaround_matrix(pred['Affine_mtrx'].detach(), acc = 0.5/crop_ratio)
64
  x0_transformed, y0_transformed = transform_standard_points(M_Predicted[0], x0_source, y0_source)
65
  x_transformed = destandarize_point(x0_transformed, dim=dim, flip = False)
66
  y_transformed = destandarize_point(y0_transformed, dim=dim, flip = False)
67
  wrapped_img = wrap_points(img_out.detach(), x_transformed, y_transformed, l=1)
68
- img_out2 = wrapped_img.detach()
69
  marked_image = torchvision.transforms.ToPILImage()(img_out2[0])
70
  else:
71
- #source_im_w_points = torchvision.transforms.ToPILImage()(torch.zeros(3,128,128))
72
  marked_image = torchvision.transforms.ToPILImage()(torch.zeros(3,128,128))
73
  return [registered_img,source_im_w_points, marked_image]
74
 
 
59
  x_source = destandarize_point(x0_source, dim=dim, flip = False)
60
  y_source = destandarize_point(y0_source, dim=dim, flip = False)
61
  source_im_w_points = wrap_points(fixed_images.detach(), x_source, y_source, l=1)
62
+ source_im_w_points = torchvision.transforms.ToPILImage()(source_im_w_points)
63
  M_Predicted = workaround_matrix(pred['Affine_mtrx'].detach(), acc = 0.5/crop_ratio)
64
  x0_transformed, y0_transformed = transform_standard_points(M_Predicted[0], x0_source, y0_source)
65
  x_transformed = destandarize_point(x0_transformed, dim=dim, flip = False)
66
  y_transformed = destandarize_point(y0_transformed, dim=dim, flip = False)
67
  wrapped_img = wrap_points(img_out.detach(), x_transformed, y_transformed, l=1)
68
+ img_out2 = wrapped_img
69
  marked_image = torchvision.transforms.ToPILImage()(img_out2[0])
70
  else:
71
+ source_im_w_points = torchvision.transforms.ToPILImage()(torch.zeros(3,128,128))
72
  marked_image = torchvision.transforms.ToPILImage()(torch.zeros(3,128,128))
73
  return [registered_img,source_im_w_points, marked_image]
74