Spaces:
Build error
Build error
Add the posiblity to save checkpoints of the model and the condition on which the model is saved as arguments
Browse files- a3c/train.py +4 -2
- a3c/worker.py +27 -3
- main.py +9 -8
a3c/train.py
CHANGED
|
@@ -6,7 +6,7 @@ from .net import Net
|
|
| 6 |
from .worker import Worker
|
| 7 |
|
| 8 |
|
| 9 |
-
def train(env, max_ep, model_checkpoint_dir, gamma=0., pretrained_model_path=None):
|
| 10 |
os.environ["OMP_NUM_THREADS"] = "1"
|
| 11 |
if not os.path.exists(model_checkpoint_dir):
|
| 12 |
os.makedirs(model_checkpoint_dir)
|
|
@@ -23,7 +23,7 @@ def train(env, max_ep, model_checkpoint_dir, gamma=0., pretrained_model_path=Non
|
|
| 23 |
|
| 24 |
# parallel training
|
| 25 |
workers = [Worker(max_ep, gnet, opt, global_ep, global_ep_r, res_queue, i, env, n_s, n_a,
|
| 26 |
-
words_list, word_width, win_ep, model_checkpoint_dir, gamma, pretrained_model_path) for i in range(mp.cpu_count())]
|
| 27 |
[w.start() for w in workers]
|
| 28 |
res = [] # record episode reward to plot
|
| 29 |
while True:
|
|
@@ -33,4 +33,6 @@ def train(env, max_ep, model_checkpoint_dir, gamma=0., pretrained_model_path=Non
|
|
| 33 |
else:
|
| 34 |
break
|
| 35 |
[w.join() for w in workers]
|
|
|
|
|
|
|
| 36 |
return global_ep, win_ep, gnet, res
|
|
|
|
| 6 |
from .worker import Worker
|
| 7 |
|
| 8 |
|
| 9 |
+
def train(env, max_ep, model_checkpoint_dir, gamma=0., pretrained_model_path=None, save=False, min_reward=9.9, every_n_save=100):
|
| 10 |
os.environ["OMP_NUM_THREADS"] = "1"
|
| 11 |
if not os.path.exists(model_checkpoint_dir):
|
| 12 |
os.makedirs(model_checkpoint_dir)
|
|
|
|
| 23 |
|
| 24 |
# parallel training
|
| 25 |
workers = [Worker(max_ep, gnet, opt, global_ep, global_ep_r, res_queue, i, env, n_s, n_a,
|
| 26 |
+
words_list, word_width, win_ep, model_checkpoint_dir, gamma, pretrained_model_path, save, min_reward, every_n_save) for i in range(mp.cpu_count())]
|
| 27 |
[w.start() for w in workers]
|
| 28 |
res = [] # record episode reward to plot
|
| 29 |
while True:
|
|
|
|
| 33 |
else:
|
| 34 |
break
|
| 35 |
[w.join() for w in workers]
|
| 36 |
+
if save:
|
| 37 |
+
torch.save(gnet.state_dict(), os.path.join(model_checkpoint_dir, f'model_{env.unwrapped.spec.id}.pth'))
|
| 38 |
return global_ep, win_ep, gnet, res
|
a3c/worker.py
CHANGED
|
@@ -11,7 +11,28 @@ from .utils import v_wrap
|
|
| 11 |
|
| 12 |
|
| 13 |
class Worker(mp.Process):
|
| 14 |
-
def __init__(
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 15 |
super(Worker, self).__init__()
|
| 16 |
self.max_ep = max_ep
|
| 17 |
self.name = 'w%02i' % name
|
|
@@ -25,6 +46,9 @@ class Worker(mp.Process):
|
|
| 25 |
self.env = env.unwrapped
|
| 26 |
self.gamma = gamma
|
| 27 |
self.model_checkpoint_dir = model_checkpoint_dir
|
|
|
|
|
|
|
|
|
|
| 28 |
|
| 29 |
def run(self):
|
| 30 |
while self.g_ep.value < self.max_ep:
|
|
@@ -81,9 +105,9 @@ class Worker(mp.Process):
|
|
| 81 |
self.lnet.load_state_dict(self.gnet.state_dict())
|
| 82 |
|
| 83 |
def save_model(self):
|
| 84 |
-
if self.g_ep_r.value >=
|
| 85 |
torch.save(self.gnet.state_dict(), os.path.join(
|
| 86 |
-
self.model_checkpoint_dir, f'model_{
|
| 87 |
|
| 88 |
def record(self, ep_r, goal_word, action, action_number):
|
| 89 |
with self.g_ep.get_lock():
|
|
|
|
| 11 |
|
| 12 |
|
| 13 |
class Worker(mp.Process):
|
| 14 |
+
def __init__(
|
| 15 |
+
self,
|
| 16 |
+
max_ep,
|
| 17 |
+
gnet,
|
| 18 |
+
opt,
|
| 19 |
+
global_ep,
|
| 20 |
+
global_ep_r,
|
| 21 |
+
res_queue,
|
| 22 |
+
name,
|
| 23 |
+
env,
|
| 24 |
+
N_S,
|
| 25 |
+
N_A,
|
| 26 |
+
words_list,
|
| 27 |
+
word_width,
|
| 28 |
+
winning_ep,
|
| 29 |
+
model_checkpoint_dir,
|
| 30 |
+
gamma=0.,
|
| 31 |
+
pretrained_model_path=None,
|
| 32 |
+
save=False,
|
| 33 |
+
min_reward=9.9,
|
| 34 |
+
every_n_save=100
|
| 35 |
+
):
|
| 36 |
super(Worker, self).__init__()
|
| 37 |
self.max_ep = max_ep
|
| 38 |
self.name = 'w%02i' % name
|
|
|
|
| 46 |
self.env = env.unwrapped
|
| 47 |
self.gamma = gamma
|
| 48 |
self.model_checkpoint_dir = model_checkpoint_dir
|
| 49 |
+
self.save = save
|
| 50 |
+
self.min_reward = min_reward
|
| 51 |
+
self.every_n_save = every_n_save
|
| 52 |
|
| 53 |
def run(self):
|
| 54 |
while self.g_ep.value < self.max_ep:
|
|
|
|
| 105 |
self.lnet.load_state_dict(self.gnet.state_dict())
|
| 106 |
|
| 107 |
def save_model(self):
|
| 108 |
+
if self.save and self.g_ep_r.value >= self.min_reward and self.g_ep.value % self.every_n_save == 0:
|
| 109 |
torch.save(self.gnet.state_dict(), os.path.join(
|
| 110 |
+
self.model_checkpoint_dir, f'model_{self.g_ep.value}.pth'))
|
| 111 |
|
| 112 |
def record(self, ep_r, goal_word, action, action_number):
|
| 113 |
with self.g_ep.get_lock():
|
main.py
CHANGED
|
@@ -14,13 +14,8 @@ from wordle_env.wordle import WordleEnvBase
|
|
| 14 |
def training_mode(args, env, model_checkpoint_dir):
|
| 15 |
max_ep = args.games
|
| 16 |
start_time = time.time()
|
| 17 |
-
if args.model_name
|
| 18 |
-
|
| 19 |
-
model_checkpoint_dir, args.model_name)
|
| 20 |
-
global_ep, win_ep, gnet, res = train(
|
| 21 |
-
env, max_ep, model_checkpoint_dir, args.gamma, pretrained_model_path)
|
| 22 |
-
else:
|
| 23 |
-
global_ep, win_ep, gnet, res = train(env, max_ep, model_checkpoint_dir, args.gamma)
|
| 24 |
print("--- %.0f seconds ---" % (time.time() - start_time))
|
| 25 |
print_results(global_ep, win_ep, res)
|
| 26 |
evaluate(gnet, env)
|
|
@@ -56,7 +51,13 @@ if __name__ == "__main__":
|
|
| 56 |
parser_train.add_argument(
|
| 57 |
"--model_name", "-n", help="If want to train from a pretrained model, the name of the pretrained model file")
|
| 58 |
parser_train.add_argument(
|
| 59 |
-
"--gamma", help="Gamma hyperparameter value", type=float, default=0.)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 60 |
parser_train.set_defaults(func=training_mode)
|
| 61 |
|
| 62 |
parser_eval = subparsers.add_parser(
|
|
|
|
| 14 |
def training_mode(args, env, model_checkpoint_dir):
|
| 15 |
max_ep = args.games
|
| 16 |
start_time = time.time()
|
| 17 |
+
pretrained_model_path = os.path.join(model_checkpoint_dir, args.model_name) if args.model_name else args.model_name
|
| 18 |
+
global_ep, win_ep, gnet, res = train(env, max_ep, model_checkpoint_dir, args.gamma, pretrained_model_path, args.save, args.min_reward, args.every_n_save)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 19 |
print("--- %.0f seconds ---" % (time.time() - start_time))
|
| 20 |
print_results(global_ep, win_ep, res)
|
| 21 |
evaluate(gnet, env)
|
|
|
|
| 51 |
parser_train.add_argument(
|
| 52 |
"--model_name", "-n", help="If want to train from a pretrained model, the name of the pretrained model file")
|
| 53 |
parser_train.add_argument(
|
| 54 |
+
"--gamma", help="Gamma hyperparameter (discount factor) value", type=float, default=0.)
|
| 55 |
+
parser_train.add_argument(
|
| 56 |
+
"--save", '-s', help="Save instances of the model while training", action='store_true')
|
| 57 |
+
parser_train.add_argument(
|
| 58 |
+
"--min_reward", help="The minimun global reward value achieved for saving the model", type=float, default=9.9)
|
| 59 |
+
parser_train.add_argument(
|
| 60 |
+
"--every_n_save", help="Check every n training steps to save the model", type=int, default=100)
|
| 61 |
parser_train.set_defaults(func=training_mode)
|
| 62 |
|
| 63 |
parser_eval = subparsers.add_parser(
|