File size: 441 Bytes
604e535
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
"""LeWorldModel baseline training entry point."""

import torch.nn.functional as F


def train(config):
    model = config["model"]
    optimizer = config["optimizer"]
    images, actions, future_actions, targets = config["batch"]
    pred = model.rollout(images, actions, future_actions)
    loss = F.mse_loss(pred, targets)
    optimizer.zero_grad(set_to_none=True)
    loss.backward()
    optimizer.step()
    return float(loss.detach())