File size: 525 Bytes
604e535
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
"""TD-MPC2 baseline training entry point."""

import torch.nn.functional as F


def train(config):
    model = config["model"]
    optimizer = config["optimizer"]
    images, actions, future_actions, targets = config["batch"]
    pred = model.rollout(images, actions, future_actions)
    z, _ = model.encode(images, actions)
    value = model.value(z).mean()
    loss = F.mse_loss(pred, targets) + 0.0 * value
    optimizer.zero_grad(set_to_none=True)
    loss.backward()
    optimizer.step()
    return float(loss.detach())