# Copyright (c) Meta Platforms, Inc. and affiliates. from typing import Any, Callable, Dict, List, Optional, Union, Iterable import lightning.pytorch as pl import torch from pathlib import Path import os import re from loguru import logger from lightning.pytorch.utilities.consolidate_checkpoint import ( _format_checkpoint, _load_distributed_checkpoint, ) from glob import glob from sam3d_objects.data.utils import get_child, set_child def rename_checkpoint_weights_using_suffix_matching( checkpoint_path_in, checkpoint_path_out, model: torch.nn.Module, strict: bool = True, keys: Optional[List[Any]] = (), ): # extract model names param_names = [n for n, _ in model.named_parameters()] buffer_names = [n for n, _ in model.named_buffers()] model_names = param_names + buffer_names # load stored weights state = torch.load(checkpoint_path_in, weights_only=False) model_state = get_child(state, *keys) model_state_names = list(model_state.keys()) # sort reversed names (sort by suffix) model_names_rev = sorted([n[::-1] for n in model_names]) model_state_names_rev = sorted([n[::-1] for n in model_state_names]) if strict and len(model_names) != len(model_state_names): raise RuntimeError( f"model and state don't have the same number of parameters ({len(model_names)} != {len(model_state_names)}), cannot match them (set strict = False to relax constraint)" ) def common_prefix_length(str_0: str, str_1: str): for count in range(min(len(str_0), len(str_1))): if str_0[count] != str_1[count]: break return count # attempt to match every model names to largest suffic matched weight name_mapping = {} i, j = 0, 0 last_n = 0 while i < len(model_names_rev): if j < len(model_state_names_rev): n = common_prefix_length(model_names_rev[i], model_state_names_rev[j]) else: n = 0 if n >= last_n: last_n = n j += 1 else: last_n = 0 name_mapping[model_names_rev[i][::-1]] = model_state_names_rev[j - 1][::-1] i += 1 if not j < len(model_state_names_rev) + 1: break # not all names might have been matched if i < len(model_names): raise RuntimeError("could not suffix match parameter names") for k, v in name_mapping.items(): logger.debug(f"{k} <- {v}") # rename weights according to matches and save to disk model_state_out = {k: model_state[v] for k, v in name_mapping.items()} set_child(state, model_state_out, *keys) torch.save(state, checkpoint_path_out) def remove_prefix_state_dict_fn(prefix: str): n = len(prefix) def state_dict_fn(state_dict): return { (key[n:] if key.startswith(prefix) else key): value for key, value in state_dict.items() } return state_dict_fn def add_prefix_state_dict_fn(prefix: str): def state_dict_fn(state_dict): return {prefix + key: value for key, value in state_dict.items()} return state_dict_fn def filter_and_remove_prefix_state_dict_fn(prefix: str): n = len(prefix) def state_dict_fn(state_dict): return { key[n:]: value for key, value in state_dict.items() if key.startswith(prefix) } return state_dict_fn def get_last_checkpoint(path: str): checkpoints = glob(os.path.join(path, "epoch=*-step=*.ckpt")) prog = re.compile(r"epoch=(\d+)-step=(\d+).ckpt") checkpoints_to_sort = [] for checkpoint in checkpoints: checkpoint_name = os.path.basename(checkpoint) match = prog.match(checkpoint_name) if match is not None: n_epoch, n_step = prog.match(checkpoint_name).groups() n_epoch, n_step = int(n_epoch), int(n_step) checkpoints_to_sort.append((n_epoch, n_step, checkpoint)) sorted_checkpoints = sorted(checkpoints_to_sort) if not len(sorted_checkpoints) > 0: raise RuntimeError(f"no checkpoint has been found at path : {path}") return sorted_checkpoints[-1][2] def load_sharded_checkpoint(path: str, device: Optional[str]): if device != "cpu": raise RuntimeError( f'loading sharded weights on device "{device}" is not available, please use the "cpu" device instead' ) checkpoint = _load_distributed_checkpoint(Path(path)) checkpoint = _format_checkpoint(checkpoint) return checkpoint def load_model_from_checkpoint( model: Union[pl.LightningModule, torch.nn.Module], checkpoint_path: str, strict: bool = True, device: Optional[str] = None, freeze: bool = False, eval: bool = False, map_name: Union[Dict[str, str], None] = None, remove_name: Union[List[str], None] = None, state_dict_key: Union[None, str, Iterable[str]] = "state_dict", state_dict_fn: Optional[Callable[[Any], Any]] = None, ): logger.info(f"Loading checkpoint from {checkpoint_path}") if os.path.isfile(checkpoint_path): checkpoint = torch.load( checkpoint_path, map_location=device, weights_only=False, ) elif os.path.isdir(checkpoint_path): # sharded checkpoint = load_sharded_checkpoint(checkpoint_path, device=device) else: # if neither a file nor a directory, path does not exist raise FileNotFoundError(checkpoint_path) if isinstance(model, pl.LightningModule): model.on_load_checkpoint(checkpoint) # get state dictionary state_dict = checkpoint if state_dict_key is not None: if isinstance(state_dict_key, str): state_dict_key = (state_dict_key,) state_dict = get_child(state_dict, *state_dict_key) # remove names if remove_name is not None: for name in remove_name: del state_dict[name] # remap names if map_name is not None: for src, dst in map_name.items(): if src not in state_dict: continue state_dict[dst] = state_dict[src] del state_dict[src] # apply custom changes to dict if state_dict_fn is not None: state_dict = state_dict_fn(state_dict) model.load_state_dict(state_dict, strict=strict) if device is not None: model = model.to(device) if freeze: for param in model.parameters(): param.requires_grad = False eval = True if eval: model.eval() return model