Spaces:
Build error
Build error
| import os | |
| import torch | |
| from dotenv import load_dotenv | |
| from huggingface_hub import hf_hub_download | |
| from wordle_env.state import update_from_mask | |
| from .net import GreedyNet | |
| from .utils import v_wrap | |
| load_dotenv() | |
| MODEL_NAME = os.getenv("RS_WORDLE_MODEL_NAME") | |
| HF_MODEL_REPO_NAME = os.getenv("HF_MODEL_REPO_NAME") | |
| MODEL_CHECKPOINT_DIR = "checkpoints" | |
| def get_play_model_path(): | |
| return os.path.join(MODEL_CHECKPOINT_DIR, MODEL_NAME) | |
| def get_net(env, pretrained_model_path): | |
| n_s = env.observation_space.shape[0] | |
| n_a = env.action_space.n | |
| words_list = env.words | |
| word_width = len(env.words[0]) | |
| net = GreedyNet(n_s, n_a, words_list, word_width) | |
| if not os.path.exists(pretrained_model_path): | |
| pretrained_model_path = hf_hub_download( | |
| HF_MODEL_REPO_NAME, MODEL_NAME, local_dir=MODEL_CHECKPOINT_DIR | |
| ) | |
| net.load_state_dict(torch.load(pretrained_model_path)) | |
| return net | |
| def get_initial_state(env): | |
| state = env.reset() | |
| return state | |
| def suggest(env, words, states, pretrained_model_path) -> str: | |
| """ | |
| Given a list of words and masks, return the next suggested word | |
| :param agent: | |
| :param env: | |
| :param sequence: History of moves and outcomes until now | |
| :return: | |
| """ | |
| env = env.unwrapped | |
| net = get_net(env, pretrained_model_path) | |
| state = get_initial_state(env) | |
| for word, mask in zip(words, states): | |
| word = word.upper() | |
| mask = list(map(int, mask)) | |
| state = update_from_mask(state, word, mask) | |
| return env.words[net.choose_action(v_wrap(state[None, :]))] | |
| def play(env, pretrained_model_path, goal_word=None): | |
| env = env.unwrapped | |
| net = get_net(env, pretrained_model_path) | |
| state = get_initial_state(env) | |
| if goal_word: | |
| env.set_goal_word(goal_word) | |
| outcomes = [] | |
| win = False | |
| for i in range(env.max_turns): | |
| action = net.choose_action(v_wrap(state[None, :])) | |
| state, reward, done, _ = env.step(action) | |
| outcomes.append(env.words[action]) | |
| if done: | |
| if reward > 0: | |
| win = True | |
| break | |
| return win, outcomes | |