cccat6's picture
Initial FlowMo-WM public code release
604e535 verified
"""TD-MPC2 baseline training entry point."""
import torch.nn.functional as F
def train(config):
model = config["model"]
optimizer = config["optimizer"]
images, actions, future_actions, targets = config["batch"]
pred = model.rollout(images, actions, future_actions)
z, _ = model.encode(images, actions)
value = model.value(z).mean()
loss = F.mse_loss(pred, targets) + 0.0 * value
optimizer.zero_grad(set_to_none=True)
loss.backward()
optimizer.step()
return float(loss.detach())