File size: 431 Bytes
604e535 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 | """PlaNet RSSM training entry point."""
import torch.nn.functional as F
def train(config):
model = config["model"]
optimizer = config["optimizer"]
images, actions, future_actions, targets = config["batch"]
pred = model.rollout(images, actions, future_actions)
loss = F.mse_loss(pred, targets)
optimizer.zero_grad(set_to_none=True)
loss.backward()
optimizer.step()
return float(loss.detach())
|