Wen-Ting Wang commited on
Commit
cf72ffa
·
1 Parent(s): 406b978

feat: Deploy Hangman AI Demo to Hugging Face Spaces

Browse files

- Create Gradio app for interactive hangman demo
- Add requirements.txt with necessary dependencies
- Include README with proper metadata for HF Spaces
- Fix short_description to meet 60 character limit

.gitignore ADDED
@@ -0,0 +1,96 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Python
2
+ __pycache__/
3
+ *.py[cod]
4
+ *$py.class
5
+ *.so
6
+ .Python
7
+ build/
8
+ develop-eggs/
9
+ dist/
10
+ downloads/
11
+ eggs/
12
+ .eggs/
13
+ lib/
14
+ lib64/
15
+ parts/
16
+ sdist/
17
+ var/
18
+ wheels/
19
+ *.egg-info/
20
+ .installed.cfg
21
+ *.egg
22
+ MANIFEST
23
+
24
+ # PyInstaller
25
+ *.manifest
26
+ *.spec
27
+
28
+ # Installer logs
29
+ pip-log.txt
30
+ pip-delete-this-directory.txt
31
+
32
+ # Unit test / coverage reports
33
+ htmlcov/
34
+ .tox/
35
+ .coverage
36
+ .coverage.*
37
+ .cache
38
+ nosetests.xml
39
+ coverage.xml
40
+ *.cover
41
+ .hypothesis/
42
+ .pytest_cache/
43
+
44
+ # Jupyter Notebook
45
+ .ipynb_checkpoints
46
+
47
+ # pyenv
48
+ .python-version
49
+
50
+ # Environments
51
+ .env
52
+ .venv
53
+ env/
54
+ venv/
55
+ ENV/
56
+ env.bak/
57
+ venv.bak/
58
+
59
+ # IDE
60
+ .vscode/
61
+ .idea/
62
+ *.swp
63
+ *.swo
64
+ *~
65
+
66
+ # OS
67
+ .DS_Store
68
+ .DS_Store?
69
+ ._*
70
+ .Spotlight-V100
71
+ .Trashes
72
+ ehthumbs.db
73
+ Thumbs.db
74
+
75
+ # Hugging Face specific
76
+ *.bin
77
+ *.safetensors
78
+ *.h5
79
+ *.ckpt
80
+ *.pth
81
+ *.pt
82
+ *.pkl
83
+ *.pickle
84
+
85
+ # Model checkpoints and data
86
+ checkpoints/
87
+ models/
88
+ data/
89
+ logs/
90
+ runs/
91
+ wandb/
92
+
93
+ # Temporary files
94
+ *.tmp
95
+ *.temp
96
+ *.log
hangman/__pycache__/__init__.cpython-312.pyc DELETED
Binary file (413 Bytes)
 
hangman/__pycache__/char_transformer.cpython-312.pyc DELETED
Binary file (3.47 kB)
 
hangman/__pycache__/hangman_core.cpython-312.pyc DELETED
Binary file (47.9 kB)
 
hangman/rl/__pycache__/__init__.cpython-312.pyc DELETED
Binary file (385 Bytes)
 
hangman/rl/__pycache__/envs.cpython-312.pyc DELETED
Binary file (7.15 kB)
 
hangman/rl/__pycache__/models.cpython-312.pyc DELETED
Binary file (14 kB)
 
hangman/rl/__pycache__/utils.cpython-312.pyc DELETED
Binary file (3.48 kB)
 
hangman/rl/eval.py DELETED
@@ -1,178 +0,0 @@
1
- import os
2
- import time
3
- import csv
4
- from argparse import Namespace
5
-
6
- import torch
7
-
8
- from .envs import BatchEnv
9
- from .priors import (
10
- CandCache,
11
- ig_exact_pick,
12
- candidate_letter_probs,
13
- pos_present_probs,
14
- )
15
-
16
-
17
- @torch.no_grad()
18
- def greedy_rollout(win_env: BatchEnv, model, device, N=1000, priors=None, log_stride: int = 256, use_cand_priors: bool = False):
19
- wins = 0
20
- total = 0
21
- B = min(win_env.batch, 256)
22
- env = BatchEnv(win_env.buckets, win_env.tries_init, B, win_env.len_choices.copy(), win_env.max_len)
23
- env.reset()
24
- model.eval()
25
- if hasattr(model, "remove_noise"):
26
- model.remove_noise()
27
- start = time.perf_counter()
28
- local_cache = CandCache(1024) if use_cand_priors else None
29
- while total < N:
30
- pat_idx, tried, lens, tries = env.observe()
31
- B_now = pat_idx.size(0)
32
- lp = torch.zeros((B_now, 26), dtype=torch.float32)
33
- for i, patt in enumerate(env.patterns):
34
- L = min(len(patt), model.max_len if hasattr(model, "max_len") else win_env.max_len)
35
- lp[i, :] = torch.tensor(priors.get(L, [0.0] * 26))
36
- if use_cand_priors and local_cache is not None:
37
- cp = torch.zeros((B_now, 26), dtype=torch.float32)
38
- for i in range(B_now):
39
- L = min(len(env.patterns[i]), win_env.max_len)
40
- tried_bits = int(env.tried_mask_bits[i])
41
- cp[i, :] = candidate_letter_probs(L, env.patterns[i], tried_bits, win_env.buckets, local_cache)
42
- else:
43
- cp = None
44
- tn = (tries.float() / win_env.tries_init).unsqueeze(1)
45
- out = model(
46
- pat_idx.to(device),
47
- tried.to(device),
48
- lens.to(device),
49
- lp.to(device),
50
- tn.to(device),
51
- cand_priors=(cp.to(device) if cp is not None else None),
52
- )
53
- if isinstance(out, tuple):
54
- logits, _v = out
55
- actions = logits.argmax(dim=1).cpu()
56
- else:
57
- q = out
58
- actions = q.argmax(dim=1).cpu()
59
- env.step(actions)
60
- finished = env.done.clone()
61
- if finished.any():
62
- batch_wins = int(env.won.sum().item())
63
- batch_finished = int(finished.sum().item())
64
- wins += batch_wins
65
- total += batch_finished
66
- if (total % max(1, log_stride) == 0) or (total >= N):
67
- wr = wins / max(1, total)
68
- elapsed = (time.perf_counter() - start) / 60.0
69
- print(f"[eval] episodes={total}/{N} win-rate={wr:.3f} | {elapsed:.1f} min")
70
- env.force_reset_done()
71
- model.train()
72
- if hasattr(model, "resample_noise"):
73
- model.resample_noise()
74
- return wins / max(1, total)
75
-
76
-
77
- @torch.no_grad()
78
- def run_solver(args, buckets, priors, pos_priors):
79
- lens = sorted(buckets.keys())
80
- env = BatchEnv(buckets, args.tries, args.batch_env, lens.copy(), args.max_len)
81
- env.reset()
82
- total = 0
83
- wins = 0
84
- start = time.perf_counter()
85
- csv_fp = None
86
- csv_writer = None
87
- if getattr(args, "csv_log", False):
88
- os.makedirs(args.out_dir, exist_ok=True)
89
- csv_path = os.path.join(args.out_dir, "metrics.csv")
90
- new_file = not os.path.exists(csv_path)
91
- csv_fp = open(csv_path, "a", newline="")
92
- csv_writer = csv.writer(csv_fp)
93
- if new_file:
94
- csv_writer.writerow(["mode", "episodes", "wins", "win_rate", "minutes"])
95
- while total < args.solver_eval_N:
96
- B = env.batch
97
- actions = torch.zeros(B, dtype=torch.long)
98
- for i in range(B):
99
- patt = env.patterns[i]
100
- L = len(patt)
101
- tried_bits = int(env.tried_mask_bits[i])
102
- if args.solver_mode == "igx":
103
- a = ig_exact_pick(tried_bits, L, patt, buckets)
104
- elif args.solver_mode == "ig":
105
- vec = candidate_letter_probs(L, patt, tried_bits, buckets, CandCache(1))
106
- score = vec.clamp(0, 1) * (1 - vec.clamp(0, 1))
107
- for j in range(26):
108
- if (tried_bits >> j) & 1:
109
- score[j] = -1.0
110
- a = int(score.argmax().item())
111
- elif args.solver_mode == "pos":
112
- vec = pos_present_probs(L, patt, pos_priors)
113
- for j in range(26):
114
- if (tried_bits >> j) & 1:
115
- vec[j] = -1.0
116
- a = int(vec.argmax().item())
117
- elif args.solver_mode == "len":
118
- vec = torch.tensor(priors.get(L, [0.0]*26), dtype=torch.float32)
119
- for j in range(26):
120
- if (tried_bits >> j) & 1:
121
- vec[j] = -1.0
122
- a = int(vec.argmax().item())
123
- else: # cand
124
- vec = candidate_letter_probs(L, patt, tried_bits, buckets, CandCache(1))
125
- for j in range(26):
126
- if (tried_bits >> j) & 1:
127
- vec[j] = -1.0
128
- a = int(vec.argmax().item())
129
- actions[i] = a
130
- env.step(actions)
131
- finished = env.done.clone()
132
- if finished.any():
133
- wins += int(env.won.sum().item())
134
- total += int(finished.sum().item())
135
- env.force_reset_done()
136
- if total % 512 == 0 or total >= args.solver_eval_N:
137
- wr = wins / max(1, total)
138
- elapsed = (time.perf_counter() - start) / 60.0
139
- print(f"[solver {args.solver_mode}] episodes={total}/{args.solver_eval_N} win-rate={wr:.3f} | {elapsed:.1f} min")
140
- if csv_writer:
141
- csv_writer.writerow([args.solver_mode, total, wins, f"{wr:.6f}", f"{elapsed:.3f}"])
142
- csv_fp.flush()
143
- final_wr = wins / max(1, total)
144
- print(f"[done][SOLVER:{args.solver_mode}] win-rate={final_wr:.3f} over {total} episodes")
145
- if csv_fp:
146
- csv_fp.close()
147
- return final_wr
148
-
149
-
150
- def run_solver_sweep(args, buckets, priors, pos_priors):
151
- modes = [m.strip() for m in str(args.sweep_modes).split(",") if m.strip()]
152
- results = {m: [] for m in modes}
153
- for r in range(int(args.sweep_repeats)):
154
- for m in modes:
155
- a = Namespace(**vars(args))
156
- a.solver_mode = m
157
- wr = run_solver(a, buckets, priors, pos_priors)
158
- results[m].append(float(wr))
159
- summary_rows = []
160
- print("\n[solver sweep] summary:")
161
- for m in modes:
162
- vals = results[m]
163
- mean_wr = sum(vals) / max(1, len(vals))
164
- best_wr = max(vals) if vals else 0.0
165
- print(f" - {m:>4}: mean={mean_wr:.3f} best={best_wr:.3f} over {len(vals)} run(s)")
166
- summary_rows.append((m, len(vals), mean_wr, best_wr))
167
- if getattr(args, "csv_log", False):
168
- os.makedirs(args.out_dir, exist_ok=True)
169
- path = os.path.join(args.out_dir, "solver_sweep.csv")
170
- new_file = not os.path.exists(path)
171
- with open(path, "a", newline="") as fp:
172
- w = csv.writer(fp)
173
- if new_file:
174
- w.writerow(["mode", "repeats", "mean_win_rate", "best_win_rate"])
175
- for m, rpt, mean_wr, best_wr in summary_rows:
176
- w.writerow([m, rpt, f"{mean_wr:.6f}", f"{best_wr:.6f}"])
177
- print(f"[solver sweep] written summary to {path}")
178
-
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
hangman/rl/ppo.py DELETED
@@ -1,21 +0,0 @@
1
- import torch
2
-
3
-
4
- def compute_gae(rewards, values, dones, gamma, lam):
5
- """
6
- rewards, values, dones: tensors [T, B]
7
- returns advantages [T, B] and returns [T, B]
8
- """
9
- T, B = rewards.size(0), rewards.size(1)
10
- adv = torch.zeros_like(rewards)
11
- lastgaelam = torch.zeros(B, device=rewards.device)
12
- next_value = values[-1]
13
- for t in reversed(range(T)):
14
- mask = 1.0 - dones[t].float()
15
- delta = rewards[t] + gamma * next_value * mask - values[t]
16
- lastgaelam = delta + gamma * lam * mask * lastgaelam
17
- adv[t] = lastgaelam
18
- next_value = values[t]
19
- returns = adv + values[:-1]
20
- return adv, returns
21
-
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
hangman/rl/replay.py DELETED
@@ -1,86 +0,0 @@
1
- from typing import Tuple
2
-
3
- import torch
4
-
5
- from .utils import enc_pattern
6
- from .priors import cand_priors_batch
7
-
8
-
9
- class Replay:
10
- def __init__(self, cap: int):
11
- self.cap = cap
12
- self.buf = []
13
- self.pos = 0
14
-
15
- def push(self, s, a, r, sp, done, word, won: bool):
16
- item = (s, a, r, sp, done, word, won)
17
- if len(self.buf) < self.cap:
18
- self.buf.append(item)
19
- else:
20
- self.buf[self.pos] = item
21
- self.pos = (self.pos + 1) % self.cap
22
-
23
- def sample(self, n):
24
- import random
25
- return random.sample(self.buf, n)
26
-
27
- def __len__(self):
28
- return len(self.buf)
29
-
30
-
31
- class SuccessReplay(Replay):
32
- pass
33
-
34
-
35
- def tensorize_batch(batch, device, max_len, priors_dict, buckets, cand_cache, cand_frac: float = 1.0):
36
- s, a, r, sp, done, _w, _won = zip(*batch)
37
- B = len(batch)
38
- # s
39
- pat_idx0 = torch.tensor([enc_pattern(si[0], max_len) for si in s], dtype=torch.long, device=device)
40
- tried0 = torch.zeros((B, 26), dtype=torch.float32, device=device)
41
- lens0 = torch.tensor([min(len(si[0]), max_len) for si in s], dtype=torch.long, device=device)
42
- pri0 = torch.zeros((B, 26), dtype=torch.float32, device=device)
43
- tries0 = torch.tensor([si[2] for si in s], dtype=torch.float32, device=device) # raw count
44
- for i, si in enumerate(s):
45
- m = si[1]
46
- for j in range(26):
47
- tried0[i, j] = 1.0 if ((m >> j) & 1) else 0.0
48
- L = min(len(si[0]), max_len)
49
- pri0[i, :] = torch.tensor(priors_dict.get(L, [0.0] * 26), dtype=torch.float32, device=device)
50
- cand0 = cand_priors_batch(s, buckets, cand_cache, max_len, device) if cand_frac >= 1.0 else None
51
-
52
- # sp
53
- pat_idx1 = torch.tensor([enc_pattern(si[0], max_len) for si in sp], dtype=torch.long, device=device)
54
- tried1 = torch.zeros((B, 26), dtype=torch.float32, device=device)
55
- lens1 = torch.tensor([min(len(si[0]), max_len) for si in sp], dtype=torch.long, device=device)
56
- pri1 = torch.zeros((B, 26), dtype=torch.float32, device=device)
57
- tries1 = torch.tensor([si[2] for si in sp], dtype=torch.float32, device=device)
58
- for i, si in enumerate(sp):
59
- m = si[1]
60
- for j in range(26):
61
- tried1[i, j] = 1.0 if ((m >> j) & 1) else 0.0
62
- L = min(len(si[0]), max_len)
63
- pri1[i, :] = torch.tensor(priors_dict.get(L, [0.0] * 26), dtype=torch.float32, device=device)
64
- cand1 = cand_priors_batch(sp, buckets, cand_cache, max_len, device) if cand_frac >= 1.0 else None
65
-
66
- a = torch.tensor(a, dtype=torch.long, device=device)
67
- r = torch.tensor(r, dtype=torch.float32, device=device)
68
- done = torch.tensor(done, dtype=torch.bool, device=device)
69
- return (
70
- pat_idx0,
71
- tried0,
72
- lens0,
73
- pri0,
74
- cand0,
75
- tries0,
76
- pat_idx1,
77
- tried1,
78
- lens1,
79
- pri1,
80
- cand1,
81
- tries1,
82
- a,
83
- r,
84
- done,
85
- )
86
-
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
hangman/rl/seed_bc.py DELETED
@@ -1,144 +0,0 @@
1
- import random
2
- from typing import List
3
-
4
- import torch
5
- import torch.nn.functional as F
6
-
7
- try:
8
- from tqdm import tqdm
9
- except Exception: # pragma: no cover - fallback in limited envs
10
- def tqdm(x, **k):
11
- return x
12
-
13
- from .envs import BatchEnv
14
- from .priors import (
15
- candidate_letter_probs,
16
- teacher_actions_batch,
17
- )
18
- from .utils import enc_pattern, atomic_save
19
-
20
-
21
- def seed_expert(replay, success_replay, buckets, tries, episodes,
22
- priors, pos_priors, cand_cache, teacher_mode: str, max_len: int):
23
- env = BatchEnv(buckets, tries, 1, sorted(buckets.keys()), max_len)
24
- env.reset()
25
- pushed = 0
26
- won_count = 0
27
- for _ in tqdm(range(episodes), desc="[seed] heuristic"):
28
- env.words[0] = random.choice(buckets[random.choice(env.len_choices)])
29
- L = len(env.words[0])
30
- env.patterns[0] = "_" * L
31
- env.tried_mask_bits[0] = 0
32
- env.tries_left[0] = tries
33
- env.done[0] = False
34
- env.won[0] = False
35
-
36
- while not env.done[0]:
37
- tried_bits = env.tried_mask_bits[0]
38
- a = teacher_actions_batch([(env.patterns[0], tried_bits, int(env.tries_left[0].item()), L)],
39
- buckets, teacher_mode, priors, pos_priors, cand_cache)[0]
40
- s = (env.patterns[0], tried_bits, int(env.tries_left[0].item()), L)
41
- r = env.step(torch.tensor([a]))[0].item()
42
- sp = (env.patterns[0], int(env.tried_mask_bits[0]), int(env.tries_left[0].item()), L)
43
- done = bool(env.done[0].item())
44
- won = bool(env.won[0].item())
45
- replay.push(s, a, float(r), sp, done, env.words[0], won)
46
- if done and won:
47
- success_replay.push(s, a, float(r), sp, done, env.words[0], won)
48
- won_count += 1
49
- pushed += 1
50
- env.force_reset_done()
51
- print(f"[seed] added transitions={pushed}, winning_episodes={won_count}")
52
-
53
-
54
- def behavior_cloning(model, replay, success_replay, priors, device, max_len, tries_init: int,
55
- steps=3000, bs=512, success_frac=0.5, lr=5e-4, wd=1e-4,
56
- bc_ckpt_path: str = "", save_every: int = 250):
57
- if steps <= 0 or len(replay) == 0:
58
- return
59
- opt = torch.optim.AdamW(model.parameters(), lr=lr, weight_decay=wd)
60
- model.train()
61
- pbar = tqdm(range(steps), desc="[bc] pretrain", leave=False)
62
- try:
63
- for t in pbar:
64
- n_succ = min(int(success_frac * bs), len(success_replay))
65
- n_base = max(1, bs - n_succ)
66
- base_batch = replay.sample(min(n_base, len(replay)))
67
- batch = base_batch
68
- if n_succ > 0:
69
- batch += success_replay.sample(n_succ)
70
-
71
- s, a, _r, _sp, _done, _w, _won = zip(*batch)
72
- B = len(batch)
73
- pat_idx = torch.tensor([enc_pattern(si[0], max_len) for si in s], dtype=torch.long, device=device)
74
- tried = torch.zeros((B, 26), dtype=torch.float32, device=device)
75
- lens = torch.tensor([min(len(si[0]), max_len) for si in s], dtype=torch.long, device=device)
76
- pri = torch.zeros((B, 26), dtype=torch.float32, device=device)
77
- for i, si in enumerate(s):
78
- m = si[1]
79
- for j in range(26):
80
- tried[i, j] = 1.0 if ((m >> j) & 1) else 0.0
81
- L = min(len(si[0]), max_len)
82
- pri[i, :] = torch.tensor(priors.get(L, [0.0] * 26), device=device)
83
- tries_norm = torch.tensor([si[2] for si in s], dtype=torch.float32, device=device) / float(tries_init)
84
- tries_norm = tries_norm.unsqueeze(1)
85
- a = torch.tensor(a, dtype=torch.long, device=device)
86
-
87
- logits = model(pat_idx, tried, lens, pri, tries_norm)
88
- loss = F.cross_entropy(logits, a)
89
-
90
- opt.zero_grad(set_to_none=True)
91
- loss.backward()
92
- torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0)
93
- opt.step()
94
-
95
- if bc_ckpt_path and ((t + 1) % max(1, save_every) == 0):
96
- atomic_save({"model": model.state_dict()}, bc_ckpt_path)
97
- finally:
98
- if bc_ckpt_path:
99
- atomic_save({"model": model.state_dict()}, bc_ckpt_path)
100
- print(f"[bc] checkpoint saved to {bc_ckpt_path}")
101
-
102
-
103
- def actor_bc_pretrain(ac, optimizer, buckets, tries, priors, pos_priors, cand_cache,
104
- max_len, steps: int, B: int, device, teacher_mode: str = "igx", cand_frac: float = 0.25):
105
- if steps <= 0:
106
- return
107
- env_bc = BatchEnv(buckets, tries, B, sorted(buckets.keys()), max_len)
108
- env_bc.reset()
109
- pbar = tqdm(range(steps), desc="[ppo-bc] pretrain", leave=False)
110
- for t in pbar:
111
- pat_idx, tried, lens_t, tries_t = env_bc.observe()
112
- B_now = pat_idx.size(0)
113
- lp = torch.zeros((B_now, 26), dtype=torch.float32)
114
- for i, patt in enumerate(env_bc.patterns):
115
- L = min(len(patt), max_len)
116
- lp[i, :] = torch.tensor(priors.get(L, [0.0] * 26))
117
- cp = torch.zeros((B_now, 26), dtype=torch.float32)
118
- if cand_frac > 0.0:
119
- import random as _rnd
120
- k = max(1, int(B_now * min(1.0, max(0.0, cand_frac))))
121
- idxs = _rnd.sample(range(B_now), k)
122
- for i in idxs:
123
- L = min(len(env_bc.patterns[i]), max_len)
124
- tried_bits = int(env_bc.tried_mask_bits[i])
125
- cp[i, :] = candidate_letter_probs(L, env_bc.patterns[i], tried_bits, buckets, cand_cache)
126
- tries_norm = (tries_t.float() / float(tries)).unsqueeze(1)
127
- states_list = []
128
- for i in range(B_now):
129
- patt = env_bc.patterns[i]
130
- tried_bits = int(env_bc.tried_mask_bits[i])
131
- L_here = len(patt)
132
- states_list.append((patt, tried_bits, int(tries_t[i].item()), L_here))
133
- ta = teacher_actions_batch(states_list, buckets, teacher_mode, priors, pos_priors, cand_cache)
134
- ta_t = torch.tensor(ta, dtype=torch.long, device=device)
135
- logits, _ = ac(pat_idx.to(device), tried.to(device), lens_t.to(device),
136
- lp.to(device), tries_norm.to(device), cand_priors=cp.to(device))
137
- loss = F.cross_entropy(logits, ta_t)
138
- optimizer.zero_grad(set_to_none=True)
139
- loss.backward()
140
- torch.nn.utils.clip_grad_norm_(ac.parameters(), 1.0)
141
- optimizer.step()
142
- env_bc.step(ta_t.cpu())
143
- env_bc.force_reset_done()
144
-
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
hangman/rl/train_bc.py DELETED
@@ -1,139 +0,0 @@
1
- import argparse
2
- import os
3
- import time
4
-
5
- import torch
6
-
7
- from .models import DuelingQNet
8
- from .replay import Replay, SuccessReplay
9
- from .seed_bc import seed_expert, behavior_cloning
10
- from .eval import greedy_rollout
11
- from .priors import build_length_priors, build_positional_priors, CandCache
12
- from .utils import load_dict, by_len, set_seed
13
-
14
-
15
- def parse_args():
16
- p = argparse.ArgumentParser("Behavior Cloning pretraining for Hangman")
17
- # data
18
- p.add_argument("--dict_path", type=str, default="data/words_250000_train.txt", help="Path to word list")
19
- p.add_argument("--len_lo", type=int, default=4, help="Min word length")
20
- p.add_argument("--len_hi", type=int, default=12, help="Max word length")
21
- p.add_argument("--max_len", type=int, default=35, help="Model max sequence length")
22
- p.add_argument("--tries", type=int, default=6, help="Initial tries for env")
23
- # seeding and BC
24
- p.add_argument("--seed_episodes", type=int, default=5000, help="How many heuristic episodes to seed replay")
25
- p.add_argument("--teacher_mode", type=str, default="igx", choices=["igx", "ig", "cand", "pos", "len"], help="Heuristic teacher policy")
26
- p.add_argument("--replay_cap", type=int, default=200_000, help="Replay buffer capacity")
27
- p.add_argument("--success_cap", type=int, default=50_000, help="Success replay capacity")
28
- p.add_argument("--bc_steps", type=int, default=5000, help="Behavior cloning optimization steps")
29
- p.add_argument("--bc_bs", type=int, default=512, help="Behavior cloning batch size")
30
- p.add_argument("--bc_lr", type=float, default=5e-4, help="Learning rate for BC")
31
- p.add_argument("--bc_wd", type=float, default=1e-4, help="Weight decay for BC")
32
- p.add_argument("--success_frac", type=float, default=0.5, help="Fraction of success samples in BC batches")
33
- p.add_argument("--save_every", type=int, default=500, help="Checkpoint frequency (steps)")
34
- p.add_argument("--out_dir", type=str, default="runs/bc", help="Output directory for checkpoints and logs")
35
- # model
36
- p.add_argument("--d_model", type=int, default=128)
37
- p.add_argument("--nhead", type=int, default=4)
38
- p.add_argument("--nlayers", type=int, default=2)
39
- p.add_argument("--ff_mult", type=int, default=4)
40
- p.add_argument("--dropout", type=float, default=0.1)
41
- p.add_argument("--noisy", action="store_true", help="Use NoisyNet layers in dueling head")
42
- p.add_argument("--device", type=str, default="cuda" if torch.cuda.is_available() else "cpu")
43
- # eval
44
- p.add_argument("--eval_N", type=int, default=2000, help="Episodes for quick greedy eval")
45
- p.add_argument("--eval_stride", type=int, default=256, help="Logging stride for eval")
46
- p.add_argument("--eval_use_cand_priors", action="store_true", help="Fuse candidate priors during eval")
47
- # misc
48
- p.add_argument("--seed", type=int, default=42)
49
- return p.parse_args()
50
-
51
-
52
- def main():
53
- args = parse_args()
54
- set_seed(int(args.seed))
55
-
56
- assert os.path.exists(args.dict_path), f"Dictionary file not found: {args.dict_path}"
57
- words = load_dict(args.dict_path)
58
- buckets = by_len(words, args.len_lo, args.len_hi)
59
- priors = build_length_priors(buckets)
60
- pos_priors = build_positional_priors(buckets, args.max_len)
61
- cand_cache = CandCache(100_000)
62
-
63
- device = torch.device(args.device)
64
-
65
- # replay buffers
66
- replay = Replay(cap=int(args.replay_cap))
67
- success_replay = SuccessReplay(cap=int(args.success_cap))
68
-
69
- # expert seeding
70
- print(f"[setup] Seeding replay with heuristic='{args.teacher_mode}', episodes={args.seed_episodes}…")
71
- seed_expert(
72
- replay,
73
- success_replay,
74
- buckets,
75
- tries=int(args.tries),
76
- episodes=int(args.seed_episodes),
77
- priors=priors,
78
- pos_priors=pos_priors,
79
- cand_cache=cand_cache,
80
- teacher_mode=str(args.teacher_mode),
81
- max_len=int(args.max_len),
82
- )
83
-
84
- # model
85
- model = DuelingQNet(
86
- d_model=int(args.d_model),
87
- nhead=int(args.nhead),
88
- nlayers=int(args.nlayers),
89
- ff_mult=int(args.ff_mult),
90
- max_len=int(args.max_len),
91
- dropout=float(args.dropout),
92
- use_noisy=bool(args.noisy),
93
- ).to(device)
94
-
95
- os.makedirs(args.out_dir, exist_ok=True)
96
- ckpt_path = os.path.join(args.out_dir, "bc_dueling_qnet.pt")
97
-
98
- # behavior cloning
99
- print(f"[train] Starting BC: steps={args.bc_steps}, bs={args.bc_bs}, success_frac={args.success_frac}")
100
- t0 = time.perf_counter()
101
- behavior_cloning(
102
- model,
103
- replay,
104
- success_replay,
105
- priors,
106
- device,
107
- max_len=int(args.max_len),
108
- tries_init=int(args.tries),
109
- steps=int(args.bc_steps),
110
- bs=int(args.bc_bs),
111
- success_frac=float(args.success_frac),
112
- lr=float(args.bc_lr),
113
- wd=float(args.bc_wd),
114
- bc_ckpt_path=ckpt_path,
115
- save_every=int(args.save_every),
116
- )
117
- dt_min = (time.perf_counter() - t0) / 60.0
118
- print(f"[train] BC finished in {dt_min:.2f} min. Checkpoint saved to {ckpt_path}")
119
-
120
- # quick greedy rollout eval
121
- print("[eval] Running greedy rollout eval…")
122
- from .envs import BatchEnv
123
- env = BatchEnv(buckets, int(args.tries), batch=64, len_choices=sorted(buckets.keys()), max_len=int(args.max_len))
124
- env.reset()
125
- wr = greedy_rollout(
126
- env,
127
- model,
128
- device=device,
129
- N=int(args.eval_N),
130
- priors=priors,
131
- log_stride=int(args.eval_stride),
132
- use_cand_priors=bool(args.eval_use_cand_priors),
133
- )
134
- print(f"[done] Greedy eval win-rate={wr:.3f} over N={args.eval_N}")
135
-
136
-
137
- if __name__ == "__main__":
138
- main()
139
-
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
hangman/utils.py DELETED
@@ -1,129 +0,0 @@
1
- import re
2
- import random
3
- import string
4
- import pandas as pd
5
- import torch
6
-
7
- RNG = random.Random(0) # reproducible sampling
8
-
9
- def get_char_mapping():
10
- # include PAD=0, a..z=1..26, _=27
11
- return {'PAD': 0, **{c: i+1 for i, c in enumerate(string.ascii_lowercase)}, '_': 27}
12
-
13
- def read_data():
14
- with open("words_250000_train.txt", "r") as f:
15
- words = [w.strip().lower() for w in f if w.strip()]
16
- # keep pure alphabetic words only
17
- words = [w for w in words if re.fullmatch(r"[a-z]+", w)]
18
- return words
19
-
20
- def create_intermediate_data(words):
21
- x = pd.DataFrame({0: words})
22
- x[1] = x[0].str.len()
23
- x['vowels_present'] = x[0].apply(lambda p: set(p) & {'a','e','i','o','u'})
24
- x['vowels_count'] = x['vowels_present'].str.len()
25
- x['unique_char_count'] = x[0].apply(lambda p: len(set(p)))
26
- # filter: length>3, ≥3 unique chars, has a vowel
27
- x_ = x[~((x['unique_char_count'].isin([0,1,2])) | (x[1] <= 3)) & (x['vowels_count'] != 0)]
28
- return x_
29
-
30
- def loop_for_permutation(unique_letters, word, all_perm, k):
31
- # choose k+1 letters to hide
32
- letters_to_hide = RNG.sample(unique_letters, k+1)
33
- masked = word
34
- for L in letters_to_hide:
35
- masked = masked.replace(L, "_")
36
- all_perm.append(masked)
37
-
38
- def permute_all(word, vowel_permutation_loop=False):
39
- # random subset sampling of letters to hide; keep at least 1 char shown
40
- uniq = list(set(word))
41
- all_perm = []
42
- hi = (len(uniq)-1) if vowel_permutation_loop else (len(uniq)-2)
43
- for k in range(max(0, hi)):
44
- loop_for_permutation(uniq, word, all_perm, k)
45
- return list(set(all_perm))
46
-
47
- def permute_consonents(word):
48
- # keep some vowels only; hide all consonants
49
- vowel_positions = [i for i,ch in enumerate(word) if ch in "aeiou"]
50
- vowels = "".join(word[i] for i in vowel_positions) # only the vowels
51
- perm_vowel_only = permute_all(vowels, vowel_permutation_loop=True)
52
- out = []
53
- for pv in perm_vowel_only:
54
- a = ["_"] * len(word)
55
- for j, ch in enumerate(pv):
56
- a[vowel_positions[j]] = ch
57
- out.append("".join(a))
58
- return out
59
-
60
- def create_masked_dictionary(df_aug):
61
- masked_dictionary = {}
62
- for i, word in df_aug[0].items():
63
- pats = permute_all(word) + permute_consonents(word)
64
- masked_dictionary[word] = list(set(pats))
65
- if i % 10000 == 0:
66
- print(f"Iteration {i} completed")
67
- return masked_dictionary
68
-
69
- def get_vowel_prob(df_vowel, vowel):
70
- if len(df_vowel)==0: return 0.0
71
- return df_vowel[0].str.contains(vowel).mean()
72
-
73
- def get_vowel_prior(df_aug):
74
- prior = {}
75
- max_len = int(df_aug[1].max())
76
- for L in range(1, max_len+1):
77
- df_v = df_aug[df_aug[1] == L]
78
- probs = [get_vowel_prob(df_v, v) for v in "aeiou"]
79
- prior[L] = pd.DataFrame({"vowel": list("aeiou"), "p": probs}).sort_values("p", ascending=False)
80
- return prior
81
-
82
- def save_vowel_prior(vowel_prior, path="prior_probabilities.pkl"):
83
- import pickle
84
- with open(path, "wb") as f:
85
- pickle.dump(vowel_prior, f)
86
-
87
- # ---------- ENCODING (align target to Hangman) ----------
88
-
89
- CMAP = get_char_mapping()
90
-
91
- def idx(c): # a..z
92
- return CMAP[c]-1
93
-
94
- def encode_input(pattern, max_len=35):
95
- vec = [0]*max_len
96
- start = max_len - len(pattern)
97
- for i,ch in enumerate(pattern):
98
- vec[start+i] = CMAP[ch] # 0 is PAD, '_'=27, letters 1..26
99
- return vec
100
-
101
- def encode_output_for_pattern(word, pattern):
102
- """Label only letters that are STILL HIDDEN by this pattern."""
103
- y = [0]*26
104
- for i,ch in enumerate(word):
105
- if pattern[i] == '_':
106
- y[idx(ch)] = 1
107
- return y
108
-
109
- def encode_words(masked_dictionary):
110
- X, Y = [], []
111
- for word, patterns in masked_dictionary.items():
112
- for pat in patterns:
113
- X.append(encode_input(pat))
114
- Y.append(encode_output_for_pattern(word, pat))
115
- return X, Y
116
-
117
- def convert_to_tensor(X, Y):
118
- X = torch.tensor(X, dtype=torch.long)
119
- Y = torch.tensor(Y, dtype=torch.float32) # for BCEWithLogitsLoss
120
- return X, Y
121
-
122
- def get_datasets():
123
- words = read_data()
124
- df_aug = create_intermediate_data(words)
125
- masked_dictionary = create_masked_dictionary(df_aug)
126
- vowel_prior = get_vowel_prior(df_aug)
127
- save_vowel_prior(vowel_prior)
128
- X, Y = encode_words(masked_dictionary)
129
- return convert_to_tensor(X, Y)