BFZD233
initial
f06f310
import os
import sys
import numpy as np
sys.path.insert(0,'Depth-Anything-V2')
import torch
import torch.nn as nn
import torch.nn.functional as F
import torchvision.transforms as T
from core.extractor import ResidualBlock
from depth_anything_v2.dpt import DepthAnythingV2
from core.utils.utils import sv_intermediate_results
from huggingface_hub import hf_hub_download
def resize_tensor(tensor, target_size=512, ratio=16):
# 获取输入 tensor 的尺寸 (B, C, H, W)
_, _, H, W = tensor.shape
# 计算 H 和 W 中较长的一边
if H > W:
new_H = target_size
new_W = int(W * (target_size / H))
else:
new_W = target_size
new_H = int(H * (target_size / W))
new_W = (np.ceil(new_W / ratio) * ratio).astype(int)
new_H = (np.ceil(new_H / ratio) * ratio).astype(int)
# 使用 interpolate 进行缩放
resized_tensor = F.interpolate(tensor, size=(new_H, new_W), mode='bicubic', align_corners=False)
return resized_tensor
def resize_to_quarter(tensor, original_size, ratio):
# 将尺寸缩小为原始尺寸的 1/4
quarter_H = original_size[0] // ratio
quarter_W = original_size[1] // ratio
# 使用 interpolate 进行缩小
resized_tensor = F.interpolate(tensor, size=(quarter_H, quarter_W), mode='bilinear', align_corners=False)
return resized_tensor
class DepthAnyExtractor(nn.Module):
def __init__(self, model_dir, output_dim=[128], norm_fn='batch', downsample=2, args=None):
super(DepthAnyExtractor, self).__init__()
self.args = args
self.norm_fn = norm_fn
self.downsample = downsample
output_list = []
for dim in output_dim:
conv_out = nn.Sequential(
ResidualBlock(128, 128, self.norm_fn, stride=1),
nn.Conv2d(128, dim[2], 3, padding=1))
output_list.append(conv_out)
self.outputs08 = nn.ModuleList(output_list)
output_list = []
for dim in output_dim:
conv_out = nn.Sequential(
ResidualBlock(128, 128, self.norm_fn, stride=1),
nn.Conv2d(128, dim[1], 3, padding=1))
output_list.append(conv_out)
self.outputs16 = nn.ModuleList(output_list)
output_list = []
for dim in output_dim:
conv_out = nn.Conv2d(128, dim[0], 3, padding=1)
output_list.append(conv_out)
self.outputs32 = nn.ModuleList(output_list)
self.layer1 = nn.Sequential(
nn.Conv2d(128, 256, kernel_size=3, stride=1, padding=1),
nn.ReLU(inplace=True),
nn.Conv2d(256, 128, kernel_size=3, stride=1, padding=1),
nn.ReLU(inplace=True),
)
self.in_planes = 128
self.layer2 = self._make_layer(128, stride=2)
self.layer3 = self._make_layer(128, stride=2)
# self._init_weights()
model_configs = {
'vits': {'encoder': 'vits', 'features': 64, 'out_channels': [48, 96, 192, 384]},
'vitb': {'encoder': 'vitb', 'features': 128, 'out_channels': [96, 192, 384, 768]},
'vitl': {'encoder': 'vitl', 'features': 256, 'out_channels': [256, 512, 1024, 1024]},
'vitg': {'encoder': 'vitg', 'features': 384, 'out_channels': [1536, 1536, 1536, 1536]}
}
encoder = "vitl"
depth_anything = DepthAnythingV2(**model_configs[encoder])
checkpoint_path = hf_hub_download(
repo_id="BFZD/Diving-into-the-Fusion-of-Monocular-Priors-for-Generalized-Stereo-Matching",
filename="dav2_models/depth_anything_v2_vitl.pth",
)
# depth_anything.load_state_dict(torch.load(os.path.join(model_dir, f'depth_anything_v2_{encoder}.pth'),
# map_location='cpu'))
depth_anything.load_state_dict(torch.load(checkpoint_path,map_location='cpu'))
# self.depth_anything = depth_anything.to('cuda')
self.depth_anything = depth_anything
mean = [0.485, 0.456, 0.406]
std = [0.229, 0.224, 0.225]
self.mean = torch.tensor(mean).view(1, 3, 1, 1).cuda()
self.std = torch.tensor(std).view(1, 3, 1, 1).cuda()
# 冻结 depth_anything 模型的所有参数
for param in self.depth_anything.parameters():
param.requires_grad = False
# def _init_weights(self):
# for m in self.modules():
# if isinstance(m, nn.Conv2d):
# nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu')
# elif isinstance(m, (nn.BatchNorm2d, nn.InstanceNorm2d, nn.GroupNorm)):
# if m.weight is not None:
# nn.init.constant_(m.weight, 1)
# if m.bias is not None:
# nn.init.constant_(m.bias, 0)
def _make_layer(self, dim, stride=1):
layer1 = ResidualBlock(self.in_planes, dim, self.norm_fn, stride=stride)
layer2 = ResidualBlock(dim, dim, self.norm_fn, stride=1)
layers = (layer1, layer2)
self.in_planes = dim
return nn.Sequential(*layers)
def forward(self, image, dual_inp=False, num_layers=3):
# resize image
B, _, H, W = image.shape
img = resize_tensor(image, target_size=518, ratio=14)
# normalization
img = ((img+1)/2 - self.mean) / self.std
# DepthAnything
with torch.no_grad():
# out_depth: [1, 1, 518, 756]
# out_fea: [1, 128, 296, 432]
depth, depth_fea = self.depth_anything(img)
# resize image
# [1, 128, H//4, W//4]
depth = resize_to_quarter(depth, (H,W), 2**self.downsample)
x = resize_to_quarter(depth_fea, (H,W), 2**self.downsample)
if self.args is not None and hasattr(self.args, "vis_inter") and self.args.vis_inter:
sv_intermediate_results(x, "depthAnything_features", self.args.sv_root)
x = self.layer1(x)
outputs08 = [f(x) for f in self.outputs08]
if num_layers == 1:
return (outputs08, v) if dual_inp else (outputs08,)
# [1, 128, H//8, W//8]
y = self.layer2(x)
outputs16 = [f(y) for f in self.outputs16]
if num_layers == 2:
return (outputs08, outputs16, v) if dual_inp else (outputs08, outputs16)
# [1, 128, H//16, W//16]
z = self.layer3(y)
outputs32 = [f(z) for f in self.outputs32]
return (outputs08, outputs16, outputs32), depth
class DepthMatchExtractor(nn.Module):
def __init__(self, model_dir, output_dim=256, norm_fn='batch', downsample=2):
super(DepthMatchExtractor, self).__init__()
self.norm_fn = norm_fn
self.downsample = downsample
self.layer1 = nn.Sequential(
nn.Conv2d(128, 256, kernel_size=3, stride=1, padding=1),
nn.ReLU(inplace=True),
nn.Conv2d(256, 128, kernel_size=3, stride=1, padding=1),
nn.ReLU(inplace=True),
)
self.in_planes = 128
self.layer2 = self._make_layer(128, stride=1)
self.conv = nn.Conv2d(128, output_dim, kernel_size=1)
# self._init_weights()
model_configs = {
'vits': {'encoder': 'vits', 'features': 64, 'out_channels': [48, 96, 192, 384]},
'vitb': {'encoder': 'vitb', 'features': 128, 'out_channels': [96, 192, 384, 768]},
'vitl': {'encoder': 'vitl', 'features': 256, 'out_channels': [256, 512, 1024, 1024]},
'vitg': {'encoder': 'vitg', 'features': 384, 'out_channels': [1536, 1536, 1536, 1536]}
}
encoder = "vitl"
depth_anything = DepthAnythingV2(**model_configs[encoder])
checkpoint_path = hf_hub_download(
repo_id="BFZD/Diving-into-the-Fusion-of-Monocular-Priors-for-Generalized-Stereo-Matching",
filename="dav2_models/depth_anything_v2_vitl.pth",
)
# depth_anything.load_state_dict(torch.load(os.path.join(model_dir, f'depth_anything_v2_{encoder}.pth'),
# map_location='cpu'))
depth_anything.load_state_dict(torch.load(checkpoint_path,map_location='cpu'))
self.depth_anything = depth_anything.to('cuda')
mean = [0.485, 0.456, 0.406]
std = [0.229, 0.224, 0.225]
self.mean = torch.tensor(mean).view(1, 3, 1, 1).cuda()
self.std = torch.tensor(std).view(1, 3, 1, 1).cuda()
# 冻结 depth_anything 模型的所有参数
for param in self.depth_anything.parameters():
param.requires_grad = False
# def _init_weights(self):
# for m in self.modules():
# if isinstance(m, nn.Conv2d):
# nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu')
# elif isinstance(m, (nn.BatchNorm2d, nn.InstanceNorm2d, nn.GroupNorm)):
# if m.weight is not None:
# nn.init.constant_(m.weight, 1)
# if m.bias is not None:
# nn.init.constant_(m.bias, 0)
def _make_layer(self, dim, stride=1):
layer1 = ResidualBlock(self.in_planes, dim, self.norm_fn, stride=stride)
layer2 = ResidualBlock(dim, dim, self.norm_fn, stride=1)
layers = (layer1, layer2)
self.in_planes = dim
return nn.Sequential(*layers)
def forward(self, x, dual_inp=False, num_layers=3):
# if input is list, combine batch dimension
is_list = isinstance(x, tuple) or isinstance(x, list)
if is_list:
batch_dim = x[0].shape[0]
x = torch.cat(x, dim=0)
# resize image
B, _, H, W = x.shape
x = resize_tensor(x, target_size=518, ratio=14)
# normalization
x = ((x+1)/2 - self.mean) / self.std
# DepthAnything
with torch.no_grad():
# out_depth: [1, 1, 518, 756]
# out_fea: [1, 128, 296, 432]
depth, depth_fea = self.depth_anything(x)
# resize image
# [1, 128, H//4, W//4]
x = resize_to_quarter(depth_fea, (H,W), 2**self.downsample)
x = self.layer1(x)
x = self.layer2(x)
x = self.conv(x)
if is_list:
x = x.split(split_size=batch_dim, dim=0)
return x