|
|
""" Load a model from its checkpoint directory """ |
|
|
import glob |
|
|
import os |
|
|
|
|
|
import hydra |
|
|
import torch |
|
|
from pyaml_env import parse_config |
|
|
|
|
|
from pvnet.models.ensemble import Ensemble |
|
|
from pvnet.models.multimodal.unimodal_teacher import Model as UMTModel |
|
|
|
|
|
|
|
|
def get_model_from_checkpoints( |
|
|
checkpoint_dir_paths: list[str], |
|
|
val_best: bool = True, |
|
|
): |
|
|
"""Load a model from its checkpoint directory""" |
|
|
is_ensemble = len(checkpoint_dir_paths) > 1 |
|
|
|
|
|
model_configs = [] |
|
|
models = [] |
|
|
data_configs = [] |
|
|
|
|
|
for path in checkpoint_dir_paths: |
|
|
|
|
|
model_config = parse_config(f"{path}/model_config.yaml") |
|
|
|
|
|
model = hydra.utils.instantiate(model_config) |
|
|
|
|
|
if val_best: |
|
|
|
|
|
files = glob.glob(f"{path}/epoch*.ckpt") |
|
|
if len(files) != 1: |
|
|
raise ValueError( |
|
|
f"Found {len(files)} checkpoints @ {path}/epoch*.ckpt. Expected one." |
|
|
) |
|
|
|
|
|
checkpoint = torch.load(files[0], map_location="cpu", weights_only=False) |
|
|
else: |
|
|
checkpoint = torch.load(f"{path}/last.ckpt", map_location="cpu", weights_only=False) |
|
|
|
|
|
model.load_state_dict(state_dict=checkpoint["state_dict"]) |
|
|
|
|
|
if isinstance(model, UMTModel): |
|
|
model, model_config = model.convert_to_multimodal_model(model_config) |
|
|
|
|
|
|
|
|
data_config = f"{path}/data_config.yaml" |
|
|
|
|
|
if os.path.isfile(data_config): |
|
|
data_configs.append(data_config) |
|
|
else: |
|
|
data_configs.append(None) |
|
|
|
|
|
model_configs.append(model_config) |
|
|
models.append(model) |
|
|
|
|
|
if is_ensemble: |
|
|
model_config = { |
|
|
"_target_": "pvnet.models.ensemble.Ensemble", |
|
|
"model_list": model_configs, |
|
|
} |
|
|
model = Ensemble(model_list=models) |
|
|
data_config = data_configs[0] |
|
|
|
|
|
else: |
|
|
model_config = model_configs[0] |
|
|
model = models[0] |
|
|
data_config = data_configs[0] |
|
|
|
|
|
return model, model_config, data_config |
|
|
|