File size: 11,760 Bytes
a6dd040
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
import torch
import torch.nn as nn
import torch.nn.functional as F
from .depth_anything.dpt import DepthAnything
from PIL import Image
import os
import os.path as osp
import matplotlib.pyplot as plt
import numpy as np

class DepthAnythingWrapper(nn.Module):
    """
    A wrapper module for DepthAnything model with frozen parameters.
    It normalizes input using UniMatch-style normalization, resizes to multiples of 14,
    performs depth prediction, and resizes output back to original resolution.

    Args:
        encoder (str): one of 'vitl', 'vitb', 'vits'.
        checkpoint_path (str): path to the pretrained checkpoint (.pth file).
    """
    def __init__(self, encoder: str, checkpoint_path: str):
        super().__init__()
        # model configurations
        model_configs = {
            'vitl': {'encoder': 'vitl', 'features': 256, 'out_channels': [256, 512, 1024, 1024]},
            'vitb': {'encoder': 'vitb', 'features': 128, 'out_channels': [96, 192, 384, 768]},
            'vits': {'encoder': 'vits', 'features': 64, 'out_channels': [48, 96, 192, 384]}
        }
        assert encoder in model_configs, f"Unsupported encoder: {encoder}"

        # load DepthAnything model
        self.depth_model = DepthAnything(model_configs[encoder])
        state_dict = torch.load(checkpoint_path, map_location='cpu')
        self.depth_model.load_state_dict(state_dict)
        self.depth_model.eval()

        # freeze parameters
        for param in self.depth_model.parameters():
            param.requires_grad = False

    def normalize_images(self, images: torch.Tensor) -> torch.Tensor:
        """
        Normalize images to match the pretrained UniMatch model.

        Args:
            images (torch.Tensor): input tensor of shape (B, V, C, H, W) or (B, C, H, W)
        Returns:
            torch.Tensor: normalized tensor with same shape as input.
        """
        # Determine extra dims before channel dim
        extras = images.dim() - 3
        shape = [1] * extras + [3, 1, 1]
        mean = torch.tensor([0.485, 0.456, 0.406], device=images.device).reshape(*shape)
        std  = torch.tensor([0.229, 0.224, 0.225], device=images.device).reshape(*shape)
        return (images - mean) / std
    

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        """
        Forward pass of the wrapper.

        Args:
            x (torch.Tensor): input tensor of shape (B, C, H, W) or (B, V, C, H, W), values in [0, 1].
        Returns:
            torch.Tensor: depth map of shape matching spatial dims of input, with channel=1.
        """
        # apply UniMatch-style normalization
        x = self.normalize_images(x)

        # if extra view dimension, merge
        merged = x
        merged_shape = merged.shape
        if merged.dim() == 5:
            B, V, C, H, W = merged_shape
            merged = merged.view(B * V, C, H, W)
        else:
            B, C, H, W = merged_shape

        # compute multiples of 14
        h14 = (H // 14) * 14
        w14 = (W // 14) * 14
        # resize to multiples of 14
        x_resized = F.interpolate(merged, size=(h14, w14), mode='bilinear', align_corners=True)

        # depth prediction
        with torch.no_grad():
            d = self.depth_model(x_resized)
            if d.dim() == 3:
                d = d.unsqueeze(1)
        # resize back to original
        d_final = F.interpolate(d, size=(H, W), mode='bilinear', align_corners=True)

        # normalize and invert depth
        depth_min = d_final.min()
        depth_max = d_final.max()
        
        #镜像
        d = depth_min + depth_max - d_final
        
        # 7. Clip to physical range [0.5, 200] meters
        d = torch.clamp(d, 0.5, 200.0)
        
        # un-merge view dimension
        if merged.dim() == 5:
            d = d.view(B, V, 1, H, W)
        return d


def visualize_depth(depth, output_path=None, cmap='viridis', vmin=None, vmax=None):
    """
    可视化深度图
    
    参数:
    depth: numpy数组 - 深度图数据 (H, W)
    output_path: str - 保存路径 (如果不提供则显示但不保存)
    cmap: str - 使用的色彩映射 ('jet', 'viridis', 'plasma'等)
    vmin, vmax: float - 颜色映射范围
    """
    # 确保深度图为2D数组
    if depth.ndim > 2:
        depth = depth.squeeze()
    
    plt.figure(figsize=(10, 8))
    
    # 设置颜色范围
    if vmin is None: vmin = np.nanmin(depth)
    if vmax is None: vmax = np.nanmax(depth)
    
    # 创建伪彩色图
    plt.imshow(depth, cmap=cmap, vmin=vmin, vmax=vmax)
    
    # 添加颜色条
    cbar = plt.colorbar()
    cbar.set_label('Depth (meters)', fontsize=12)
    
    # 设置标题和标签
    plt.title('Depth Map Visualization', fontsize=14)
    plt.axis('off')  # 不显示坐标轴
    
    # 保存或显示
    if output_path:
        os.makedirs(osp.dirname(output_path), exist_ok=True)
        plt.savefig(output_path, bbox_inches='tight', dpi=200)
        plt.close()
        print(f"深度图已保存至: {output_path}")
    else:
        plt.show()

def test_depth_wrapper():
    # 配置
    encoder = 'vitb'
    checkpoint_path = 'pretrained/pretrained_weights/depth_anything_v2_vitb.pth'
    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
    image_path = "/mnt/pfs/users/chaojun.ni/wangweijie_mnt/yeqing/BEV-Splat/checkpoints/dl3dv-256x448-depthsplat-base-randview2-6/images/dl3dv_14eb48a50e37df548894ab6d8cd628a21dae14bbe6c462e894616fc5962e6c49/color/000028_gt.png"
    
    # 实例化并移动到 device
    model = DepthAnythingWrapper(encoder, checkpoint_path).to(device)
    model.eval()

    # 1. 加载图像
    image = Image.open(image_path).convert('RGB')
    original_size = image.size  # 保存原始尺寸用于后续比较
    
    # 手动实现预处理 - 替换 torchvision.transforms
    # 1. 调整大小
    target_size = (448, 256)
    resized_image = image.resize(target_size, Image.BILINEAR)
    
    # 2. 转换为numpy数组并归一化到[0,1]
    image_array = np.array(resized_image).astype(np.float32) / 255.0
    
    # 3. 转换为PyTorch张量并调整维度顺序 (H, W, C) -> (C, H, W)
    image_tensor = torch.from_numpy(image_array).permute(2, 0, 1)
    
    # 4. 添加批次维度
    x = image_tensor.unsqueeze(0).to(device)
    
    print(f"输入张量尺寸: {x.shape} (设备: {x.device})")
    
    # 前向推理
    with torch.no_grad():
        depth = model(x)

    # 检查输出
    print(f"输入形状:  {x.shape}")
    print(f"输出形状:  {depth.shape}  (应为 [B,1,H,W])")

    # 打印深度图统计信息
    depth_np = depth.squeeze().cpu().numpy()  # 移除批次和通道维度 -> (H, W)
    
    print("\n深度图统计:")
    print(f"最小值: {depth_np.min():.4f} m, 最大值: {depth_np.max():.4f} m")
    print(f"均值:   {depth_np.mean():.4f} m, 标准差: {depth_np.std():.4f}")
    
    # =================== 新增深度图可视化 ===================
    # 1. 保存原始图像
    input_dir = './visualization'
    os.makedirs(input_dir, exist_ok=True)
    input_image_path = osp.join(input_dir, 'input_image.jpg')
    image.save(input_image_path)
    print(f"原始图像已保存至: {input_image_path}")
    
    # 2. 保存灰度深度图
    # 归一化深度图到0-255范围
    depth_min = depth_np.min()
    depth_max = depth_np.max()
    depth_normalized = (depth_np - depth_min) / (depth_max - depth_min + 1e-6)
    depth_grayscale = (depth_normalized * 255).astype(np.uint8)
    depth_pil = Image.fromarray(depth_grayscale)
    
    # 恢复原始尺寸(如果需要)
    depth_pil = depth_pil.resize(original_size, Image.NEAREST)
    
    # 保存灰度图
    depth_grayscale_path = osp.join(input_dir, 'depth_grayscale.jpg')
    depth_pil.save(depth_grayscale_path)
    print(f"灰度深度图已保存至: {depth_grayscale_path}")
    
   # 3. 创建伪彩色深度图 (Jet色彩映射)
    depth_colored_path = osp.join(input_dir, 'depth_colormap.png')

    # 创建无坐标轴的伪彩色图
    plt.figure(figsize=(10, 8))
    plt.imshow(depth_np, cmap='jet', vmin=depth_np.min(), vmax=depth_np.max())
    plt.axis('off')  # 关闭坐标轴
    plt.subplots_adjust(left=0, right=1, top=1, bottom=0)  # 移除所有边距
    plt.savefig(depth_colored_path, bbox_inches='tight', pad_inches=0, dpi=200)
    plt.close()
    print(f"伪彩色深度图已保存至: {depth_colored_path}")
    
    # 4. 显示深度图与原始图像的对比
    fig, axes = plt.subplots(1, 2, figsize=(15, 6))
    
    # 原始图像
    axes[0].imshow(image)
    axes[0].set_title('Original Image', fontsize=12)
    axes[0].axis('off')
    
    # 伪彩色深度图
    depth_colored = plt.imread(depth_colored_path)
    axes[1].imshow(depth_colored)
    axes[1].set_title('Depth Map (Jet Colormap)', fontsize=12)
    axes[1].axis('off')
    
    # 保存对比图
    comparison_path = osp.join(input_dir, 'depth_comparison.png')
    plt.savefig(comparison_path, bbox_inches='tight', dpi=150)
    plt.close()
    
    print(f"图像对比图已保存至: {comparison_path}")
    print("\n可视化已完成,结果保存在:", input_dir)

if __name__ == '__main__':
    test_depth_wrapper()


# from .depth_anything.dpt import DepthAnything
# import torch
# import torch.nn.functional as F

# model_configs = {
#     'vitl': {'encoder': 'vitl', 'features': 256, 'out_channels': [256, 512, 1024, 1024]},
#     'vitb': {'encoder': 'vitb', 'features': 128, 'out_channels': [96, 192, 384, 768]},
#     'vits': {'encoder': 'vits', 'features': 64, 'out_channels': [48, 96, 192, 384]}
# }

# encoder = 'vitb'
# checkpoint_path = f'/mnt/pfs/users/wangweijie/yeqing/BEV-Splat/src/depth_anything/checkpoints/depth_anything_{encoder}14.pth'

# # 1. 创建模型并加载权重
# depth_anything = DepthAnything(model_configs[encoder])
# state_dict = torch.load(checkpoint_path, map_location='cpu')
# depth_anything.load_state_dict(state_dict)
# depth_anything.eval()

# # 2. 构造随机输入并做标准化
# test_input = torch.randn(6, 256, 448, 3).float()  # (B, H, W, C)
# test_input = test_input.permute(0, 3, 1, 2)        # -> (B, C, H, W)
# test_input = test_input / 255.0
# mean = torch.tensor([0.485, 0.456, 0.406]).view(1,3,1,1)
# std  = torch.tensor([0.229, 0.224, 0.225]).view(1,3,1,1)
# test_input = (test_input - mean) / std

# # 3. GPU 加速
# device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
# depth_anything = depth_anything.to(device)
# test_input = test_input.to(device)

# # —— 新增部分 —— #
# # 4. 记录原始 H/W
# ori_h, ori_w = test_input.shape[-2:]          # 256, 448

# # 5. 计算向下取整到 14 的倍数
# resize_h = (ori_h // 14) * 14  # 252
# resize_w = (ori_w // 14) * 14  # 448

# # 6. 对输入做插值
# test_input_resized = F.interpolate(
#     test_input,
#     size=(resize_h, resize_w),
#     mode='bilinear',
#     align_corners=True
# )

# # 7. 推理
# with torch.no_grad():
#     depth_pred_resized = depth_anything(test_input_resized)  # (6,1,252,448)
#     if depth_pred_resized.dim() == 3:
#         depth_pred_resized = depth_pred_resized.unsqueeze(1)  # -> (B,1,H,W)

# # 8. 将预测结果插值回原始大小
# depth_pred = F.interpolate(
#     depth_pred_resized,
#     size=(ori_h, ori_w),
#     mode='bilinear',
#     align_corners=True
# )  # (6,1,256,448)

# # —— 恢复后续流程 —— #
# print("预测深度图形状:", depth_pred.shape)  # (6,1,256,448)

# depth_maps = depth_pred.squeeze(1).cpu().numpy()  # (6,256,448)
# print("第一幅深度图统计:")
# print(f"最小值: {depth_maps[0].min():.4f}, 最大值: {depth_maps[0].max():.4f}")
# print(f"均值:   {depth_maps[0].mean():.4f}, 标准差: {depth_maps[0].std():.4f}")