"""PlaNet RSSM training entry point.""" import torch.nn.functional as F def train(config): model = config["model"] optimizer = config["optimizer"] images, actions, future_actions, targets = config["batch"] pred = model.rollout(images, actions, future_actions) loss = F.mse_loss(pred, targets) optimizer.zero_grad(set_to_none=True) loss.backward() optimizer.step() return float(loss.detach())