depthsplat / src /test /try_depthanything.py
Yeqing0814's picture
Upload folder using huggingface_hub
a6dd040 verified
import torch
import torch.nn as nn
import torch.nn.functional as F
from .depth_anything.dpt import DepthAnything
from PIL import Image
import os
import os.path as osp
import matplotlib.pyplot as plt
import numpy as np
class DepthAnythingWrapper(nn.Module):
"""
A wrapper module for DepthAnything model with frozen parameters.
It normalizes input using UniMatch-style normalization, resizes to multiples of 14,
performs depth prediction, and resizes output back to original resolution.
Args:
encoder (str): one of 'vitl', 'vitb', 'vits'.
checkpoint_path (str): path to the pretrained checkpoint (.pth file).
"""
def __init__(self, encoder: str, checkpoint_path: str):
super().__init__()
# model configurations
model_configs = {
'vitl': {'encoder': 'vitl', 'features': 256, 'out_channels': [256, 512, 1024, 1024]},
'vitb': {'encoder': 'vitb', 'features': 128, 'out_channels': [96, 192, 384, 768]},
'vits': {'encoder': 'vits', 'features': 64, 'out_channels': [48, 96, 192, 384]}
}
assert encoder in model_configs, f"Unsupported encoder: {encoder}"
# load DepthAnything model
self.depth_model = DepthAnything(model_configs[encoder])
state_dict = torch.load(checkpoint_path, map_location='cpu')
self.depth_model.load_state_dict(state_dict)
self.depth_model.eval()
# freeze parameters
for param in self.depth_model.parameters():
param.requires_grad = False
def normalize_images(self, images: torch.Tensor) -> torch.Tensor:
"""
Normalize images to match the pretrained UniMatch model.
Args:
images (torch.Tensor): input tensor of shape (B, V, C, H, W) or (B, C, H, W)
Returns:
torch.Tensor: normalized tensor with same shape as input.
"""
# Determine extra dims before channel dim
extras = images.dim() - 3
shape = [1] * extras + [3, 1, 1]
mean = torch.tensor([0.485, 0.456, 0.406], device=images.device).reshape(*shape)
std = torch.tensor([0.229, 0.224, 0.225], device=images.device).reshape(*shape)
return (images - mean) / std
def forward(self, x: torch.Tensor) -> torch.Tensor:
"""
Forward pass of the wrapper.
Args:
x (torch.Tensor): input tensor of shape (B, C, H, W) or (B, V, C, H, W), values in [0, 1].
Returns:
torch.Tensor: depth map of shape matching spatial dims of input, with channel=1.
"""
# apply UniMatch-style normalization
x = self.normalize_images(x)
# if extra view dimension, merge
merged = x
merged_shape = merged.shape
if merged.dim() == 5:
B, V, C, H, W = merged_shape
merged = merged.view(B * V, C, H, W)
else:
B, C, H, W = merged_shape
# compute multiples of 14
h14 = (H // 14) * 14
w14 = (W // 14) * 14
# resize to multiples of 14
x_resized = F.interpolate(merged, size=(h14, w14), mode='bilinear', align_corners=True)
# depth prediction
with torch.no_grad():
d = self.depth_model(x_resized)
if d.dim() == 3:
d = d.unsqueeze(1)
# resize back to original
d_final = F.interpolate(d, size=(H, W), mode='bilinear', align_corners=True)
# normalize and invert depth
depth_min = d_final.min()
depth_max = d_final.max()
#镜像
d = depth_min + depth_max - d_final
# 7. Clip to physical range [0.5, 200] meters
d = torch.clamp(d, 0.5, 200.0)
# un-merge view dimension
if merged.dim() == 5:
d = d.view(B, V, 1, H, W)
return d
def visualize_depth(depth, output_path=None, cmap='viridis', vmin=None, vmax=None):
"""
可视化深度图
参数:
depth: numpy数组 - 深度图数据 (H, W)
output_path: str - 保存路径 (如果不提供则显示但不保存)
cmap: str - 使用的色彩映射 ('jet', 'viridis', 'plasma'等)
vmin, vmax: float - 颜色映射范围
"""
# 确保深度图为2D数组
if depth.ndim > 2:
depth = depth.squeeze()
plt.figure(figsize=(10, 8))
# 设置颜色范围
if vmin is None: vmin = np.nanmin(depth)
if vmax is None: vmax = np.nanmax(depth)
# 创建伪彩色图
plt.imshow(depth, cmap=cmap, vmin=vmin, vmax=vmax)
# 添加颜色条
cbar = plt.colorbar()
cbar.set_label('Depth (meters)', fontsize=12)
# 设置标题和标签
plt.title('Depth Map Visualization', fontsize=14)
plt.axis('off') # 不显示坐标轴
# 保存或显示
if output_path:
os.makedirs(osp.dirname(output_path), exist_ok=True)
plt.savefig(output_path, bbox_inches='tight', dpi=200)
plt.close()
print(f"深度图已保存至: {output_path}")
else:
plt.show()
def test_depth_wrapper():
# 配置
encoder = 'vitb'
checkpoint_path = 'pretrained/pretrained_weights/depth_anything_v2_vitb.pth'
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
image_path = "/mnt/pfs/users/chaojun.ni/wangweijie_mnt/yeqing/BEV-Splat/checkpoints/dl3dv-256x448-depthsplat-base-randview2-6/images/dl3dv_14eb48a50e37df548894ab6d8cd628a21dae14bbe6c462e894616fc5962e6c49/color/000028_gt.png"
# 实例化并移动到 device
model = DepthAnythingWrapper(encoder, checkpoint_path).to(device)
model.eval()
# 1. 加载图像
image = Image.open(image_path).convert('RGB')
original_size = image.size # 保存原始尺寸用于后续比较
# 手动实现预处理 - 替换 torchvision.transforms
# 1. 调整大小
target_size = (448, 256)
resized_image = image.resize(target_size, Image.BILINEAR)
# 2. 转换为numpy数组并归一化到[0,1]
image_array = np.array(resized_image).astype(np.float32) / 255.0
# 3. 转换为PyTorch张量并调整维度顺序 (H, W, C) -> (C, H, W)
image_tensor = torch.from_numpy(image_array).permute(2, 0, 1)
# 4. 添加批次维度
x = image_tensor.unsqueeze(0).to(device)
print(f"输入张量尺寸: {x.shape} (设备: {x.device})")
# 前向推理
with torch.no_grad():
depth = model(x)
# 检查输出
print(f"输入形状: {x.shape}")
print(f"输出形状: {depth.shape} (应为 [B,1,H,W])")
# 打印深度图统计信息
depth_np = depth.squeeze().cpu().numpy() # 移除批次和通道维度 -> (H, W)
print("\n深度图统计:")
print(f"最小值: {depth_np.min():.4f} m, 最大值: {depth_np.max():.4f} m")
print(f"均值: {depth_np.mean():.4f} m, 标准差: {depth_np.std():.4f}")
# =================== 新增深度图可视化 ===================
# 1. 保存原始图像
input_dir = './visualization'
os.makedirs(input_dir, exist_ok=True)
input_image_path = osp.join(input_dir, 'input_image.jpg')
image.save(input_image_path)
print(f"原始图像已保存至: {input_image_path}")
# 2. 保存灰度深度图
# 归一化深度图到0-255范围
depth_min = depth_np.min()
depth_max = depth_np.max()
depth_normalized = (depth_np - depth_min) / (depth_max - depth_min + 1e-6)
depth_grayscale = (depth_normalized * 255).astype(np.uint8)
depth_pil = Image.fromarray(depth_grayscale)
# 恢复原始尺寸(如果需要)
depth_pil = depth_pil.resize(original_size, Image.NEAREST)
# 保存灰度图
depth_grayscale_path = osp.join(input_dir, 'depth_grayscale.jpg')
depth_pil.save(depth_grayscale_path)
print(f"灰度深度图已保存至: {depth_grayscale_path}")
# 3. 创建伪彩色深度图 (Jet色彩映射)
depth_colored_path = osp.join(input_dir, 'depth_colormap.png')
# 创建无坐标轴的伪彩色图
plt.figure(figsize=(10, 8))
plt.imshow(depth_np, cmap='jet', vmin=depth_np.min(), vmax=depth_np.max())
plt.axis('off') # 关闭坐标轴
plt.subplots_adjust(left=0, right=1, top=1, bottom=0) # 移除所有边距
plt.savefig(depth_colored_path, bbox_inches='tight', pad_inches=0, dpi=200)
plt.close()
print(f"伪彩色深度图已保存至: {depth_colored_path}")
# 4. 显示深度图与原始图像的对比
fig, axes = plt.subplots(1, 2, figsize=(15, 6))
# 原始图像
axes[0].imshow(image)
axes[0].set_title('Original Image', fontsize=12)
axes[0].axis('off')
# 伪彩色深度图
depth_colored = plt.imread(depth_colored_path)
axes[1].imshow(depth_colored)
axes[1].set_title('Depth Map (Jet Colormap)', fontsize=12)
axes[1].axis('off')
# 保存对比图
comparison_path = osp.join(input_dir, 'depth_comparison.png')
plt.savefig(comparison_path, bbox_inches='tight', dpi=150)
plt.close()
print(f"图像对比图已保存至: {comparison_path}")
print("\n可视化已完成,结果保存在:", input_dir)
if __name__ == '__main__':
test_depth_wrapper()
# from .depth_anything.dpt import DepthAnything
# import torch
# import torch.nn.functional as F
# model_configs = {
# 'vitl': {'encoder': 'vitl', 'features': 256, 'out_channels': [256, 512, 1024, 1024]},
# 'vitb': {'encoder': 'vitb', 'features': 128, 'out_channels': [96, 192, 384, 768]},
# 'vits': {'encoder': 'vits', 'features': 64, 'out_channels': [48, 96, 192, 384]}
# }
# encoder = 'vitb'
# checkpoint_path = f'/mnt/pfs/users/wangweijie/yeqing/BEV-Splat/src/depth_anything/checkpoints/depth_anything_{encoder}14.pth'
# # 1. 创建模型并加载权重
# depth_anything = DepthAnything(model_configs[encoder])
# state_dict = torch.load(checkpoint_path, map_location='cpu')
# depth_anything.load_state_dict(state_dict)
# depth_anything.eval()
# # 2. 构造随机输入并做标准化
# test_input = torch.randn(6, 256, 448, 3).float() # (B, H, W, C)
# test_input = test_input.permute(0, 3, 1, 2) # -> (B, C, H, W)
# test_input = test_input / 255.0
# mean = torch.tensor([0.485, 0.456, 0.406]).view(1,3,1,1)
# std = torch.tensor([0.229, 0.224, 0.225]).view(1,3,1,1)
# test_input = (test_input - mean) / std
# # 3. GPU 加速
# device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
# depth_anything = depth_anything.to(device)
# test_input = test_input.to(device)
# # —— 新增部分 —— #
# # 4. 记录原始 H/W
# ori_h, ori_w = test_input.shape[-2:] # 256, 448
# # 5. 计算向下取整到 14 的倍数
# resize_h = (ori_h // 14) * 14 # 252
# resize_w = (ori_w // 14) * 14 # 448
# # 6. 对输入做插值
# test_input_resized = F.interpolate(
# test_input,
# size=(resize_h, resize_w),
# mode='bilinear',
# align_corners=True
# )
# # 7. 推理
# with torch.no_grad():
# depth_pred_resized = depth_anything(test_input_resized) # (6,1,252,448)
# if depth_pred_resized.dim() == 3:
# depth_pred_resized = depth_pred_resized.unsqueeze(1) # -> (B,1,H,W)
# # 8. 将预测结果插值回原始大小
# depth_pred = F.interpolate(
# depth_pred_resized,
# size=(ori_h, ori_w),
# mode='bilinear',
# align_corners=True
# ) # (6,1,256,448)
# # —— 恢复后续流程 —— #
# print("预测深度图形状:", depth_pred.shape) # (6,1,256,448)
# depth_maps = depth_pred.squeeze(1).cpu().numpy() # (6,256,448)
# print("第一幅深度图统计:")
# print(f"最小值: {depth_maps[0].min():.4f}, 最大值: {depth_maps[0].max():.4f}")
# print(f"均值: {depth_maps[0].mean():.4f}, 标准差: {depth_maps[0].std():.4f}")