import torch import torch.nn as nn import torch.nn.functional as F from .depth_anything.dpt import DepthAnything from PIL import Image import os import os.path as osp import matplotlib.pyplot as plt import numpy as np class DepthAnythingWrapper(nn.Module): """ A wrapper module for DepthAnything model with frozen parameters. It normalizes input using UniMatch-style normalization, resizes to multiples of 14, performs depth prediction, and resizes output back to original resolution. Args: encoder (str): one of 'vitl', 'vitb', 'vits'. checkpoint_path (str): path to the pretrained checkpoint (.pth file). """ def __init__(self, encoder: str, checkpoint_path: str): super().__init__() # model configurations model_configs = { 'vitl': {'encoder': 'vitl', 'features': 256, 'out_channels': [256, 512, 1024, 1024]}, 'vitb': {'encoder': 'vitb', 'features': 128, 'out_channels': [96, 192, 384, 768]}, 'vits': {'encoder': 'vits', 'features': 64, 'out_channels': [48, 96, 192, 384]} } assert encoder in model_configs, f"Unsupported encoder: {encoder}" # load DepthAnything model self.depth_model = DepthAnything(model_configs[encoder]) state_dict = torch.load(checkpoint_path, map_location='cpu') self.depth_model.load_state_dict(state_dict) self.depth_model.eval() # freeze parameters for param in self.depth_model.parameters(): param.requires_grad = False def normalize_images(self, images: torch.Tensor) -> torch.Tensor: """ Normalize images to match the pretrained UniMatch model. Args: images (torch.Tensor): input tensor of shape (B, V, C, H, W) or (B, C, H, W) Returns: torch.Tensor: normalized tensor with same shape as input. """ # Determine extra dims before channel dim extras = images.dim() - 3 shape = [1] * extras + [3, 1, 1] mean = torch.tensor([0.485, 0.456, 0.406], device=images.device).reshape(*shape) std = torch.tensor([0.229, 0.224, 0.225], device=images.device).reshape(*shape) return (images - mean) / std def forward(self, x: torch.Tensor) -> torch.Tensor: """ Forward pass of the wrapper. Args: x (torch.Tensor): input tensor of shape (B, C, H, W) or (B, V, C, H, W), values in [0, 1]. Returns: torch.Tensor: depth map of shape matching spatial dims of input, with channel=1. """ # apply UniMatch-style normalization x = self.normalize_images(x) # if extra view dimension, merge merged = x merged_shape = merged.shape if merged.dim() == 5: B, V, C, H, W = merged_shape merged = merged.view(B * V, C, H, W) else: B, C, H, W = merged_shape # compute multiples of 14 h14 = (H // 14) * 14 w14 = (W // 14) * 14 # resize to multiples of 14 x_resized = F.interpolate(merged, size=(h14, w14), mode='bilinear', align_corners=True) # depth prediction with torch.no_grad(): d = self.depth_model(x_resized) if d.dim() == 3: d = d.unsqueeze(1) # resize back to original d_final = F.interpolate(d, size=(H, W), mode='bilinear', align_corners=True) # normalize and invert depth depth_min = d_final.min() depth_max = d_final.max() #镜像 d = depth_min + depth_max - d_final # 7. Clip to physical range [0.5, 200] meters d = torch.clamp(d, 0.5, 200.0) # un-merge view dimension if merged.dim() == 5: d = d.view(B, V, 1, H, W) return d def visualize_depth(depth, output_path=None, cmap='viridis', vmin=None, vmax=None): """ 可视化深度图 参数: depth: numpy数组 - 深度图数据 (H, W) output_path: str - 保存路径 (如果不提供则显示但不保存) cmap: str - 使用的色彩映射 ('jet', 'viridis', 'plasma'等) vmin, vmax: float - 颜色映射范围 """ # 确保深度图为2D数组 if depth.ndim > 2: depth = depth.squeeze() plt.figure(figsize=(10, 8)) # 设置颜色范围 if vmin is None: vmin = np.nanmin(depth) if vmax is None: vmax = np.nanmax(depth) # 创建伪彩色图 plt.imshow(depth, cmap=cmap, vmin=vmin, vmax=vmax) # 添加颜色条 cbar = plt.colorbar() cbar.set_label('Depth (meters)', fontsize=12) # 设置标题和标签 plt.title('Depth Map Visualization', fontsize=14) plt.axis('off') # 不显示坐标轴 # 保存或显示 if output_path: os.makedirs(osp.dirname(output_path), exist_ok=True) plt.savefig(output_path, bbox_inches='tight', dpi=200) plt.close() print(f"深度图已保存至: {output_path}") else: plt.show() def test_depth_wrapper(): # 配置 encoder = 'vitb' checkpoint_path = 'pretrained/pretrained_weights/depth_anything_v2_vitb.pth' device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') image_path = "/mnt/pfs/users/chaojun.ni/wangweijie_mnt/yeqing/BEV-Splat/checkpoints/dl3dv-256x448-depthsplat-base-randview2-6/images/dl3dv_14eb48a50e37df548894ab6d8cd628a21dae14bbe6c462e894616fc5962e6c49/color/000028_gt.png" # 实例化并移动到 device model = DepthAnythingWrapper(encoder, checkpoint_path).to(device) model.eval() # 1. 加载图像 image = Image.open(image_path).convert('RGB') original_size = image.size # 保存原始尺寸用于后续比较 # 手动实现预处理 - 替换 torchvision.transforms # 1. 调整大小 target_size = (448, 256) resized_image = image.resize(target_size, Image.BILINEAR) # 2. 转换为numpy数组并归一化到[0,1] image_array = np.array(resized_image).astype(np.float32) / 255.0 # 3. 转换为PyTorch张量并调整维度顺序 (H, W, C) -> (C, H, W) image_tensor = torch.from_numpy(image_array).permute(2, 0, 1) # 4. 添加批次维度 x = image_tensor.unsqueeze(0).to(device) print(f"输入张量尺寸: {x.shape} (设备: {x.device})") # 前向推理 with torch.no_grad(): depth = model(x) # 检查输出 print(f"输入形状: {x.shape}") print(f"输出形状: {depth.shape} (应为 [B,1,H,W])") # 打印深度图统计信息 depth_np = depth.squeeze().cpu().numpy() # 移除批次和通道维度 -> (H, W) print("\n深度图统计:") print(f"最小值: {depth_np.min():.4f} m, 最大值: {depth_np.max():.4f} m") print(f"均值: {depth_np.mean():.4f} m, 标准差: {depth_np.std():.4f}") # =================== 新增深度图可视化 =================== # 1. 保存原始图像 input_dir = './visualization' os.makedirs(input_dir, exist_ok=True) input_image_path = osp.join(input_dir, 'input_image.jpg') image.save(input_image_path) print(f"原始图像已保存至: {input_image_path}") # 2. 保存灰度深度图 # 归一化深度图到0-255范围 depth_min = depth_np.min() depth_max = depth_np.max() depth_normalized = (depth_np - depth_min) / (depth_max - depth_min + 1e-6) depth_grayscale = (depth_normalized * 255).astype(np.uint8) depth_pil = Image.fromarray(depth_grayscale) # 恢复原始尺寸(如果需要) depth_pil = depth_pil.resize(original_size, Image.NEAREST) # 保存灰度图 depth_grayscale_path = osp.join(input_dir, 'depth_grayscale.jpg') depth_pil.save(depth_grayscale_path) print(f"灰度深度图已保存至: {depth_grayscale_path}") # 3. 创建伪彩色深度图 (Jet色彩映射) depth_colored_path = osp.join(input_dir, 'depth_colormap.png') # 创建无坐标轴的伪彩色图 plt.figure(figsize=(10, 8)) plt.imshow(depth_np, cmap='jet', vmin=depth_np.min(), vmax=depth_np.max()) plt.axis('off') # 关闭坐标轴 plt.subplots_adjust(left=0, right=1, top=1, bottom=0) # 移除所有边距 plt.savefig(depth_colored_path, bbox_inches='tight', pad_inches=0, dpi=200) plt.close() print(f"伪彩色深度图已保存至: {depth_colored_path}") # 4. 显示深度图与原始图像的对比 fig, axes = plt.subplots(1, 2, figsize=(15, 6)) # 原始图像 axes[0].imshow(image) axes[0].set_title('Original Image', fontsize=12) axes[0].axis('off') # 伪彩色深度图 depth_colored = plt.imread(depth_colored_path) axes[1].imshow(depth_colored) axes[1].set_title('Depth Map (Jet Colormap)', fontsize=12) axes[1].axis('off') # 保存对比图 comparison_path = osp.join(input_dir, 'depth_comparison.png') plt.savefig(comparison_path, bbox_inches='tight', dpi=150) plt.close() print(f"图像对比图已保存至: {comparison_path}") print("\n可视化已完成,结果保存在:", input_dir) if __name__ == '__main__': test_depth_wrapper() # from .depth_anything.dpt import DepthAnything # import torch # import torch.nn.functional as F # model_configs = { # 'vitl': {'encoder': 'vitl', 'features': 256, 'out_channels': [256, 512, 1024, 1024]}, # 'vitb': {'encoder': 'vitb', 'features': 128, 'out_channels': [96, 192, 384, 768]}, # 'vits': {'encoder': 'vits', 'features': 64, 'out_channels': [48, 96, 192, 384]} # } # encoder = 'vitb' # checkpoint_path = f'/mnt/pfs/users/wangweijie/yeqing/BEV-Splat/src/depth_anything/checkpoints/depth_anything_{encoder}14.pth' # # 1. 创建模型并加载权重 # depth_anything = DepthAnything(model_configs[encoder]) # state_dict = torch.load(checkpoint_path, map_location='cpu') # depth_anything.load_state_dict(state_dict) # depth_anything.eval() # # 2. 构造随机输入并做标准化 # test_input = torch.randn(6, 256, 448, 3).float() # (B, H, W, C) # test_input = test_input.permute(0, 3, 1, 2) # -> (B, C, H, W) # test_input = test_input / 255.0 # mean = torch.tensor([0.485, 0.456, 0.406]).view(1,3,1,1) # std = torch.tensor([0.229, 0.224, 0.225]).view(1,3,1,1) # test_input = (test_input - mean) / std # # 3. GPU 加速 # device = torch.device("cuda" if torch.cuda.is_available() else "cpu") # depth_anything = depth_anything.to(device) # test_input = test_input.to(device) # # —— 新增部分 —— # # # 4. 记录原始 H/W # ori_h, ori_w = test_input.shape[-2:] # 256, 448 # # 5. 计算向下取整到 14 的倍数 # resize_h = (ori_h // 14) * 14 # 252 # resize_w = (ori_w // 14) * 14 # 448 # # 6. 对输入做插值 # test_input_resized = F.interpolate( # test_input, # size=(resize_h, resize_w), # mode='bilinear', # align_corners=True # ) # # 7. 推理 # with torch.no_grad(): # depth_pred_resized = depth_anything(test_input_resized) # (6,1,252,448) # if depth_pred_resized.dim() == 3: # depth_pred_resized = depth_pred_resized.unsqueeze(1) # -> (B,1,H,W) # # 8. 将预测结果插值回原始大小 # depth_pred = F.interpolate( # depth_pred_resized, # size=(ori_h, ori_w), # mode='bilinear', # align_corners=True # ) # (6,1,256,448) # # —— 恢复后续流程 —— # # print("预测深度图形状:", depth_pred.shape) # (6,1,256,448) # depth_maps = depth_pred.squeeze(1).cpu().numpy() # (6,256,448) # print("第一幅深度图统计:") # print(f"最小值: {depth_maps[0].min():.4f}, 最大值: {depth_maps[0].max():.4f}") # print(f"均值: {depth_maps[0].mean():.4f}, 标准差: {depth_maps[0].std():.4f}")