Spaces:
Build error
Build error
| import os | |
| import random | |
| import numpy as np | |
| import torch | |
| import torch.multiprocessing as mp | |
| from .net import Net | |
| from .shared_adam import SharedAdam | |
| from .worker import Worker | |
| def _set_seed(seed: int = 100) -> None: | |
| np.random.seed(seed) | |
| random.seed(seed) | |
| torch.manual_seed(seed) | |
| if torch.cuda.is_available(): | |
| torch.cuda.manual_seed(seed) | |
| # When running on the CuDNN backend, two further options must be set | |
| torch.backends.cudnn.deterministic = True | |
| torch.backends.cudnn.benchmark = False | |
| # Set a fixed value for the hash seed | |
| os.environ["PYTHONHASHSEED"] = str(seed) | |
| def train( | |
| env, | |
| max_ep, | |
| model_checkpoint_dir, | |
| gamma=0.0, | |
| seed=100, | |
| pretrained_model_path=None, | |
| save=False, | |
| min_reward=9.9, | |
| every_n_save=100, | |
| ): | |
| os.environ["OMP_NUM_THREADS"] = "1" | |
| if not os.path.exists(model_checkpoint_dir): | |
| os.makedirs(model_checkpoint_dir) | |
| n_s = env.observation_space.shape[0] | |
| n_a = env.action_space.n | |
| words_list = env.words | |
| word_width = len(env.words[0]) | |
| # Set global seeds for randoms | |
| _set_seed(seed) | |
| gnet = Net(n_s, n_a, words_list, word_width) # global network | |
| if pretrained_model_path: | |
| gnet.load_state_dict(torch.load(pretrained_model_path)) | |
| gnet.share_memory() # share the global parameters in multiprocessing | |
| opt = SharedAdam( | |
| gnet.parameters(), lr=1e-4, betas=(0.92, 0.999) | |
| ) # global optimizer | |
| global_ep, global_ep_r, res_queue, win_ep = ( | |
| mp.Value("i", 0), | |
| mp.Value("d", 0.0), | |
| mp.Queue(), | |
| mp.Value("i", 0), | |
| ) | |
| # parallel training | |
| workers = [ | |
| Worker( | |
| max_ep, | |
| gnet, | |
| opt, | |
| global_ep, | |
| global_ep_r, | |
| res_queue, | |
| i, | |
| env, | |
| n_s, | |
| n_a, | |
| words_list, | |
| word_width, | |
| win_ep, | |
| model_checkpoint_dir, | |
| gamma, | |
| pretrained_model_path, | |
| save, | |
| min_reward, | |
| every_n_save, | |
| ) | |
| for i in range(mp.cpu_count()) | |
| ] | |
| [w.start() for w in workers] | |
| res = [] # record episode reward to plot | |
| while True: | |
| r = res_queue.get() | |
| if r is not None: | |
| res.append(r) | |
| else: | |
| break | |
| [w.join() for w in workers] | |
| if save: | |
| torch.save( | |
| gnet.state_dict(), | |
| os.path.join(model_checkpoint_dir, f"model_{env.unwrapped.spec.id}.pth"), | |
| ) | |
| return global_ep, win_ep, gnet, res | |