DeBoneDiT / code /depth_loss.py
diaoquesang's picture
Upload 15 files
6107278 verified
Raw
History Blame Contribute Delete
2.43 kB
from depth_anything_v2.dpt import DepthAnythingV2
import torch
import torch.nn.functional as F
device = 'cuda' if torch.cuda.is_available() else 'mps' if torch.backends.mps.is_available() else 'cpu'
print(f"Using device: {device}")
model_configs = {
'vits': {'encoder': 'vits', 'features': 64, 'out_channels': [48, 96, 192, 384]},
'vitb': {'encoder': 'vitb', 'features': 128, 'out_channels': [96, 192, 384, 768]},
'vitl': {'encoder': 'vitl', 'features': 256, 'out_channels': [256, 512, 1024, 1024]},
'vitg': {'encoder': 'vitg', 'features': 384, 'out_channels': [1536, 1536, 1536, 1536]}
}
encoder = 'vitb'
model_path = f'checkpoints/depth_anything_v2_{encoder}.pth'
model = DepthAnythingV2(**model_configs[encoder])
state_dict = torch.load(model_path, map_location=device)
model.load_state_dict(state_dict)
model = model.to(device).eval().requires_grad_(False)
def depth_loss(A: torch.Tensor, B: torch.Tensor, model=model) -> torch.Tensor:
A = A.to(device)
B = B.to(device)
if A.shape[1] == 1:
A = A.repeat(1, 3, 1, 1)
if B.shape[1] == 1:
B = B.repeat(1, 3, 1, 1)
target_size = 518
resized_A = F.interpolate(A, size=(target_size, target_size), mode='bilinear', align_corners=False)
resized_B = F.interpolate(B, size=(target_size, target_size), mode='bilinear', align_corners=False)
features_A = model.pretrained.get_intermediate_layers(
resized_A,
model.intermediate_layer_idx[model.encoder],
return_class_token=True
)
features_B = model.pretrained.get_intermediate_layers(
resized_B,
model.intermediate_layer_idx[model.encoder],
return_class_token=True
)
total_loss = 0.0
num_layers = len(features_A)
for feat_A, feat_B in zip(features_A, features_B):
token_feat_A, _ = feat_A
token_feat_B, _ = feat_B
layer_loss = F.mse_loss(token_feat_A, token_feat_B)
total_loss += layer_loss
return total_loss / num_layers
if __name__ == '__main__':
batch_size, channels, height, width = 4, 1, 1024, 1024
dummy_input_A = torch.randn(batch_size, channels, height, width).to(device)
dummy_input_B = torch.randn(batch_size, channels, height, width).to(device)
loss = depth_loss(dummy_input_A, dummy_input_B, model)
print(f"Computed depth loss: {loss.item():.6f}")