BFZD233
initial
f06f310
import os
import sys
import numpy as np
sys.path.insert(0,'mast3r')
import torch
import torch.nn as nn
import torch.nn.functional as F
import torchvision.transforms as T
from core.extractor import ResidualBlock
from mast3r.model import AsymmetricMASt3R
# def resize_and_pad_tensor(tensor, target_size=512):
# # 获取输入 tensor 的尺寸 (B, C, H, W)
# _, _, H, W = tensor.shape
# # 计算 H 和 W 中较长的一边
# if H > W:
# new_H = target_size
# new_W = int(W * (target_size / H))
# else:
# new_W = target_size
# new_H = int(H * (target_size / W))
# # 使用 interpolate 进行缩放
# resized_tensor = F.interpolate(tensor, size=(new_H, new_W), mode='bilinear', align_corners=False)
# # 计算是否需要填充,使得尺寸可以被16整除
# pad_H = (16 - new_H % 16) if new_H % 16 != 0 else 0
# pad_W = (16 - new_W % 16) if new_W % 16 != 0 else 0
# # 进行填充,确保两边可以被16整除
# padding = (0, pad_W, 0, pad_H) # (left, right, top, bottom)
# padded_tensor = F.pad(resized_tensor, padding)
# return padded_tensor
def resize_tensor(tensor, target_size=512, ratio=16):
# 获取输入 tensor 的尺寸 (B, C, H, W)
_, _, H, W = tensor.shape
# 计算 H 和 W 中较长的一边
if H > W:
new_H = target_size
new_W = int(W * (target_size / H))
else:
new_W = target_size
new_H = int(H * (target_size / W))
new_W = (np.ceil(new_W / ratio) * ratio).astype(int)
new_H = (np.ceil(new_H / ratio) * ratio).astype(int)
# 使用 interpolate 进行缩放
resized_tensor = F.interpolate(tensor, size=(new_H, new_W), mode='bicubic', align_corners=False)
return resized_tensor
def resize_to_quarter(tensor, original_size, ratio):
# 将尺寸缩小为原始尺寸的 1/4
quarter_H = original_size[0] // ratio
quarter_W = original_size[1] // ratio
# 使用 interpolate 进行缩小
resized_tensor = F.interpolate(tensor, size=(quarter_H, quarter_W), mode='bilinear', align_corners=False)
return resized_tensor
class Mast3rExtractor(nn.Module):
def __init__(self, model_name, output_dim=128, norm_fn='batch', downsample=2):
super(Mast3rExtractor, self).__init__()
self.norm_fn = norm_fn
self.downsample = downsample
if self.norm_fn == 'group':
self.norm1 = nn.GroupNorm(num_groups=8, num_channels=64)
elif self.norm_fn == 'batch':
self.norm1 = nn.BatchNorm2d(64)
elif self.norm_fn == 'instance':
self.norm1 = nn.InstanceNorm2d(64)
elif self.norm_fn == 'none':
self.norm1 = nn.Sequential()
# self.layer1 = nn.Sequential(
# nn.Conv2d(32, 64, kernel_size=7, stride=1, padding=3),
# self.norm1,
# nn.ReLU(inplace=True),
# )
self.layer1 = nn.Sequential(
nn.Conv2d(32, 64, kernel_size=7, stride=1, padding=3),
self.norm1,
nn.ReLU(inplace=True),
nn.Conv2d(64, 64, kernel_size=3, stride=1, padding=1),
self.norm1,
nn.ReLU(inplace=True),
)
self.in_planes = 64
self.layer2 = self._make_layer(128, stride=1)
# output convolution
self.conv = nn.Conv2d(128, output_dim, kernel_size=1)
# self._init_weights()
self.mast3r = AsymmetricMASt3R.from_pretrained(model_name).to('cuda')
# 冻结 Mast3r 模型的所有参数
for param in self.mast3r.parameters():
param.requires_grad = False
# def _init_weights(self):
# for m in self.modules():
# if isinstance(m, nn.Conv2d):
# nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu')
# elif isinstance(m, (nn.BatchNorm2d, nn.InstanceNorm2d, nn.GroupNorm)):
# if m.weight is not None:
# nn.init.constant_(m.weight, 1)
# if m.bias is not None:
# nn.init.constant_(m.bias, 0)
# def _make_layer(self, dim, stride=1):
# layer1 = ResidualBlock(self.in_planes, dim, self.norm_fn, stride=stride)
# layer2 = ResidualBlock(dim, dim, self.norm_fn, stride=1)
# layers = (layer1, layer2)
# self.in_planes = dim
# return nn.Sequential(*layers)
# def _make_layer(self, dim, stride=1):
# layer1 = ResidualBlock(self.in_planes, self.in_planes, self.norm_fn, stride=stride)
# layer2 = ResidualBlock(self.in_planes, dim, self.norm_fn, stride=1)
# layer3 = ResidualBlock(dim, dim, self.norm_fn, stride=1)
# layers = (layer1, layer2, layer3)
# self.in_planes = dim
# return nn.Sequential(*layers)
def _make_layer(self, dim, stride=1):
layer1 = ResidualBlock(self.in_planes, self.in_planes, self.norm_fn, stride=stride)
layer1 = ResidualBlock(self.in_planes, self.in_planes, self.norm_fn, stride=stride)
layer2 = ResidualBlock(self.in_planes, dim, self.norm_fn, stride=1)
layer3 = ResidualBlock(dim, dim, self.norm_fn, stride=1)
layer3 = ResidualBlock(dim, dim, self.norm_fn, stride=1)
layers = (layer1, layer2, layer3)
self.in_planes = dim
return nn.Sequential(*layers)
def forward(self, image1, image2, dual_inp=False):
# resize image
B, _, H, W = image1.shape
image1 = resize_tensor(image1)
image2 = resize_tensor(image2)
# data format for MaSt3R
_, _, H1, W1 = image1.shape
view1 = dict(img=image1,
true_shape=torch.tensor([[H1,W1]], dtype=torch.int32).to(image1.device),
idx=B, instance=str(B))
view2 = dict(img=image2,
true_shape=torch.tensor([[H1,W1]], dtype=torch.int32).to(image1.device),
idx=B, instance=str(B))
# mast3r
with torch.no_grad():
pred1, pred2 = self.mast3r(view1, view2)
# fetch features
## 3, 1, 24, 1
fea1 = [pred1['pts3d'], pred1['conf'][...,None], pred1['desc'], pred1['desc_conf'][...,None]]
fea1 = torch.cat(fea1, dim=-1).permute((0,3,1,2))
fea1 = torch.cat([image1, fea1], dim=1) # 32
fea2 = [pred2['pts3d_in_other_view'], pred2['conf'][...,None], pred2['desc'], pred2['desc_conf'][...,None]]
fea2 = torch.cat(fea2, dim=-1).permute((0,3,1,2))
fea2 = torch.cat([image2, fea2], dim=1) # 32
x = torch.cat([fea1,fea2], dim=0) # 32
# resize image
x = resize_to_quarter(x, (H,W), 2**self.downsample)
# conv
x = self.layer1(x)
x = self.layer2(x)
x = self.conv(x)
x = x.split(split_size=B, dim=0)
return x