BFZD233
initial
f06f310
import re
import os
import sys
import logging
import numpy as np
from collections import OrderedDict
# logging.basicConfig(level=logging.INFO,
# format='%(asctime)s %(levelname)-8s [%(filename)s:%(lineno)d] %(message)s',)
import torch
import torch.nn as nn
import torch.nn.functional as F
from core.utils.plane import convert2patch
class Geometry_MLP(nn.Module):
def __init__(self, args):
super(Geometry_MLP, self).__init__()
self.args = args
self.reg = nn.Sequential(
nn.Linear(3,3),
nn.Linear(3,2),
)
if args.geo_fusion.lower()=="max":
self.fusion = nn.AdaptiveMaxPool1d(1)
elif args.geo_fusion.lower()=="mean":
self.fusion = nn.AdaptiveAvgPool1d(1)
else:
raise Exception(f"{args.geo_fusion} is not supported")
def forward(self, img_coord, flow_up):
# (1,4,factor*factor,H,W)
factor = 2 ** self.args.n_downsample
fit_points = torch.cat([img_coord, flow_up], dim=1)
fit_points = convert2patch(fit_points,
patch_size=factor,
div_last=False) # (1,3,factor*factor,H,W)
A = fit_points[:,:3].permute((0,2,3,4,1)) # (1,factor*factor,H,W,3)
ab_proposals = self.reg(A) # (1,factor*factor,H,W,2)
B,L,H,W,C = ab_proposals.shape
ab = self.fusion(ab_proposals.view(B,L,-1).transpose(-1,-2)) # (1,H*W*2,1)
ab = ab.view(B,H,W,C).permute((0,3,1,2)) # (1,2,H,W)
geo = torch.cat([disparity[:,:1],ab], dim=1)
return ab
class Geometry_Conv(nn.Module):
def __init__(self, args):
super(Geometry_Conv, self).__init__()
self.args = args
self.reg = nn.Sequential(
nn.Conv2d(3, 4, kernel_size=3, padding=1, stride=1),
nn.LeakyReLU(inplace=True),
nn.Conv2d(4, 8, kernel_size=3, padding=1, stride=2),
nn.LeakyReLU(inplace=True),
nn.Conv2d(8, 5, kernel_size=3, padding=1, stride=2),
nn.LeakyReLU(inplace=True),
nn.Conv2d(5, 5, kernel_size=1, padding=0, stride=1),
)
def forward(self, img_coord, disparity_up, disparity):
# img_coord: (1,2,H*factor,W*factor)
# disparity_up: (1,1,H*factor,W*factor)
# disparity: (1,1,H,W)
# factor = 2 ** self.args.n_downsample
# points = torch.cat([img_coord, disparity_up], dim=1) # (1,3,factor*H,factor*W)
points = torch.cat([img_coord, disparity_up.detach()], dim=1) # (1,3,factor*H,factor*W)
rest_params = self.reg(points) # (1,5,H,W)
params = torch.cat([disparity,rest_params], dim=1) # (1,6,H,W)
return params
class Geometry_Conv_Split(nn.Module):
def __init__(self, args):
super(Geometry_Conv_Split, self).__init__()
self.args = args
self.encode = nn.Sequential(
nn.Conv2d(3, 4, kernel_size=3, padding=1, stride=1),
nn.LeakyReLU(inplace=True),
nn.Conv2d(4, 8, kernel_size=3, padding=1, stride=2),
nn.LeakyReLU(inplace=True),
)
self.decode_plane = nn.Sequential(
nn.Conv2d(8, 4, kernel_size=3, padding=1, stride=2),
nn.LeakyReLU(inplace=True),
nn.Conv2d(4, 2, kernel_size=1, padding=0, stride=1),
)
self.decode_curvature = nn.Sequential(
nn.Conv2d(8, 4, kernel_size=3, padding=1, stride=2),
nn.LeakyReLU(inplace=True),
nn.Conv2d(4, 3, kernel_size=1, padding=0, stride=1),
)
def forward(self, img_coord, disparity_up, disparity):
# img_coord: (1,2,H*factor,W*factor)
# disparity_up: (1,1,H*factor,W*factor)
# disparity: (1,1,H,W)
# factor = 2 ** self.args.n_downsample
points = torch.cat([img_coord, disparity_up], dim=1) # (1,3,factor*H,factor*W)
latten = self.encode(points) # (1,8,factor*H/2,factor*W/2)
plane_ab = self.decode_plane(latten) # (1,2,H,W)
hessian_g = self.decode_curvature(latten) # (1,3,H,W)
params = torch.cat([disparity,plane_ab,hessian_g], dim=1) # (1,6,H,W)
return params
class LBPEncoder(nn.Module):
"""
Computes the modified Local Binary Patterns (LBP) of an image using custom neighbor offsets.
"""
def __init__(self, args):
super(LBPEncoder, self).__init__()
self.args = args
self.lbp_neighbor_offsets = self._parse_offsets(self.args.lbp_neighbor_offsets)
self._build_lbp_kernel()
self.sigmoid = nn.Sigmoid()
def _build_lbp_kernel(self):
# Determine the kernel size based on the maximum offset
self.num_neighbors = len(self.lbp_neighbor_offsets)
self.max_offset = int(np.abs(self.lbp_neighbor_offsets).max())
self.kernel_size = 2 * self.max_offset + 1
self.padding = self.max_offset
# Initialize the convolution layer for depthwise convolution
self.lbp_conv = nn.Conv2d(
in_channels=1,
out_channels=self.num_neighbors,
kernel_size=self.kernel_size,
padding=self.padding,
padding_mode="replicate",
bias=False,
groups=1 # Since in_channels=1, groups=1 makes it depthwise
)
self.lbp_weight = torch.zeros(self.num_neighbors, 1,
self.kernel_size, self.kernel_size).float()
center_y, center_x = self.max_offset, self.max_offset
for idx, (dy, dx) in enumerate(self.lbp_neighbor_offsets):
# Compute the position in the kernel for the neighbor
y, x = center_y + dy, center_x + dx
if 0 <= y < self.kernel_size and 0 <= x < self.kernel_size:
self.lbp_weight[idx, 0, y, x] = 1.0
self.lbp_weight[idx, 0, center_y, center_x] = -1.0
else:
raise ValueError(f"Offset ({dy}, {dx}) is out of kernel bounds.")
# Assign the weight to the convolution layer
self.lbp_conv.weight = nn.Parameter(self.lbp_weight)
self.lbp_conv.weight.requires_grad = False # Ensure weights are not updated during training
def _parse_offsets(self, offsets_str):
"""
Parses a string to extract neighbor offsets.
Parameters:
offsets_str (str): String defining neighbor offsets, e.g., "(-1,-1), (1,1), (-1,1), (1,-1)"
Returns:
list of tuples: List of neighbor offsets.
"""
# extract coordinate pairs
pattern = r'\((-?\d+),\s*(-?\d+)\)'
matches = re.findall(pattern, offsets_str)
if not matches:
raise ValueError(offsets_str + ": not suppoted format, please check it!")
offsets = [(int(y), int(x)) for y, x in matches]
return np.array(offsets)
def forward(self, img):
"""
Parameters:
img (torch.Tensor): Grayscale image tensor of shape [N, 1, H, W].
Returns:
torch.Tensor: Modified LBP image of shape [N, C, H, W].
"""
with torch.no_grad():
# Apply convolution to compute differences directly
differences = self.lbp_conv(img) # Shape: [1, N, H, W] due to padding
# Apply sigmoid to the differences to get encoding values between 0 and 1
encoding = self.sigmoid(differences) # Shape: [1, N, H, W]
return encoding