Spaces:
Runtime error
Runtime error
| import numpy as np | |
| import torch | |
| class Trainer: | |
| def __init__(self, model, optim, logger): | |
| self.model = model | |
| self.optim = optim | |
| self.logger = logger | |
| self.train_it = 0 | |
| def test_agent(self, model, logger, test_envs, num_episodes): | |
| avg_success = [] | |
| for seq_idx, test_env in enumerate(test_envs): | |
| key_prefix = f"{test_env.name}/" | |
| for j in range(num_episodes): | |
| obs, done, episode_return, episode_len = test_env.reset(), False, 0, 0 | |
| while not done: | |
| action = model.get_action(obs) | |
| obs, reward, done, _ = test_env.step(action) | |
| episode_return += reward | |
| episode_len += 1 | |
| logger.store({key_prefix + "return": episode_return, key_prefix + "ep_length": episode_len}) | |
| logger.log_tabular(key_prefix + "return", with_min_and_max=True) | |
| logger.log_tabular(key_prefix + "ep_length", average_only=True) | |
| env_success = test_env.pop_successes() | |
| avg_success += env_success | |
| logger.log_tabular(key_prefix + "success", np.mean(env_success)) | |
| key = "average_success" | |
| logger.log_tabular(key, np.mean(avg_success)) | |
| def log(self, logger, step, model): | |
| # Log info about epoch | |
| logger.log_tabular("total_env_steps", step) | |
| logger.log_tabular("train/loss", average_only=True) | |
| logger.log_tabular("train/action", average_only=True) | |
| for e, w in enumerate(model.model.weight.flatten()): | |
| logger.log_tabular(f"weight{e}", w.item()) | |
| return logger.dump_tabular() | |
| def update(self, env, probs, model, optim, logger): | |
| target = torch.as_tensor([env.get_target_action()], dtype=torch.float32) | |
| action, log_prob = model.log_prob(probs, target) | |
| optim.zero_grad() | |
| loss = -torch.mean(log_prob) | |
| loss.backward() | |
| optim.step() | |
| logger.store({"train/action": action}) | |
| logger.store({"train/loss": loss.item()}) | |
| def train(self, env, test_envs, steps, log_every, num_eval_eps): | |
| obs = env.reset() | |
| for timestep in range(steps): | |
| self.train_it += 1 | |
| if (timestep + 1) % log_every == 0: | |
| self.test_agent(self.model, self.logger, test_envs, num_eval_eps) | |
| self.log(self.logger, self.train_it, self.model) | |
| output = self.model(obs) | |
| action, log_prob = self.model.sample(output) | |
| self.update(env, output, self.model, self.optim, self.logger) | |
| obs, reward, done, info = env.step(action) | |
| if done: | |
| obs = env.reset() | |