| """TD-MPC2 baseline model definition.""" | |
| import torch | |
| from experiments.shared.src.models.image_world_models import TDMPC2ImageWorldModel | |
| from experiments.tdmpc2.src.config import default_config | |
| def build_model(config): | |
| return TDMPC2ImageWorldModel(config) | |
| def load_model(checkpoint_path, config=None): | |
| cfg = default_config() if config is None else config | |
| model = build_model(cfg) | |
| model.load_state_dict(torch.load(checkpoint_path, map_location="cpu")) | |
| return model | |