| """TD-MPC2 latent rollout interface.""" | |
| def rollout(model, batch, horizon): | |
| images, actions, future_actions = batch | |
| return model.rollout(images, actions, future_actions[:, :horizon]) | |
| """TD-MPC2 latent rollout interface.""" | |
| def rollout(model, batch, horizon): | |
| images, actions, future_actions = batch | |
| return model.rollout(images, actions, future_actions[:, :horizon]) | |