File size: 6,590 Bytes
6bc32b6
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
# Copyright (c) Meta Platforms, Inc. and affiliates.
from typing import Any, Callable, Dict, List, Optional, Union, Iterable
import lightning.pytorch as pl
import torch
from pathlib import Path
import os
import re
from loguru import logger
from lightning.pytorch.utilities.consolidate_checkpoint import (
    _format_checkpoint,
    _load_distributed_checkpoint,
)
from glob import glob

from sam3d_objects.data.utils import get_child, set_child


def rename_checkpoint_weights_using_suffix_matching(
    checkpoint_path_in,
    checkpoint_path_out,
    model: torch.nn.Module,
    strict: bool = True,
    keys: Optional[List[Any]] = (),
):
    # extract model names
    param_names = [n for n, _ in model.named_parameters()]
    buffer_names = [n for n, _ in model.named_buffers()]
    model_names = param_names + buffer_names

    # load stored weights
    state = torch.load(checkpoint_path_in, weights_only=False)

    model_state = get_child(state, *keys)
    model_state_names = list(model_state.keys())

    # sort reversed names (sort by suffix)
    model_names_rev = sorted([n[::-1] for n in model_names])
    model_state_names_rev = sorted([n[::-1] for n in model_state_names])

    if strict and len(model_names) != len(model_state_names):
        raise RuntimeError(
            f"model and state don't have the same number of parameters ({len(model_names)} != {len(model_state_names)}), cannot match them (set strict = False to relax constraint)"
        )

    def common_prefix_length(str_0: str, str_1: str):
        for count in range(min(len(str_0), len(str_1))):
            if str_0[count] != str_1[count]:
                break
        return count

    # attempt to match every model names to largest suffic matched weight
    name_mapping = {}
    i, j = 0, 0
    last_n = 0
    while i < len(model_names_rev):
        if j < len(model_state_names_rev):
            n = common_prefix_length(model_names_rev[i], model_state_names_rev[j])
        else:
            n = 0

        if n >= last_n:
            last_n = n
            j += 1
        else:
            last_n = 0
            name_mapping[model_names_rev[i][::-1]] = model_state_names_rev[j - 1][::-1]
            i += 1

        if not j < len(model_state_names_rev) + 1:
            break

    # not all names might have been matched
    if i < len(model_names):
        raise RuntimeError("could not suffix match parameter names")

    for k, v in name_mapping.items():
        logger.debug(f"{k} <- {v}")

    # rename weights according to matches and save to disk
    model_state_out = {k: model_state[v] for k, v in name_mapping.items()}
    set_child(state, model_state_out, *keys)
    torch.save(state, checkpoint_path_out)


def remove_prefix_state_dict_fn(prefix: str):
    n = len(prefix)

    def state_dict_fn(state_dict):
        return {
            (key[n:] if key.startswith(prefix) else key): value
            for key, value in state_dict.items()
        }

    return state_dict_fn


def add_prefix_state_dict_fn(prefix: str):
    def state_dict_fn(state_dict):
        return {prefix + key: value for key, value in state_dict.items()}

    return state_dict_fn


def filter_and_remove_prefix_state_dict_fn(prefix: str):
    n = len(prefix)

    def state_dict_fn(state_dict):
        return {
            key[n:]: value
            for key, value in state_dict.items()
            if key.startswith(prefix)
        }

    return state_dict_fn


def get_last_checkpoint(path: str):
    checkpoints = glob(os.path.join(path, "epoch=*-step=*.ckpt"))
    prog = re.compile(r"epoch=(\d+)-step=(\d+).ckpt")

    checkpoints_to_sort = []
    for checkpoint in checkpoints:
        checkpoint_name = os.path.basename(checkpoint)
        match = prog.match(checkpoint_name)
        if match is not None:
            n_epoch, n_step = prog.match(checkpoint_name).groups()
            n_epoch, n_step = int(n_epoch), int(n_step)
            checkpoints_to_sort.append((n_epoch, n_step, checkpoint))

    sorted_checkpoints = sorted(checkpoints_to_sort)
    if not len(sorted_checkpoints) > 0:
        raise RuntimeError(f"no checkpoint has been found at path : {path}")
    return sorted_checkpoints[-1][2]


def load_sharded_checkpoint(path: str, device: Optional[str]):
    if device != "cpu":
        raise RuntimeError(
            f'loading sharded weights on device "{device}" is not available, please use the "cpu" device instead'
        )
    checkpoint = _load_distributed_checkpoint(Path(path))
    checkpoint = _format_checkpoint(checkpoint)
    return checkpoint


def load_model_from_checkpoint(
    model: Union[pl.LightningModule, torch.nn.Module],
    checkpoint_path: str,
    strict: bool = True,
    device: Optional[str] = None,
    freeze: bool = False,
    eval: bool = False,
    map_name: Union[Dict[str, str], None] = None,
    remove_name: Union[List[str], None] = None,
    state_dict_key: Union[None, str, Iterable[str]] = "state_dict",
    state_dict_fn: Optional[Callable[[Any], Any]] = None,
):
    logger.info(f"Loading checkpoint from {checkpoint_path}")
    if os.path.isfile(checkpoint_path):
        checkpoint = torch.load(
            checkpoint_path,
            map_location=device,
            weights_only=False,
        )
    elif os.path.isdir(checkpoint_path):  # sharded
        checkpoint = load_sharded_checkpoint(checkpoint_path, device=device)
    else:  # if neither a file nor a directory, path does not exist
        raise FileNotFoundError(checkpoint_path)

    if isinstance(model, pl.LightningModule):
        model.on_load_checkpoint(checkpoint)

    # get state dictionary
    state_dict = checkpoint
    if state_dict_key is not None:
        if isinstance(state_dict_key, str):
            state_dict_key = (state_dict_key,)
        state_dict = get_child(state_dict, *state_dict_key)

    # remove names
    if remove_name is not None:
        for name in remove_name:
            del state_dict[name]

    # remap names
    if map_name is not None:
        for src, dst in map_name.items():
            if src not in state_dict:
                continue
            state_dict[dst] = state_dict[src]
            del state_dict[src]

    # apply custom changes to dict
    if state_dict_fn is not None:
        state_dict = state_dict_fn(state_dict)

    model.load_state_dict(state_dict, strict=strict)

    if device is not None:
        model = model.to(device)

    if freeze:
        for param in model.parameters():
            param.requires_grad = False
        eval = True

    if eval:
        model.eval()

    return model