Spaces:
Runtime error
Runtime error
Update networks.py
Browse files- networks.py +47 -40
networks.py
CHANGED
|
@@ -182,12 +182,12 @@ class TpsGridGen(nn.Module):
|
|
| 182 |
Li_block = self.Li[:self.N, :self.N]
|
| 183 |
|
| 184 |
# Compute weights
|
| 185 |
-
W_X = torch.bmm(Li_block.expand(batch_size,
|
| 186 |
-
W_Y = torch.bmm(Li_block.expand(batch_size,
|
| 187 |
|
| 188 |
# Prepare grid tensors
|
| 189 |
-
grid_X = self.grid_X.expand(batch_size,
|
| 190 |
-
grid_Y = self.grid_Y.expand(batch_size,
|
| 191 |
|
| 192 |
# Compute transformed coordinates
|
| 193 |
points_X = self.transform_points(grid_X, W_X, Q_X)
|
|
@@ -197,30 +197,36 @@ class TpsGridGen(nn.Module):
|
|
| 197 |
|
| 198 |
def transform_points(self, grid, W, Q):
|
| 199 |
batch_size, h, w, _ = grid.size()
|
|
|
|
| 200 |
|
| 201 |
-
# Flatten grid to (batch_size, H*W,
|
| 202 |
-
grid_flat = grid.view(batch_size,
|
| 203 |
|
| 204 |
# Prepare control points
|
| 205 |
-
P = torch.cat([self.P_X_base, self.P_Y_base], 1).
|
|
|
|
| 206 |
|
| 207 |
# Compute distance between grid points and control points
|
| 208 |
-
|
|
|
|
|
|
|
| 209 |
|
| 210 |
# Compute U (radial basis function)
|
| 211 |
-
dist_squared = torch.sum(torch.pow(delta, 2),
|
| 212 |
dist_squared[dist_squared == 0] = 1 # Avoid log(0)
|
| 213 |
U = torch.mul(dist_squared, torch.log(dist_squared))
|
| 214 |
|
| 215 |
# Compute affine transformation
|
| 216 |
A = torch.cat([
|
| 217 |
-
torch.ones(batch_size,
|
| 218 |
-
grid_flat
|
| 219 |
-
|
| 220 |
-
], 2)
|
| 221 |
|
| 222 |
# Combine affine and non-affine components
|
| 223 |
-
|
|
|
|
|
|
|
|
|
|
| 224 |
return points.view(batch_size, h, w, 1)
|
| 225 |
|
| 226 |
class GMM(nn.Module):
|
|
@@ -361,36 +367,37 @@ def load_checkpoint(model, checkpoint_path, strict=True):
|
|
| 361 |
|
| 362 |
state_dict = torch.load(checkpoint_path, map_location=torch.device('cpu'))
|
| 363 |
|
| 364 |
-
#
|
| 365 |
-
|
| 366 |
-
|
| 367 |
-
|
| 368 |
-
|
| 369 |
-
|
| 370 |
-
|
| 371 |
-
|
| 372 |
-
|
| 373 |
-
|
| 374 |
-
updated_state_dict = {}
|
| 375 |
-
for key, value in state_dict.items():
|
| 376 |
-
# Handle buffer name changes
|
| 377 |
-
for old_name, new_name in buffer_mapping.items():
|
| 378 |
-
if key.startswith(old_name):
|
| 379 |
-
key = key.replace(old_name, new_name)
|
| 380 |
|
| 381 |
-
#
|
| 382 |
-
|
| 383 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 384 |
|
| 385 |
-
# Load
|
| 386 |
-
model.load_state_dict(
|
| 387 |
|
| 388 |
-
# Print warnings
|
| 389 |
model_keys = set(model.state_dict().keys())
|
| 390 |
-
|
| 391 |
-
|
| 392 |
-
|
| 393 |
-
unexpected = checkpoint_keys - set(updated_state_dict.keys())
|
| 394 |
|
| 395 |
if missing:
|
| 396 |
print(f"Missing keys: {sorted(missing)}")
|
|
|
|
| 182 |
Li_block = self.Li[:self.N, :self.N]
|
| 183 |
|
| 184 |
# Compute weights
|
| 185 |
+
W_X = torch.bmm(Li_block.expand(batch_size, -1, -1), Q_X)
|
| 186 |
+
W_Y = torch.bmm(Li_block.expand(batch_size, -1, -1), Q_Y)
|
| 187 |
|
| 188 |
# Prepare grid tensors
|
| 189 |
+
grid_X = self.grid_X.expand(batch_size, -1, -1, -1)
|
| 190 |
+
grid_Y = self.grid_Y.expand(batch_size, -1, -1, -1)
|
| 191 |
|
| 192 |
# Compute transformed coordinates
|
| 193 |
points_X = self.transform_points(grid_X, W_X, Q_X)
|
|
|
|
| 197 |
|
| 198 |
def transform_points(self, grid, W, Q):
|
| 199 |
batch_size, h, w, _ = grid.size()
|
| 200 |
+
n_points = h * w
|
| 201 |
|
| 202 |
+
# Flatten grid to (batch_size, H*W, 1)
|
| 203 |
+
grid_flat = grid.view(batch_size, n_points, 1)
|
| 204 |
|
| 205 |
# Prepare control points
|
| 206 |
+
P = torch.cat([self.P_X_base, self.P_Y_base], 1).t().unsqueeze(0) # (1, 2, N)
|
| 207 |
+
P = P.expand(batch_size, -1, -1) # (B, 2, N)
|
| 208 |
|
| 209 |
# Compute distance between grid points and control points
|
| 210 |
+
grid_expanded = grid_flat.expand(-1, -1, self.N) # (B, H*W, N)
|
| 211 |
+
P_expanded = P.expand(n_points, -1, -1).permute(1, 0, 2) # (B, H*W, N)
|
| 212 |
+
delta = grid_expanded - P_expanded
|
| 213 |
|
| 214 |
# Compute U (radial basis function)
|
| 215 |
+
dist_squared = torch.sum(torch.pow(delta, 2), dim=1, keepdim=True) # (B, H*W, 1)
|
| 216 |
dist_squared[dist_squared == 0] = 1 # Avoid log(0)
|
| 217 |
U = torch.mul(dist_squared, torch.log(dist_squared))
|
| 218 |
|
| 219 |
# Compute affine transformation
|
| 220 |
A = torch.cat([
|
| 221 |
+
torch.ones(batch_size, n_points, 1, device=grid.device),
|
| 222 |
+
grid_flat.view(batch_size, n_points, 1)
|
| 223 |
+
], dim=2)
|
|
|
|
| 224 |
|
| 225 |
# Combine affine and non-affine components
|
| 226 |
+
affine = torch.bmm(A, Q.view(batch_size, 1, 3).permute(0, 2, 1))
|
| 227 |
+
non_affine = torch.bmm(U.permute(0, 2, 1), W).permute(0, 2, 1)
|
| 228 |
+
points = affine + non_affine
|
| 229 |
+
|
| 230 |
return points.view(batch_size, h, w, 1)
|
| 231 |
|
| 232 |
class GMM(nn.Module):
|
|
|
|
| 367 |
|
| 368 |
state_dict = torch.load(checkpoint_path, map_location=torch.device('cpu'))
|
| 369 |
|
| 370 |
+
# Initialize TPS grid parameters if missing
|
| 371 |
+
if 'gridGen.P_X_base' not in state_dict:
|
| 372 |
+
print("Initializing TPS grid parameters...")
|
| 373 |
+
grid_size = model.gridGen.grid_size
|
| 374 |
+
axis_coords = np.linspace(-1, 1, grid_size)
|
| 375 |
+
P_Y, P_X = np.meshgrid(axis_coords, axis_coords)
|
| 376 |
+
P_X = torch.FloatTensor(P_X.reshape(-1, 1))
|
| 377 |
+
P_Y = torch.FloatTensor(P_Y.reshape(-1, 1))
|
| 378 |
+
state_dict['gridGen.P_X_base'] = P_X
|
| 379 |
+
state_dict['gridGen.P_Y_base'] = P_Y
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 380 |
|
| 381 |
+
# Compute Li
|
| 382 |
+
Li = model.gridGen.compute_L_inverse(P_X, P_Y)
|
| 383 |
+
state_dict['gridGen.Li'] = Li
|
| 384 |
+
|
| 385 |
+
# Create grid
|
| 386 |
+
grid_X, grid_Y = np.meshgrid(
|
| 387 |
+
np.linspace(-1, 1, model.gridGen.out_w),
|
| 388 |
+
np.linspace(-1, 1, model.gridGen.out_h)
|
| 389 |
+
)
|
| 390 |
+
state_dict['gridGen.grid_X'] = torch.FloatTensor(grid_X).unsqueeze(0).unsqueeze(3)
|
| 391 |
+
state_dict['gridGen.grid_Y'] = torch.FloatTensor(grid_Y).unsqueeze(0).unsqueeze(3)
|
| 392 |
|
| 393 |
+
# Load state dict
|
| 394 |
+
model.load_state_dict(state_dict, strict=strict)
|
| 395 |
|
| 396 |
+
# Print warnings
|
| 397 |
model_keys = set(model.state_dict().keys())
|
| 398 |
+
ckpt_keys = set(state_dict.keys())
|
| 399 |
+
missing = model_keys - ckpt_keys
|
| 400 |
+
unexpected = ckpt_keys - model_keys
|
|
|
|
| 401 |
|
| 402 |
if missing:
|
| 403 |
print(f"Missing keys: {sorted(missing)}")
|