File size: 563 Bytes
604e535
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
"""FlowMo model definition."""

import torch

from experiments.flowmo.src.config import default_config
from experiments.shared.src.models.image_world_models import FlowMoImageWorldModel

def build_model(config):
    """Build the FlowMo world model."""
    return FlowMoImageWorldModel(config)


def load_model(checkpoint_path, config=None):
    """Load a trained FlowMo checkpoint."""
    cfg = default_config() if config is None else config
    model = build_model(cfg)
    model.load_state_dict(torch.load(checkpoint_path, map_location="cpu"))
    return model