sato2ru commited on
Commit
ff59e62
Β·
verified Β·
1 Parent(s): 538243c

Upload app.py with huggingface_hub

Browse files
Files changed (1) hide show
  1. app.py +139 -0
app.py ADDED
@@ -0,0 +1,139 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+
2
+ import json, math, torch, gradio as gr
3
+ from collections import Counter
4
+ import numpy as np
5
+ from huggingface_hub import hf_hub_download
6
+ import torch.nn as nn
7
+
8
+ REPO_ID = "YOUR_USERNAME/wordle-solver" # ← update this
9
+
10
+ # ── Load assets from HF Hub ──────────────────────────────────────
11
+ config = json.load(open(hf_hub_download(REPO_ID, "config.json")))
12
+ ANSWERS = json.load(open(hf_hub_download(REPO_ID, "answers.json")))
13
+ ALLOWED = json.load(open(hf_hub_download(REPO_ID, "allowed.json")))
14
+ WORD2IDX = {w: i for i, w in enumerate(ALLOWED)}
15
+ LETTERS = "abcdefghijklmnopqrstuvwxyz"
16
+ L2I = {c: i for i, c in enumerate(LETTERS)}
17
+ INPUT_DIM = config["input_dim"]
18
+ OUTPUT_DIM = config["output_dim"]
19
+ OPENING = config["opening_guess"]
20
+ WIN_PATTERN = (2,2,2,2,2)
21
+
22
+ # ── Model ────────────────────────────────────────────────────────
23
+ class WordleNet(nn.Module):
24
+ def __init__(self):
25
+ super().__init__()
26
+ h = config["hidden"]
27
+ self.net = nn.Sequential(
28
+ nn.Linear(INPUT_DIM, h), nn.BatchNorm1d(h), nn.ReLU(), nn.Dropout(0.3),
29
+ nn.Linear(h, h), nn.BatchNorm1d(h), nn.ReLU(), nn.Dropout(0.3),
30
+ nn.Linear(h, 256), nn.BatchNorm1d(256), nn.ReLU(),
31
+ nn.Linear(256, OUTPUT_DIM)
32
+ )
33
+ def forward(self, x): return self.net(x)
34
+
35
+ model = WordleNet()
36
+ model.load_state_dict(torch.load(hf_hub_download(REPO_ID, "model_weights.pt"), map_location="cpu"))
37
+ model.eval()
38
+
39
+ # ── Helpers ──────────────────────────────────────────────────────
40
+ def get_pattern(guess, answer):
41
+ pattern = [0]*5
42
+ counts = Counter(answer)
43
+ for i in range(5):
44
+ if guess[i] == answer[i]: pattern[i] = 2; counts[guess[i]] -= 1
45
+ for i in range(5):
46
+ if pattern[i] == 0 and counts.get(guess[i],0) > 0:
47
+ pattern[i] = 1; counts[guess[i]] -= 1
48
+ return tuple(pattern)
49
+
50
+ def filter_words(words, guess, pattern):
51
+ return [w for w in words if get_pattern(guess, w) == pattern]
52
+
53
+ def entropy_score(guess, possible):
54
+ buckets = Counter(get_pattern(guess, w) for w in possible)
55
+ n = len(possible)
56
+ return sum(-(c/n)*math.log2(c/n) for c in buckets.values())
57
+
58
+ def encode_board(history):
59
+ vec = np.zeros(INPUT_DIM, dtype=np.float32)
60
+ for word, pattern in history:
61
+ for pos, (letter, state) in enumerate(zip(word, pattern)):
62
+ vec[L2I[letter]*15 + pos*3 + state] = 1.0
63
+ return vec
64
+
65
+ def model_suggest(history, possible):
66
+ if len(possible) == 1: return possible[0]
67
+ if not history: return OPENING
68
+ state = torch.tensor(encode_board(history)).unsqueeze(0)
69
+ with torch.no_grad():
70
+ logits = model(state)[0]
71
+ top5 = [ALLOWED[i] for i in logits.topk(5).indices.tolist()]
72
+ return max(top5, key=lambda w: entropy_score(w, possible))
73
+
74
+ # ── State ─────────────────────────────────────────────────────────
75
+ def init_state():
76
+ return {"possible": list(ANSWERS), "history": [], "done": False}
77
+
78
+ def render_board(history):
79
+ colours = {0: "⬜", 1: "🟨", 2: "🟩"}
80
+ rows = []
81
+ for word, pattern in history:
82
+ tiles = " ".join(f"{colours[s]}{c.upper()}" for c, s in zip(word, pattern))
83
+ rows.append(tiles)
84
+ return "
85
+ ".join(rows) if rows else "(no guesses yet)"
86
+
87
+ def process_guess(guess_input, pattern_input, state):
88
+ if state["done"]:
89
+ return render_board(state["history"]), "Game over β€” press Reset", state
90
+
91
+ guess = guess_input.strip().lower()
92
+ if len(guess) != 5:
93
+ return render_board(state["history"]), "⚠️ Guess must be 5 letters", state
94
+ if len(pattern_input) != 5 or not all(c in "012" for c in pattern_input):
95
+ return render_board(state["history"]), "⚠️ Pattern must be 5 digits (0/1/2)", state
96
+
97
+ pattern = tuple(int(c) for c in pattern_input)
98
+ state["history"].append((guess, pattern))
99
+
100
+ if pattern == WIN_PATTERN:
101
+ state["done"] = True
102
+ msg = f"πŸŽ‰ Solved in {len(state["history"])} turns!"
103
+ return render_board(state["history"]), msg, state
104
+
105
+ state["possible"] = filter_words(state["possible"], guess, pattern)
106
+ if not state["possible"]:
107
+ state["done"] = True
108
+ return render_board(state["history"]), "❌ No words left. Check your input.", state
109
+
110
+ suggestion = model_suggest(state["history"], state["possible"])
111
+ msg = f"Try: **{suggestion.upper()}** | {len(state["possible"])} words left"
112
+ return render_board(state["history"]), msg, state
113
+
114
+ def reset(_state):
115
+ s = init_state()
116
+ return render_board([]), f"Try: **{OPENING.upper()}** to start", s
117
+
118
+ # ── Gradio UI ───────────���─────────────────────────────────────────
119
+ with gr.Blocks(title="Wordle Solver", theme=gr.themes.Monochrome()) as demo:
120
+ gr.Markdown("# 🟩 Wordle Solver
121
+ Entropy-trained neural network. Enter each guess + the colour pattern.")
122
+ gr.Markdown("**Pattern key:** `0` = ⬜ grey · `1` = 🟨 yellow · `2` = 🟩 green")
123
+
124
+ state = gr.State(init_state())
125
+ board_out = gr.Textbox(label="Board", lines=7, interactive=False)
126
+ msg_out = gr.Markdown(f"Try: **{OPENING.upper()}** to start")
127
+
128
+ with gr.Row():
129
+ guess_in = gr.Textbox(label="Your guess", placeholder="crane", max_lines=1)
130
+ pattern_in = gr.Textbox(label="Pattern (5 digits)", placeholder="02100", max_lines=1)
131
+
132
+ with gr.Row():
133
+ submit_btn = gr.Button("Submit", variant="primary")
134
+ reset_btn = gr.Button("Reset")
135
+
136
+ submit_btn.click(process_guess, [guess_in, pattern_in, state], [board_out, msg_out, state])
137
+ reset_btn.click(reset, [state], [board_out, msg_out, state])
138
+
139
+ demo.launch()