File size: 193 Bytes
604e535
 
 
 
 
 
1
2
3
4
5
6
7
"""TD-MPC2 latent rollout interface."""


def rollout(model, batch, horizon):
    images, actions, future_actions = batch
    return model.rollout(images, actions, future_actions[:, :horizon])