Spaces:
Runtime error
Runtime error
Update networks.py
Browse files- networks.py +8 -49
networks.py
CHANGED
|
@@ -1,13 +1,11 @@
|
|
| 1 |
-
# coding=utf-8
|
| 2 |
import torch
|
| 3 |
import torch.nn as nn
|
|
|
|
| 4 |
from torch.nn import init
|
| 5 |
from torchvision import models
|
| 6 |
import os
|
| 7 |
-
|
| 8 |
import numpy as np
|
| 9 |
|
| 10 |
-
|
| 11 |
def weights_init_normal(m):
|
| 12 |
classname = m.__class__.__name__
|
| 13 |
if classname.find('Conv') != -1:
|
|
@@ -18,7 +16,6 @@ def weights_init_normal(m):
|
|
| 18 |
init.normal_(m.weight.data, 1.0, 0.02)
|
| 19 |
init.constant_(m.bias.data, 0.0)
|
| 20 |
|
| 21 |
-
|
| 22 |
def weights_init_xavier(m):
|
| 23 |
classname = m.__class__.__name__
|
| 24 |
if classname.find('Conv') != -1:
|
|
@@ -29,7 +26,6 @@ def weights_init_xavier(m):
|
|
| 29 |
init.normal_(m.weight.data, 1.0, 0.02)
|
| 30 |
init.constant_(m.bias.data, 0.0)
|
| 31 |
|
| 32 |
-
|
| 33 |
def weights_init_kaiming(m):
|
| 34 |
classname = m.__class__.__name__
|
| 35 |
if classname.find('Conv') != -1:
|
|
@@ -40,7 +36,6 @@ def weights_init_kaiming(m):
|
|
| 40 |
init.normal_(m.weight.data, 1.0, 0.02)
|
| 41 |
init.constant_(m.bias.data, 0.0)
|
| 42 |
|
| 43 |
-
|
| 44 |
def init_weights(net, init_type='normal'):
|
| 45 |
print('initialization method [%s]' % init_type)
|
| 46 |
if init_type == 'normal':
|
|
@@ -53,7 +48,6 @@ def init_weights(net, init_type='normal'):
|
|
| 53 |
raise NotImplementedError(
|
| 54 |
'initialization method [%s] is not implemented' % init_type)
|
| 55 |
|
| 56 |
-
|
| 57 |
class FeatureExtraction(nn.Module):
|
| 58 |
def __init__(self, input_nc, ngf=64, n_layers=3, norm_layer=nn.BatchNorm2d, use_dropout=False):
|
| 59 |
super(FeatureExtraction, self).__init__()
|
|
@@ -78,7 +72,6 @@ class FeatureExtraction(nn.Module):
|
|
| 78 |
def forward(self, x):
|
| 79 |
return self.model(x)
|
| 80 |
|
| 81 |
-
|
| 82 |
class FeatureL2Norm(torch.nn.Module):
|
| 83 |
def __init__(self):
|
| 84 |
super(FeatureL2Norm, self).__init__()
|
|
@@ -89,7 +82,6 @@ class FeatureL2Norm(torch.nn.Module):
|
|
| 89 |
epsilon, 0.5).unsqueeze(1).expand_as(feature)
|
| 90 |
return torch.div(feature, norm)
|
| 91 |
|
| 92 |
-
|
| 93 |
class FeatureCorrelation(nn.Module):
|
| 94 |
def __init__(self):
|
| 95 |
super(FeatureCorrelation, self).__init__()
|
|
@@ -105,9 +97,8 @@ class FeatureCorrelation(nn.Module):
|
|
| 105 |
b, h, w, h*w).transpose(2, 3).transpose(1, 2)
|
| 106 |
return correlation_tensor
|
| 107 |
|
| 108 |
-
|
| 109 |
class FeatureRegression(nn.Module):
|
| 110 |
-
def __init__(self, input_nc=512, output_dim=6
|
| 111 |
super(FeatureRegression, self).__init__()
|
| 112 |
self.conv = nn.Sequential(
|
| 113 |
nn.Conv2d(input_nc, 512, kernel_size=4, stride=2, padding=1),
|
|
@@ -125,10 +116,6 @@ class FeatureRegression(nn.Module):
|
|
| 125 |
)
|
| 126 |
self.linear = nn.Linear(64 * 4 * 3, output_dim)
|
| 127 |
self.tanh = nn.Tanh()
|
| 128 |
-
if use_cuda:
|
| 129 |
-
self.conv.cuda()
|
| 130 |
-
self.linear.cuda()
|
| 131 |
-
self.tanh.cuda()
|
| 132 |
|
| 133 |
def forward(self, x):
|
| 134 |
x = self.conv(x)
|
|
@@ -137,7 +124,6 @@ class FeatureRegression(nn.Module):
|
|
| 137 |
x = self.tanh(x)
|
| 138 |
return x
|
| 139 |
|
| 140 |
-
|
| 141 |
class AffineGridGen(nn.Module):
|
| 142 |
def __init__(self, out_h=256, out_w=192, out_ch=3):
|
| 143 |
super(AffineGridGen, self).__init__()
|
|
@@ -152,13 +138,11 @@ class AffineGridGen(nn.Module):
|
|
| 152 |
(batch_size, self.out_ch, self.out_h, self.out_w))
|
| 153 |
return F.affine_grid(theta, out_size)
|
| 154 |
|
| 155 |
-
|
| 156 |
class TpsGridGen(nn.Module):
|
| 157 |
-
def __init__(self, out_h=256, out_w=192, use_regular_grid=True, grid_size=3, reg_factor=0
|
| 158 |
super(TpsGridGen, self).__init__()
|
| 159 |
self.out_h, self.out_w = out_h, out_w
|
| 160 |
self.reg_factor = reg_factor
|
| 161 |
-
self.use_cuda = use_cuda
|
| 162 |
|
| 163 |
# create grid in numpy
|
| 164 |
self.grid = np.zeros([self.out_h, self.out_w, 3], dtype=np.float32)
|
|
@@ -168,9 +152,6 @@ class TpsGridGen(nn.Module):
|
|
| 168 |
# grid_X,grid_Y: size [1,H,W,1,1]
|
| 169 |
self.grid_X = torch.FloatTensor(self.grid_X).unsqueeze(0).unsqueeze(3)
|
| 170 |
self.grid_Y = torch.FloatTensor(self.grid_Y).unsqueeze(0).unsqueeze(3)
|
| 171 |
-
if use_cuda:
|
| 172 |
-
self.grid_X = self.grid_X.cuda()
|
| 173 |
-
self.grid_Y = self.grid_Y.cuda()
|
| 174 |
|
| 175 |
# initialize regular grid for control points P_i
|
| 176 |
if use_regular_grid:
|
|
@@ -188,11 +169,6 @@ class TpsGridGen(nn.Module):
|
|
| 188 |
3).unsqueeze(4).transpose(0, 4)
|
| 189 |
self.P_Y = P_Y.unsqueeze(2).unsqueeze(
|
| 190 |
3).unsqueeze(4).transpose(0, 4)
|
| 191 |
-
if use_cuda:
|
| 192 |
-
self.P_X = self.P_X.cuda()
|
| 193 |
-
self.P_Y = self.P_Y.cuda()
|
| 194 |
-
self.P_X_base = self.P_X_base.cuda()
|
| 195 |
-
self.P_Y_base = self.P_Y_base.cuda()
|
| 196 |
|
| 197 |
def forward(self, theta):
|
| 198 |
warped_grid = self.apply_transformation(
|
|
@@ -217,8 +193,6 @@ class TpsGridGen(nn.Module):
|
|
| 217 |
L = torch.cat((torch.cat((K, P), 1), torch.cat(
|
| 218 |
(P.transpose(0, 1), Z), 1)), 0)
|
| 219 |
Li = torch.inverse(L)
|
| 220 |
-
if self.use_cuda:
|
| 221 |
-
Li = Li.cuda()
|
| 222 |
return Li
|
| 223 |
|
| 224 |
def apply_transformation(self, theta, points):
|
|
@@ -315,8 +289,6 @@ class TpsGridGen(nn.Module):
|
|
| 315 |
# |num_downs|: number of downsamplings in UNet. For example,
|
| 316 |
# if |num_downs| == 7, image of size 128x128 will become of size 1x1
|
| 317 |
# at the bottleneck
|
| 318 |
-
|
| 319 |
-
|
| 320 |
class UnetGenerator(nn.Module):
|
| 321 |
def __init__(self, input_nc, output_nc, num_downs, ngf=64,
|
| 322 |
norm_layer=nn.BatchNorm2d, use_dropout=False):
|
|
@@ -341,7 +313,6 @@ class UnetGenerator(nn.Module):
|
|
| 341 |
def forward(self, input):
|
| 342 |
return self.model(input)
|
| 343 |
|
| 344 |
-
|
| 345 |
# Defines the submodule with skip connection.
|
| 346 |
# X -------------------identity---------------------- X
|
| 347 |
# |-- downsampling -- |submodule| -- upsampling --|
|
|
@@ -395,7 +366,6 @@ class UnetSkipConnectionBlock(nn.Module):
|
|
| 395 |
else:
|
| 396 |
return torch.cat([x, self.model(x)], 1)
|
| 397 |
|
| 398 |
-
|
| 399 |
class Vgg19(nn.Module):
|
| 400 |
def __init__(self, requires_grad=False):
|
| 401 |
super(Vgg19, self).__init__()
|
|
@@ -428,12 +398,10 @@ class Vgg19(nn.Module):
|
|
| 428 |
out = [h_relu1, h_relu2, h_relu3, h_relu4, h_relu5]
|
| 429 |
return out
|
| 430 |
|
| 431 |
-
|
| 432 |
class VGGLoss(nn.Module):
|
| 433 |
def __init__(self, layids=None):
|
| 434 |
super(VGGLoss, self).__init__()
|
| 435 |
self.vgg = Vgg19()
|
| 436 |
-
self.vgg.cuda()
|
| 437 |
self.criterion = nn.L1Loss()
|
| 438 |
self.weights = [1.0/32, 1.0/16, 1.0/8, 1.0/4, 1.0]
|
| 439 |
self.layids = layids
|
|
@@ -448,7 +416,6 @@ class VGGLoss(nn.Module):
|
|
| 448 |
self.criterion(x_vgg[i], y_vgg[i].detach())
|
| 449 |
return loss
|
| 450 |
|
| 451 |
-
|
| 452 |
class DT(nn.Module):
|
| 453 |
def __init__(self):
|
| 454 |
super(DT, self).__init__()
|
|
@@ -457,17 +424,15 @@ class DT(nn.Module):
|
|
| 457 |
dt = torch.abs(x1 - x2)
|
| 458 |
return dt
|
| 459 |
|
| 460 |
-
|
| 461 |
class DT2(nn.Module):
|
| 462 |
def __init__(self):
|
| 463 |
-
super(
|
| 464 |
|
| 465 |
def forward(self, x1, y1, x2, y2):
|
| 466 |
dt = torch.sqrt(torch.mul(x1 - x2, x1 - x2) +
|
| 467 |
torch.mul(y1 - y2, y1 - y2))
|
| 468 |
return dt
|
| 469 |
|
| 470 |
-
|
| 471 |
class GicLoss(nn.Module):
|
| 472 |
def __init__(self, opt):
|
| 473 |
super(GicLoss, self).__init__()
|
|
@@ -496,7 +461,6 @@ class GicLoss(nn.Module):
|
|
| 496 |
|
| 497 |
return torch.sum(torch.abs(dtleft - dtright) + torch.abs(dtup - dtdown))
|
| 498 |
|
| 499 |
-
|
| 500 |
class GMM(nn.Module):
|
| 501 |
""" Geometric Matching Module
|
| 502 |
"""
|
|
@@ -510,9 +474,9 @@ class GMM(nn.Module):
|
|
| 510 |
self.l2norm = FeatureL2Norm()
|
| 511 |
self.correlation = FeatureCorrelation()
|
| 512 |
self.regression = FeatureRegression(
|
| 513 |
-
input_nc=192, output_dim=2*opt.grid_size**2
|
| 514 |
self.gridGen = TpsGridGen(
|
| 515 |
-
opt.fine_height, opt.fine_width,
|
| 516 |
|
| 517 |
def forward(self, inputA, inputB):
|
| 518 |
featureA = self.extractionA(inputA)
|
|
@@ -525,17 +489,12 @@ class GMM(nn.Module):
|
|
| 525 |
grid = self.gridGen(theta)
|
| 526 |
return grid, theta
|
| 527 |
|
| 528 |
-
|
| 529 |
def save_checkpoint(model, save_path):
|
| 530 |
if not os.path.exists(os.path.dirname(save_path)):
|
| 531 |
os.makedirs(os.path.dirname(save_path))
|
| 532 |
-
|
| 533 |
-
torch.save(model.cpu().state_dict(), save_path)
|
| 534 |
-
model.cuda()
|
| 535 |
-
|
| 536 |
|
| 537 |
def load_checkpoint(model, checkpoint_path):
|
| 538 |
if not os.path.exists(checkpoint_path):
|
| 539 |
return
|
| 540 |
-
model.load_state_dict(torch.load(checkpoint_path))
|
| 541 |
-
model.cuda()
|
|
|
|
|
|
|
| 1 |
import torch
|
| 2 |
import torch.nn as nn
|
| 3 |
+
import torch.nn.functional as F
|
| 4 |
from torch.nn import init
|
| 5 |
from torchvision import models
|
| 6 |
import os
|
|
|
|
| 7 |
import numpy as np
|
| 8 |
|
|
|
|
| 9 |
def weights_init_normal(m):
|
| 10 |
classname = m.__class__.__name__
|
| 11 |
if classname.find('Conv') != -1:
|
|
|
|
| 16 |
init.normal_(m.weight.data, 1.0, 0.02)
|
| 17 |
init.constant_(m.bias.data, 0.0)
|
| 18 |
|
|
|
|
| 19 |
def weights_init_xavier(m):
|
| 20 |
classname = m.__class__.__name__
|
| 21 |
if classname.find('Conv') != -1:
|
|
|
|
| 26 |
init.normal_(m.weight.data, 1.0, 0.02)
|
| 27 |
init.constant_(m.bias.data, 0.0)
|
| 28 |
|
|
|
|
| 29 |
def weights_init_kaiming(m):
|
| 30 |
classname = m.__class__.__name__
|
| 31 |
if classname.find('Conv') != -1:
|
|
|
|
| 36 |
init.normal_(m.weight.data, 1.0, 0.02)
|
| 37 |
init.constant_(m.bias.data, 0.0)
|
| 38 |
|
|
|
|
| 39 |
def init_weights(net, init_type='normal'):
|
| 40 |
print('initialization method [%s]' % init_type)
|
| 41 |
if init_type == 'normal':
|
|
|
|
| 48 |
raise NotImplementedError(
|
| 49 |
'initialization method [%s] is not implemented' % init_type)
|
| 50 |
|
|
|
|
| 51 |
class FeatureExtraction(nn.Module):
|
| 52 |
def __init__(self, input_nc, ngf=64, n_layers=3, norm_layer=nn.BatchNorm2d, use_dropout=False):
|
| 53 |
super(FeatureExtraction, self).__init__()
|
|
|
|
| 72 |
def forward(self, x):
|
| 73 |
return self.model(x)
|
| 74 |
|
|
|
|
| 75 |
class FeatureL2Norm(torch.nn.Module):
|
| 76 |
def __init__(self):
|
| 77 |
super(FeatureL2Norm, self).__init__()
|
|
|
|
| 82 |
epsilon, 0.5).unsqueeze(1).expand_as(feature)
|
| 83 |
return torch.div(feature, norm)
|
| 84 |
|
|
|
|
| 85 |
class FeatureCorrelation(nn.Module):
|
| 86 |
def __init__(self):
|
| 87 |
super(FeatureCorrelation, self).__init__()
|
|
|
|
| 97 |
b, h, w, h*w).transpose(2, 3).transpose(1, 2)
|
| 98 |
return correlation_tensor
|
| 99 |
|
|
|
|
| 100 |
class FeatureRegression(nn.Module):
|
| 101 |
+
def __init__(self, input_nc=512, output_dim=6):
|
| 102 |
super(FeatureRegression, self).__init__()
|
| 103 |
self.conv = nn.Sequential(
|
| 104 |
nn.Conv2d(input_nc, 512, kernel_size=4, stride=2, padding=1),
|
|
|
|
| 116 |
)
|
| 117 |
self.linear = nn.Linear(64 * 4 * 3, output_dim)
|
| 118 |
self.tanh = nn.Tanh()
|
|
|
|
|
|
|
|
|
|
|
|
|
| 119 |
|
| 120 |
def forward(self, x):
|
| 121 |
x = self.conv(x)
|
|
|
|
| 124 |
x = self.tanh(x)
|
| 125 |
return x
|
| 126 |
|
|
|
|
| 127 |
class AffineGridGen(nn.Module):
|
| 128 |
def __init__(self, out_h=256, out_w=192, out_ch=3):
|
| 129 |
super(AffineGridGen, self).__init__()
|
|
|
|
| 138 |
(batch_size, self.out_ch, self.out_h, self.out_w))
|
| 139 |
return F.affine_grid(theta, out_size)
|
| 140 |
|
|
|
|
| 141 |
class TpsGridGen(nn.Module):
|
| 142 |
+
def __init__(self, out_h=256, out_w=192, use_regular_grid=True, grid_size=3, reg_factor=0):
|
| 143 |
super(TpsGridGen, self).__init__()
|
| 144 |
self.out_h, self.out_w = out_h, out_w
|
| 145 |
self.reg_factor = reg_factor
|
|
|
|
| 146 |
|
| 147 |
# create grid in numpy
|
| 148 |
self.grid = np.zeros([self.out_h, self.out_w, 3], dtype=np.float32)
|
|
|
|
| 152 |
# grid_X,grid_Y: size [1,H,W,1,1]
|
| 153 |
self.grid_X = torch.FloatTensor(self.grid_X).unsqueeze(0).unsqueeze(3)
|
| 154 |
self.grid_Y = torch.FloatTensor(self.grid_Y).unsqueeze(0).unsqueeze(3)
|
|
|
|
|
|
|
|
|
|
| 155 |
|
| 156 |
# initialize regular grid for control points P_i
|
| 157 |
if use_regular_grid:
|
|
|
|
| 169 |
3).unsqueeze(4).transpose(0, 4)
|
| 170 |
self.P_Y = P_Y.unsqueeze(2).unsqueeze(
|
| 171 |
3).unsqueeze(4).transpose(0, 4)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 172 |
|
| 173 |
def forward(self, theta):
|
| 174 |
warped_grid = self.apply_transformation(
|
|
|
|
| 193 |
L = torch.cat((torch.cat((K, P), 1), torch.cat(
|
| 194 |
(P.transpose(0, 1), Z), 1)), 0)
|
| 195 |
Li = torch.inverse(L)
|
|
|
|
|
|
|
| 196 |
return Li
|
| 197 |
|
| 198 |
def apply_transformation(self, theta, points):
|
|
|
|
| 289 |
# |num_downs|: number of downsamplings in UNet. For example,
|
| 290 |
# if |num_downs| == 7, image of size 128x128 will become of size 1x1
|
| 291 |
# at the bottleneck
|
|
|
|
|
|
|
| 292 |
class UnetGenerator(nn.Module):
|
| 293 |
def __init__(self, input_nc, output_nc, num_downs, ngf=64,
|
| 294 |
norm_layer=nn.BatchNorm2d, use_dropout=False):
|
|
|
|
| 313 |
def forward(self, input):
|
| 314 |
return self.model(input)
|
| 315 |
|
|
|
|
| 316 |
# Defines the submodule with skip connection.
|
| 317 |
# X -------------------identity---------------------- X
|
| 318 |
# |-- downsampling -- |submodule| -- upsampling --|
|
|
|
|
| 366 |
else:
|
| 367 |
return torch.cat([x, self.model(x)], 1)
|
| 368 |
|
|
|
|
| 369 |
class Vgg19(nn.Module):
|
| 370 |
def __init__(self, requires_grad=False):
|
| 371 |
super(Vgg19, self).__init__()
|
|
|
|
| 398 |
out = [h_relu1, h_relu2, h_relu3, h_relu4, h_relu5]
|
| 399 |
return out
|
| 400 |
|
|
|
|
| 401 |
class VGGLoss(nn.Module):
|
| 402 |
def __init__(self, layids=None):
|
| 403 |
super(VGGLoss, self).__init__()
|
| 404 |
self.vgg = Vgg19()
|
|
|
|
| 405 |
self.criterion = nn.L1Loss()
|
| 406 |
self.weights = [1.0/32, 1.0/16, 1.0/8, 1.0/4, 1.0]
|
| 407 |
self.layids = layids
|
|
|
|
| 416 |
self.criterion(x_vgg[i], y_vgg[i].detach())
|
| 417 |
return loss
|
| 418 |
|
|
|
|
| 419 |
class DT(nn.Module):
|
| 420 |
def __init__(self):
|
| 421 |
super(DT, self).__init__()
|
|
|
|
| 424 |
dt = torch.abs(x1 - x2)
|
| 425 |
return dt
|
| 426 |
|
|
|
|
| 427 |
class DT2(nn.Module):
|
| 428 |
def __init__(self):
|
| 429 |
+
super(DT2, self).__init__()
|
| 430 |
|
| 431 |
def forward(self, x1, y1, x2, y2):
|
| 432 |
dt = torch.sqrt(torch.mul(x1 - x2, x1 - x2) +
|
| 433 |
torch.mul(y1 - y2, y1 - y2))
|
| 434 |
return dt
|
| 435 |
|
|
|
|
| 436 |
class GicLoss(nn.Module):
|
| 437 |
def __init__(self, opt):
|
| 438 |
super(GicLoss, self).__init__()
|
|
|
|
| 461 |
|
| 462 |
return torch.sum(torch.abs(dtleft - dtright) + torch.abs(dtup - dtdown))
|
| 463 |
|
|
|
|
| 464 |
class GMM(nn.Module):
|
| 465 |
""" Geometric Matching Module
|
| 466 |
"""
|
|
|
|
| 474 |
self.l2norm = FeatureL2Norm()
|
| 475 |
self.correlation = FeatureCorrelation()
|
| 476 |
self.regression = FeatureRegression(
|
| 477 |
+
input_nc=192, output_dim=2*opt.grid_size**2)
|
| 478 |
self.gridGen = TpsGridGen(
|
| 479 |
+
opt.fine_height, opt.fine_width, grid_size=opt.grid_size)
|
| 480 |
|
| 481 |
def forward(self, inputA, inputB):
|
| 482 |
featureA = self.extractionA(inputA)
|
|
|
|
| 489 |
grid = self.gridGen(theta)
|
| 490 |
return grid, theta
|
| 491 |
|
|
|
|
| 492 |
def save_checkpoint(model, save_path):
|
| 493 |
if not os.path.exists(os.path.dirname(save_path)):
|
| 494 |
os.makedirs(os.path.dirname(save_path))
|
| 495 |
+
torch.save(model.state_dict(), save_path)
|
|
|
|
|
|
|
|
|
|
| 496 |
|
| 497 |
def load_checkpoint(model, checkpoint_path):
|
| 498 |
if not os.path.exists(checkpoint_path):
|
| 499 |
return
|
| 500 |
+
model.load_state_dict(torch.load(checkpoint_path, map_location=torch.device('cpu')))
|
|
|