"""PlaNet RSSM model definition.""" import torch from experiments.planet.src.config import default_config from experiments.shared.src.models.image_world_models import RSSMImageWorldModel def build_model(config): return RSSMImageWorldModel(config) def load_model(checkpoint_path, config=None): cfg = default_config() if config is None else config model = build_model(cfg) model.load_state_dict(torch.load(checkpoint_path, map_location="cpu")) return model