Spaces:
Build error
Build error
Add play mode
Browse filesFrom a word a state and a saved model the model returns the probable goal word
- a3c/eval.py +1 -15
- a3c/play.py +48 -0
- main.py +21 -1
a3c/eval.py
CHANGED
|
@@ -2,6 +2,7 @@ import os
|
|
| 2 |
import torch
|
| 3 |
|
| 4 |
from .net import GreedyNet
|
|
|
|
| 5 |
from .utils import v_wrap
|
| 6 |
|
| 7 |
|
|
@@ -38,18 +39,3 @@ def evaluate(net, env):
|
|
| 38 |
print(f"Evaluation complete, won {n_wins/N*100}% and took {n_win_guesses/n_wins} guesses per win, "
|
| 39 |
f"{n_guesses / N} including losses.")
|
| 40 |
return n_wins/N*100, n_win_guesses/n_wins
|
| 41 |
-
|
| 42 |
-
|
| 43 |
-
def play(net, env):
|
| 44 |
-
state = env.reset()
|
| 45 |
-
outcomes = []
|
| 46 |
-
win = False
|
| 47 |
-
for i in range(env.max_turns):
|
| 48 |
-
action = net.choose_action(v_wrap(state[None, :]))
|
| 49 |
-
state, reward, done, _ = env.step(action)
|
| 50 |
-
outcomes.append((env.words[action], reward))
|
| 51 |
-
if done:
|
| 52 |
-
if reward >= 0:
|
| 53 |
-
win = True
|
| 54 |
-
break
|
| 55 |
-
return win, outcomes
|
|
|
|
| 2 |
import torch
|
| 3 |
|
| 4 |
from .net import GreedyNet
|
| 5 |
+
from .play import play
|
| 6 |
from .utils import v_wrap
|
| 7 |
|
| 8 |
|
|
|
|
| 39 |
print(f"Evaluation complete, won {n_wins/N*100}% and took {n_win_guesses/n_wins} guesses per win, "
|
| 40 |
f"{n_guesses / N} including losses.")
|
| 41 |
return n_wins/N*100, n_win_guesses/n_wins
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
a3c/play.py
ADDED
|
@@ -0,0 +1,48 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import torch
|
| 2 |
+
from .net import GreedyNet
|
| 3 |
+
from .utils import v_wrap
|
| 4 |
+
from wordle_env.state import update_from_mask
|
| 5 |
+
|
| 6 |
+
|
| 7 |
+
def suggest(
|
| 8 |
+
env,
|
| 9 |
+
words,
|
| 10 |
+
states,
|
| 11 |
+
pretrained_model_path
|
| 12 |
+
) -> str:
|
| 13 |
+
"""
|
| 14 |
+
Given a list of words and masks, return the next suggested word
|
| 15 |
+
|
| 16 |
+
:param agent:
|
| 17 |
+
:param env:
|
| 18 |
+
:param sequence: History of moves and outcomes until now
|
| 19 |
+
:return:
|
| 20 |
+
"""
|
| 21 |
+
n_s = env.observation_space.shape[0]
|
| 22 |
+
n_a = env.action_space.n
|
| 23 |
+
env = env.unwrapped
|
| 24 |
+
state = env.reset()
|
| 25 |
+
words_list = env.words
|
| 26 |
+
word_width = len(env.words[0])
|
| 27 |
+
net = GreedyNet(n_s, n_a, words_list, word_width)
|
| 28 |
+
net.load_state_dict(torch.load(pretrained_model_path))
|
| 29 |
+
for word, mask in zip(words, states):
|
| 30 |
+
word = word.upper()
|
| 31 |
+
mask = list(map(int, mask))
|
| 32 |
+
state = update_from_mask(state, word, mask)
|
| 33 |
+
return env.words[net.choose_action(v_wrap(state[None, :]))]
|
| 34 |
+
|
| 35 |
+
|
| 36 |
+
def play(net, env):
|
| 37 |
+
state = env.reset()
|
| 38 |
+
outcomes = []
|
| 39 |
+
win = False
|
| 40 |
+
for i in range(env.max_turns):
|
| 41 |
+
action = net.choose_action(v_wrap(state[None, :]))
|
| 42 |
+
state, reward, done, _ = env.step(action)
|
| 43 |
+
outcomes.append((env.words[action], reward))
|
| 44 |
+
if done:
|
| 45 |
+
if reward >= 0:
|
| 46 |
+
win = True
|
| 47 |
+
break
|
| 48 |
+
return win, outcomes
|
main.py
CHANGED
|
@@ -8,6 +8,7 @@ import time
|
|
| 8 |
import matplotlib.pyplot as plt
|
| 9 |
from a3c.train import train
|
| 10 |
from a3c.eval import evaluate, evaluate_checkpoints
|
|
|
|
| 11 |
from wordle_env.wordle import WordleEnvBase
|
| 12 |
|
| 13 |
|
|
@@ -27,6 +28,15 @@ def evaluation_mode(args, env, model_checkpoint_dir):
|
|
| 27 |
print(results)
|
| 28 |
|
| 29 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 30 |
def print_results(global_ep, win_ep, res):
|
| 31 |
print("Jugadas:", global_ep.value)
|
| 32 |
print("Ganadas:", win_ep.value)
|
|
@@ -49,7 +59,7 @@ if __name__ == "__main__":
|
|
| 49 |
parser_train.add_argument(
|
| 50 |
"--games", "-g", help="Number of games to train", type=int, required=True)
|
| 51 |
parser_train.add_argument(
|
| 52 |
-
"--model_name", "-
|
| 53 |
parser_train.add_argument(
|
| 54 |
"--gamma", help="Gamma hyperparameter (discount factor) value", type=float, default=0.)
|
| 55 |
parser_train.add_argument(
|
|
@@ -64,6 +74,16 @@ if __name__ == "__main__":
|
|
| 64 |
'eval', help='Evaluate saved models for the enviroment')
|
| 65 |
parser_eval.set_defaults(func=evaluation_mode)
|
| 66 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 67 |
args = parser.parse_args()
|
| 68 |
env_id = args.enviroment
|
| 69 |
env = gym.make(env_id)
|
|
|
|
| 8 |
import matplotlib.pyplot as plt
|
| 9 |
from a3c.train import train
|
| 10 |
from a3c.eval import evaluate, evaluate_checkpoints
|
| 11 |
+
from a3c.play import suggest
|
| 12 |
from wordle_env.wordle import WordleEnvBase
|
| 13 |
|
| 14 |
|
|
|
|
| 28 |
print(results)
|
| 29 |
|
| 30 |
|
| 31 |
+
def play_mode(args, env, model_checkpoint_dir):
|
| 32 |
+
print("Play mode")
|
| 33 |
+
words = [ word.strip() for word in args.words.split(',') ]
|
| 34 |
+
states = [ state.strip() for state in args.states.split(',') ]
|
| 35 |
+
pretrained_model_path = os.path.join(model_checkpoint_dir, args.model_name)
|
| 36 |
+
word = suggest(env, words, states, pretrained_model_path)
|
| 37 |
+
print(word)
|
| 38 |
+
|
| 39 |
+
|
| 40 |
def print_results(global_ep, win_ep, res):
|
| 41 |
print("Jugadas:", global_ep.value)
|
| 42 |
print("Ganadas:", win_ep.value)
|
|
|
|
| 59 |
parser_train.add_argument(
|
| 60 |
"--games", "-g", help="Number of games to train", type=int, required=True)
|
| 61 |
parser_train.add_argument(
|
| 62 |
+
"--model_name", "-m", help="If want to train from a pretrained model, the name of the pretrained model file")
|
| 63 |
parser_train.add_argument(
|
| 64 |
"--gamma", help="Gamma hyperparameter (discount factor) value", type=float, default=0.)
|
| 65 |
parser_train.add_argument(
|
|
|
|
| 74 |
'eval', help='Evaluate saved models for the enviroment')
|
| 75 |
parser_eval.set_defaults(func=evaluation_mode)
|
| 76 |
|
| 77 |
+
parser_play = subparsers.add_parser(
|
| 78 |
+
'play', help='Give the model a word and the state result and the model will try to predict the goal word')
|
| 79 |
+
parser_play.add_argument(
|
| 80 |
+
"--words", "-w", help="List of words played in the wordle game", required=True)
|
| 81 |
+
parser_play.add_argument(
|
| 82 |
+
"--states", "-st", help="List of states returned by playing each of the words", required=True)
|
| 83 |
+
parser_play.add_argument(
|
| 84 |
+
"--model_name", "-m", help="Name of the pretrained model file thich will play the game", required=True)
|
| 85 |
+
parser_play.set_defaults(func=play_mode)
|
| 86 |
+
|
| 87 |
args = parser.parse_args()
|
| 88 |
env_id = args.enviroment
|
| 89 |
env = gym.make(env_id)
|