|
|
import torch |
|
|
import torch.nn as nn |
|
|
import torch.nn.functional as F |
|
|
from .depth_anything.dpt import DepthAnything |
|
|
from PIL import Image |
|
|
import os |
|
|
import os.path as osp |
|
|
import matplotlib.pyplot as plt |
|
|
import numpy as np |
|
|
|
|
|
class DepthAnythingWrapper(nn.Module): |
|
|
""" |
|
|
A wrapper module for DepthAnything model with frozen parameters. |
|
|
It normalizes input using UniMatch-style normalization, resizes to multiples of 14, |
|
|
performs depth prediction, and resizes output back to original resolution. |
|
|
|
|
|
Args: |
|
|
encoder (str): one of 'vitl', 'vitb', 'vits'. |
|
|
checkpoint_path (str): path to the pretrained checkpoint (.pth file). |
|
|
""" |
|
|
def __init__(self, encoder: str, checkpoint_path: str): |
|
|
super().__init__() |
|
|
|
|
|
model_configs = { |
|
|
'vitl': {'encoder': 'vitl', 'features': 256, 'out_channels': [256, 512, 1024, 1024]}, |
|
|
'vitb': {'encoder': 'vitb', 'features': 128, 'out_channels': [96, 192, 384, 768]}, |
|
|
'vits': {'encoder': 'vits', 'features': 64, 'out_channels': [48, 96, 192, 384]} |
|
|
} |
|
|
assert encoder in model_configs, f"Unsupported encoder: {encoder}" |
|
|
|
|
|
|
|
|
self.depth_model = DepthAnything(model_configs[encoder]) |
|
|
state_dict = torch.load(checkpoint_path, map_location='cpu') |
|
|
self.depth_model.load_state_dict(state_dict) |
|
|
self.depth_model.eval() |
|
|
|
|
|
|
|
|
for param in self.depth_model.parameters(): |
|
|
param.requires_grad = False |
|
|
|
|
|
def normalize_images(self, images: torch.Tensor) -> torch.Tensor: |
|
|
""" |
|
|
Normalize images to match the pretrained UniMatch model. |
|
|
|
|
|
Args: |
|
|
images (torch.Tensor): input tensor of shape (B, V, C, H, W) or (B, C, H, W) |
|
|
Returns: |
|
|
torch.Tensor: normalized tensor with same shape as input. |
|
|
""" |
|
|
|
|
|
extras = images.dim() - 3 |
|
|
shape = [1] * extras + [3, 1, 1] |
|
|
mean = torch.tensor([0.485, 0.456, 0.406], device=images.device).reshape(*shape) |
|
|
std = torch.tensor([0.229, 0.224, 0.225], device=images.device).reshape(*shape) |
|
|
return (images - mean) / std |
|
|
|
|
|
|
|
|
def forward(self, x: torch.Tensor) -> torch.Tensor: |
|
|
""" |
|
|
Forward pass of the wrapper. |
|
|
|
|
|
Args: |
|
|
x (torch.Tensor): input tensor of shape (B, C, H, W) or (B, V, C, H, W), values in [0, 1]. |
|
|
Returns: |
|
|
torch.Tensor: depth map of shape matching spatial dims of input, with channel=1. |
|
|
""" |
|
|
|
|
|
x = self.normalize_images(x) |
|
|
|
|
|
|
|
|
merged = x |
|
|
merged_shape = merged.shape |
|
|
if merged.dim() == 5: |
|
|
B, V, C, H, W = merged_shape |
|
|
merged = merged.view(B * V, C, H, W) |
|
|
else: |
|
|
B, C, H, W = merged_shape |
|
|
|
|
|
|
|
|
h14 = (H // 14) * 14 |
|
|
w14 = (W // 14) * 14 |
|
|
|
|
|
x_resized = F.interpolate(merged, size=(h14, w14), mode='bilinear', align_corners=True) |
|
|
|
|
|
|
|
|
with torch.no_grad(): |
|
|
d = self.depth_model(x_resized) |
|
|
if d.dim() == 3: |
|
|
d = d.unsqueeze(1) |
|
|
|
|
|
d_final = F.interpolate(d, size=(H, W), mode='bilinear', align_corners=True) |
|
|
|
|
|
|
|
|
depth_min = d_final.min() |
|
|
depth_max = d_final.max() |
|
|
|
|
|
|
|
|
d = depth_min + depth_max - d_final |
|
|
|
|
|
|
|
|
d = torch.clamp(d, 0.5, 200.0) |
|
|
|
|
|
|
|
|
if merged.dim() == 5: |
|
|
d = d.view(B, V, 1, H, W) |
|
|
return d |
|
|
|
|
|
|
|
|
def visualize_depth(depth, output_path=None, cmap='viridis', vmin=None, vmax=None): |
|
|
""" |
|
|
可视化深度图 |
|
|
|
|
|
参数: |
|
|
depth: numpy数组 - 深度图数据 (H, W) |
|
|
output_path: str - 保存路径 (如果不提供则显示但不保存) |
|
|
cmap: str - 使用的色彩映射 ('jet', 'viridis', 'plasma'等) |
|
|
vmin, vmax: float - 颜色映射范围 |
|
|
""" |
|
|
|
|
|
if depth.ndim > 2: |
|
|
depth = depth.squeeze() |
|
|
|
|
|
plt.figure(figsize=(10, 8)) |
|
|
|
|
|
|
|
|
if vmin is None: vmin = np.nanmin(depth) |
|
|
if vmax is None: vmax = np.nanmax(depth) |
|
|
|
|
|
|
|
|
plt.imshow(depth, cmap=cmap, vmin=vmin, vmax=vmax) |
|
|
|
|
|
|
|
|
cbar = plt.colorbar() |
|
|
cbar.set_label('Depth (meters)', fontsize=12) |
|
|
|
|
|
|
|
|
plt.title('Depth Map Visualization', fontsize=14) |
|
|
plt.axis('off') |
|
|
|
|
|
|
|
|
if output_path: |
|
|
os.makedirs(osp.dirname(output_path), exist_ok=True) |
|
|
plt.savefig(output_path, bbox_inches='tight', dpi=200) |
|
|
plt.close() |
|
|
print(f"深度图已保存至: {output_path}") |
|
|
else: |
|
|
plt.show() |
|
|
|
|
|
def test_depth_wrapper(): |
|
|
|
|
|
encoder = 'vitb' |
|
|
checkpoint_path = 'pretrained/pretrained_weights/depth_anything_v2_vitb.pth' |
|
|
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') |
|
|
image_path = "/mnt/pfs/users/chaojun.ni/wangweijie_mnt/yeqing/BEV-Splat/checkpoints/dl3dv-256x448-depthsplat-base-randview2-6/images/dl3dv_14eb48a50e37df548894ab6d8cd628a21dae14bbe6c462e894616fc5962e6c49/color/000028_gt.png" |
|
|
|
|
|
|
|
|
model = DepthAnythingWrapper(encoder, checkpoint_path).to(device) |
|
|
model.eval() |
|
|
|
|
|
|
|
|
image = Image.open(image_path).convert('RGB') |
|
|
original_size = image.size |
|
|
|
|
|
|
|
|
|
|
|
target_size = (448, 256) |
|
|
resized_image = image.resize(target_size, Image.BILINEAR) |
|
|
|
|
|
|
|
|
image_array = np.array(resized_image).astype(np.float32) / 255.0 |
|
|
|
|
|
|
|
|
image_tensor = torch.from_numpy(image_array).permute(2, 0, 1) |
|
|
|
|
|
|
|
|
x = image_tensor.unsqueeze(0).to(device) |
|
|
|
|
|
print(f"输入张量尺寸: {x.shape} (设备: {x.device})") |
|
|
|
|
|
|
|
|
with torch.no_grad(): |
|
|
depth = model(x) |
|
|
|
|
|
|
|
|
print(f"输入形状: {x.shape}") |
|
|
print(f"输出形状: {depth.shape} (应为 [B,1,H,W])") |
|
|
|
|
|
|
|
|
depth_np = depth.squeeze().cpu().numpy() |
|
|
|
|
|
print("\n深度图统计:") |
|
|
print(f"最小值: {depth_np.min():.4f} m, 最大值: {depth_np.max():.4f} m") |
|
|
print(f"均值: {depth_np.mean():.4f} m, 标准差: {depth_np.std():.4f}") |
|
|
|
|
|
|
|
|
|
|
|
input_dir = './visualization' |
|
|
os.makedirs(input_dir, exist_ok=True) |
|
|
input_image_path = osp.join(input_dir, 'input_image.jpg') |
|
|
image.save(input_image_path) |
|
|
print(f"原始图像已保存至: {input_image_path}") |
|
|
|
|
|
|
|
|
|
|
|
depth_min = depth_np.min() |
|
|
depth_max = depth_np.max() |
|
|
depth_normalized = (depth_np - depth_min) / (depth_max - depth_min + 1e-6) |
|
|
depth_grayscale = (depth_normalized * 255).astype(np.uint8) |
|
|
depth_pil = Image.fromarray(depth_grayscale) |
|
|
|
|
|
|
|
|
depth_pil = depth_pil.resize(original_size, Image.NEAREST) |
|
|
|
|
|
|
|
|
depth_grayscale_path = osp.join(input_dir, 'depth_grayscale.jpg') |
|
|
depth_pil.save(depth_grayscale_path) |
|
|
print(f"灰度深度图已保存至: {depth_grayscale_path}") |
|
|
|
|
|
|
|
|
depth_colored_path = osp.join(input_dir, 'depth_colormap.png') |
|
|
|
|
|
|
|
|
plt.figure(figsize=(10, 8)) |
|
|
plt.imshow(depth_np, cmap='jet', vmin=depth_np.min(), vmax=depth_np.max()) |
|
|
plt.axis('off') |
|
|
plt.subplots_adjust(left=0, right=1, top=1, bottom=0) |
|
|
plt.savefig(depth_colored_path, bbox_inches='tight', pad_inches=0, dpi=200) |
|
|
plt.close() |
|
|
print(f"伪彩色深度图已保存至: {depth_colored_path}") |
|
|
|
|
|
|
|
|
fig, axes = plt.subplots(1, 2, figsize=(15, 6)) |
|
|
|
|
|
|
|
|
axes[0].imshow(image) |
|
|
axes[0].set_title('Original Image', fontsize=12) |
|
|
axes[0].axis('off') |
|
|
|
|
|
|
|
|
depth_colored = plt.imread(depth_colored_path) |
|
|
axes[1].imshow(depth_colored) |
|
|
axes[1].set_title('Depth Map (Jet Colormap)', fontsize=12) |
|
|
axes[1].axis('off') |
|
|
|
|
|
|
|
|
comparison_path = osp.join(input_dir, 'depth_comparison.png') |
|
|
plt.savefig(comparison_path, bbox_inches='tight', dpi=150) |
|
|
plt.close() |
|
|
|
|
|
print(f"图像对比图已保存至: {comparison_path}") |
|
|
print("\n可视化已完成,结果保存在:", input_dir) |
|
|
|
|
|
if __name__ == '__main__': |
|
|
test_depth_wrapper() |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|