Apple / run.py
New Author Name
init
4b714e2
import torch
import wandb
from apple.envs.discrete_apple import get_apple_env
from apple.logger import EpochLogger
from apple.models.categorical_policy import CategoricalPolicy
from apple.training.reinforce_trainer import ReinforceTrainer
from apple.training.trainer import Trainer
from apple.utils import set_seed
from input_args import apple_parse_args
def main(
run_kind: str,
c: float = 1.0,
start_x: float = 0.0,
goal_x: float = 50.0,
time_limit: int = 200,
bias_in_state: bool = True,
position_in_state: bool = False,
apple_in_state: bool = True,
lr: float = 1e-3,
pretrain_steps: int = 0,
steps: int = 10000,
log_every: int = 1,
num_eval_eps: int = 1,
pretrain: str = "phase1",
finetune: str = "full",
log_to_wandb: bool = False,
wandbcommit: int = 1,
verbose: bool = False,
output_dir="logs/apple",
gamma: float = 1.0,
update_every: int = 10,
seed=0,
**kwargs,
):
set_seed(seed)
logger = EpochLogger(
exp_name=run_kind,
output_dir=output_dir,
log_to_wandb=log_to_wandb,
wandbcommit=wandbcommit,
verbose=verbose,
)
env_kwargs = dict(
start_x=start_x,
goal_x=goal_x,
c=c,
time_limit=time_limit,
bias_in_state=bias_in_state,
position_in_state=position_in_state,
apple_in_state=apple_in_state,
)
env_phase1 = get_apple_env(pretrain, **env_kwargs)
env_phase2 = get_apple_env(finetune, **env_kwargs)
test_envs = [get_apple_env(task, **env_kwargs) for task in ["full", "phase1", "phase2"]]
model = CategoricalPolicy(env_phase1.observation_space.shape[0], 1)
optim = torch.optim.SGD(model.parameters(), lr=lr)
if run_kind == "reinforce":
trainer = ReinforceTrainer(model, optim, logger, gamma=gamma)
trainer.train(env_phase1, test_envs, pretrain_steps, log_every, update_every, num_eval_eps)
trainer.train(env_phase2, test_envs, steps, log_every, update_every, num_eval_eps)
elif run_kind == "bc":
trainer = Trainer(model, optim, logger)
trainer.train(env_phase1, test_envs, pretrain_steps, log_every, num_eval_eps)
trainer.train(env_phase2, test_envs, steps, log_every, num_eval_eps)
if __name__ == "__main__":
args = apple_parse_args()
if args.log_to_wandb:
wandb.init(
entity="gmum",
project="apple",
config=args,
settings=wandb.Settings(start_method="fork"),
)
main(**vars(args))