import numpy as np import torch class Trainer: def __init__(self, model, optim, logger): self.model = model self.optim = optim self.logger = logger self.train_it = 0 def test_agent(self, model, logger, test_envs, num_episodes): avg_success = [] for seq_idx, test_env in enumerate(test_envs): key_prefix = f"{test_env.name}/" for j in range(num_episodes): obs, done, episode_return, episode_len = test_env.reset(), False, 0, 0 while not done: action = model.get_action(obs) obs, reward, done, _ = test_env.step(action) episode_return += reward episode_len += 1 logger.store({key_prefix + "return": episode_return, key_prefix + "ep_length": episode_len}) logger.log_tabular(key_prefix + "return", with_min_and_max=True) logger.log_tabular(key_prefix + "ep_length", average_only=True) env_success = test_env.pop_successes() avg_success += env_success logger.log_tabular(key_prefix + "success", np.mean(env_success)) key = "average_success" logger.log_tabular(key, np.mean(avg_success)) def log(self, logger, step, model): # Log info about epoch logger.log_tabular("total_env_steps", step) logger.log_tabular("train/loss", average_only=True) logger.log_tabular("train/action", average_only=True) for e, w in enumerate(model.model.weight.flatten()): logger.log_tabular(f"weight{e}", w.item()) return logger.dump_tabular() def update(self, env, probs, model, optim, logger): target = torch.as_tensor([env.get_target_action()], dtype=torch.float32) action, log_prob = model.log_prob(probs, target) optim.zero_grad() loss = -torch.mean(log_prob) loss.backward() optim.step() logger.store({"train/action": action}) logger.store({"train/loss": loss.item()}) def train(self, env, test_envs, steps, log_every, num_eval_eps): obs = env.reset() for timestep in range(steps): self.train_it += 1 if (timestep + 1) % log_every == 0: self.test_agent(self.model, self.logger, test_envs, num_eval_eps) self.log(self.logger, self.train_it, self.model) output = self.model(obs) action, log_prob = self.model.sample(output) self.update(env, output, self.model, self.optim, self.logger) obs, reward, done, info = env.step(action) if done: obs = env.reset()