| """PlaNet RSSM training entry point.""" | |
| import torch.nn.functional as F | |
| def train(config): | |
| model = config["model"] | |
| optimizer = config["optimizer"] | |
| images, actions, future_actions, targets = config["batch"] | |
| pred = model.rollout(images, actions, future_actions) | |
| loss = F.mse_loss(pred, targets) | |
| optimizer.zero_grad(set_to_none=True) | |
| loss.backward() | |
| optimizer.step() | |
| return float(loss.detach()) | |