Spaces:
Running
on
Zero
Running
on
Zero
File size: 6,590 Bytes
6bc32b6 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 |
# Copyright (c) Meta Platforms, Inc. and affiliates.
from typing import Any, Callable, Dict, List, Optional, Union, Iterable
import lightning.pytorch as pl
import torch
from pathlib import Path
import os
import re
from loguru import logger
from lightning.pytorch.utilities.consolidate_checkpoint import (
_format_checkpoint,
_load_distributed_checkpoint,
)
from glob import glob
from sam3d_objects.data.utils import get_child, set_child
def rename_checkpoint_weights_using_suffix_matching(
checkpoint_path_in,
checkpoint_path_out,
model: torch.nn.Module,
strict: bool = True,
keys: Optional[List[Any]] = (),
):
# extract model names
param_names = [n for n, _ in model.named_parameters()]
buffer_names = [n for n, _ in model.named_buffers()]
model_names = param_names + buffer_names
# load stored weights
state = torch.load(checkpoint_path_in, weights_only=False)
model_state = get_child(state, *keys)
model_state_names = list(model_state.keys())
# sort reversed names (sort by suffix)
model_names_rev = sorted([n[::-1] for n in model_names])
model_state_names_rev = sorted([n[::-1] for n in model_state_names])
if strict and len(model_names) != len(model_state_names):
raise RuntimeError(
f"model and state don't have the same number of parameters ({len(model_names)} != {len(model_state_names)}), cannot match them (set strict = False to relax constraint)"
)
def common_prefix_length(str_0: str, str_1: str):
for count in range(min(len(str_0), len(str_1))):
if str_0[count] != str_1[count]:
break
return count
# attempt to match every model names to largest suffic matched weight
name_mapping = {}
i, j = 0, 0
last_n = 0
while i < len(model_names_rev):
if j < len(model_state_names_rev):
n = common_prefix_length(model_names_rev[i], model_state_names_rev[j])
else:
n = 0
if n >= last_n:
last_n = n
j += 1
else:
last_n = 0
name_mapping[model_names_rev[i][::-1]] = model_state_names_rev[j - 1][::-1]
i += 1
if not j < len(model_state_names_rev) + 1:
break
# not all names might have been matched
if i < len(model_names):
raise RuntimeError("could not suffix match parameter names")
for k, v in name_mapping.items():
logger.debug(f"{k} <- {v}")
# rename weights according to matches and save to disk
model_state_out = {k: model_state[v] for k, v in name_mapping.items()}
set_child(state, model_state_out, *keys)
torch.save(state, checkpoint_path_out)
def remove_prefix_state_dict_fn(prefix: str):
n = len(prefix)
def state_dict_fn(state_dict):
return {
(key[n:] if key.startswith(prefix) else key): value
for key, value in state_dict.items()
}
return state_dict_fn
def add_prefix_state_dict_fn(prefix: str):
def state_dict_fn(state_dict):
return {prefix + key: value for key, value in state_dict.items()}
return state_dict_fn
def filter_and_remove_prefix_state_dict_fn(prefix: str):
n = len(prefix)
def state_dict_fn(state_dict):
return {
key[n:]: value
for key, value in state_dict.items()
if key.startswith(prefix)
}
return state_dict_fn
def get_last_checkpoint(path: str):
checkpoints = glob(os.path.join(path, "epoch=*-step=*.ckpt"))
prog = re.compile(r"epoch=(\d+)-step=(\d+).ckpt")
checkpoints_to_sort = []
for checkpoint in checkpoints:
checkpoint_name = os.path.basename(checkpoint)
match = prog.match(checkpoint_name)
if match is not None:
n_epoch, n_step = prog.match(checkpoint_name).groups()
n_epoch, n_step = int(n_epoch), int(n_step)
checkpoints_to_sort.append((n_epoch, n_step, checkpoint))
sorted_checkpoints = sorted(checkpoints_to_sort)
if not len(sorted_checkpoints) > 0:
raise RuntimeError(f"no checkpoint has been found at path : {path}")
return sorted_checkpoints[-1][2]
def load_sharded_checkpoint(path: str, device: Optional[str]):
if device != "cpu":
raise RuntimeError(
f'loading sharded weights on device "{device}" is not available, please use the "cpu" device instead'
)
checkpoint = _load_distributed_checkpoint(Path(path))
checkpoint = _format_checkpoint(checkpoint)
return checkpoint
def load_model_from_checkpoint(
model: Union[pl.LightningModule, torch.nn.Module],
checkpoint_path: str,
strict: bool = True,
device: Optional[str] = None,
freeze: bool = False,
eval: bool = False,
map_name: Union[Dict[str, str], None] = None,
remove_name: Union[List[str], None] = None,
state_dict_key: Union[None, str, Iterable[str]] = "state_dict",
state_dict_fn: Optional[Callable[[Any], Any]] = None,
):
logger.info(f"Loading checkpoint from {checkpoint_path}")
if os.path.isfile(checkpoint_path):
checkpoint = torch.load(
checkpoint_path,
map_location=device,
weights_only=False,
)
elif os.path.isdir(checkpoint_path): # sharded
checkpoint = load_sharded_checkpoint(checkpoint_path, device=device)
else: # if neither a file nor a directory, path does not exist
raise FileNotFoundError(checkpoint_path)
if isinstance(model, pl.LightningModule):
model.on_load_checkpoint(checkpoint)
# get state dictionary
state_dict = checkpoint
if state_dict_key is not None:
if isinstance(state_dict_key, str):
state_dict_key = (state_dict_key,)
state_dict = get_child(state_dict, *state_dict_key)
# remove names
if remove_name is not None:
for name in remove_name:
del state_dict[name]
# remap names
if map_name is not None:
for src, dst in map_name.items():
if src not in state_dict:
continue
state_dict[dst] = state_dict[src]
del state_dict[src]
# apply custom changes to dict
if state_dict_fn is not None:
state_dict = state_dict_fn(state_dict)
model.load_state_dict(state_dict, strict=strict)
if device is not None:
model = model.to(device)
if freeze:
for param in model.parameters():
param.requires_grad = False
eval = True
if eval:
model.eval()
return model
|