File size: 480 Bytes
604e535 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 | """PlaNet RSSM model definition."""
import torch
from experiments.planet.src.config import default_config
from experiments.shared.src.models.image_world_models import RSSMImageWorldModel
def build_model(config):
return RSSMImageWorldModel(config)
def load_model(checkpoint_path, config=None):
cfg = default_config() if config is None else config
model = build_model(cfg)
model.load_state_dict(torch.load(checkpoint_path, map_location="cpu"))
return model
|