Spaces:
Build error
Build error
Add posibility to save and load models
Browse filesAlso add an evaluation task to evaluate saved models
- .gitignore +4 -1
- a3c/discrete_A3C.py +9 -6
- a3c/utils.py +6 -0
- main.py +29 -4
.gitignore
CHANGED
|
@@ -113,4 +113,7 @@ GitHub.sublime-settings
|
|
| 113 |
!.vscode/tasks.json
|
| 114 |
!.vscode/launch.json
|
| 115 |
!.vscode/extensions.json
|
| 116 |
-
.history
|
|
|
|
|
|
|
|
|
|
|
|
| 113 |
!.vscode/tasks.json
|
| 114 |
!.vscode/launch.json
|
| 115 |
!.vscode/extensions.json
|
| 116 |
+
.history
|
| 117 |
+
|
| 118 |
+
# PyTorch model files
|
| 119 |
+
*.pth
|
a3c/discrete_A3C.py
CHANGED
|
@@ -6,14 +6,14 @@ View more on my Chinese tutorial page [莫烦Python](https://morvanzhou.github.i
|
|
| 6 |
"""
|
| 7 |
import os
|
| 8 |
import torch.multiprocessing as mp
|
| 9 |
-
from .utils import v_wrap, push_and_pull, record
|
| 10 |
from .shared_adam import SharedAdam
|
| 11 |
from .net import Net
|
| 12 |
|
| 13 |
GAMMA = 0.65
|
| 14 |
|
| 15 |
class Worker(mp.Process):
|
| 16 |
-
def __init__(self, max_ep, gnet, opt, global_ep, global_ep_r, res_queue, name, env, N_S, N_A, words_list, word_width, winning_ep):
|
| 17 |
super(Worker, self).__init__()
|
| 18 |
self.max_ep = max_ep
|
| 19 |
self.name = 'w%02i' % name
|
|
@@ -22,6 +22,7 @@ class Worker(mp.Process):
|
|
| 22 |
self.word_list = words_list
|
| 23 |
self.lnet = Net(N_S, N_A, words_list, word_width) # local network
|
| 24 |
self.env = env.unwrapped
|
|
|
|
| 25 |
|
| 26 |
def run(self):
|
| 27 |
while self.g_ep.value < self.max_ep:
|
|
@@ -40,16 +41,18 @@ class Worker(mp.Process):
|
|
| 40 |
# sync
|
| 41 |
push_and_pull(self.opt, self.lnet, self.gnet, done, s_, buffer_s, buffer_a, buffer_r, GAMMA)
|
| 42 |
goal_word = self.word_list[self.env.goal_word]
|
| 43 |
-
record(self.g_ep, self.g_ep_r, ep_r, self.res_queue, self.name, goal_word, self.word_list[a], len(buffer_a), self.winning_ep)
|
|
|
|
| 44 |
buffer_s, buffer_a, buffer_r = [], [], []
|
| 45 |
break
|
| 46 |
s = s_
|
| 47 |
self.res_queue.put(None)
|
| 48 |
|
| 49 |
|
| 50 |
-
def train(env, max_ep):
|
| 51 |
os.environ["OMP_NUM_THREADS"] = "1"
|
| 52 |
-
|
|
|
|
| 53 |
n_s = env.observation_space.shape[0]
|
| 54 |
n_a = env.action_space.n
|
| 55 |
words_list = env.words
|
|
@@ -60,7 +63,7 @@ def train(env, max_ep):
|
|
| 60 |
global_ep, global_ep_r, res_queue, win_ep = mp.Value('i', 0), mp.Value('d', 0.), mp.Queue(), mp.Value('i', 0)
|
| 61 |
|
| 62 |
# parallel training
|
| 63 |
-
workers = [Worker(max_ep, gnet, opt, global_ep, global_ep_r, res_queue, i, env, n_s, n_a, words_list, word_width, win_ep) for i in range(mp.cpu_count())]
|
| 64 |
[w.start() for w in workers]
|
| 65 |
res = [] # record episode reward to plot
|
| 66 |
while True:
|
|
|
|
| 6 |
"""
|
| 7 |
import os
|
| 8 |
import torch.multiprocessing as mp
|
| 9 |
+
from .utils import v_wrap, push_and_pull, record, save_model
|
| 10 |
from .shared_adam import SharedAdam
|
| 11 |
from .net import Net
|
| 12 |
|
| 13 |
GAMMA = 0.65
|
| 14 |
|
| 15 |
class Worker(mp.Process):
|
| 16 |
+
def __init__(self, max_ep, gnet, opt, global_ep, global_ep_r, res_queue, name, env, N_S, N_A, words_list, word_width, winning_ep, model_checkpoint_dir):
|
| 17 |
super(Worker, self).__init__()
|
| 18 |
self.max_ep = max_ep
|
| 19 |
self.name = 'w%02i' % name
|
|
|
|
| 22 |
self.word_list = words_list
|
| 23 |
self.lnet = Net(N_S, N_A, words_list, word_width) # local network
|
| 24 |
self.env = env.unwrapped
|
| 25 |
+
self.model_checkpoint_dir = model_checkpoint_dir
|
| 26 |
|
| 27 |
def run(self):
|
| 28 |
while self.g_ep.value < self.max_ep:
|
|
|
|
| 41 |
# sync
|
| 42 |
push_and_pull(self.opt, self.lnet, self.gnet, done, s_, buffer_s, buffer_a, buffer_r, GAMMA)
|
| 43 |
goal_word = self.word_list[self.env.goal_word]
|
| 44 |
+
record( self.g_ep, self.g_ep_r, ep_r, self.res_queue, self.name, goal_word, self.word_list[a], len(buffer_a), self.winning_ep)
|
| 45 |
+
save_model(self.gnet, self.model_checkpoint_dir, self.g_ep.value, self.g_ep_r.value)
|
| 46 |
buffer_s, buffer_a, buffer_r = [], [], []
|
| 47 |
break
|
| 48 |
s = s_
|
| 49 |
self.res_queue.put(None)
|
| 50 |
|
| 51 |
|
| 52 |
+
def train(env, max_ep, model_checkpoint_dir):
|
| 53 |
os.environ["OMP_NUM_THREADS"] = "1"
|
| 54 |
+
if not os.path.exists(model_checkpoint_dir):
|
| 55 |
+
os.makedirs(model_checkpoint_dir)
|
| 56 |
n_s = env.observation_space.shape[0]
|
| 57 |
n_a = env.action_space.n
|
| 58 |
words_list = env.words
|
|
|
|
| 63 |
global_ep, global_ep_r, res_queue, win_ep = mp.Value('i', 0), mp.Value('d', 0.), mp.Queue(), mp.Value('i', 0)
|
| 64 |
|
| 65 |
# parallel training
|
| 66 |
+
workers = [Worker(max_ep, gnet, opt, global_ep, global_ep_r, res_queue, i, env, n_s, n_a, words_list, word_width, win_ep, model_checkpoint_dir) for i in range(mp.cpu_count())]
|
| 67 |
[w.start() for w in workers]
|
| 68 |
res = [] # record episode reward to plot
|
| 69 |
while True:
|
a3c/utils.py
CHANGED
|
@@ -1,6 +1,7 @@
|
|
| 1 |
"""
|
| 2 |
Functions that use multiple times
|
| 3 |
"""
|
|
|
|
| 4 |
from torch import nn
|
| 5 |
import torch
|
| 6 |
import numpy as np
|
|
@@ -46,6 +47,11 @@ def push_and_pull(opt, lnet, gnet, done, s_, bs, ba, br, gamma):
|
|
| 46 |
lnet.load_state_dict(gnet.state_dict())
|
| 47 |
|
| 48 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 49 |
def record(global_ep, global_ep_r, ep_r, res_queue, name, goal_word, action, action_number, winning_ep):
|
| 50 |
with global_ep.get_lock():
|
| 51 |
global_ep.value += 1
|
|
|
|
| 1 |
"""
|
| 2 |
Functions that use multiple times
|
| 3 |
"""
|
| 4 |
+
import os
|
| 5 |
from torch import nn
|
| 6 |
import torch
|
| 7 |
import numpy as np
|
|
|
|
| 47 |
lnet.load_state_dict(gnet.state_dict())
|
| 48 |
|
| 49 |
|
| 50 |
+
def save_model(gnet, dir, episode, reward):
|
| 51 |
+
if reward >= 9 and episode % 100 == 0:
|
| 52 |
+
torch.save(gnet.state_dict(), os.path.join(dir, f'model_{episode}.pth'))
|
| 53 |
+
|
| 54 |
+
|
| 55 |
def record(global_ep, global_ep_r, ep_r, res_queue, name, goal_word, action, action_number, winning_ep):
|
| 56 |
with global_ep.get_lock():
|
| 57 |
global_ep.value += 1
|
main.py
CHANGED
|
@@ -1,10 +1,29 @@
|
|
| 1 |
import sys
|
|
|
|
| 2 |
import gym
|
|
|
|
| 3 |
import matplotlib.pyplot as plt
|
| 4 |
from a3c.discrete_A3C import train
|
| 5 |
from a3c.utils import v_wrap
|
|
|
|
| 6 |
from wordle_env.wordle import WordleEnvBase
|
| 7 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 8 |
|
| 9 |
def evaluate(net, env):
|
| 10 |
print("Evaluation mode")
|
|
@@ -21,9 +40,9 @@ def evaluate(net, env):
|
|
| 21 |
# else:
|
| 22 |
# print("Lost!", goal_word, outcomes)
|
| 23 |
n_guesses += len(outcomes)
|
| 24 |
-
|
| 25 |
print(f"Evaluation complete, won {n_wins/N*100}% and took {n_win_guesses/n_wins} guesses per win, "
|
| 26 |
f"{n_guesses / N} including losses.")
|
|
|
|
| 27 |
|
| 28 |
def play(net, env):
|
| 29 |
state = env.reset()
|
|
@@ -51,7 +70,13 @@ def print_results(global_ep, win_ep, res):
|
|
| 51 |
if __name__ == "__main__":
|
| 52 |
max_ep = int(sys.argv[1]) if len(sys.argv) > 1 else 100000
|
| 53 |
env_id = sys.argv[2] if len(sys.argv) > 2 else 'WordleEnv100FullAction-v0'
|
|
|
|
| 54 |
env = gym.make(env_id)
|
| 55 |
-
|
| 56 |
-
|
| 57 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
import sys
|
| 2 |
+
import os
|
| 3 |
import gym
|
| 4 |
+
import torch
|
| 5 |
import matplotlib.pyplot as plt
|
| 6 |
from a3c.discrete_A3C import train
|
| 7 |
from a3c.utils import v_wrap
|
| 8 |
+
from a3c.net import Net
|
| 9 |
from wordle_env.wordle import WordleEnvBase
|
| 10 |
|
| 11 |
+
def evaluate_checkpoints(dir, env):
|
| 12 |
+
n_s = env.observation_space.shape[0]
|
| 13 |
+
n_a = env.action_space.n
|
| 14 |
+
words_list = env.words
|
| 15 |
+
word_width = len(env.words[0])
|
| 16 |
+
net = Net(n_s, n_a, words_list, word_width)
|
| 17 |
+
results = {}
|
| 18 |
+
print(dir)
|
| 19 |
+
for checkpoint in os.listdir(dir):
|
| 20 |
+
checkpoint_path = os.path.join(dir, checkpoint)
|
| 21 |
+
if os.path.isfile(checkpoint_path):
|
| 22 |
+
net.load_state_dict(torch.load(checkpoint_path))
|
| 23 |
+
wins, guesses = evaluate(net, env)
|
| 24 |
+
results[checkpoint] = wins, guesses
|
| 25 |
+
return dict(sorted(results.items(), key=lambda x: (x[1][0], -x[1][1]), reverse=True))
|
| 26 |
+
|
| 27 |
|
| 28 |
def evaluate(net, env):
|
| 29 |
print("Evaluation mode")
|
|
|
|
| 40 |
# else:
|
| 41 |
# print("Lost!", goal_word, outcomes)
|
| 42 |
n_guesses += len(outcomes)
|
|
|
|
| 43 |
print(f"Evaluation complete, won {n_wins/N*100}% and took {n_win_guesses/n_wins} guesses per win, "
|
| 44 |
f"{n_guesses / N} including losses.")
|
| 45 |
+
return n_wins/N*100, n_win_guesses/n_wins
|
| 46 |
|
| 47 |
def play(net, env):
|
| 48 |
state = env.reset()
|
|
|
|
| 70 |
if __name__ == "__main__":
|
| 71 |
max_ep = int(sys.argv[1]) if len(sys.argv) > 1 else 100000
|
| 72 |
env_id = sys.argv[2] if len(sys.argv) > 2 else 'WordleEnv100FullAction-v0'
|
| 73 |
+
evaluation = True if len(sys.argv) > 3 and sys.argv[3] == 'evaluation' else False
|
| 74 |
env = gym.make(env_id)
|
| 75 |
+
model_checkpoint_dir = os.path.join('checkpoints', env.unwrapped.spec.id)
|
| 76 |
+
if not evaluation:
|
| 77 |
+
global_ep, win_ep, gnet, res = train(env, max_ep, model_checkpoint_dir)
|
| 78 |
+
print_results(global_ep, win_ep, res)
|
| 79 |
+
evaluate(gnet, env)
|
| 80 |
+
else:
|
| 81 |
+
results = evaluate_checkpoints(model_checkpoint_dir, env)
|
| 82 |
+
print(results)
|