sato2ru commited on
Commit
4e29ac0
Β·
verified Β·
1 Parent(s): 46b0f52

replicate railway

Browse files
Files changed (1) hide show
  1. app.py +119 -94
app.py CHANGED
@@ -1,8 +1,9 @@
1
- from fastapi import FastAPI, HTTPException
2
  from fastapi.middleware.cors import CORSMiddleware
3
  from pydantic import BaseModel
4
  import json, math, torch, numpy as np
5
  from collections import Counter
 
6
  import torch.nn as nn
7
  from huggingface_hub import hf_hub_download
8
 
@@ -16,19 +17,22 @@ app.add_middleware(
16
  allow_headers=["*"],
17
  )
18
 
19
- # ── Load assets ───────────────────────────────────────────────────────────────
20
- print("Loading model...")
21
- config = json.load(open(hf_hub_download(HF_REPO_ID, "config.json")))
 
22
  ANSWERS = json.load(open(hf_hub_download(HF_REPO_ID, "answers.json")))
23
  ALLOWED = json.load(open(hf_hub_download(HF_REPO_ID, "allowed.json")))
 
24
  LETTERS = "abcdefghijklmnopqrstuvwxyz"
25
  L2I = {c: i for i, c in enumerate(LETTERS)}
26
- INPUT_DIM = config["input_dim"]
27
- OUTPUT_DIM = config["output_dim"]
28
- OPENING = config["opening_guess"]
29
  WIN_PATTERN = (2, 2, 2, 2, 2)
30
 
31
- # ── Model ─────────────────────────────────────────────────────────────────────
 
32
  class WordleNet(nn.Module):
33
  def __init__(self):
34
  super().__init__()
@@ -41,17 +45,50 @@ class WordleNet(nn.Module):
41
  )
42
  def forward(self, x): return self.net(x)
43
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
44
  model = WordleNet()
45
  model.load_state_dict(
46
  torch.load(hf_hub_download(HF_REPO_ID, "model_weights.pt"), map_location="cpu")
47
  )
48
  model.eval()
49
- print("Model loaded βœ…")
50
 
51
- # ── Helpers ───────────────────────────────────────────────────────────────────
 
 
 
 
 
 
 
 
 
52
  def get_pattern(guess, answer):
53
- pattern = [0] * 5
54
- counts = Counter(answer)
55
  for i in range(5):
56
  if guess[i] == answer[i]:
57
  pattern[i] = 2
@@ -68,105 +105,82 @@ def filter_words(words, guess, pattern):
68
  def entropy_score(guess, possible):
69
  buckets = Counter(get_pattern(guess, w) for w in possible)
70
  n = len(possible)
71
- return sum(-(c / n) * math.log2(c / n) for c in buckets.values())
72
 
73
  def encode_board(history):
74
  vec = np.zeros(INPUT_DIM, dtype=np.float32)
75
  for word, pattern in history:
76
  for pos, (letter, state) in enumerate(zip(word, pattern)):
77
- vec[L2I[letter] * 15 + pos * 3 + state] = 1.0
78
  return vec
79
 
80
- def is_consistent(word, history):
81
- for guess, pattern in history:
82
- green_letters = {letter for letter, state in zip(guess, pattern) if state == 2}
83
- for pos, (letter, state) in enumerate(zip(guess, pattern)):
84
- if state == 2:
85
- if word[pos] != letter:
86
- return False
87
- elif state == 1:
88
- if letter not in word or word[pos] == letter:
89
- return False
90
- else:
91
- if letter not in green_letters and letter in word:
92
- return False
93
- return True
94
-
95
- def model_suggest(history, possible):
96
- if not possible: return None
97
- if len(possible) == 1: return possible[0]
98
- if not history: return OPENING
99
-
100
- already_guessed = {w for w, _ in history}
101
- possible_not_guessed = [w for w in possible if w not in already_guessed]
102
-
103
- if len(possible) <= 6:
104
- ambiguous = set()
105
- for pos in range(5):
106
- letters_at_pos = {w[pos] for w in possible}
107
- if len(letters_at_pos) > 1:
108
- ambiguous.update(letters_at_pos)
109
-
110
- best_word, best_score = None, -1
111
- for g in ALLOWED:
112
- if g in already_guessed:
113
- continue
114
- if not is_consistent(g, history):
115
- continue
116
- if g in possible and len(possible) > 2:
117
- continue
118
- score = len(set(g) & ambiguous) * 2 + entropy_score(g, possible)
119
- if score > best_score:
120
- best_score, best_word = score, g
121
-
122
- if not best_word:
123
- best_word = possible_not_guessed[0] if possible_not_guessed else possible[0]
124
- return best_word
125
 
126
  state = torch.tensor(encode_board(history)).unsqueeze(0)
127
  with torch.no_grad():
128
- logits = model(state)[0]
129
 
130
- top50 = [ALLOWED[i] for i in logits.topk(50).indices.tolist()]
131
- valid = [w for w in top50
132
- if w not in already_guessed and is_consistent(w, history)]
 
 
 
 
 
133
 
134
- if not valid:
135
- return max(possible_not_guessed or possible,
136
- key=lambda w: entropy_score(w, possible))
137
 
138
- return max(valid[:10], key=lambda w: entropy_score(w, possible))
139
 
 
 
 
 
 
140
 
141
- def top_suggestions(history, possible, n=5):
142
- if not possible: return []
 
 
 
 
 
 
 
 
143
 
144
- already_guessed = {w for w, _ in history}
 
 
 
145
 
 
 
 
 
146
  if not history:
147
- candidates = [OPENING] + [w for w in ALLOWED if w != OPENING][:30]
148
  else:
149
- state = torch.tensor(encode_board(history)).unsqueeze(0)
150
- with torch.no_grad():
151
- logits = model(state)[0]
152
- candidates = [ALLOWED[i] for i in logits.topk(50).indices.tolist()]
153
-
154
- candidates = [w for w in candidates
155
- if w not in already_guessed and is_consistent(w, history)]
156
 
157
  possible_set = set(possible)
158
- scored = [
159
- {
160
- "word": w,
161
- "entropy": round(entropy_score(w, possible), 3),
162
- "is_possible": w in possible_set,
163
- }
164
- for w in candidates
165
- ]
166
- scored.sort(key=lambda x: (-x["entropy"], not x["is_possible"]))
167
  return scored[:n]
168
 
169
- # ── Models ────────────────────────────────────────────────────────────────────
 
170
  class GuessEntry(BaseModel):
171
  word: str
172
  pattern: list[int]
@@ -181,14 +195,20 @@ class SuggestResponse(BaseModel):
181
  bits_remaining: float
182
  solved: bool
183
  message: str
 
184
 
185
- # ── Routes ────────────────────────────────────────────────────────────────────
 
186
  @app.get("/")
187
  def root():
188
  return {"status": "ok", "opener": OPENING}
189
 
190
  @app.post("/suggest", response_model=SuggestResponse)
191
- def suggest(req: SuggestRequest):
 
 
 
 
192
  possible = list(ANSWERS)
193
 
194
  for entry in req.history:
@@ -196,12 +216,12 @@ def suggest(req: SuggestRequest):
196
  pattern = tuple(entry.pattern)
197
  if len(word) != 5:
198
  raise HTTPException(400, f"Word must be 5 letters: {word}")
199
- if len(pattern) != 5 or not all(p in (0, 1, 2) for p in pattern):
200
  raise HTTPException(400, "Pattern must be 5 values of 0, 1, or 2")
201
  if pattern == WIN_PATTERN:
202
  return SuggestResponse(
203
  suggestion=word, top_suggestions=[], possible_count=1,
204
- bits_remaining=0.0, solved=True,
205
  message=f"Solved in {len(req.history)} guesses!"
206
  )
207
  possible = filter_words(possible, word, pattern)
@@ -210,10 +230,10 @@ def suggest(req: SuggestRequest):
210
  raise HTTPException(422, "No possible words remaining. Check your pattern input.")
211
 
212
  history_tuples = [(e.word.lower(), tuple(e.pattern)) for e in req.history]
213
- suggestion = model_suggest(history_tuples, possible)
214
  if not suggestion:
215
  suggestion = possible[0]
216
- top_suggs = top_suggestions(history_tuples, possible)
217
  bits = math.log2(len(possible)) if len(possible) > 1 else 0.0
218
 
219
  return SuggestResponse(
@@ -222,9 +242,14 @@ def suggest(req: SuggestRequest):
222
  possible_count=len(possible),
223
  bits_remaining=round(bits, 2),
224
  solved=False,
 
225
  message=f"{len(possible)} words remaining β€” try {suggestion.upper()}"
226
  )
227
 
228
  @app.get("/opener")
229
  def get_opener():
230
  return {"word": OPENING}
 
 
 
 
 
1
+ from fastapi import FastAPI, HTTPException, Query
2
  from fastapi.middleware.cors import CORSMiddleware
3
  from pydantic import BaseModel
4
  import json, math, torch, numpy as np
5
  from collections import Counter
6
+ from typing import Optional
7
  import torch.nn as nn
8
  from huggingface_hub import hf_hub_download
9
 
 
17
  allow_headers=["*"],
18
  )
19
 
20
+ # ── Load word lists & configs ─────────────────────────────────────────────────
21
+ print("Loading configs and word lists...")
22
+ config = json.load(open(hf_hub_download(HF_REPO_ID, "config.json")))
23
+ rl_config = json.load(open(hf_hub_download(HF_REPO_ID, "rl_config.json")))
24
  ANSWERS = json.load(open(hf_hub_download(HF_REPO_ID, "answers.json")))
25
  ALLOWED = json.load(open(hf_hub_download(HF_REPO_ID, "allowed.json")))
26
+ WORD2IDX = {w: i for i, w in enumerate(ALLOWED)}
27
  LETTERS = "abcdefghijklmnopqrstuvwxyz"
28
  L2I = {c: i for i, c in enumerate(LETTERS)}
29
+ INPUT_DIM = config["input_dim"]
30
+ OUTPUT_DIM = config["output_dim"]
31
+ OPENING = config["opening_guess"]
32
  WIN_PATTERN = (2, 2, 2, 2, 2)
33
 
34
+
35
+ # ── Model architecture ────────────────────────────────────────────────────────
36
  class WordleNet(nn.Module):
37
  def __init__(self):
38
  super().__init__()
 
45
  )
46
  def forward(self, x): return self.net(x)
47
 
48
+
49
+ class RLWordleNet(nn.Module):
50
+ """Same encoder as WordleNet but with BatchNorm-safe single-sample forward."""
51
+ def __init__(self):
52
+ super().__init__()
53
+ h = rl_config["hidden"]
54
+ self.encoder = nn.Sequential(
55
+ nn.Linear(INPUT_DIM, h), nn.BatchNorm1d(h), nn.ReLU(), nn.Dropout(0.3),
56
+ nn.Linear(h, h), nn.BatchNorm1d(h), nn.ReLU(), nn.Dropout(0.3),
57
+ nn.Linear(h, 256), nn.BatchNorm1d(256), nn.ReLU(),
58
+ )
59
+ self.policy_head = nn.Linear(256, OUTPUT_DIM)
60
+
61
+ def forward(self, x):
62
+ single = x.shape[0] == 1
63
+ if single:
64
+ x = x.repeat(2, 1)
65
+ feat = self.encoder(x)
66
+ if single:
67
+ feat = feat[:1]
68
+ return self.policy_head(feat)
69
+
70
+
71
+ # ── Load weights ──────────────────────────────────────────────────────────────
72
+ print("Loading supervised model...")
73
  model = WordleNet()
74
  model.load_state_dict(
75
  torch.load(hf_hub_download(HF_REPO_ID, "model_weights.pt"), map_location="cpu")
76
  )
77
  model.eval()
 
78
 
79
+ print("Loading RL model...")
80
+ rl_model = RLWordleNet()
81
+ rl_model.load_state_dict(
82
+ torch.load(hf_hub_download(HF_REPO_ID, "rl_model_weights.pt"), map_location="cpu"), strict=False
83
+ )
84
+ rl_model.eval()
85
+ print("Both models loaded.")
86
+
87
+
88
+ # ── Core logic ────────────────────────────────────────────────────────────────
89
  def get_pattern(guess, answer):
90
+ pattern = [0]*5
91
+ counts = Counter(answer)
92
  for i in range(5):
93
  if guess[i] == answer[i]:
94
  pattern[i] = 2
 
105
  def entropy_score(guess, possible):
106
  buckets = Counter(get_pattern(guess, w) for w in possible)
107
  n = len(possible)
108
+ return sum(-(c/n)*math.log2(c/n) for c in buckets.values())
109
 
110
  def encode_board(history):
111
  vec = np.zeros(INPUT_DIM, dtype=np.float32)
112
  for word, pattern in history:
113
  for pos, (letter, state) in enumerate(zip(word, pattern)):
114
+ vec[L2I[letter]*15 + pos*3 + state] = 1.0
115
  return vec
116
 
117
+ def get_logits(history, possible, use_rl=False):
118
+ """Get top-20 model words using constraint-filtered mask."""
119
+ active_model = rl_model if use_rl else model
120
+ possible_set = set(possible)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
121
 
122
  state = torch.tensor(encode_board(history)).unsqueeze(0)
123
  with torch.no_grad():
124
+ logits = active_model(state)[0]
125
 
126
+ if use_rl:
127
+ mask = torch.full((OUTPUT_DIM,), float('-inf'))
128
+ for i, w in enumerate(ALLOWED):
129
+ if w in possible_set:
130
+ mask[i] = 0.0
131
+ if mask.max() == float('-inf'):
132
+ mask[:] = 0.0
133
+ logits = logits + mask
134
 
135
+ return [ALLOWED[i] for i in logits.topk(20).indices.tolist()]
 
 
136
 
 
137
 
138
+ def model_suggest(history, possible, use_rl=False):
139
+ if not possible: return None
140
+ if len(possible) == 1: return possible[0]
141
+ if not history: return OPENING
142
+ guessed = {w for w, _ in history}
143
 
144
+ model_words = get_logits(history, possible, use_rl)
145
+
146
+ if len(possible) <= 6:
147
+ best, best_worst = None, float('inf')
148
+ for g in list(possible) + model_words:
149
+ if g in guessed: continue
150
+ worst = max(Counter(get_pattern(g, w) for w in possible).values())
151
+ if worst < best_worst:
152
+ best_worst, best = worst, g
153
+ return best or possible[0]
154
 
155
+ candidates = list(dict.fromkeys(model_words + list(possible)))
156
+ candidates = [w for w in candidates if w not in guessed]
157
+ if not candidates: return possible[0]
158
+ return max(candidates, key=lambda w: entropy_score(w, possible))
159
 
160
+
161
+ def top_suggestions(history, possible, use_rl=False, n=5):
162
+ if not possible: return []
163
+ guessed = {w for w, _ in history}
164
  if not history:
165
+ candidates = [OPENING] + [w for w in ALLOWED if w != OPENING][:20]
166
  else:
167
+ model_words = get_logits(history, possible, use_rl)
168
+ candidates = list(dict.fromkeys(model_words + list(possible)))
 
 
 
 
 
169
 
170
  possible_set = set(possible)
171
+ candidates = [w for w in candidates if w in possible_set and w not in guessed]
172
+
173
+ # fallback β€” if all possible words were guessed, show from full possible
174
+ if not candidates:
175
+ candidates = [w for w in possible if w not in guessed]
176
+
177
+ scored = [{"word": w, "entropy": round(entropy_score(w, possible), 3), "is_possible": True}
178
+ for w in candidates]
179
+ scored.sort(key=lambda x: -x["entropy"])
180
  return scored[:n]
181
 
182
+
183
+ # ── Schemas ───────────────────────────────────────────────────────────────────
184
  class GuessEntry(BaseModel):
185
  word: str
186
  pattern: list[int]
 
195
  bits_remaining: float
196
  solved: bool
197
  message: str
198
+ model_used: str
199
 
200
+
201
+ # ── Endpoints ─────────────────────────────────────────────────────────────────
202
  @app.get("/")
203
  def root():
204
  return {"status": "ok", "opener": OPENING}
205
 
206
  @app.post("/suggest", response_model=SuggestResponse)
207
+ def suggest(
208
+ req: SuggestRequest,
209
+ model: str = Query(default="supervised", pattern="^(supervised|rl)$")
210
+ ):
211
+ use_rl = model == "rl"
212
  possible = list(ANSWERS)
213
 
214
  for entry in req.history:
 
216
  pattern = tuple(entry.pattern)
217
  if len(word) != 5:
218
  raise HTTPException(400, f"Word must be 5 letters: {word}")
219
+ if len(pattern) != 5 or not all(p in (0,1,2) for p in pattern):
220
  raise HTTPException(400, "Pattern must be 5 values of 0, 1, or 2")
221
  if pattern == WIN_PATTERN:
222
  return SuggestResponse(
223
  suggestion=word, top_suggestions=[], possible_count=1,
224
+ bits_remaining=0.0, solved=True, model_used=model,
225
  message=f"Solved in {len(req.history)} guesses!"
226
  )
227
  possible = filter_words(possible, word, pattern)
 
230
  raise HTTPException(422, "No possible words remaining. Check your pattern input.")
231
 
232
  history_tuples = [(e.word.lower(), tuple(e.pattern)) for e in req.history]
233
+ suggestion = model_suggest(history_tuples, possible, use_rl=use_rl)
234
  if not suggestion:
235
  suggestion = possible[0]
236
+ top_suggs = top_suggestions(history_tuples, possible, use_rl=use_rl)
237
  bits = math.log2(len(possible)) if len(possible) > 1 else 0.0
238
 
239
  return SuggestResponse(
 
242
  possible_count=len(possible),
243
  bits_remaining=round(bits, 2),
244
  solved=False,
245
+ model_used=model,
246
  message=f"{len(possible)} words remaining β€” try {suggestion.upper()}"
247
  )
248
 
249
  @app.get("/opener")
250
  def get_opener():
251
  return {"word": OPENING}
252
+
253
+ if __name__ == "__main__":
254
+ import uvicorn
255
+ uvicorn.run("app:app", host="0.0.0.0", port=7860, reload=True)