cccat6's picture
Initial FlowMo-WM public code release
604e535 verified
"""PlaNet RSSM training entry point."""
import torch.nn.functional as F
def train(config):
model = config["model"]
optimizer = config["optimizer"]
images, actions, future_actions, targets = config["batch"]
pred = model.rollout(images, actions, future_actions)
loss = F.mse_loss(pred, targets)
optimizer.zero_grad(set_to_none=True)
loss.backward()
optimizer.step()
return float(loss.detach())