import torch import wandb from apple.envs.discrete_apple import get_apple_env from apple.logger import EpochLogger from apple.models.categorical_policy import CategoricalPolicy from apple.training.reinforce_trainer import ReinforceTrainer from apple.training.trainer import Trainer from apple.utils import set_seed from input_args import apple_parse_args def main( run_kind: str, c: float = 1.0, start_x: float = 0.0, goal_x: float = 50.0, time_limit: int = 200, bias_in_state: bool = True, position_in_state: bool = False, apple_in_state: bool = True, lr: float = 1e-3, pretrain_steps: int = 0, steps: int = 10000, log_every: int = 1, num_eval_eps: int = 1, pretrain: str = "phase1", finetune: str = "full", log_to_wandb: bool = False, wandbcommit: int = 1, verbose: bool = False, output_dir="logs/apple", gamma: float = 1.0, update_every: int = 10, seed=0, **kwargs, ): set_seed(seed) logger = EpochLogger( exp_name=run_kind, output_dir=output_dir, log_to_wandb=log_to_wandb, wandbcommit=wandbcommit, verbose=verbose, ) env_kwargs = dict( start_x=start_x, goal_x=goal_x, c=c, time_limit=time_limit, bias_in_state=bias_in_state, position_in_state=position_in_state, apple_in_state=apple_in_state, ) env_phase1 = get_apple_env(pretrain, **env_kwargs) env_phase2 = get_apple_env(finetune, **env_kwargs) test_envs = [get_apple_env(task, **env_kwargs) for task in ["full", "phase1", "phase2"]] model = CategoricalPolicy(env_phase1.observation_space.shape[0], 1) optim = torch.optim.SGD(model.parameters(), lr=lr) if run_kind == "reinforce": trainer = ReinforceTrainer(model, optim, logger, gamma=gamma) trainer.train(env_phase1, test_envs, pretrain_steps, log_every, update_every, num_eval_eps) trainer.train(env_phase2, test_envs, steps, log_every, update_every, num_eval_eps) elif run_kind == "bc": trainer = Trainer(model, optim, logger) trainer.train(env_phase1, test_envs, pretrain_steps, log_every, num_eval_eps) trainer.train(env_phase2, test_envs, steps, log_every, num_eval_eps) if __name__ == "__main__": args = apple_parse_args() if args.log_to_wandb: wandb.init( entity="gmum", project="apple", config=args, settings=wandb.Settings(start_method="fork"), ) main(**vars(args))