File size: 2,557 Bytes
4b714e2
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
import torch

import wandb

from apple.envs.discrete_apple import get_apple_env
from apple.logger import EpochLogger
from apple.models.categorical_policy import CategoricalPolicy
from apple.training.reinforce_trainer import ReinforceTrainer
from apple.training.trainer import Trainer
from apple.utils import set_seed
from input_args import apple_parse_args


def main(
    run_kind: str,
    c: float = 1.0,
    start_x: float = 0.0,
    goal_x: float = 50.0,
    time_limit: int = 200,
    bias_in_state: bool = True,
    position_in_state: bool = False,
    apple_in_state: bool = True,
    lr: float = 1e-3,
    pretrain_steps: int = 0,
    steps: int = 10000,
    log_every: int = 1,
    num_eval_eps: int = 1,
    pretrain: str = "phase1",
    finetune: str = "full",
    log_to_wandb: bool = False,
    wandbcommit: int = 1,
    verbose: bool = False,
    output_dir="logs/apple",
    gamma: float = 1.0,
    update_every: int = 10,
    seed=0,
    **kwargs,
):
    set_seed(seed)

    logger = EpochLogger(
        exp_name=run_kind,
        output_dir=output_dir,
        log_to_wandb=log_to_wandb,
        wandbcommit=wandbcommit,
        verbose=verbose,
    )

    env_kwargs = dict(
        start_x=start_x,
        goal_x=goal_x,
        c=c,
        time_limit=time_limit,
        bias_in_state=bias_in_state,
        position_in_state=position_in_state,
        apple_in_state=apple_in_state,
    )

    env_phase1 = get_apple_env(pretrain, **env_kwargs)
    env_phase2 = get_apple_env(finetune, **env_kwargs)
    test_envs = [get_apple_env(task, **env_kwargs) for task in ["full", "phase1", "phase2"]]

    model = CategoricalPolicy(env_phase1.observation_space.shape[0], 1)
    optim = torch.optim.SGD(model.parameters(), lr=lr)

    if run_kind == "reinforce":
        trainer = ReinforceTrainer(model, optim, logger, gamma=gamma)
        trainer.train(env_phase1, test_envs, pretrain_steps, log_every, update_every, num_eval_eps)
        trainer.train(env_phase2, test_envs, steps, log_every, update_every, num_eval_eps)
    elif run_kind == "bc":
        trainer = Trainer(model, optim, logger)
        trainer.train(env_phase1, test_envs, pretrain_steps, log_every, num_eval_eps)
        trainer.train(env_phase2, test_envs, steps, log_every, num_eval_eps)


if __name__ == "__main__":
    args = apple_parse_args()

    if args.log_to_wandb:
        wandb.init(
            entity="gmum",
            project="apple",
            config=args,
            settings=wandb.Settings(start_method="fork"),
        )

    main(**vars(args))