Spaces:
Running on Zero
Running on Zero
File size: 7,416 Bytes
19be62c | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 | #!/usr/bin/env python3
# Copyright (C) 2025-present Naver Corporation. All rights reserved.
import os
import torch
import numpy as np
import argparse
from tqdm import tqdm
from torch.utils.data import DataLoader
from must3r.model import *
from must3r.model.blocks.attention import toggle_memory_efficient_attention
from must3r.engine.inference import inference, postprocess, concat_preds
from must3r.tools.geometry import apply_log_to_norm
from must3r.datasets import * # noqa
import must3r.tools.path_to_dust3r # noqa
from dust3r.losses import L21
from dust3r.utils.geometry import geotrf
torch.multiprocessing.set_sharing_strategy('file_system')
def get_args_parser():
parser = argparse.ArgumentParser('MUSt3R eval', add_help=False)
parser.add_argument('--output', default=None)
# model and criterion
parser.add_argument('--encoder', default=None, type=str)
parser.add_argument('--decoder', default=None)
parser.add_argument('--init_num_views', default=2, type=int,
help="number of views to use when initializing the memory")
parser.add_argument('--batch_num_views', default=1, type=int,
help="number of views to use at once when updating the memory")
parser.add_argument('--max_batch_size', default=None, type=int,
help="max batch size for encoder/renderer")
parser.add_argument('--render_once', action='store_true', default=False)
parser.add_argument('--loss_in_log', action='store_true', default=False,
help="apply loss in log")
parser.add_argument('--chkpt', required=True, type=str, help="path to weights")
parser.add_argument('--eval_memory_num_views', default=None, nargs='+', type=int,
help="number of views to use when updating the memory")
parser.add_argument('--verbose', action='store_true', default=False)
# dataset
parser.add_argument('--dataset',
required=True,
type=str, help="test set")
parser.add_argument('--num_workers', default=8, type=int,
help="max batch size for encoder/renderer")
parser.add_argument('--batch_size', default=8, type=int,
help="Batch size per GPU (effective batch size is batch_size * accum_iter * # gpus")
return parser
if __name__ == "__main__":
device = 'cuda'
toggle_memory_efficient_attention(True)
parser = get_args_parser()
args = parser.parse_args()
if args.output is not None:
os.makedirs(os.path.dirname(args.output), exist_ok=True)
criterion = L21
print('Loading pretrained: ', args.chkpt)
encoder, decoder = load_model(
args.chkpt, encoder=args.encoder, decoder=args.decoder, device='cuda')
pointmaps_activation = get_pointmaps_activation(decoder)
dataset = eval(args.dataset)
dataset.set_epoch(0)
num_views_all = len(dataset[0])
dataloader = DataLoader(dataset, batch_size=args.batch_size, shuffle=False, num_workers=args.num_workers)
with torch.no_grad():
if args.eval_memory_num_views is None:
num_views_dec_all = list(range(args.init_num_views, num_views_all + 1))
else:
num_views_dec_all = args.eval_memory_num_views
for num_views_dec in num_views_dec_all:
losses_firstpass = [[] for _ in range(num_views_all)] # loss for each image, seen and unseen
losses_imgs = [[] for _ in range(num_views_all)] # loss for each image, seen and unseen
losses_all = []
for views in tqdm(dataloader):
assert len(views) == num_views_all
# DATA PREPARATION
imgs = [b['img'] for b in views]
imgs = torch.stack(imgs, dim=1).to(device)
B, _, three, H, W, = imgs.shape
true_shape = [b['true_shape'] for b in views]
true_shape = torch.stack(true_shape, dim=1).to(device)
gt_c2w = [b['camera_pose'] for b in views]
gt_c2w = torch.stack(gt_c2w, dim=1).to(device) # B, nimgs, 4, 4
gt_w2c = torch.linalg.inv(gt_c2w)
in_camera0 = gt_w2c[:, 0]
gt_pts = [b['pts3d'] for b in views]
gt_pts = torch.stack(gt_pts, dim=1).to(device)
gt_pts = geotrf(in_camera0, gt_pts) # B, nimgs, H, W, 3
if args.loss_in_log:
gt_pts_log = apply_log_to_norm(gt_pts, dim=-1)
gt_valid = [b['valid_mask'] for b in views] # B, H, W
gt_valid = torch.stack(gt_valid, dim=1).to(device) # B, nimgs, H, W
mem_batches = [min(args.init_num_views, num_views_dec)]
while (sum_b := sum(mem_batches)) != num_views_dec:
size_b = min(args.batch_num_views, num_views_dec - sum_b)
mem_batches.append(size_b)
if args.render_once:
to_render = list(range(num_views_dec, num_views_all))
else:
to_render = None
x_out_0, x_out = inference(encoder, decoder, imgs, true_shape, mem_batches,
verbose=args.verbose, max_bs=args.max_batch_size,
to_render=to_render)
x_out_0 = postprocess(x_out_0, pointmaps_activation=pointmaps_activation)
x_out = postprocess(x_out, pointmaps_activation=pointmaps_activation)
if to_render is not None:
x_out = concat_preds(x_out_0, x_out)
x_out_0, x_out = x_out_0['pts3d'], x_out['pts3d']
if x_out_0 is not None:
# apply the loss
x_out_0_v = x_out_0.view(B, num_views_dec, H, W, three)
for b in range(B):
for i in range(num_views_dec):
loss_i = criterion(gt_pts[b, i][gt_valid[b, i]], x_out_0_v[b, i][gt_valid[b, i]])
losses_firstpass[i].append(loss_i.cpu())
# apply the loss
x_out_v = x_out.view(B, num_views_all, H, W, three)
for b in range(B):
for i in range(num_views_all):
loss_i = criterion(gt_pts[b, i][gt_valid[b, i]], x_out_v[b, i][gt_valid[b, i]])
losses_imgs[i].append(loss_i.cpu())
for b in range(B):
loss_value = criterion(gt_pts[b][gt_valid[b]], x_out_v[b][gt_valid[b]])
losses_all.append(loss_value.cpu())
result_str = f'{num_views_dec=}\n'
if len(losses_firstpass[0]) > 0:
for i in range(num_views_dec):
result_str += (f'first pass {i} - mean = {np.mean(losses_firstpass[i])}, '
f'median = {np.median(losses_firstpass[i])}\n')
for i in range(num_views_all):
result_str += f'{i} - mean = {np.mean(losses_imgs[i])}, median = {np.median(losses_imgs[i])}\n'
result_str += f'global - mean = {np.mean(losses_all)}, median = {np.median(losses_all)}\n'
print(result_str)
if args.output is not None:
with open(args.output, 'a') as fid:
fid.write(result_str)
|