File size: 7,416 Bytes
19be62c
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
#!/usr/bin/env python3
# Copyright (C) 2025-present Naver Corporation. All rights reserved.
import os
import torch
import numpy as np
import argparse
from tqdm import tqdm
from torch.utils.data import DataLoader

from must3r.model import *
from must3r.model.blocks.attention import toggle_memory_efficient_attention
from must3r.engine.inference import inference, postprocess, concat_preds
from must3r.tools.geometry import apply_log_to_norm
from must3r.datasets import *  # noqa

import must3r.tools.path_to_dust3r  # noqa
from dust3r.losses import L21
from dust3r.utils.geometry import geotrf
torch.multiprocessing.set_sharing_strategy('file_system')


def get_args_parser():
    parser = argparse.ArgumentParser('MUSt3R eval', add_help=False)
    parser.add_argument('--output', default=None)

    # model and criterion
    parser.add_argument('--encoder', default=None, type=str)
    parser.add_argument('--decoder', default=None)

    parser.add_argument('--init_num_views', default=2, type=int,
                        help="number of views to use when initializing the memory")
    parser.add_argument('--batch_num_views', default=1, type=int,
                        help="number of views to use at once when updating the memory")
    parser.add_argument('--max_batch_size', default=None, type=int,
                        help="max batch size for encoder/renderer")

    parser.add_argument('--render_once', action='store_true', default=False)

    parser.add_argument('--loss_in_log', action='store_true', default=False,
                        help="apply loss in log")
    parser.add_argument('--chkpt', required=True, type=str, help="path to weights")

    parser.add_argument('--eval_memory_num_views', default=None, nargs='+', type=int,
                        help="number of views to use when updating the memory")

    parser.add_argument('--verbose', action='store_true', default=False)
    # dataset
    parser.add_argument('--dataset',
                        required=True,
                        type=str, help="test set")
    parser.add_argument('--num_workers', default=8, type=int,
                        help="max batch size for encoder/renderer")

    parser.add_argument('--batch_size', default=8, type=int,
                        help="Batch size per GPU (effective batch size is batch_size * accum_iter * # gpus")
    return parser


if __name__ == "__main__":
    device = 'cuda'
    toggle_memory_efficient_attention(True)
    parser = get_args_parser()
    args = parser.parse_args()

    if args.output is not None:
        os.makedirs(os.path.dirname(args.output), exist_ok=True)

    criterion = L21
    print('Loading pretrained: ', args.chkpt)
    encoder, decoder = load_model(
        args.chkpt, encoder=args.encoder, decoder=args.decoder, device='cuda')
    pointmaps_activation = get_pointmaps_activation(decoder)
    dataset = eval(args.dataset)
    dataset.set_epoch(0)

    num_views_all = len(dataset[0])
    dataloader = DataLoader(dataset, batch_size=args.batch_size, shuffle=False, num_workers=args.num_workers)

    with torch.no_grad():
        if args.eval_memory_num_views is None:
            num_views_dec_all = list(range(args.init_num_views, num_views_all + 1))
        else:
            num_views_dec_all = args.eval_memory_num_views

        for num_views_dec in num_views_dec_all:
            losses_firstpass = [[] for _ in range(num_views_all)]  # loss for each image, seen and unseen
            losses_imgs = [[] for _ in range(num_views_all)]  # loss for each image, seen and unseen
            losses_all = []
            for views in tqdm(dataloader):
                assert len(views) == num_views_all

                # DATA PREPARATION
                imgs = [b['img'] for b in views]
                imgs = torch.stack(imgs, dim=1).to(device)
                B, _, three, H, W, = imgs.shape

                true_shape = [b['true_shape'] for b in views]
                true_shape = torch.stack(true_shape, dim=1).to(device)

                gt_c2w = [b['camera_pose'] for b in views]
                gt_c2w = torch.stack(gt_c2w, dim=1).to(device)  # B, nimgs, 4, 4
                gt_w2c = torch.linalg.inv(gt_c2w)

                in_camera0 = gt_w2c[:, 0]

                gt_pts = [b['pts3d'] for b in views]
                gt_pts = torch.stack(gt_pts, dim=1).to(device)
                gt_pts = geotrf(in_camera0, gt_pts)  # B, nimgs, H, W, 3

                if args.loss_in_log:
                    gt_pts_log = apply_log_to_norm(gt_pts, dim=-1)

                gt_valid = [b['valid_mask'] for b in views]  # B, H, W
                gt_valid = torch.stack(gt_valid, dim=1).to(device)  # B, nimgs, H, W

                mem_batches = [min(args.init_num_views, num_views_dec)]
                while (sum_b := sum(mem_batches)) != num_views_dec:
                    size_b = min(args.batch_num_views, num_views_dec - sum_b)
                    mem_batches.append(size_b)

                if args.render_once:
                    to_render = list(range(num_views_dec, num_views_all))
                else:
                    to_render = None
                x_out_0, x_out = inference(encoder, decoder, imgs, true_shape, mem_batches,
                                           verbose=args.verbose, max_bs=args.max_batch_size,
                                           to_render=to_render)
                x_out_0 = postprocess(x_out_0, pointmaps_activation=pointmaps_activation)
                x_out = postprocess(x_out, pointmaps_activation=pointmaps_activation)
                if to_render is not None:
                    x_out = concat_preds(x_out_0, x_out)

                x_out_0, x_out = x_out_0['pts3d'], x_out['pts3d']
                if x_out_0 is not None:
                    # apply the loss
                    x_out_0_v = x_out_0.view(B, num_views_dec, H, W, three)
                    for b in range(B):
                        for i in range(num_views_dec):
                            loss_i = criterion(gt_pts[b, i][gt_valid[b, i]], x_out_0_v[b, i][gt_valid[b, i]])
                            losses_firstpass[i].append(loss_i.cpu())

                # apply the loss
                x_out_v = x_out.view(B, num_views_all, H, W, three)
                for b in range(B):
                    for i in range(num_views_all):
                        loss_i = criterion(gt_pts[b, i][gt_valid[b, i]], x_out_v[b, i][gt_valid[b, i]])
                        losses_imgs[i].append(loss_i.cpu())
                for b in range(B):
                    loss_value = criterion(gt_pts[b][gt_valid[b]], x_out_v[b][gt_valid[b]])
                    losses_all.append(loss_value.cpu())

            result_str = f'{num_views_dec=}\n'
            if len(losses_firstpass[0]) > 0:
                for i in range(num_views_dec):
                    result_str += (f'first pass {i} - mean = {np.mean(losses_firstpass[i])}, '
                                   f'median = {np.median(losses_firstpass[i])}\n')
            for i in range(num_views_all):
                result_str += f'{i} - mean = {np.mean(losses_imgs[i])}, median = {np.median(losses_imgs[i])}\n'
            result_str += f'global - mean = {np.mean(losses_all)}, median = {np.median(losses_all)}\n'

            print(result_str)
            if args.output is not None:
                with open(args.output, 'a') as fid:
                    fid.write(result_str)