"""LeWorldModel baseline model definition.""" import torch from experiments.leworldmodel.src.config import default_config from experiments.shared.src.models.image_world_models import LeWorldModelImage def build_model(config): return LeWorldModelImage(config) def load_model(checkpoint_path, config=None): cfg = default_config() if config is None else config model = build_model(cfg) model.load_state_dict(torch.load(checkpoint_path, map_location="cpu")) return model