| from pathlib import Path |
| from omegaconf import OmegaConf |
| import torch |
| from ldm.util import instantiate_from_config |
| import logging |
| from contextlib import contextmanager |
|
|
| from contextlib import contextmanager |
| import logging |
|
|
| @contextmanager |
| def all_logging_disabled(highest_level=logging.CRITICAL): |
| """ |
| A context manager that will prevent any logging messages |
| triggered during the body from being processed. |
| |
| :param highest_level: the maximum logging level in use. |
| This would only need to be changed if a custom level greater than CRITICAL |
| is defined. |
| |
| https://gist.github.com/simon-weber/7853144 |
| """ |
| |
| |
| |
| |
|
|
| previous_level = logging.root.manager.disable |
|
|
| logging.disable(highest_level) |
|
|
| try: |
| yield |
| finally: |
| logging.disable(previous_level) |
|
|
| def load_training_dir(train_dir, device, epoch="last"): |
| """Load a checkpoint and config from training directory""" |
| train_dir = Path(train_dir) |
| ckpt = list(train_dir.rglob(f"*{epoch}.ckpt")) |
| assert len(ckpt) == 1, f"found {len(ckpt)} matching ckpt files" |
| config = list(train_dir.rglob(f"*-project.yaml")) |
| assert len(ckpt) > 0, f"didn't find any config in {train_dir}" |
| if len(config) > 1: |
| print(f"found {len(config)} matching config files") |
| config = sorted(config)[-1] |
| print(f"selecting {config}") |
| else: |
| config = config[0] |
|
|
|
|
| config = OmegaConf.load(config) |
| return load_model_from_config(config, ckpt[0], device) |
|
|
| def load_model_from_config(config, ckpt, device="cpu", verbose=False): |
| """Loads a model from config and a ckpt |
| if config is a path will use omegaconf to load |
| """ |
| if isinstance(config, (str, Path)): |
| config = OmegaConf.load(config) |
|
|
| with all_logging_disabled(): |
| print(f"Loading model from {ckpt}") |
| pl_sd = torch.load(ckpt, map_location="cpu") |
| global_step = pl_sd["global_step"] |
| sd = pl_sd["state_dict"] |
| model = instantiate_from_config(config.model) |
| m, u = model.load_state_dict(sd, strict=False) |
| if len(m) > 0 and verbose: |
| print("missing keys:") |
| print(m) |
| if len(u) > 0 and verbose: |
| print("unexpected keys:") |
| model.to(device) |
| model.eval() |
| model.cond_stage_model.device = device |
| return model |