Apple / apple /training /trainer.py
New Author Name
init
4b714e2
import numpy as np
import torch
class Trainer:
def __init__(self, model, optim, logger):
self.model = model
self.optim = optim
self.logger = logger
self.train_it = 0
def test_agent(self, model, logger, test_envs, num_episodes):
avg_success = []
for seq_idx, test_env in enumerate(test_envs):
key_prefix = f"{test_env.name}/"
for j in range(num_episodes):
obs, done, episode_return, episode_len = test_env.reset(), False, 0, 0
while not done:
action = model.get_action(obs)
obs, reward, done, _ = test_env.step(action)
episode_return += reward
episode_len += 1
logger.store({key_prefix + "return": episode_return, key_prefix + "ep_length": episode_len})
logger.log_tabular(key_prefix + "return", with_min_and_max=True)
logger.log_tabular(key_prefix + "ep_length", average_only=True)
env_success = test_env.pop_successes()
avg_success += env_success
logger.log_tabular(key_prefix + "success", np.mean(env_success))
key = "average_success"
logger.log_tabular(key, np.mean(avg_success))
def log(self, logger, step, model):
# Log info about epoch
logger.log_tabular("total_env_steps", step)
logger.log_tabular("train/loss", average_only=True)
logger.log_tabular("train/action", average_only=True)
for e, w in enumerate(model.model.weight.flatten()):
logger.log_tabular(f"weight{e}", w.item())
return logger.dump_tabular()
def update(self, env, probs, model, optim, logger):
target = torch.as_tensor([env.get_target_action()], dtype=torch.float32)
action, log_prob = model.log_prob(probs, target)
optim.zero_grad()
loss = -torch.mean(log_prob)
loss.backward()
optim.step()
logger.store({"train/action": action})
logger.store({"train/loss": loss.item()})
def train(self, env, test_envs, steps, log_every, num_eval_eps):
obs = env.reset()
for timestep in range(steps):
self.train_it += 1
if (timestep + 1) % log_every == 0:
self.test_agent(self.model, self.logger, test_envs, num_eval_eps)
self.log(self.logger, self.train_it, self.model)
output = self.model(obs)
action, log_prob = self.model.sample(output)
self.update(env, output, self.model, self.optim, self.logger)
obs, reward, done, info = env.step(action)
if done:
obs = env.reset()