leonepson's picture
Upload 254 files
5960497 verified
import os
import click
import numpy as np
import json
from mpi4py import MPI
from baselines import logger
from baselines.common import set_global_seeds, tf_util
from baselines.common.mpi_moments import mpi_moments
import baselines.her.experiment.config as config
from baselines.her.rollout import RolloutWorker
def mpi_average(value):
if not isinstance(value, list):
value = [value]
if not any(value):
value = [0.]
return mpi_moments(np.array(value))[0]
def train(*, policy, rollout_worker, evaluator,
n_epochs, n_test_rollouts, n_cycles, n_batches, policy_save_interval,
save_path, demo_file, **kwargs):
rank = MPI.COMM_WORLD.Get_rank()
if save_path:
latest_policy_path = os.path.join(save_path, 'policy_latest.pkl')
best_policy_path = os.path.join(save_path, 'policy_best.pkl')
periodic_policy_path = os.path.join(save_path, 'policy_{}.pkl')
logger.info("Training...")
best_success_rate = -1
if policy.bc_loss == 1: policy.init_demo_buffer(demo_file) #initialize demo buffer if training with demonstrations
# num_timesteps = n_epochs * n_cycles * rollout_length * number of rollout workers
for epoch in range(n_epochs):
# train
rollout_worker.clear_history()
for _ in range(n_cycles):
episode = rollout_worker.generate_rollouts()
policy.store_episode(episode)
for _ in range(n_batches):
policy.train()
policy.update_target_net()
# test
evaluator.clear_history()
for _ in range(n_test_rollouts):
evaluator.generate_rollouts()
# record logs
logger.record_tabular('epoch', epoch)
for key, val in evaluator.logs('test'):
logger.record_tabular(key, mpi_average(val))
for key, val in rollout_worker.logs('train'):
logger.record_tabular(key, mpi_average(val))
for key, val in policy.logs():
logger.record_tabular(key, mpi_average(val))
if rank == 0:
logger.dump_tabular()
# save the policy if it's better than the previous ones
success_rate = mpi_average(evaluator.current_success_rate())
if rank == 0 and success_rate >= best_success_rate and save_path:
best_success_rate = success_rate
logger.info('New best success rate: {}. Saving policy to {} ...'.format(best_success_rate, best_policy_path))
evaluator.save_policy(best_policy_path)
evaluator.save_policy(latest_policy_path)
if rank == 0 and policy_save_interval > 0 and epoch % policy_save_interval == 0 and save_path:
policy_path = periodic_policy_path.format(epoch)
logger.info('Saving periodic policy to {} ...'.format(policy_path))
evaluator.save_policy(policy_path)
# make sure that different threads have different seeds
local_uniform = np.random.uniform(size=(1,))
root_uniform = local_uniform.copy()
MPI.COMM_WORLD.Bcast(root_uniform, root=0)
if rank != 0:
assert local_uniform[0] != root_uniform[0]
return policy
def learn(*, network, env, total_timesteps,
seed=None,
eval_env=None,
replay_strategy='future',
policy_save_interval=5,
clip_return=True,
demo_file=None,
override_params=None,
load_path=None,
save_path=None,
**kwargs
):
override_params = override_params or {}
if MPI is not None:
rank = MPI.COMM_WORLD.Get_rank()
num_cpu = MPI.COMM_WORLD.Get_size()
# Seed everything.
rank_seed = seed + 1000000 * rank if seed is not None else None
set_global_seeds(rank_seed)
# Prepare params.
params = config.DEFAULT_PARAMS
env_name = env.spec.id
params['env_name'] = env_name
params['replay_strategy'] = replay_strategy
if env_name in config.DEFAULT_ENV_PARAMS:
params.update(config.DEFAULT_ENV_PARAMS[env_name]) # merge env-specific parameters in
params.update(**override_params) # makes it possible to override any parameter
with open(os.path.join(logger.get_dir(), 'params.json'), 'w') as f:
json.dump(params, f)
params = config.prepare_params(params)
params['rollout_batch_size'] = env.num_envs
if demo_file is not None:
params['bc_loss'] = 1
params.update(kwargs)
config.log_params(params, logger=logger)
if num_cpu == 1:
logger.warn()
logger.warn('*** Warning ***')
logger.warn(
'You are running HER with just a single MPI worker. This will work, but the ' +
'experiments that we report in Plappert et al. (2018, https://arxiv.org/abs/1802.09464) ' +
'were obtained with --num_cpu 19. This makes a significant difference and if you ' +
'are looking to reproduce those results, be aware of this. Please also refer to ' +
'https://github.com/openai/baselines/issues/314 for further details.')
logger.warn('****************')
logger.warn()
dims = config.configure_dims(params)
policy = config.configure_ddpg(dims=dims, params=params, clip_return=clip_return)
if load_path is not None:
tf_util.load_variables(load_path)
rollout_params = {
'exploit': False,
'use_target_net': False,
'use_demo_states': True,
'compute_Q': False,
'T': params['T'],
}
eval_params = {
'exploit': True,
'use_target_net': params['test_with_polyak'],
'use_demo_states': False,
'compute_Q': True,
'T': params['T'],
}
for name in ['T', 'rollout_batch_size', 'gamma', 'noise_eps', 'random_eps']:
rollout_params[name] = params[name]
eval_params[name] = params[name]
eval_env = eval_env or env
rollout_worker = RolloutWorker(env, policy, dims, logger, monitor=True, **rollout_params)
evaluator = RolloutWorker(eval_env, policy, dims, logger, **eval_params)
n_cycles = params['n_cycles']
n_epochs = total_timesteps // n_cycles // rollout_worker.T // rollout_worker.rollout_batch_size
return train(
save_path=save_path, policy=policy, rollout_worker=rollout_worker,
evaluator=evaluator, n_epochs=n_epochs, n_test_rollouts=params['n_test_rollouts'],
n_cycles=params['n_cycles'], n_batches=params['n_batches'],
policy_save_interval=policy_save_interval, demo_file=demo_file)
@click.command()
@click.option('--env', type=str, default='FetchReach-v1', help='the name of the OpenAI Gym environment that you want to train on')
@click.option('--total_timesteps', type=int, default=int(5e5), help='the number of timesteps to run')
@click.option('--seed', type=int, default=0, help='the random seed used to seed both the environment and the training code')
@click.option('--policy_save_interval', type=int, default=5, help='the interval with which policy pickles are saved. If set to 0, only the best and latest policy will be pickled.')
@click.option('--replay_strategy', type=click.Choice(['future', 'none']), default='future', help='the HER replay strategy to be used. "future" uses HER, "none" disables HER.')
@click.option('--clip_return', type=int, default=1, help='whether or not returns should be clipped')
@click.option('--demo_file', type=str, default = 'PATH/TO/DEMO/DATA/FILE.npz', help='demo data file path')
def main(**kwargs):
learn(**kwargs)
if __name__ == '__main__':
main()