cccat6's picture
Initial FlowMo-WM public code release
604e535 verified
raw
history blame
563 Bytes
"""FlowMo model definition."""
import torch
from experiments.flowmo.src.config import default_config
from experiments.shared.src.models.image_world_models import FlowMoImageWorldModel
def build_model(config):
"""Build the FlowMo world model."""
return FlowMoImageWorldModel(config)
def load_model(checkpoint_path, config=None):
"""Load a trained FlowMo checkpoint."""
cfg = default_config() if config is None else config
model = build_model(cfg)
model.load_state_dict(torch.load(checkpoint_path, map_location="cpu"))
return model