File size: 11,760 Bytes
a6dd040 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 270 271 272 273 274 275 276 277 278 279 280 281 282 283 284 285 286 287 288 289 290 291 292 293 294 295 296 297 298 299 300 301 302 303 304 305 306 307 308 309 310 311 312 313 314 315 316 317 318 319 320 321 322 323 324 325 326 | import torch
import torch.nn as nn
import torch.nn.functional as F
from .depth_anything.dpt import DepthAnything
from PIL import Image
import os
import os.path as osp
import matplotlib.pyplot as plt
import numpy as np
class DepthAnythingWrapper(nn.Module):
"""
A wrapper module for DepthAnything model with frozen parameters.
It normalizes input using UniMatch-style normalization, resizes to multiples of 14,
performs depth prediction, and resizes output back to original resolution.
Args:
encoder (str): one of 'vitl', 'vitb', 'vits'.
checkpoint_path (str): path to the pretrained checkpoint (.pth file).
"""
def __init__(self, encoder: str, checkpoint_path: str):
super().__init__()
# model configurations
model_configs = {
'vitl': {'encoder': 'vitl', 'features': 256, 'out_channels': [256, 512, 1024, 1024]},
'vitb': {'encoder': 'vitb', 'features': 128, 'out_channels': [96, 192, 384, 768]},
'vits': {'encoder': 'vits', 'features': 64, 'out_channels': [48, 96, 192, 384]}
}
assert encoder in model_configs, f"Unsupported encoder: {encoder}"
# load DepthAnything model
self.depth_model = DepthAnything(model_configs[encoder])
state_dict = torch.load(checkpoint_path, map_location='cpu')
self.depth_model.load_state_dict(state_dict)
self.depth_model.eval()
# freeze parameters
for param in self.depth_model.parameters():
param.requires_grad = False
def normalize_images(self, images: torch.Tensor) -> torch.Tensor:
"""
Normalize images to match the pretrained UniMatch model.
Args:
images (torch.Tensor): input tensor of shape (B, V, C, H, W) or (B, C, H, W)
Returns:
torch.Tensor: normalized tensor with same shape as input.
"""
# Determine extra dims before channel dim
extras = images.dim() - 3
shape = [1] * extras + [3, 1, 1]
mean = torch.tensor([0.485, 0.456, 0.406], device=images.device).reshape(*shape)
std = torch.tensor([0.229, 0.224, 0.225], device=images.device).reshape(*shape)
return (images - mean) / std
def forward(self, x: torch.Tensor) -> torch.Tensor:
"""
Forward pass of the wrapper.
Args:
x (torch.Tensor): input tensor of shape (B, C, H, W) or (B, V, C, H, W), values in [0, 1].
Returns:
torch.Tensor: depth map of shape matching spatial dims of input, with channel=1.
"""
# apply UniMatch-style normalization
x = self.normalize_images(x)
# if extra view dimension, merge
merged = x
merged_shape = merged.shape
if merged.dim() == 5:
B, V, C, H, W = merged_shape
merged = merged.view(B * V, C, H, W)
else:
B, C, H, W = merged_shape
# compute multiples of 14
h14 = (H // 14) * 14
w14 = (W // 14) * 14
# resize to multiples of 14
x_resized = F.interpolate(merged, size=(h14, w14), mode='bilinear', align_corners=True)
# depth prediction
with torch.no_grad():
d = self.depth_model(x_resized)
if d.dim() == 3:
d = d.unsqueeze(1)
# resize back to original
d_final = F.interpolate(d, size=(H, W), mode='bilinear', align_corners=True)
# normalize and invert depth
depth_min = d_final.min()
depth_max = d_final.max()
#镜像
d = depth_min + depth_max - d_final
# 7. Clip to physical range [0.5, 200] meters
d = torch.clamp(d, 0.5, 200.0)
# un-merge view dimension
if merged.dim() == 5:
d = d.view(B, V, 1, H, W)
return d
def visualize_depth(depth, output_path=None, cmap='viridis', vmin=None, vmax=None):
"""
可视化深度图
参数:
depth: numpy数组 - 深度图数据 (H, W)
output_path: str - 保存路径 (如果不提供则显示但不保存)
cmap: str - 使用的色彩映射 ('jet', 'viridis', 'plasma'等)
vmin, vmax: float - 颜色映射范围
"""
# 确保深度图为2D数组
if depth.ndim > 2:
depth = depth.squeeze()
plt.figure(figsize=(10, 8))
# 设置颜色范围
if vmin is None: vmin = np.nanmin(depth)
if vmax is None: vmax = np.nanmax(depth)
# 创建伪彩色图
plt.imshow(depth, cmap=cmap, vmin=vmin, vmax=vmax)
# 添加颜色条
cbar = plt.colorbar()
cbar.set_label('Depth (meters)', fontsize=12)
# 设置标题和标签
plt.title('Depth Map Visualization', fontsize=14)
plt.axis('off') # 不显示坐标轴
# 保存或显示
if output_path:
os.makedirs(osp.dirname(output_path), exist_ok=True)
plt.savefig(output_path, bbox_inches='tight', dpi=200)
plt.close()
print(f"深度图已保存至: {output_path}")
else:
plt.show()
def test_depth_wrapper():
# 配置
encoder = 'vitb'
checkpoint_path = 'pretrained/pretrained_weights/depth_anything_v2_vitb.pth'
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
image_path = "/mnt/pfs/users/chaojun.ni/wangweijie_mnt/yeqing/BEV-Splat/checkpoints/dl3dv-256x448-depthsplat-base-randview2-6/images/dl3dv_14eb48a50e37df548894ab6d8cd628a21dae14bbe6c462e894616fc5962e6c49/color/000028_gt.png"
# 实例化并移动到 device
model = DepthAnythingWrapper(encoder, checkpoint_path).to(device)
model.eval()
# 1. 加载图像
image = Image.open(image_path).convert('RGB')
original_size = image.size # 保存原始尺寸用于后续比较
# 手动实现预处理 - 替换 torchvision.transforms
# 1. 调整大小
target_size = (448, 256)
resized_image = image.resize(target_size, Image.BILINEAR)
# 2. 转换为numpy数组并归一化到[0,1]
image_array = np.array(resized_image).astype(np.float32) / 255.0
# 3. 转换为PyTorch张量并调整维度顺序 (H, W, C) -> (C, H, W)
image_tensor = torch.from_numpy(image_array).permute(2, 0, 1)
# 4. 添加批次维度
x = image_tensor.unsqueeze(0).to(device)
print(f"输入张量尺寸: {x.shape} (设备: {x.device})")
# 前向推理
with torch.no_grad():
depth = model(x)
# 检查输出
print(f"输入形状: {x.shape}")
print(f"输出形状: {depth.shape} (应为 [B,1,H,W])")
# 打印深度图统计信息
depth_np = depth.squeeze().cpu().numpy() # 移除批次和通道维度 -> (H, W)
print("\n深度图统计:")
print(f"最小值: {depth_np.min():.4f} m, 最大值: {depth_np.max():.4f} m")
print(f"均值: {depth_np.mean():.4f} m, 标准差: {depth_np.std():.4f}")
# =================== 新增深度图可视化 ===================
# 1. 保存原始图像
input_dir = './visualization'
os.makedirs(input_dir, exist_ok=True)
input_image_path = osp.join(input_dir, 'input_image.jpg')
image.save(input_image_path)
print(f"原始图像已保存至: {input_image_path}")
# 2. 保存灰度深度图
# 归一化深度图到0-255范围
depth_min = depth_np.min()
depth_max = depth_np.max()
depth_normalized = (depth_np - depth_min) / (depth_max - depth_min + 1e-6)
depth_grayscale = (depth_normalized * 255).astype(np.uint8)
depth_pil = Image.fromarray(depth_grayscale)
# 恢复原始尺寸(如果需要)
depth_pil = depth_pil.resize(original_size, Image.NEAREST)
# 保存灰度图
depth_grayscale_path = osp.join(input_dir, 'depth_grayscale.jpg')
depth_pil.save(depth_grayscale_path)
print(f"灰度深度图已保存至: {depth_grayscale_path}")
# 3. 创建伪彩色深度图 (Jet色彩映射)
depth_colored_path = osp.join(input_dir, 'depth_colormap.png')
# 创建无坐标轴的伪彩色图
plt.figure(figsize=(10, 8))
plt.imshow(depth_np, cmap='jet', vmin=depth_np.min(), vmax=depth_np.max())
plt.axis('off') # 关闭坐标轴
plt.subplots_adjust(left=0, right=1, top=1, bottom=0) # 移除所有边距
plt.savefig(depth_colored_path, bbox_inches='tight', pad_inches=0, dpi=200)
plt.close()
print(f"伪彩色深度图已保存至: {depth_colored_path}")
# 4. 显示深度图与原始图像的对比
fig, axes = plt.subplots(1, 2, figsize=(15, 6))
# 原始图像
axes[0].imshow(image)
axes[0].set_title('Original Image', fontsize=12)
axes[0].axis('off')
# 伪彩色深度图
depth_colored = plt.imread(depth_colored_path)
axes[1].imshow(depth_colored)
axes[1].set_title('Depth Map (Jet Colormap)', fontsize=12)
axes[1].axis('off')
# 保存对比图
comparison_path = osp.join(input_dir, 'depth_comparison.png')
plt.savefig(comparison_path, bbox_inches='tight', dpi=150)
plt.close()
print(f"图像对比图已保存至: {comparison_path}")
print("\n可视化已完成,结果保存在:", input_dir)
if __name__ == '__main__':
test_depth_wrapper()
# from .depth_anything.dpt import DepthAnything
# import torch
# import torch.nn.functional as F
# model_configs = {
# 'vitl': {'encoder': 'vitl', 'features': 256, 'out_channels': [256, 512, 1024, 1024]},
# 'vitb': {'encoder': 'vitb', 'features': 128, 'out_channels': [96, 192, 384, 768]},
# 'vits': {'encoder': 'vits', 'features': 64, 'out_channels': [48, 96, 192, 384]}
# }
# encoder = 'vitb'
# checkpoint_path = f'/mnt/pfs/users/wangweijie/yeqing/BEV-Splat/src/depth_anything/checkpoints/depth_anything_{encoder}14.pth'
# # 1. 创建模型并加载权重
# depth_anything = DepthAnything(model_configs[encoder])
# state_dict = torch.load(checkpoint_path, map_location='cpu')
# depth_anything.load_state_dict(state_dict)
# depth_anything.eval()
# # 2. 构造随机输入并做标准化
# test_input = torch.randn(6, 256, 448, 3).float() # (B, H, W, C)
# test_input = test_input.permute(0, 3, 1, 2) # -> (B, C, H, W)
# test_input = test_input / 255.0
# mean = torch.tensor([0.485, 0.456, 0.406]).view(1,3,1,1)
# std = torch.tensor([0.229, 0.224, 0.225]).view(1,3,1,1)
# test_input = (test_input - mean) / std
# # 3. GPU 加速
# device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
# depth_anything = depth_anything.to(device)
# test_input = test_input.to(device)
# # —— 新增部分 —— #
# # 4. 记录原始 H/W
# ori_h, ori_w = test_input.shape[-2:] # 256, 448
# # 5. 计算向下取整到 14 的倍数
# resize_h = (ori_h // 14) * 14 # 252
# resize_w = (ori_w // 14) * 14 # 448
# # 6. 对输入做插值
# test_input_resized = F.interpolate(
# test_input,
# size=(resize_h, resize_w),
# mode='bilinear',
# align_corners=True
# )
# # 7. 推理
# with torch.no_grad():
# depth_pred_resized = depth_anything(test_input_resized) # (6,1,252,448)
# if depth_pred_resized.dim() == 3:
# depth_pred_resized = depth_pred_resized.unsqueeze(1) # -> (B,1,H,W)
# # 8. 将预测结果插值回原始大小
# depth_pred = F.interpolate(
# depth_pred_resized,
# size=(ori_h, ori_w),
# mode='bilinear',
# align_corners=True
# ) # (6,1,256,448)
# # —— 恢复后续流程 —— #
# print("预测深度图形状:", depth_pred.shape) # (6,1,256,448)
# depth_maps = depth_pred.squeeze(1).cpu().numpy() # (6,256,448)
# print("第一幅深度图统计:")
# print(f"最小值: {depth_maps[0].min():.4f}, 最大值: {depth_maps[0].max():.4f}")
# print(f"均值: {depth_maps[0].mean():.4f}, 标准差: {depth_maps[0].std():.4f}")
|