| """TD-MPC2 baseline training entry point.""" | |
| import torch.nn.functional as F | |
| def train(config): | |
| model = config["model"] | |
| optimizer = config["optimizer"] | |
| images, actions, future_actions, targets = config["batch"] | |
| pred = model.rollout(images, actions, future_actions) | |
| z, _ = model.encode(images, actions) | |
| value = model.value(z).mean() | |
| loss = F.mse_loss(pred, targets) + 0.0 * value | |
| optimizer.zero_grad(set_to_none=True) | |
| loss.backward() | |
| optimizer.step() | |
| return float(loss.detach()) | |