Spaces:
Runtime error
Runtime error
| import torch | |
| import wandb | |
| from apple.envs.discrete_apple import get_apple_env | |
| from apple.logger import EpochLogger | |
| from apple.models.categorical_policy import CategoricalPolicy | |
| from apple.training.reinforce_trainer import ReinforceTrainer | |
| from apple.training.trainer import Trainer | |
| from apple.utils import set_seed | |
| from input_args import apple_parse_args | |
| def main( | |
| run_kind: str, | |
| c: float = 1.0, | |
| start_x: float = 0.0, | |
| goal_x: float = 50.0, | |
| time_limit: int = 200, | |
| bias_in_state: bool = True, | |
| position_in_state: bool = False, | |
| apple_in_state: bool = True, | |
| lr: float = 1e-3, | |
| pretrain_steps: int = 0, | |
| steps: int = 10000, | |
| log_every: int = 1, | |
| num_eval_eps: int = 1, | |
| pretrain: str = "phase1", | |
| finetune: str = "full", | |
| log_to_wandb: bool = False, | |
| wandbcommit: int = 1, | |
| verbose: bool = False, | |
| output_dir="logs/apple", | |
| gamma: float = 1.0, | |
| update_every: int = 10, | |
| seed=0, | |
| **kwargs, | |
| ): | |
| set_seed(seed) | |
| logger = EpochLogger( | |
| exp_name=run_kind, | |
| output_dir=output_dir, | |
| log_to_wandb=log_to_wandb, | |
| wandbcommit=wandbcommit, | |
| verbose=verbose, | |
| ) | |
| env_kwargs = dict( | |
| start_x=start_x, | |
| goal_x=goal_x, | |
| c=c, | |
| time_limit=time_limit, | |
| bias_in_state=bias_in_state, | |
| position_in_state=position_in_state, | |
| apple_in_state=apple_in_state, | |
| ) | |
| env_phase1 = get_apple_env(pretrain, **env_kwargs) | |
| env_phase2 = get_apple_env(finetune, **env_kwargs) | |
| test_envs = [get_apple_env(task, **env_kwargs) for task in ["full", "phase1", "phase2"]] | |
| model = CategoricalPolicy(env_phase1.observation_space.shape[0], 1) | |
| optim = torch.optim.SGD(model.parameters(), lr=lr) | |
| if run_kind == "reinforce": | |
| trainer = ReinforceTrainer(model, optim, logger, gamma=gamma) | |
| trainer.train(env_phase1, test_envs, pretrain_steps, log_every, update_every, num_eval_eps) | |
| trainer.train(env_phase2, test_envs, steps, log_every, update_every, num_eval_eps) | |
| elif run_kind == "bc": | |
| trainer = Trainer(model, optim, logger) | |
| trainer.train(env_phase1, test_envs, pretrain_steps, log_every, num_eval_eps) | |
| trainer.train(env_phase2, test_envs, steps, log_every, num_eval_eps) | |
| if __name__ == "__main__": | |
| args = apple_parse_args() | |
| if args.log_to_wandb: | |
| wandb.init( | |
| entity="gmum", | |
| project="apple", | |
| config=args, | |
| settings=wandb.Settings(start_method="fork"), | |
| ) | |
| main(**vars(args)) | |