"""TD-MPC2 baseline model definition.""" import torch from experiments.shared.src.models.image_world_models import TDMPC2ImageWorldModel from experiments.tdmpc2.src.config import default_config def build_model(config): return TDMPC2ImageWorldModel(config) def load_model(checkpoint_path, config=None): cfg = default_config() if config is None else config model = build_model(cfg) model.load_state_dict(torch.load(checkpoint_path, map_location="cpu")) return model