Amould commited on
Commit
545e895
·
verified ·
1 Parent(s): fffea84

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +45 -22
app.py CHANGED
@@ -1,39 +1,62 @@
1
  import gradio as gr
2
-
3
  import os
4
-
5
  #os.system('pip install torch -q')
6
 
7
  import sys
8
  import numpy as np
9
- import torch
10
  from PIL import Image
11
-
 
 
 
 
 
 
 
 
12
  from codes import *
13
 
14
  ## Print samples
15
- vxm_model_loaded = create_model(dim = 128)
16
- vxm_model_loaded.load_weights('./modelbest/')
17
 
18
- vxm_model_loaded_affine = create_model(dim = 128)
19
- vxm_model_loaded_affine.load_weights('./modelbest/')
20
 
21
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
22
 
23
  def imgnt_reg(img1,img2, model_selected):
24
- fixed_images = np.empty((1, 128, 128, 3))
25
- moving_images = np.empty((1, 128, 128, 3))
26
  # prepare inputs:
27
- fixed_images[0] = preprocess_image(img1, dim = 128)
28
- moving_images[0] = preprocess_image(img2, dim = 128)
29
-
30
- model_inputs = [moving_images/255, fixed_images/255]
31
- if model_selected == "Imagenet-wild":
32
- Y_predict = vxm_model_loaded(model_inputs)
33
- else:
34
- Y_predict = vxm_model_loaded_affine(model_inputs)
35
- img_out = Y_predict[0][0].numpy()*255
36
- registered_img = Image.fromarray(img_out.astype(np.uint8))
37
  return registered_img
38
 
39
 
@@ -66,9 +89,9 @@ with gr.Blocks() as demo:
66
  )
67
  inputs = [image_1, image_2, model_list]
68
  iface = gr.Interface(fn=imgnt_reg, inputs=inputs,outputs=out_image,
69
- title="Imagenet registration V1",
70
  description="Upload 2 images to generate a registered one:",
71
  examples=[["./examples/ex1.jpg","./examples/ex2.jpg"]],
72
  )
73
 
74
- demo.queue(default_enabled = True).launch(debug = True)
 
1
  import gradio as gr
 
2
  import os
 
3
  #os.system('pip install torch -q')
4
 
5
  import sys
6
  import numpy as np
 
7
  from PIL import Image
8
+ import itertools
9
+ import glob
10
+ import random
11
+ import torch
12
+ import torchvision
13
+ import torchvision.transforms as transforms
14
+ import torch.nn as nn
15
+ import torch.optim as optim
16
+ import torch.nn.functional as F
17
  from codes import *
18
 
19
  ## Print samples
 
 
20
 
21
+ from torchvision.models import resnet18
 
22
 
23
 
24
+ device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")
25
+
26
+ file_savingfolder = './modelbest/'
27
+ ext = '_bestVal'
28
+ core_model_tst = resnet18(pretrained=True)
29
+ core_model_tst.fc = Identity()
30
+ core_model_tst.load_state_dict(torch.load(file_savingfolder+'core_model'+ext+'.pth'))
31
+ core_model_tst.to(device)
32
+ IR_Model_tst = Build_IRmodel_Resnet(core_model_tst, registration_method)
33
+ IR_Model_tst.load_state_dict(torch.load(file_savingfolder+'IR_Model'+ext+'.pth'))
34
+ IR_Model_tst.to(device)
35
+ IR_Model_tst.eval()
36
+
37
+
38
+ def wrap_imge_cropped(Affine_mtrx, source_img, dim1=224, dim2=128):
39
+ source_img224 = torch.nn.ZeroPad2d(int((dim1-dim2)/2))(source_img)
40
+ grd = torch.nn.functional.affine_grid(Affine_mtrx, size=source_img224.shape,align_corners=False)
41
+ wrapped_img = torch.nn.functional.grid_sample(source_img224, grid=grd,
42
+ mode='bilinear', padding_mode='zeros', align_corners=False)
43
+ wrapped_img = torchvision.transforms.CenterCrop((dim2, dim2))(wrapped_img)
44
+ return wrapped_img
45
 
46
  def imgnt_reg(img1,img2, model_selected):
47
+ #fixed_images = np.empty((1, 128, 128, 3))
48
+ #moving_images = np.empty((1, 128, 128, 3))
49
  # prepare inputs:
50
+ fixed_images = preprocess_image(img1, dim = 128)
51
+ moving_images = preprocess_image(img2, dim = 128)
52
+ M_i = torch.normal(torch.zeros([2,3]), torch.ones([2,3]))
53
+ model_inputs = {'source':fixed_images,
54
+ 'target':moving_images,
55
+ 'M_i' :M_i.usqueeze(0) }
56
+ #[moving_images/255, fixed_images/255]
57
+ pred = IR_Model_tst(model_inputs)
58
+ img_out = wrap_imge_cropped(pred['Affine_mtrx'], fixed_images, dim1=224, dim2=128)
59
+ registered_img = torchvision.transforms.ToPILImage()(img_out[0])
60
  return registered_img
61
 
62
 
 
89
  )
90
  inputs = [image_1, image_2, model_list]
91
  iface = gr.Interface(fn=imgnt_reg, inputs=inputs,outputs=out_image,
92
+ title="Imagenet registration V2",
93
  description="Upload 2 images to generate a registered one:",
94
  examples=[["./examples/ex1.jpg","./examples/ex2.jpg"]],
95
  )
96
 
97
+ demo.queue(default_enabled = True).launch(debug = True)