Wen-Ting Wang
commited on
Commit
·
cf72ffa
1
Parent(s):
406b978
feat: Deploy Hangman AI Demo to Hugging Face Spaces
Browse files- Create Gradio app for interactive hangman demo
- Add requirements.txt with necessary dependencies
- Include README with proper metadata for HF Spaces
- Fix short_description to meet 60 character limit
- .gitignore +96 -0
- hangman/__pycache__/__init__.cpython-312.pyc +0 -0
- hangman/__pycache__/char_transformer.cpython-312.pyc +0 -0
- hangman/__pycache__/hangman_core.cpython-312.pyc +0 -0
- hangman/rl/__pycache__/__init__.cpython-312.pyc +0 -0
- hangman/rl/__pycache__/envs.cpython-312.pyc +0 -0
- hangman/rl/__pycache__/models.cpython-312.pyc +0 -0
- hangman/rl/__pycache__/utils.cpython-312.pyc +0 -0
- hangman/rl/eval.py +0 -178
- hangman/rl/ppo.py +0 -21
- hangman/rl/replay.py +0 -86
- hangman/rl/seed_bc.py +0 -144
- hangman/rl/train_bc.py +0 -139
- hangman/utils.py +0 -129
.gitignore
ADDED
|
@@ -0,0 +1,96 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Python
|
| 2 |
+
__pycache__/
|
| 3 |
+
*.py[cod]
|
| 4 |
+
*$py.class
|
| 5 |
+
*.so
|
| 6 |
+
.Python
|
| 7 |
+
build/
|
| 8 |
+
develop-eggs/
|
| 9 |
+
dist/
|
| 10 |
+
downloads/
|
| 11 |
+
eggs/
|
| 12 |
+
.eggs/
|
| 13 |
+
lib/
|
| 14 |
+
lib64/
|
| 15 |
+
parts/
|
| 16 |
+
sdist/
|
| 17 |
+
var/
|
| 18 |
+
wheels/
|
| 19 |
+
*.egg-info/
|
| 20 |
+
.installed.cfg
|
| 21 |
+
*.egg
|
| 22 |
+
MANIFEST
|
| 23 |
+
|
| 24 |
+
# PyInstaller
|
| 25 |
+
*.manifest
|
| 26 |
+
*.spec
|
| 27 |
+
|
| 28 |
+
# Installer logs
|
| 29 |
+
pip-log.txt
|
| 30 |
+
pip-delete-this-directory.txt
|
| 31 |
+
|
| 32 |
+
# Unit test / coverage reports
|
| 33 |
+
htmlcov/
|
| 34 |
+
.tox/
|
| 35 |
+
.coverage
|
| 36 |
+
.coverage.*
|
| 37 |
+
.cache
|
| 38 |
+
nosetests.xml
|
| 39 |
+
coverage.xml
|
| 40 |
+
*.cover
|
| 41 |
+
.hypothesis/
|
| 42 |
+
.pytest_cache/
|
| 43 |
+
|
| 44 |
+
# Jupyter Notebook
|
| 45 |
+
.ipynb_checkpoints
|
| 46 |
+
|
| 47 |
+
# pyenv
|
| 48 |
+
.python-version
|
| 49 |
+
|
| 50 |
+
# Environments
|
| 51 |
+
.env
|
| 52 |
+
.venv
|
| 53 |
+
env/
|
| 54 |
+
venv/
|
| 55 |
+
ENV/
|
| 56 |
+
env.bak/
|
| 57 |
+
venv.bak/
|
| 58 |
+
|
| 59 |
+
# IDE
|
| 60 |
+
.vscode/
|
| 61 |
+
.idea/
|
| 62 |
+
*.swp
|
| 63 |
+
*.swo
|
| 64 |
+
*~
|
| 65 |
+
|
| 66 |
+
# OS
|
| 67 |
+
.DS_Store
|
| 68 |
+
.DS_Store?
|
| 69 |
+
._*
|
| 70 |
+
.Spotlight-V100
|
| 71 |
+
.Trashes
|
| 72 |
+
ehthumbs.db
|
| 73 |
+
Thumbs.db
|
| 74 |
+
|
| 75 |
+
# Hugging Face specific
|
| 76 |
+
*.bin
|
| 77 |
+
*.safetensors
|
| 78 |
+
*.h5
|
| 79 |
+
*.ckpt
|
| 80 |
+
*.pth
|
| 81 |
+
*.pt
|
| 82 |
+
*.pkl
|
| 83 |
+
*.pickle
|
| 84 |
+
|
| 85 |
+
# Model checkpoints and data
|
| 86 |
+
checkpoints/
|
| 87 |
+
models/
|
| 88 |
+
data/
|
| 89 |
+
logs/
|
| 90 |
+
runs/
|
| 91 |
+
wandb/
|
| 92 |
+
|
| 93 |
+
# Temporary files
|
| 94 |
+
*.tmp
|
| 95 |
+
*.temp
|
| 96 |
+
*.log
|
hangman/__pycache__/__init__.cpython-312.pyc
DELETED
|
Binary file (413 Bytes)
|
|
|
hangman/__pycache__/char_transformer.cpython-312.pyc
DELETED
|
Binary file (3.47 kB)
|
|
|
hangman/__pycache__/hangman_core.cpython-312.pyc
DELETED
|
Binary file (47.9 kB)
|
|
|
hangman/rl/__pycache__/__init__.cpython-312.pyc
DELETED
|
Binary file (385 Bytes)
|
|
|
hangman/rl/__pycache__/envs.cpython-312.pyc
DELETED
|
Binary file (7.15 kB)
|
|
|
hangman/rl/__pycache__/models.cpython-312.pyc
DELETED
|
Binary file (14 kB)
|
|
|
hangman/rl/__pycache__/utils.cpython-312.pyc
DELETED
|
Binary file (3.48 kB)
|
|
|
hangman/rl/eval.py
DELETED
|
@@ -1,178 +0,0 @@
|
|
| 1 |
-
import os
|
| 2 |
-
import time
|
| 3 |
-
import csv
|
| 4 |
-
from argparse import Namespace
|
| 5 |
-
|
| 6 |
-
import torch
|
| 7 |
-
|
| 8 |
-
from .envs import BatchEnv
|
| 9 |
-
from .priors import (
|
| 10 |
-
CandCache,
|
| 11 |
-
ig_exact_pick,
|
| 12 |
-
candidate_letter_probs,
|
| 13 |
-
pos_present_probs,
|
| 14 |
-
)
|
| 15 |
-
|
| 16 |
-
|
| 17 |
-
@torch.no_grad()
|
| 18 |
-
def greedy_rollout(win_env: BatchEnv, model, device, N=1000, priors=None, log_stride: int = 256, use_cand_priors: bool = False):
|
| 19 |
-
wins = 0
|
| 20 |
-
total = 0
|
| 21 |
-
B = min(win_env.batch, 256)
|
| 22 |
-
env = BatchEnv(win_env.buckets, win_env.tries_init, B, win_env.len_choices.copy(), win_env.max_len)
|
| 23 |
-
env.reset()
|
| 24 |
-
model.eval()
|
| 25 |
-
if hasattr(model, "remove_noise"):
|
| 26 |
-
model.remove_noise()
|
| 27 |
-
start = time.perf_counter()
|
| 28 |
-
local_cache = CandCache(1024) if use_cand_priors else None
|
| 29 |
-
while total < N:
|
| 30 |
-
pat_idx, tried, lens, tries = env.observe()
|
| 31 |
-
B_now = pat_idx.size(0)
|
| 32 |
-
lp = torch.zeros((B_now, 26), dtype=torch.float32)
|
| 33 |
-
for i, patt in enumerate(env.patterns):
|
| 34 |
-
L = min(len(patt), model.max_len if hasattr(model, "max_len") else win_env.max_len)
|
| 35 |
-
lp[i, :] = torch.tensor(priors.get(L, [0.0] * 26))
|
| 36 |
-
if use_cand_priors and local_cache is not None:
|
| 37 |
-
cp = torch.zeros((B_now, 26), dtype=torch.float32)
|
| 38 |
-
for i in range(B_now):
|
| 39 |
-
L = min(len(env.patterns[i]), win_env.max_len)
|
| 40 |
-
tried_bits = int(env.tried_mask_bits[i])
|
| 41 |
-
cp[i, :] = candidate_letter_probs(L, env.patterns[i], tried_bits, win_env.buckets, local_cache)
|
| 42 |
-
else:
|
| 43 |
-
cp = None
|
| 44 |
-
tn = (tries.float() / win_env.tries_init).unsqueeze(1)
|
| 45 |
-
out = model(
|
| 46 |
-
pat_idx.to(device),
|
| 47 |
-
tried.to(device),
|
| 48 |
-
lens.to(device),
|
| 49 |
-
lp.to(device),
|
| 50 |
-
tn.to(device),
|
| 51 |
-
cand_priors=(cp.to(device) if cp is not None else None),
|
| 52 |
-
)
|
| 53 |
-
if isinstance(out, tuple):
|
| 54 |
-
logits, _v = out
|
| 55 |
-
actions = logits.argmax(dim=1).cpu()
|
| 56 |
-
else:
|
| 57 |
-
q = out
|
| 58 |
-
actions = q.argmax(dim=1).cpu()
|
| 59 |
-
env.step(actions)
|
| 60 |
-
finished = env.done.clone()
|
| 61 |
-
if finished.any():
|
| 62 |
-
batch_wins = int(env.won.sum().item())
|
| 63 |
-
batch_finished = int(finished.sum().item())
|
| 64 |
-
wins += batch_wins
|
| 65 |
-
total += batch_finished
|
| 66 |
-
if (total % max(1, log_stride) == 0) or (total >= N):
|
| 67 |
-
wr = wins / max(1, total)
|
| 68 |
-
elapsed = (time.perf_counter() - start) / 60.0
|
| 69 |
-
print(f"[eval] episodes={total}/{N} win-rate={wr:.3f} | {elapsed:.1f} min")
|
| 70 |
-
env.force_reset_done()
|
| 71 |
-
model.train()
|
| 72 |
-
if hasattr(model, "resample_noise"):
|
| 73 |
-
model.resample_noise()
|
| 74 |
-
return wins / max(1, total)
|
| 75 |
-
|
| 76 |
-
|
| 77 |
-
@torch.no_grad()
|
| 78 |
-
def run_solver(args, buckets, priors, pos_priors):
|
| 79 |
-
lens = sorted(buckets.keys())
|
| 80 |
-
env = BatchEnv(buckets, args.tries, args.batch_env, lens.copy(), args.max_len)
|
| 81 |
-
env.reset()
|
| 82 |
-
total = 0
|
| 83 |
-
wins = 0
|
| 84 |
-
start = time.perf_counter()
|
| 85 |
-
csv_fp = None
|
| 86 |
-
csv_writer = None
|
| 87 |
-
if getattr(args, "csv_log", False):
|
| 88 |
-
os.makedirs(args.out_dir, exist_ok=True)
|
| 89 |
-
csv_path = os.path.join(args.out_dir, "metrics.csv")
|
| 90 |
-
new_file = not os.path.exists(csv_path)
|
| 91 |
-
csv_fp = open(csv_path, "a", newline="")
|
| 92 |
-
csv_writer = csv.writer(csv_fp)
|
| 93 |
-
if new_file:
|
| 94 |
-
csv_writer.writerow(["mode", "episodes", "wins", "win_rate", "minutes"])
|
| 95 |
-
while total < args.solver_eval_N:
|
| 96 |
-
B = env.batch
|
| 97 |
-
actions = torch.zeros(B, dtype=torch.long)
|
| 98 |
-
for i in range(B):
|
| 99 |
-
patt = env.patterns[i]
|
| 100 |
-
L = len(patt)
|
| 101 |
-
tried_bits = int(env.tried_mask_bits[i])
|
| 102 |
-
if args.solver_mode == "igx":
|
| 103 |
-
a = ig_exact_pick(tried_bits, L, patt, buckets)
|
| 104 |
-
elif args.solver_mode == "ig":
|
| 105 |
-
vec = candidate_letter_probs(L, patt, tried_bits, buckets, CandCache(1))
|
| 106 |
-
score = vec.clamp(0, 1) * (1 - vec.clamp(0, 1))
|
| 107 |
-
for j in range(26):
|
| 108 |
-
if (tried_bits >> j) & 1:
|
| 109 |
-
score[j] = -1.0
|
| 110 |
-
a = int(score.argmax().item())
|
| 111 |
-
elif args.solver_mode == "pos":
|
| 112 |
-
vec = pos_present_probs(L, patt, pos_priors)
|
| 113 |
-
for j in range(26):
|
| 114 |
-
if (tried_bits >> j) & 1:
|
| 115 |
-
vec[j] = -1.0
|
| 116 |
-
a = int(vec.argmax().item())
|
| 117 |
-
elif args.solver_mode == "len":
|
| 118 |
-
vec = torch.tensor(priors.get(L, [0.0]*26), dtype=torch.float32)
|
| 119 |
-
for j in range(26):
|
| 120 |
-
if (tried_bits >> j) & 1:
|
| 121 |
-
vec[j] = -1.0
|
| 122 |
-
a = int(vec.argmax().item())
|
| 123 |
-
else: # cand
|
| 124 |
-
vec = candidate_letter_probs(L, patt, tried_bits, buckets, CandCache(1))
|
| 125 |
-
for j in range(26):
|
| 126 |
-
if (tried_bits >> j) & 1:
|
| 127 |
-
vec[j] = -1.0
|
| 128 |
-
a = int(vec.argmax().item())
|
| 129 |
-
actions[i] = a
|
| 130 |
-
env.step(actions)
|
| 131 |
-
finished = env.done.clone()
|
| 132 |
-
if finished.any():
|
| 133 |
-
wins += int(env.won.sum().item())
|
| 134 |
-
total += int(finished.sum().item())
|
| 135 |
-
env.force_reset_done()
|
| 136 |
-
if total % 512 == 0 or total >= args.solver_eval_N:
|
| 137 |
-
wr = wins / max(1, total)
|
| 138 |
-
elapsed = (time.perf_counter() - start) / 60.0
|
| 139 |
-
print(f"[solver {args.solver_mode}] episodes={total}/{args.solver_eval_N} win-rate={wr:.3f} | {elapsed:.1f} min")
|
| 140 |
-
if csv_writer:
|
| 141 |
-
csv_writer.writerow([args.solver_mode, total, wins, f"{wr:.6f}", f"{elapsed:.3f}"])
|
| 142 |
-
csv_fp.flush()
|
| 143 |
-
final_wr = wins / max(1, total)
|
| 144 |
-
print(f"[done][SOLVER:{args.solver_mode}] win-rate={final_wr:.3f} over {total} episodes")
|
| 145 |
-
if csv_fp:
|
| 146 |
-
csv_fp.close()
|
| 147 |
-
return final_wr
|
| 148 |
-
|
| 149 |
-
|
| 150 |
-
def run_solver_sweep(args, buckets, priors, pos_priors):
|
| 151 |
-
modes = [m.strip() for m in str(args.sweep_modes).split(",") if m.strip()]
|
| 152 |
-
results = {m: [] for m in modes}
|
| 153 |
-
for r in range(int(args.sweep_repeats)):
|
| 154 |
-
for m in modes:
|
| 155 |
-
a = Namespace(**vars(args))
|
| 156 |
-
a.solver_mode = m
|
| 157 |
-
wr = run_solver(a, buckets, priors, pos_priors)
|
| 158 |
-
results[m].append(float(wr))
|
| 159 |
-
summary_rows = []
|
| 160 |
-
print("\n[solver sweep] summary:")
|
| 161 |
-
for m in modes:
|
| 162 |
-
vals = results[m]
|
| 163 |
-
mean_wr = sum(vals) / max(1, len(vals))
|
| 164 |
-
best_wr = max(vals) if vals else 0.0
|
| 165 |
-
print(f" - {m:>4}: mean={mean_wr:.3f} best={best_wr:.3f} over {len(vals)} run(s)")
|
| 166 |
-
summary_rows.append((m, len(vals), mean_wr, best_wr))
|
| 167 |
-
if getattr(args, "csv_log", False):
|
| 168 |
-
os.makedirs(args.out_dir, exist_ok=True)
|
| 169 |
-
path = os.path.join(args.out_dir, "solver_sweep.csv")
|
| 170 |
-
new_file = not os.path.exists(path)
|
| 171 |
-
with open(path, "a", newline="") as fp:
|
| 172 |
-
w = csv.writer(fp)
|
| 173 |
-
if new_file:
|
| 174 |
-
w.writerow(["mode", "repeats", "mean_win_rate", "best_win_rate"])
|
| 175 |
-
for m, rpt, mean_wr, best_wr in summary_rows:
|
| 176 |
-
w.writerow([m, rpt, f"{mean_wr:.6f}", f"{best_wr:.6f}"])
|
| 177 |
-
print(f"[solver sweep] written summary to {path}")
|
| 178 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
hangman/rl/ppo.py
DELETED
|
@@ -1,21 +0,0 @@
|
|
| 1 |
-
import torch
|
| 2 |
-
|
| 3 |
-
|
| 4 |
-
def compute_gae(rewards, values, dones, gamma, lam):
|
| 5 |
-
"""
|
| 6 |
-
rewards, values, dones: tensors [T, B]
|
| 7 |
-
returns advantages [T, B] and returns [T, B]
|
| 8 |
-
"""
|
| 9 |
-
T, B = rewards.size(0), rewards.size(1)
|
| 10 |
-
adv = torch.zeros_like(rewards)
|
| 11 |
-
lastgaelam = torch.zeros(B, device=rewards.device)
|
| 12 |
-
next_value = values[-1]
|
| 13 |
-
for t in reversed(range(T)):
|
| 14 |
-
mask = 1.0 - dones[t].float()
|
| 15 |
-
delta = rewards[t] + gamma * next_value * mask - values[t]
|
| 16 |
-
lastgaelam = delta + gamma * lam * mask * lastgaelam
|
| 17 |
-
adv[t] = lastgaelam
|
| 18 |
-
next_value = values[t]
|
| 19 |
-
returns = adv + values[:-1]
|
| 20 |
-
return adv, returns
|
| 21 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
hangman/rl/replay.py
DELETED
|
@@ -1,86 +0,0 @@
|
|
| 1 |
-
from typing import Tuple
|
| 2 |
-
|
| 3 |
-
import torch
|
| 4 |
-
|
| 5 |
-
from .utils import enc_pattern
|
| 6 |
-
from .priors import cand_priors_batch
|
| 7 |
-
|
| 8 |
-
|
| 9 |
-
class Replay:
|
| 10 |
-
def __init__(self, cap: int):
|
| 11 |
-
self.cap = cap
|
| 12 |
-
self.buf = []
|
| 13 |
-
self.pos = 0
|
| 14 |
-
|
| 15 |
-
def push(self, s, a, r, sp, done, word, won: bool):
|
| 16 |
-
item = (s, a, r, sp, done, word, won)
|
| 17 |
-
if len(self.buf) < self.cap:
|
| 18 |
-
self.buf.append(item)
|
| 19 |
-
else:
|
| 20 |
-
self.buf[self.pos] = item
|
| 21 |
-
self.pos = (self.pos + 1) % self.cap
|
| 22 |
-
|
| 23 |
-
def sample(self, n):
|
| 24 |
-
import random
|
| 25 |
-
return random.sample(self.buf, n)
|
| 26 |
-
|
| 27 |
-
def __len__(self):
|
| 28 |
-
return len(self.buf)
|
| 29 |
-
|
| 30 |
-
|
| 31 |
-
class SuccessReplay(Replay):
|
| 32 |
-
pass
|
| 33 |
-
|
| 34 |
-
|
| 35 |
-
def tensorize_batch(batch, device, max_len, priors_dict, buckets, cand_cache, cand_frac: float = 1.0):
|
| 36 |
-
s, a, r, sp, done, _w, _won = zip(*batch)
|
| 37 |
-
B = len(batch)
|
| 38 |
-
# s
|
| 39 |
-
pat_idx0 = torch.tensor([enc_pattern(si[0], max_len) for si in s], dtype=torch.long, device=device)
|
| 40 |
-
tried0 = torch.zeros((B, 26), dtype=torch.float32, device=device)
|
| 41 |
-
lens0 = torch.tensor([min(len(si[0]), max_len) for si in s], dtype=torch.long, device=device)
|
| 42 |
-
pri0 = torch.zeros((B, 26), dtype=torch.float32, device=device)
|
| 43 |
-
tries0 = torch.tensor([si[2] for si in s], dtype=torch.float32, device=device) # raw count
|
| 44 |
-
for i, si in enumerate(s):
|
| 45 |
-
m = si[1]
|
| 46 |
-
for j in range(26):
|
| 47 |
-
tried0[i, j] = 1.0 if ((m >> j) & 1) else 0.0
|
| 48 |
-
L = min(len(si[0]), max_len)
|
| 49 |
-
pri0[i, :] = torch.tensor(priors_dict.get(L, [0.0] * 26), dtype=torch.float32, device=device)
|
| 50 |
-
cand0 = cand_priors_batch(s, buckets, cand_cache, max_len, device) if cand_frac >= 1.0 else None
|
| 51 |
-
|
| 52 |
-
# sp
|
| 53 |
-
pat_idx1 = torch.tensor([enc_pattern(si[0], max_len) for si in sp], dtype=torch.long, device=device)
|
| 54 |
-
tried1 = torch.zeros((B, 26), dtype=torch.float32, device=device)
|
| 55 |
-
lens1 = torch.tensor([min(len(si[0]), max_len) for si in sp], dtype=torch.long, device=device)
|
| 56 |
-
pri1 = torch.zeros((B, 26), dtype=torch.float32, device=device)
|
| 57 |
-
tries1 = torch.tensor([si[2] for si in sp], dtype=torch.float32, device=device)
|
| 58 |
-
for i, si in enumerate(sp):
|
| 59 |
-
m = si[1]
|
| 60 |
-
for j in range(26):
|
| 61 |
-
tried1[i, j] = 1.0 if ((m >> j) & 1) else 0.0
|
| 62 |
-
L = min(len(si[0]), max_len)
|
| 63 |
-
pri1[i, :] = torch.tensor(priors_dict.get(L, [0.0] * 26), dtype=torch.float32, device=device)
|
| 64 |
-
cand1 = cand_priors_batch(sp, buckets, cand_cache, max_len, device) if cand_frac >= 1.0 else None
|
| 65 |
-
|
| 66 |
-
a = torch.tensor(a, dtype=torch.long, device=device)
|
| 67 |
-
r = torch.tensor(r, dtype=torch.float32, device=device)
|
| 68 |
-
done = torch.tensor(done, dtype=torch.bool, device=device)
|
| 69 |
-
return (
|
| 70 |
-
pat_idx0,
|
| 71 |
-
tried0,
|
| 72 |
-
lens0,
|
| 73 |
-
pri0,
|
| 74 |
-
cand0,
|
| 75 |
-
tries0,
|
| 76 |
-
pat_idx1,
|
| 77 |
-
tried1,
|
| 78 |
-
lens1,
|
| 79 |
-
pri1,
|
| 80 |
-
cand1,
|
| 81 |
-
tries1,
|
| 82 |
-
a,
|
| 83 |
-
r,
|
| 84 |
-
done,
|
| 85 |
-
)
|
| 86 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
hangman/rl/seed_bc.py
DELETED
|
@@ -1,144 +0,0 @@
|
|
| 1 |
-
import random
|
| 2 |
-
from typing import List
|
| 3 |
-
|
| 4 |
-
import torch
|
| 5 |
-
import torch.nn.functional as F
|
| 6 |
-
|
| 7 |
-
try:
|
| 8 |
-
from tqdm import tqdm
|
| 9 |
-
except Exception: # pragma: no cover - fallback in limited envs
|
| 10 |
-
def tqdm(x, **k):
|
| 11 |
-
return x
|
| 12 |
-
|
| 13 |
-
from .envs import BatchEnv
|
| 14 |
-
from .priors import (
|
| 15 |
-
candidate_letter_probs,
|
| 16 |
-
teacher_actions_batch,
|
| 17 |
-
)
|
| 18 |
-
from .utils import enc_pattern, atomic_save
|
| 19 |
-
|
| 20 |
-
|
| 21 |
-
def seed_expert(replay, success_replay, buckets, tries, episodes,
|
| 22 |
-
priors, pos_priors, cand_cache, teacher_mode: str, max_len: int):
|
| 23 |
-
env = BatchEnv(buckets, tries, 1, sorted(buckets.keys()), max_len)
|
| 24 |
-
env.reset()
|
| 25 |
-
pushed = 0
|
| 26 |
-
won_count = 0
|
| 27 |
-
for _ in tqdm(range(episodes), desc="[seed] heuristic"):
|
| 28 |
-
env.words[0] = random.choice(buckets[random.choice(env.len_choices)])
|
| 29 |
-
L = len(env.words[0])
|
| 30 |
-
env.patterns[0] = "_" * L
|
| 31 |
-
env.tried_mask_bits[0] = 0
|
| 32 |
-
env.tries_left[0] = tries
|
| 33 |
-
env.done[0] = False
|
| 34 |
-
env.won[0] = False
|
| 35 |
-
|
| 36 |
-
while not env.done[0]:
|
| 37 |
-
tried_bits = env.tried_mask_bits[0]
|
| 38 |
-
a = teacher_actions_batch([(env.patterns[0], tried_bits, int(env.tries_left[0].item()), L)],
|
| 39 |
-
buckets, teacher_mode, priors, pos_priors, cand_cache)[0]
|
| 40 |
-
s = (env.patterns[0], tried_bits, int(env.tries_left[0].item()), L)
|
| 41 |
-
r = env.step(torch.tensor([a]))[0].item()
|
| 42 |
-
sp = (env.patterns[0], int(env.tried_mask_bits[0]), int(env.tries_left[0].item()), L)
|
| 43 |
-
done = bool(env.done[0].item())
|
| 44 |
-
won = bool(env.won[0].item())
|
| 45 |
-
replay.push(s, a, float(r), sp, done, env.words[0], won)
|
| 46 |
-
if done and won:
|
| 47 |
-
success_replay.push(s, a, float(r), sp, done, env.words[0], won)
|
| 48 |
-
won_count += 1
|
| 49 |
-
pushed += 1
|
| 50 |
-
env.force_reset_done()
|
| 51 |
-
print(f"[seed] added transitions={pushed}, winning_episodes={won_count}")
|
| 52 |
-
|
| 53 |
-
|
| 54 |
-
def behavior_cloning(model, replay, success_replay, priors, device, max_len, tries_init: int,
|
| 55 |
-
steps=3000, bs=512, success_frac=0.5, lr=5e-4, wd=1e-4,
|
| 56 |
-
bc_ckpt_path: str = "", save_every: int = 250):
|
| 57 |
-
if steps <= 0 or len(replay) == 0:
|
| 58 |
-
return
|
| 59 |
-
opt = torch.optim.AdamW(model.parameters(), lr=lr, weight_decay=wd)
|
| 60 |
-
model.train()
|
| 61 |
-
pbar = tqdm(range(steps), desc="[bc] pretrain", leave=False)
|
| 62 |
-
try:
|
| 63 |
-
for t in pbar:
|
| 64 |
-
n_succ = min(int(success_frac * bs), len(success_replay))
|
| 65 |
-
n_base = max(1, bs - n_succ)
|
| 66 |
-
base_batch = replay.sample(min(n_base, len(replay)))
|
| 67 |
-
batch = base_batch
|
| 68 |
-
if n_succ > 0:
|
| 69 |
-
batch += success_replay.sample(n_succ)
|
| 70 |
-
|
| 71 |
-
s, a, _r, _sp, _done, _w, _won = zip(*batch)
|
| 72 |
-
B = len(batch)
|
| 73 |
-
pat_idx = torch.tensor([enc_pattern(si[0], max_len) for si in s], dtype=torch.long, device=device)
|
| 74 |
-
tried = torch.zeros((B, 26), dtype=torch.float32, device=device)
|
| 75 |
-
lens = torch.tensor([min(len(si[0]), max_len) for si in s], dtype=torch.long, device=device)
|
| 76 |
-
pri = torch.zeros((B, 26), dtype=torch.float32, device=device)
|
| 77 |
-
for i, si in enumerate(s):
|
| 78 |
-
m = si[1]
|
| 79 |
-
for j in range(26):
|
| 80 |
-
tried[i, j] = 1.0 if ((m >> j) & 1) else 0.0
|
| 81 |
-
L = min(len(si[0]), max_len)
|
| 82 |
-
pri[i, :] = torch.tensor(priors.get(L, [0.0] * 26), device=device)
|
| 83 |
-
tries_norm = torch.tensor([si[2] for si in s], dtype=torch.float32, device=device) / float(tries_init)
|
| 84 |
-
tries_norm = tries_norm.unsqueeze(1)
|
| 85 |
-
a = torch.tensor(a, dtype=torch.long, device=device)
|
| 86 |
-
|
| 87 |
-
logits = model(pat_idx, tried, lens, pri, tries_norm)
|
| 88 |
-
loss = F.cross_entropy(logits, a)
|
| 89 |
-
|
| 90 |
-
opt.zero_grad(set_to_none=True)
|
| 91 |
-
loss.backward()
|
| 92 |
-
torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0)
|
| 93 |
-
opt.step()
|
| 94 |
-
|
| 95 |
-
if bc_ckpt_path and ((t + 1) % max(1, save_every) == 0):
|
| 96 |
-
atomic_save({"model": model.state_dict()}, bc_ckpt_path)
|
| 97 |
-
finally:
|
| 98 |
-
if bc_ckpt_path:
|
| 99 |
-
atomic_save({"model": model.state_dict()}, bc_ckpt_path)
|
| 100 |
-
print(f"[bc] checkpoint saved to {bc_ckpt_path}")
|
| 101 |
-
|
| 102 |
-
|
| 103 |
-
def actor_bc_pretrain(ac, optimizer, buckets, tries, priors, pos_priors, cand_cache,
|
| 104 |
-
max_len, steps: int, B: int, device, teacher_mode: str = "igx", cand_frac: float = 0.25):
|
| 105 |
-
if steps <= 0:
|
| 106 |
-
return
|
| 107 |
-
env_bc = BatchEnv(buckets, tries, B, sorted(buckets.keys()), max_len)
|
| 108 |
-
env_bc.reset()
|
| 109 |
-
pbar = tqdm(range(steps), desc="[ppo-bc] pretrain", leave=False)
|
| 110 |
-
for t in pbar:
|
| 111 |
-
pat_idx, tried, lens_t, tries_t = env_bc.observe()
|
| 112 |
-
B_now = pat_idx.size(0)
|
| 113 |
-
lp = torch.zeros((B_now, 26), dtype=torch.float32)
|
| 114 |
-
for i, patt in enumerate(env_bc.patterns):
|
| 115 |
-
L = min(len(patt), max_len)
|
| 116 |
-
lp[i, :] = torch.tensor(priors.get(L, [0.0] * 26))
|
| 117 |
-
cp = torch.zeros((B_now, 26), dtype=torch.float32)
|
| 118 |
-
if cand_frac > 0.0:
|
| 119 |
-
import random as _rnd
|
| 120 |
-
k = max(1, int(B_now * min(1.0, max(0.0, cand_frac))))
|
| 121 |
-
idxs = _rnd.sample(range(B_now), k)
|
| 122 |
-
for i in idxs:
|
| 123 |
-
L = min(len(env_bc.patterns[i]), max_len)
|
| 124 |
-
tried_bits = int(env_bc.tried_mask_bits[i])
|
| 125 |
-
cp[i, :] = candidate_letter_probs(L, env_bc.patterns[i], tried_bits, buckets, cand_cache)
|
| 126 |
-
tries_norm = (tries_t.float() / float(tries)).unsqueeze(1)
|
| 127 |
-
states_list = []
|
| 128 |
-
for i in range(B_now):
|
| 129 |
-
patt = env_bc.patterns[i]
|
| 130 |
-
tried_bits = int(env_bc.tried_mask_bits[i])
|
| 131 |
-
L_here = len(patt)
|
| 132 |
-
states_list.append((patt, tried_bits, int(tries_t[i].item()), L_here))
|
| 133 |
-
ta = teacher_actions_batch(states_list, buckets, teacher_mode, priors, pos_priors, cand_cache)
|
| 134 |
-
ta_t = torch.tensor(ta, dtype=torch.long, device=device)
|
| 135 |
-
logits, _ = ac(pat_idx.to(device), tried.to(device), lens_t.to(device),
|
| 136 |
-
lp.to(device), tries_norm.to(device), cand_priors=cp.to(device))
|
| 137 |
-
loss = F.cross_entropy(logits, ta_t)
|
| 138 |
-
optimizer.zero_grad(set_to_none=True)
|
| 139 |
-
loss.backward()
|
| 140 |
-
torch.nn.utils.clip_grad_norm_(ac.parameters(), 1.0)
|
| 141 |
-
optimizer.step()
|
| 142 |
-
env_bc.step(ta_t.cpu())
|
| 143 |
-
env_bc.force_reset_done()
|
| 144 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
hangman/rl/train_bc.py
DELETED
|
@@ -1,139 +0,0 @@
|
|
| 1 |
-
import argparse
|
| 2 |
-
import os
|
| 3 |
-
import time
|
| 4 |
-
|
| 5 |
-
import torch
|
| 6 |
-
|
| 7 |
-
from .models import DuelingQNet
|
| 8 |
-
from .replay import Replay, SuccessReplay
|
| 9 |
-
from .seed_bc import seed_expert, behavior_cloning
|
| 10 |
-
from .eval import greedy_rollout
|
| 11 |
-
from .priors import build_length_priors, build_positional_priors, CandCache
|
| 12 |
-
from .utils import load_dict, by_len, set_seed
|
| 13 |
-
|
| 14 |
-
|
| 15 |
-
def parse_args():
|
| 16 |
-
p = argparse.ArgumentParser("Behavior Cloning pretraining for Hangman")
|
| 17 |
-
# data
|
| 18 |
-
p.add_argument("--dict_path", type=str, default="data/words_250000_train.txt", help="Path to word list")
|
| 19 |
-
p.add_argument("--len_lo", type=int, default=4, help="Min word length")
|
| 20 |
-
p.add_argument("--len_hi", type=int, default=12, help="Max word length")
|
| 21 |
-
p.add_argument("--max_len", type=int, default=35, help="Model max sequence length")
|
| 22 |
-
p.add_argument("--tries", type=int, default=6, help="Initial tries for env")
|
| 23 |
-
# seeding and BC
|
| 24 |
-
p.add_argument("--seed_episodes", type=int, default=5000, help="How many heuristic episodes to seed replay")
|
| 25 |
-
p.add_argument("--teacher_mode", type=str, default="igx", choices=["igx", "ig", "cand", "pos", "len"], help="Heuristic teacher policy")
|
| 26 |
-
p.add_argument("--replay_cap", type=int, default=200_000, help="Replay buffer capacity")
|
| 27 |
-
p.add_argument("--success_cap", type=int, default=50_000, help="Success replay capacity")
|
| 28 |
-
p.add_argument("--bc_steps", type=int, default=5000, help="Behavior cloning optimization steps")
|
| 29 |
-
p.add_argument("--bc_bs", type=int, default=512, help="Behavior cloning batch size")
|
| 30 |
-
p.add_argument("--bc_lr", type=float, default=5e-4, help="Learning rate for BC")
|
| 31 |
-
p.add_argument("--bc_wd", type=float, default=1e-4, help="Weight decay for BC")
|
| 32 |
-
p.add_argument("--success_frac", type=float, default=0.5, help="Fraction of success samples in BC batches")
|
| 33 |
-
p.add_argument("--save_every", type=int, default=500, help="Checkpoint frequency (steps)")
|
| 34 |
-
p.add_argument("--out_dir", type=str, default="runs/bc", help="Output directory for checkpoints and logs")
|
| 35 |
-
# model
|
| 36 |
-
p.add_argument("--d_model", type=int, default=128)
|
| 37 |
-
p.add_argument("--nhead", type=int, default=4)
|
| 38 |
-
p.add_argument("--nlayers", type=int, default=2)
|
| 39 |
-
p.add_argument("--ff_mult", type=int, default=4)
|
| 40 |
-
p.add_argument("--dropout", type=float, default=0.1)
|
| 41 |
-
p.add_argument("--noisy", action="store_true", help="Use NoisyNet layers in dueling head")
|
| 42 |
-
p.add_argument("--device", type=str, default="cuda" if torch.cuda.is_available() else "cpu")
|
| 43 |
-
# eval
|
| 44 |
-
p.add_argument("--eval_N", type=int, default=2000, help="Episodes for quick greedy eval")
|
| 45 |
-
p.add_argument("--eval_stride", type=int, default=256, help="Logging stride for eval")
|
| 46 |
-
p.add_argument("--eval_use_cand_priors", action="store_true", help="Fuse candidate priors during eval")
|
| 47 |
-
# misc
|
| 48 |
-
p.add_argument("--seed", type=int, default=42)
|
| 49 |
-
return p.parse_args()
|
| 50 |
-
|
| 51 |
-
|
| 52 |
-
def main():
|
| 53 |
-
args = parse_args()
|
| 54 |
-
set_seed(int(args.seed))
|
| 55 |
-
|
| 56 |
-
assert os.path.exists(args.dict_path), f"Dictionary file not found: {args.dict_path}"
|
| 57 |
-
words = load_dict(args.dict_path)
|
| 58 |
-
buckets = by_len(words, args.len_lo, args.len_hi)
|
| 59 |
-
priors = build_length_priors(buckets)
|
| 60 |
-
pos_priors = build_positional_priors(buckets, args.max_len)
|
| 61 |
-
cand_cache = CandCache(100_000)
|
| 62 |
-
|
| 63 |
-
device = torch.device(args.device)
|
| 64 |
-
|
| 65 |
-
# replay buffers
|
| 66 |
-
replay = Replay(cap=int(args.replay_cap))
|
| 67 |
-
success_replay = SuccessReplay(cap=int(args.success_cap))
|
| 68 |
-
|
| 69 |
-
# expert seeding
|
| 70 |
-
print(f"[setup] Seeding replay with heuristic='{args.teacher_mode}', episodes={args.seed_episodes}…")
|
| 71 |
-
seed_expert(
|
| 72 |
-
replay,
|
| 73 |
-
success_replay,
|
| 74 |
-
buckets,
|
| 75 |
-
tries=int(args.tries),
|
| 76 |
-
episodes=int(args.seed_episodes),
|
| 77 |
-
priors=priors,
|
| 78 |
-
pos_priors=pos_priors,
|
| 79 |
-
cand_cache=cand_cache,
|
| 80 |
-
teacher_mode=str(args.teacher_mode),
|
| 81 |
-
max_len=int(args.max_len),
|
| 82 |
-
)
|
| 83 |
-
|
| 84 |
-
# model
|
| 85 |
-
model = DuelingQNet(
|
| 86 |
-
d_model=int(args.d_model),
|
| 87 |
-
nhead=int(args.nhead),
|
| 88 |
-
nlayers=int(args.nlayers),
|
| 89 |
-
ff_mult=int(args.ff_mult),
|
| 90 |
-
max_len=int(args.max_len),
|
| 91 |
-
dropout=float(args.dropout),
|
| 92 |
-
use_noisy=bool(args.noisy),
|
| 93 |
-
).to(device)
|
| 94 |
-
|
| 95 |
-
os.makedirs(args.out_dir, exist_ok=True)
|
| 96 |
-
ckpt_path = os.path.join(args.out_dir, "bc_dueling_qnet.pt")
|
| 97 |
-
|
| 98 |
-
# behavior cloning
|
| 99 |
-
print(f"[train] Starting BC: steps={args.bc_steps}, bs={args.bc_bs}, success_frac={args.success_frac}")
|
| 100 |
-
t0 = time.perf_counter()
|
| 101 |
-
behavior_cloning(
|
| 102 |
-
model,
|
| 103 |
-
replay,
|
| 104 |
-
success_replay,
|
| 105 |
-
priors,
|
| 106 |
-
device,
|
| 107 |
-
max_len=int(args.max_len),
|
| 108 |
-
tries_init=int(args.tries),
|
| 109 |
-
steps=int(args.bc_steps),
|
| 110 |
-
bs=int(args.bc_bs),
|
| 111 |
-
success_frac=float(args.success_frac),
|
| 112 |
-
lr=float(args.bc_lr),
|
| 113 |
-
wd=float(args.bc_wd),
|
| 114 |
-
bc_ckpt_path=ckpt_path,
|
| 115 |
-
save_every=int(args.save_every),
|
| 116 |
-
)
|
| 117 |
-
dt_min = (time.perf_counter() - t0) / 60.0
|
| 118 |
-
print(f"[train] BC finished in {dt_min:.2f} min. Checkpoint saved to {ckpt_path}")
|
| 119 |
-
|
| 120 |
-
# quick greedy rollout eval
|
| 121 |
-
print("[eval] Running greedy rollout eval…")
|
| 122 |
-
from .envs import BatchEnv
|
| 123 |
-
env = BatchEnv(buckets, int(args.tries), batch=64, len_choices=sorted(buckets.keys()), max_len=int(args.max_len))
|
| 124 |
-
env.reset()
|
| 125 |
-
wr = greedy_rollout(
|
| 126 |
-
env,
|
| 127 |
-
model,
|
| 128 |
-
device=device,
|
| 129 |
-
N=int(args.eval_N),
|
| 130 |
-
priors=priors,
|
| 131 |
-
log_stride=int(args.eval_stride),
|
| 132 |
-
use_cand_priors=bool(args.eval_use_cand_priors),
|
| 133 |
-
)
|
| 134 |
-
print(f"[done] Greedy eval win-rate={wr:.3f} over N={args.eval_N}")
|
| 135 |
-
|
| 136 |
-
|
| 137 |
-
if __name__ == "__main__":
|
| 138 |
-
main()
|
| 139 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
hangman/utils.py
DELETED
|
@@ -1,129 +0,0 @@
|
|
| 1 |
-
import re
|
| 2 |
-
import random
|
| 3 |
-
import string
|
| 4 |
-
import pandas as pd
|
| 5 |
-
import torch
|
| 6 |
-
|
| 7 |
-
RNG = random.Random(0) # reproducible sampling
|
| 8 |
-
|
| 9 |
-
def get_char_mapping():
|
| 10 |
-
# include PAD=0, a..z=1..26, _=27
|
| 11 |
-
return {'PAD': 0, **{c: i+1 for i, c in enumerate(string.ascii_lowercase)}, '_': 27}
|
| 12 |
-
|
| 13 |
-
def read_data():
|
| 14 |
-
with open("words_250000_train.txt", "r") as f:
|
| 15 |
-
words = [w.strip().lower() for w in f if w.strip()]
|
| 16 |
-
# keep pure alphabetic words only
|
| 17 |
-
words = [w for w in words if re.fullmatch(r"[a-z]+", w)]
|
| 18 |
-
return words
|
| 19 |
-
|
| 20 |
-
def create_intermediate_data(words):
|
| 21 |
-
x = pd.DataFrame({0: words})
|
| 22 |
-
x[1] = x[0].str.len()
|
| 23 |
-
x['vowels_present'] = x[0].apply(lambda p: set(p) & {'a','e','i','o','u'})
|
| 24 |
-
x['vowels_count'] = x['vowels_present'].str.len()
|
| 25 |
-
x['unique_char_count'] = x[0].apply(lambda p: len(set(p)))
|
| 26 |
-
# filter: length>3, ≥3 unique chars, has a vowel
|
| 27 |
-
x_ = x[~((x['unique_char_count'].isin([0,1,2])) | (x[1] <= 3)) & (x['vowels_count'] != 0)]
|
| 28 |
-
return x_
|
| 29 |
-
|
| 30 |
-
def loop_for_permutation(unique_letters, word, all_perm, k):
|
| 31 |
-
# choose k+1 letters to hide
|
| 32 |
-
letters_to_hide = RNG.sample(unique_letters, k+1)
|
| 33 |
-
masked = word
|
| 34 |
-
for L in letters_to_hide:
|
| 35 |
-
masked = masked.replace(L, "_")
|
| 36 |
-
all_perm.append(masked)
|
| 37 |
-
|
| 38 |
-
def permute_all(word, vowel_permutation_loop=False):
|
| 39 |
-
# random subset sampling of letters to hide; keep at least 1 char shown
|
| 40 |
-
uniq = list(set(word))
|
| 41 |
-
all_perm = []
|
| 42 |
-
hi = (len(uniq)-1) if vowel_permutation_loop else (len(uniq)-2)
|
| 43 |
-
for k in range(max(0, hi)):
|
| 44 |
-
loop_for_permutation(uniq, word, all_perm, k)
|
| 45 |
-
return list(set(all_perm))
|
| 46 |
-
|
| 47 |
-
def permute_consonents(word):
|
| 48 |
-
# keep some vowels only; hide all consonants
|
| 49 |
-
vowel_positions = [i for i,ch in enumerate(word) if ch in "aeiou"]
|
| 50 |
-
vowels = "".join(word[i] for i in vowel_positions) # only the vowels
|
| 51 |
-
perm_vowel_only = permute_all(vowels, vowel_permutation_loop=True)
|
| 52 |
-
out = []
|
| 53 |
-
for pv in perm_vowel_only:
|
| 54 |
-
a = ["_"] * len(word)
|
| 55 |
-
for j, ch in enumerate(pv):
|
| 56 |
-
a[vowel_positions[j]] = ch
|
| 57 |
-
out.append("".join(a))
|
| 58 |
-
return out
|
| 59 |
-
|
| 60 |
-
def create_masked_dictionary(df_aug):
|
| 61 |
-
masked_dictionary = {}
|
| 62 |
-
for i, word in df_aug[0].items():
|
| 63 |
-
pats = permute_all(word) + permute_consonents(word)
|
| 64 |
-
masked_dictionary[word] = list(set(pats))
|
| 65 |
-
if i % 10000 == 0:
|
| 66 |
-
print(f"Iteration {i} completed")
|
| 67 |
-
return masked_dictionary
|
| 68 |
-
|
| 69 |
-
def get_vowel_prob(df_vowel, vowel):
|
| 70 |
-
if len(df_vowel)==0: return 0.0
|
| 71 |
-
return df_vowel[0].str.contains(vowel).mean()
|
| 72 |
-
|
| 73 |
-
def get_vowel_prior(df_aug):
|
| 74 |
-
prior = {}
|
| 75 |
-
max_len = int(df_aug[1].max())
|
| 76 |
-
for L in range(1, max_len+1):
|
| 77 |
-
df_v = df_aug[df_aug[1] == L]
|
| 78 |
-
probs = [get_vowel_prob(df_v, v) for v in "aeiou"]
|
| 79 |
-
prior[L] = pd.DataFrame({"vowel": list("aeiou"), "p": probs}).sort_values("p", ascending=False)
|
| 80 |
-
return prior
|
| 81 |
-
|
| 82 |
-
def save_vowel_prior(vowel_prior, path="prior_probabilities.pkl"):
|
| 83 |
-
import pickle
|
| 84 |
-
with open(path, "wb") as f:
|
| 85 |
-
pickle.dump(vowel_prior, f)
|
| 86 |
-
|
| 87 |
-
# ---------- ENCODING (align target to Hangman) ----------
|
| 88 |
-
|
| 89 |
-
CMAP = get_char_mapping()
|
| 90 |
-
|
| 91 |
-
def idx(c): # a..z
|
| 92 |
-
return CMAP[c]-1
|
| 93 |
-
|
| 94 |
-
def encode_input(pattern, max_len=35):
|
| 95 |
-
vec = [0]*max_len
|
| 96 |
-
start = max_len - len(pattern)
|
| 97 |
-
for i,ch in enumerate(pattern):
|
| 98 |
-
vec[start+i] = CMAP[ch] # 0 is PAD, '_'=27, letters 1..26
|
| 99 |
-
return vec
|
| 100 |
-
|
| 101 |
-
def encode_output_for_pattern(word, pattern):
|
| 102 |
-
"""Label only letters that are STILL HIDDEN by this pattern."""
|
| 103 |
-
y = [0]*26
|
| 104 |
-
for i,ch in enumerate(word):
|
| 105 |
-
if pattern[i] == '_':
|
| 106 |
-
y[idx(ch)] = 1
|
| 107 |
-
return y
|
| 108 |
-
|
| 109 |
-
def encode_words(masked_dictionary):
|
| 110 |
-
X, Y = [], []
|
| 111 |
-
for word, patterns in masked_dictionary.items():
|
| 112 |
-
for pat in patterns:
|
| 113 |
-
X.append(encode_input(pat))
|
| 114 |
-
Y.append(encode_output_for_pattern(word, pat))
|
| 115 |
-
return X, Y
|
| 116 |
-
|
| 117 |
-
def convert_to_tensor(X, Y):
|
| 118 |
-
X = torch.tensor(X, dtype=torch.long)
|
| 119 |
-
Y = torch.tensor(Y, dtype=torch.float32) # for BCEWithLogitsLoss
|
| 120 |
-
return X, Y
|
| 121 |
-
|
| 122 |
-
def get_datasets():
|
| 123 |
-
words = read_data()
|
| 124 |
-
df_aug = create_intermediate_data(words)
|
| 125 |
-
masked_dictionary = create_masked_dictionary(df_aug)
|
| 126 |
-
vowel_prior = get_vowel_prior(df_aug)
|
| 127 |
-
save_vowel_prior(vowel_prior)
|
| 128 |
-
X, Y = encode_words(masked_dictionary)
|
| 129 |
-
return convert_to_tensor(X, Y)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|