Spaces:
Runtime error
Runtime error
File size: 2,557 Bytes
4b714e2 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 |
import torch
import wandb
from apple.envs.discrete_apple import get_apple_env
from apple.logger import EpochLogger
from apple.models.categorical_policy import CategoricalPolicy
from apple.training.reinforce_trainer import ReinforceTrainer
from apple.training.trainer import Trainer
from apple.utils import set_seed
from input_args import apple_parse_args
def main(
run_kind: str,
c: float = 1.0,
start_x: float = 0.0,
goal_x: float = 50.0,
time_limit: int = 200,
bias_in_state: bool = True,
position_in_state: bool = False,
apple_in_state: bool = True,
lr: float = 1e-3,
pretrain_steps: int = 0,
steps: int = 10000,
log_every: int = 1,
num_eval_eps: int = 1,
pretrain: str = "phase1",
finetune: str = "full",
log_to_wandb: bool = False,
wandbcommit: int = 1,
verbose: bool = False,
output_dir="logs/apple",
gamma: float = 1.0,
update_every: int = 10,
seed=0,
**kwargs,
):
set_seed(seed)
logger = EpochLogger(
exp_name=run_kind,
output_dir=output_dir,
log_to_wandb=log_to_wandb,
wandbcommit=wandbcommit,
verbose=verbose,
)
env_kwargs = dict(
start_x=start_x,
goal_x=goal_x,
c=c,
time_limit=time_limit,
bias_in_state=bias_in_state,
position_in_state=position_in_state,
apple_in_state=apple_in_state,
)
env_phase1 = get_apple_env(pretrain, **env_kwargs)
env_phase2 = get_apple_env(finetune, **env_kwargs)
test_envs = [get_apple_env(task, **env_kwargs) for task in ["full", "phase1", "phase2"]]
model = CategoricalPolicy(env_phase1.observation_space.shape[0], 1)
optim = torch.optim.SGD(model.parameters(), lr=lr)
if run_kind == "reinforce":
trainer = ReinforceTrainer(model, optim, logger, gamma=gamma)
trainer.train(env_phase1, test_envs, pretrain_steps, log_every, update_every, num_eval_eps)
trainer.train(env_phase2, test_envs, steps, log_every, update_every, num_eval_eps)
elif run_kind == "bc":
trainer = Trainer(model, optim, logger)
trainer.train(env_phase1, test_envs, pretrain_steps, log_every, num_eval_eps)
trainer.train(env_phase2, test_envs, steps, log_every, num_eval_eps)
if __name__ == "__main__":
args = apple_parse_args()
if args.log_to_wandb:
wandb.init(
entity="gmum",
project="apple",
config=args,
settings=wandb.Settings(start_method="fork"),
)
main(**vars(args))
|