| from depth_anything_v2.dpt import DepthAnythingV2 | |
| import torch | |
| import torch.nn.functional as F | |
| device = 'cuda' if torch.cuda.is_available() else 'mps' if torch.backends.mps.is_available() else 'cpu' | |
| print(f"Using device: {device}") | |
| model_configs = { | |
| 'vits': {'encoder': 'vits', 'features': 64, 'out_channels': [48, 96, 192, 384]}, | |
| 'vitb': {'encoder': 'vitb', 'features': 128, 'out_channels': [96, 192, 384, 768]}, | |
| 'vitl': {'encoder': 'vitl', 'features': 256, 'out_channels': [256, 512, 1024, 1024]}, | |
| 'vitg': {'encoder': 'vitg', 'features': 384, 'out_channels': [1536, 1536, 1536, 1536]} | |
| } | |
| encoder = 'vitb' | |
| model_path = f'checkpoints/depth_anything_v2_{encoder}.pth' | |
| model = DepthAnythingV2(**model_configs[encoder]) | |
| state_dict = torch.load(model_path, map_location=device) | |
| model.load_state_dict(state_dict) | |
| model = model.to(device).eval().requires_grad_(False) | |
| def depth_loss(A: torch.Tensor, B: torch.Tensor, model=model) -> torch.Tensor: | |
| A = A.to(device) | |
| B = B.to(device) | |
| if A.shape[1] == 1: | |
| A = A.repeat(1, 3, 1, 1) | |
| if B.shape[1] == 1: | |
| B = B.repeat(1, 3, 1, 1) | |
| target_size = 518 | |
| resized_A = F.interpolate(A, size=(target_size, target_size), mode='bilinear', align_corners=False) | |
| resized_B = F.interpolate(B, size=(target_size, target_size), mode='bilinear', align_corners=False) | |
| features_A = model.pretrained.get_intermediate_layers( | |
| resized_A, | |
| model.intermediate_layer_idx[model.encoder], | |
| return_class_token=True | |
| ) | |
| features_B = model.pretrained.get_intermediate_layers( | |
| resized_B, | |
| model.intermediate_layer_idx[model.encoder], | |
| return_class_token=True | |
| ) | |
| total_loss = 0.0 | |
| num_layers = len(features_A) | |
| for feat_A, feat_B in zip(features_A, features_B): | |
| token_feat_A, _ = feat_A | |
| token_feat_B, _ = feat_B | |
| layer_loss = F.mse_loss(token_feat_A, token_feat_B) | |
| total_loss += layer_loss | |
| return total_loss / num_layers | |
| if __name__ == '__main__': | |
| batch_size, channels, height, width = 4, 1, 1024, 1024 | |
| dummy_input_A = torch.randn(batch_size, channels, height, width).to(device) | |
| dummy_input_B = torch.randn(batch_size, channels, height, width).to(device) | |
| loss = depth_loss(dummy_input_A, dummy_input_B, model) | |
| print(f"Computed depth loss: {loss.item():.6f}") | |