Spaces:
Runtime error
Runtime error
Update networks.py
Browse files- networks.py +15 -10
networks.py
CHANGED
|
@@ -6,6 +6,15 @@ from torchvision import models
|
|
| 6 |
import os
|
| 7 |
import numpy as np
|
| 8 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 9 |
def weights_init_normal(m):
|
| 10 |
classname = m.__class__.__name__
|
| 11 |
if classname.find('Conv') != -1:
|
|
@@ -143,6 +152,7 @@ class TpsGridGen(nn.Module):
|
|
| 143 |
super(TpsGridGen, self).__init__()
|
| 144 |
self.out_h, self.out_w = out_h, out_w
|
| 145 |
self.reg_factor = reg_factor
|
|
|
|
| 146 |
|
| 147 |
# create grid in numpy
|
| 148 |
self.grid = np.zeros([self.out_h, self.out_w, 3], dtype=np.float32)
|
|
@@ -173,7 +183,6 @@ class TpsGridGen(nn.Module):
|
|
| 173 |
def forward(self, theta):
|
| 174 |
warped_grid = self.apply_transformation(
|
| 175 |
theta, torch.cat((self.grid_X, self.grid_Y), 3))
|
| 176 |
-
|
| 177 |
return warped_grid
|
| 178 |
|
| 179 |
def compute_L_inverse(self, X, Y):
|
|
@@ -285,10 +294,6 @@ class TpsGridGen(nn.Module):
|
|
| 285 |
|
| 286 |
return torch.cat((points_X_prime, points_Y_prime), 3)
|
| 287 |
|
| 288 |
-
# Defines the Unet generator.
|
| 289 |
-
# |num_downs|: number of downsamplings in UNet. For example,
|
| 290 |
-
# if |num_downs| == 7, image of size 128x128 will become of size 1x1
|
| 291 |
-
# at the bottleneck
|
| 292 |
class UnetGenerator(nn.Module):
|
| 293 |
def __init__(self, input_nc, output_nc, num_downs, ngf=64,
|
| 294 |
norm_layer=nn.BatchNorm2d, use_dropout=False):
|
|
@@ -313,9 +318,6 @@ class UnetGenerator(nn.Module):
|
|
| 313 |
def forward(self, input):
|
| 314 |
return self.model(input)
|
| 315 |
|
| 316 |
-
# Defines the submodule with skip connection.
|
| 317 |
-
# X -------------------identity---------------------- X
|
| 318 |
-
# |-- downsampling -- |submodule| -- upsampling --|
|
| 319 |
class UnetSkipConnectionBlock(nn.Module):
|
| 320 |
def __init__(self, outer_nc, inner_nc, input_nc=None,
|
| 321 |
submodule=None, outermost=False, innermost=False, norm_layer=nn.BatchNorm2d, use_dropout=False):
|
|
@@ -464,9 +466,12 @@ class GicLoss(nn.Module):
|
|
| 464 |
class GMM(nn.Module):
|
| 465 |
""" Geometric Matching Module
|
| 466 |
"""
|
| 467 |
-
|
| 468 |
-
def __init__(self, opt):
|
| 469 |
super(GMM, self).__init__()
|
|
|
|
|
|
|
|
|
|
|
|
|
| 470 |
self.extractionA = FeatureExtraction(
|
| 471 |
22, ngf=64, n_layers=3, norm_layer=nn.BatchNorm2d)
|
| 472 |
self.extractionB = FeatureExtraction(
|
|
|
|
| 6 |
import os
|
| 7 |
import numpy as np
|
| 8 |
|
| 9 |
+
# Configuration class to hold all parameters
|
| 10 |
+
class Options:
|
| 11 |
+
def __init__(self):
|
| 12 |
+
# Default values
|
| 13 |
+
self.fine_height = 256
|
| 14 |
+
self.fine_width = 192
|
| 15 |
+
self.grid_size = 3
|
| 16 |
+
self.use_dropout = False
|
| 17 |
+
|
| 18 |
def weights_init_normal(m):
|
| 19 |
classname = m.__class__.__name__
|
| 20 |
if classname.find('Conv') != -1:
|
|
|
|
| 152 |
super(TpsGridGen, self).__init__()
|
| 153 |
self.out_h, self.out_w = out_h, out_w
|
| 154 |
self.reg_factor = reg_factor
|
| 155 |
+
self.grid_size = grid_size
|
| 156 |
|
| 157 |
# create grid in numpy
|
| 158 |
self.grid = np.zeros([self.out_h, self.out_w, 3], dtype=np.float32)
|
|
|
|
| 183 |
def forward(self, theta):
|
| 184 |
warped_grid = self.apply_transformation(
|
| 185 |
theta, torch.cat((self.grid_X, self.grid_Y), 3))
|
|
|
|
| 186 |
return warped_grid
|
| 187 |
|
| 188 |
def compute_L_inverse(self, X, Y):
|
|
|
|
| 294 |
|
| 295 |
return torch.cat((points_X_prime, points_Y_prime), 3)
|
| 296 |
|
|
|
|
|
|
|
|
|
|
|
|
|
| 297 |
class UnetGenerator(nn.Module):
|
| 298 |
def __init__(self, input_nc, output_nc, num_downs, ngf=64,
|
| 299 |
norm_layer=nn.BatchNorm2d, use_dropout=False):
|
|
|
|
| 318 |
def forward(self, input):
|
| 319 |
return self.model(input)
|
| 320 |
|
|
|
|
|
|
|
|
|
|
| 321 |
class UnetSkipConnectionBlock(nn.Module):
|
| 322 |
def __init__(self, outer_nc, inner_nc, input_nc=None,
|
| 323 |
submodule=None, outermost=False, innermost=False, norm_layer=nn.BatchNorm2d, use_dropout=False):
|
|
|
|
| 466 |
class GMM(nn.Module):
|
| 467 |
""" Geometric Matching Module
|
| 468 |
"""
|
| 469 |
+
def __init__(self, opt=None):
|
|
|
|
| 470 |
super(GMM, self).__init__()
|
| 471 |
+
# Initialize default options if none provided
|
| 472 |
+
if opt is None:
|
| 473 |
+
opt = Options()
|
| 474 |
+
|
| 475 |
self.extractionA = FeatureExtraction(
|
| 476 |
22, ngf=64, n_layers=3, norm_layer=nn.BatchNorm2d)
|
| 477 |
self.extractionB = FeatureExtraction(
|