Spaces:
Runtime error
Runtime error
Update app.py
Browse files
app.py
CHANGED
|
@@ -1,39 +1,62 @@
|
|
| 1 |
import gradio as gr
|
| 2 |
-
|
| 3 |
import os
|
| 4 |
-
|
| 5 |
#os.system('pip install torch -q')
|
| 6 |
|
| 7 |
import sys
|
| 8 |
import numpy as np
|
| 9 |
-
import torch
|
| 10 |
from PIL import Image
|
| 11 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 12 |
from codes import *
|
| 13 |
|
| 14 |
## Print samples
|
| 15 |
-
vxm_model_loaded = create_model(dim = 128)
|
| 16 |
-
vxm_model_loaded.load_weights('./modelbest/')
|
| 17 |
|
| 18 |
-
|
| 19 |
-
vxm_model_loaded_affine.load_weights('./modelbest/')
|
| 20 |
|
| 21 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 22 |
|
| 23 |
def imgnt_reg(img1,img2, model_selected):
|
| 24 |
-
fixed_images = np.empty((1, 128, 128, 3))
|
| 25 |
-
moving_images = np.empty((1, 128, 128, 3))
|
| 26 |
# prepare inputs:
|
| 27 |
-
fixed_images
|
| 28 |
-
moving_images
|
| 29 |
-
|
| 30 |
-
model_inputs =
|
| 31 |
-
|
| 32 |
-
|
| 33 |
-
|
| 34 |
-
|
| 35 |
-
img_out =
|
| 36 |
-
registered_img =
|
| 37 |
return registered_img
|
| 38 |
|
| 39 |
|
|
@@ -66,9 +89,9 @@ with gr.Blocks() as demo:
|
|
| 66 |
)
|
| 67 |
inputs = [image_1, image_2, model_list]
|
| 68 |
iface = gr.Interface(fn=imgnt_reg, inputs=inputs,outputs=out_image,
|
| 69 |
-
title="Imagenet registration
|
| 70 |
description="Upload 2 images to generate a registered one:",
|
| 71 |
examples=[["./examples/ex1.jpg","./examples/ex2.jpg"]],
|
| 72 |
)
|
| 73 |
|
| 74 |
-
demo.queue(default_enabled = True).launch(debug = True)
|
|
|
|
| 1 |
import gradio as gr
|
|
|
|
| 2 |
import os
|
|
|
|
| 3 |
#os.system('pip install torch -q')
|
| 4 |
|
| 5 |
import sys
|
| 6 |
import numpy as np
|
|
|
|
| 7 |
from PIL import Image
|
| 8 |
+
import itertools
|
| 9 |
+
import glob
|
| 10 |
+
import random
|
| 11 |
+
import torch
|
| 12 |
+
import torchvision
|
| 13 |
+
import torchvision.transforms as transforms
|
| 14 |
+
import torch.nn as nn
|
| 15 |
+
import torch.optim as optim
|
| 16 |
+
import torch.nn.functional as F
|
| 17 |
from codes import *
|
| 18 |
|
| 19 |
## Print samples
|
|
|
|
|
|
|
| 20 |
|
| 21 |
+
from torchvision.models import resnet18
|
|
|
|
| 22 |
|
| 23 |
|
| 24 |
+
device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")
|
| 25 |
+
|
| 26 |
+
file_savingfolder = './modelbest/'
|
| 27 |
+
ext = '_bestVal'
|
| 28 |
+
core_model_tst = resnet18(pretrained=True)
|
| 29 |
+
core_model_tst.fc = Identity()
|
| 30 |
+
core_model_tst.load_state_dict(torch.load(file_savingfolder+'core_model'+ext+'.pth'))
|
| 31 |
+
core_model_tst.to(device)
|
| 32 |
+
IR_Model_tst = Build_IRmodel_Resnet(core_model_tst, registration_method)
|
| 33 |
+
IR_Model_tst.load_state_dict(torch.load(file_savingfolder+'IR_Model'+ext+'.pth'))
|
| 34 |
+
IR_Model_tst.to(device)
|
| 35 |
+
IR_Model_tst.eval()
|
| 36 |
+
|
| 37 |
+
|
| 38 |
+
def wrap_imge_cropped(Affine_mtrx, source_img, dim1=224, dim2=128):
|
| 39 |
+
source_img224 = torch.nn.ZeroPad2d(int((dim1-dim2)/2))(source_img)
|
| 40 |
+
grd = torch.nn.functional.affine_grid(Affine_mtrx, size=source_img224.shape,align_corners=False)
|
| 41 |
+
wrapped_img = torch.nn.functional.grid_sample(source_img224, grid=grd,
|
| 42 |
+
mode='bilinear', padding_mode='zeros', align_corners=False)
|
| 43 |
+
wrapped_img = torchvision.transforms.CenterCrop((dim2, dim2))(wrapped_img)
|
| 44 |
+
return wrapped_img
|
| 45 |
|
| 46 |
def imgnt_reg(img1,img2, model_selected):
|
| 47 |
+
#fixed_images = np.empty((1, 128, 128, 3))
|
| 48 |
+
#moving_images = np.empty((1, 128, 128, 3))
|
| 49 |
# prepare inputs:
|
| 50 |
+
fixed_images = preprocess_image(img1, dim = 128)
|
| 51 |
+
moving_images = preprocess_image(img2, dim = 128)
|
| 52 |
+
M_i = torch.normal(torch.zeros([2,3]), torch.ones([2,3]))
|
| 53 |
+
model_inputs = {'source':fixed_images,
|
| 54 |
+
'target':moving_images,
|
| 55 |
+
'M_i' :M_i.usqueeze(0) }
|
| 56 |
+
#[moving_images/255, fixed_images/255]
|
| 57 |
+
pred = IR_Model_tst(model_inputs)
|
| 58 |
+
img_out = wrap_imge_cropped(pred['Affine_mtrx'], fixed_images, dim1=224, dim2=128)
|
| 59 |
+
registered_img = torchvision.transforms.ToPILImage()(img_out[0])
|
| 60 |
return registered_img
|
| 61 |
|
| 62 |
|
|
|
|
| 89 |
)
|
| 90 |
inputs = [image_1, image_2, model_list]
|
| 91 |
iface = gr.Interface(fn=imgnt_reg, inputs=inputs,outputs=out_image,
|
| 92 |
+
title="Imagenet registration V2",
|
| 93 |
description="Upload 2 images to generate a registered one:",
|
| 94 |
examples=[["./examples/ex1.jpg","./examples/ex2.jpg"]],
|
| 95 |
)
|
| 96 |
|
| 97 |
+
demo.queue(default_enabled = True).launch(debug = True)
|