File size: 11,274 Bytes
78d2329
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
import os
from collections import OrderedDict
from typing import Any

import torch

from optgs.misc.io import cyan


# Function to extract the step number from the filename
def extract_step(file_name):
    step_str = file_name.split("-")[1].split("_")[1].replace(".ckpt", "")
    return int(step_str)


def find_latest_ckpt(ckpt_dir):
    # List all files in the directory that end with .ckpt
    ckpt_files = [f for f in os.listdir(ckpt_dir) if f.endswith(".ckpt")]

    # Check if there are any .ckpt files in the directory
    if not ckpt_files:
        raise ValueError(f"No .ckpt files found in {ckpt_dir}.")
    else:
        # Find the file with the maximum step
        latest_ckpt_file = max(ckpt_files, key=extract_step)
        return ckpt_dir / latest_ckpt_file


def no_resume_upsampler(pretrained_state_dict):
    new_state_dict = OrderedDict()
    for key, value in pretrained_state_dict.items():
        if 'upsampler' not in key:
            new_state_dict[key] = value
    return new_state_dict


def load_partial_state_dict(model, pretrained_state_dict):
    # Load only matching parameters
    model_state_dict = model.state_dict()
    filtered_state_dict = {
        k: v for k, v in pretrained_state_dict.items()
        if k in model_state_dict and v.shape == model_state_dict[k].shape
    }
    # for key in model_state_dict:
    #     if key not in filtered_state_dict:
    #         print(key)
    model_state_dict.update(filtered_state_dict)
    model.load_state_dict(model_state_dict)


def _load_state_dict(path):
    ckpt = torch.load(path, map_location='cpu')
    if 'state_dict' in ckpt:
        return ckpt['state_dict']
    if 'model' in ckpt:
        return ckpt['model']
    return ckpt


def load_optimizer(cfg, scene_trainer, strict_load):
    pretrained_model = torch.load(cfg.checkpointing.pretrained_optimizer, map_location='cpu')
    if 'state_dict' in pretrained_model:
        pretrained_model = pretrained_model['state_dict']
    # Strip scene_trainer. prefix if present (Lightning checkpoint format)
    pretrained_model = {k.replace("scene_trainer.", ""): v for k, v in pretrained_model.items()}
    if any(k.startswith("optimizer.") for k in pretrained_model):
        # Unified repo format: keys are optimizer.*
        optimizer_state_dict = {k[len("optimizer."):]: v for k, v in pretrained_model.items() if
                                k.startswith("optimizer.")}
    else:
        # Resplat repo format: keys are encoder.* (before init/opt split).
        # Strip encoder. prefix; init-related keys will be ignored via strict=False.
        optimizer_state_dict = {k[len("encoder."):]: v for k, v in pretrained_model.items() if k.startswith("encoder.")}
        # Rename module attributes that changed when the encoder was split.
        _ORIG_OPTIMIZER_ATTR_RENAMES = {
            "render_error_mv_attn": "update_error_attn",
        }
        renamed = {}
        for k, v in optimizer_state_dict.items():
            for old, new in _ORIG_OPTIMIZER_ATTR_RENAMES.items():
                if k == old or k.startswith(old + "."):
                    k = new + k[len(old):]
                    break
            renamed[k] = v
        optimizer_state_dict = renamed

    # If init_state_wo_features is True, remove all feature-related parameters from the optimizer state dict
    print(cfg.scene_trainer.scene_optimizer.init_state_wo_features)

    if getattr(cfg.scene_trainer.scene_optimizer, "init_state_wo_features", False):
        optimizer_state_dict = {k: v for k, v in optimizer_state_dict.items() if "update_proj" not in k}
    scene_trainer.optimizer.load_state_dict(optimizer_state_dict, strict=strict_load)
    print(cyan(f"Loaded pretrained optimizer: {cfg.checkpointing.pretrained_optimizer}"))


def load_initializer(cfg, scene_trainer, strict_load):
    pretrained_model = torch.load(cfg.checkpointing.pretrained_initializer, map_location='cpu')
    if 'state_dict' in pretrained_model:
        pretrained_model = pretrained_model['state_dict']
    # Strip scene_trainer. prefix if present (Lightning checkpoint format)
    pretrained_model = {k.replace("scene_trainer.", ""): v for k, v in pretrained_model.items()}
    if any(k.startswith("initializer.") for k in pretrained_model):
        assert all(k.startswith("initializer.") for k in pretrained_model)
        # Current repo format: keys are initializer.*
        initializer_state_dict = {k[len("initializer."):]: v for k, v in pretrained_model.items() if
                                  k.startswith("initializer.")}
    else:
        # Resplat repo format: keys are encoder.* (before init/opt split)
        initializer_state_dict = {k[len("encoder."):]: v for k, v in pretrained_model.items() if
                                  k.startswith("encoder.")}
    scene_trainer.initializer.load_state_dict(initializer_state_dict, strict=strict_load)
    print(cyan(f"Loaded pretrained initializer: {cfg.checkpointing.pretrained_initializer}"))


def load_full_model(cfg, scene_trainer, strict_load):
    pretrained_model = torch.load(cfg.checkpointing.pretrained_model, map_location='cpu')
    if 'state_dict' in pretrained_model:
        pretrained_model = pretrained_model['state_dict']
    if cfg.checkpointing.partial_load:
        print('partial load')
        load_partial_state_dict(scene_trainer, pretrained_model)
    else:
        scene_trainer.load_state_dict(pretrained_model, strict=strict_load)
    print(cyan(f"Loaded pretrained weights: {cfg.checkpointing.pretrained_model}"))


def load_base_model(cfg, scene_trainer, strict_load: bool | Any):
    if cfg.checkpointing.pretrained_model is not None:
        load_full_model(cfg, scene_trainer, strict_load)
    else:
        # Load pretrained initializer if available
        if cfg.checkpointing.pretrained_initializer is not None:
            load_initializer(cfg, scene_trainer, strict_load)

        if cfg.checkpointing.pretrained_optimizer is not None and scene_trainer.optimizer is not None:
            load_optimizer(cfg, scene_trainer, strict_load)


def load_model_weights(cfg, scene_trainer, strict_load, mode: str):
    assert mode in ("train", "test")

    if mode == "train":
        # only load monodepth
        if cfg.checkpointing.pretrained_monodepth is not None:
            strict_load = False
            pretrained_model = torch.load(cfg.checkpointing.pretrained_monodepth, map_location='cpu')
            if 'state_dict' in pretrained_model:
                pretrained_model = pretrained_model['state_dict']
            if cfg.model.encoder.separate_depth_color or cfg.model.encoder.separate_depth_gaussian_scale:
                scene_trainer.encoder.feature_extractor.load_state_dict(pretrained_model, strict=strict_load)
            else:
                scene_trainer.encoder.depth_predictor.load_state_dict(pretrained_model, strict=strict_load)
            print(cyan(f"Loaded pretrained monodepth: {cfg.checkpointing.pretrained_monodepth}"))

        # freeze mono vit
        if cfg.checkpointing.freeze_mono_vit:
            print('freeze mono vit')
            for params in scene_trainer.encoder.depth_predictor.pretrained.parameters():
                params.requires_grad = False

        # load pretrained mvdepth
        if cfg.checkpointing.pretrained_mvdepth is not None:
            pretrained_model = torch.load(cfg.checkpointing.pretrained_mvdepth, map_location='cpu')['model']
            if cfg.model.encoder.separate_depth_color or cfg.model.encoder.separate_depth_gaussian_scale:
                scene_trainer.encoder.feature_extractor.load_state_dict(pretrained_model, strict=False)
            else:
                scene_trainer.encoder.depth_predictor.load_state_dict(pretrained_model, strict=False)
            print(cyan(f"Loaded pretrained mvdepth: {cfg.checkpointing.pretrained_mvdepth}"))

    # load full model (or separate initializer/optimizer checkpoints)
    load_base_model(cfg, scene_trainer, strict_load)

    # load pretrained depth
    if cfg.checkpointing.pretrained_depth is not None:
        pretrained_model = _load_state_dict(cfg.checkpointing.pretrained_depth)
        if mode == "train":
            if cfg.checkpointing.partial_load:
                print('partial load depth')
                load_partial_state_dict(scene_trainer.initializer.depth_predictor, pretrained_model)
            else:
                if cfg.checkpointing.no_resume_upsampler:
                    pretrained_model = no_resume_upsampler(pretrained_model)
                    strict_load = False
                scene_trainer.initializer.depth_predictor.load_state_dict(pretrained_model, strict=strict_load)
        else:
            scene_trainer.initializer.depth_predictor.load_state_dict(pretrained_model, strict=True)
        print(cyan(f"Loaded pretrained depth: {cfg.checkpointing.pretrained_depth}"))

    # load pretrained scale predictor
    if mode == "train" and cfg.checkpointing.pretrained_scale_predictor is not None:
        pretrained_model = _load_state_dict(cfg.checkpointing.pretrained_scale_predictor)
        scene_trainer.encoder.scale_predictor.load_state_dict(pretrained_model, strict=strict_load)
        print(cyan(f"Loaded pretrained scale predictor: {cfg.checkpointing.pretrained_scale_predictor}"))

        print('freeze scale predictor')
        for params in scene_trainer.encoder.scale_predictor.parameters():
            params.requires_grad = False

    # load pretrained update module
    if cfg.checkpointing.resume_update_module is not None:
        pretrained_model = _load_state_dict(cfg.checkpointing.resume_update_module)

        # Filter and load only matching "update_" parameters
        filtered_dict = {
            k: v for k, v in pretrained_model.items()
            if "encoder.update" in k and k in scene_trainer.state_dict()
            and v.shape == scene_trainer.state_dict()[k].shape
        }

        # Load them using strict=False so it skips missing/unmatched keys
        scene_trainer.load_state_dict(filtered_dict, strict=False)
        print(cyan(f"Loaded pretrained update module: {cfg.checkpointing.resume_update_module}"))

    if mode == "train":
        apply_freezes(cfg, scene_trainer)


def apply_freezes(cfg, scene_trainer):
    if getattr(cfg.scene_trainer.scene_initializer, 'freeze_depth', False):
        print('freeze depth')
        for params in scene_trainer.initializer.depth_predictor.parameters():
            params.requires_grad = False

    if not cfg.scene_trainer.train_scene_init:
        print('train refine only, freezing scene initializer')
        for name, params in scene_trainer.initializer.named_parameters():
            params.requires_grad = False

    if cfg.scene_trainer.num_update_steps > 0:
        if not cfg.scene_trainer.train_scene_opt:
            print('train refine only, freezing scene optimizer')
            for name, params in scene_trainer.optimizer.named_parameters():
                params.requires_grad = False
        if cfg.scene_trainer.scene_optimizer.train_global_update_only:
            print('train global update only')
            for name, params in scene_trainer.optimizer.named_parameters():
                if 'global_update' not in name:
                    params.requires_grad = False