Evangelinejy commited on
Commit
dd29cd2
·
verified ·
1 Parent(s): 21a0795

Upload SFT checkpoint: C6p5e18_200m_alpha0.200_beta0.100

Browse files
C6p5e18_200m_alpha0.200_beta0.100/config.json ADDED
@@ -0,0 +1,61 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "architectures": [
3
+ "Qwen3ForCausalLM"
4
+ ],
5
+ "attention_bias": false,
6
+ "attention_dropout": 0.0,
7
+ "bos_token_id": 0,
8
+ "dtype": "float32",
9
+ "eos_token_id": 1,
10
+ "head_dim": 128,
11
+ "hidden_act": "silu",
12
+ "hidden_size": 768,
13
+ "initializer_range": 0.02,
14
+ "intermediate_size": 2304,
15
+ "layer_types": [
16
+ "full_attention",
17
+ "full_attention",
18
+ "full_attention",
19
+ "full_attention",
20
+ "full_attention",
21
+ "full_attention",
22
+ "full_attention",
23
+ "full_attention",
24
+ "full_attention",
25
+ "full_attention",
26
+ "full_attention",
27
+ "full_attention",
28
+ "full_attention",
29
+ "full_attention",
30
+ "full_attention",
31
+ "full_attention",
32
+ "full_attention",
33
+ "full_attention",
34
+ "full_attention",
35
+ "full_attention",
36
+ "full_attention",
37
+ "full_attention",
38
+ "full_attention",
39
+ "full_attention"
40
+ ],
41
+ "max_position_embeddings": 2048,
42
+ "max_window_layers": 24,
43
+ "model_type": "qwen3",
44
+ "num_attention_heads": 12,
45
+ "num_hidden_layers": 24,
46
+ "num_key_value_heads": 4,
47
+ "pad_token_id": 0,
48
+ "rms_norm_eps": 1e-06,
49
+ "rope_scaling": {
50
+ "factor": 2.0,
51
+ "original_max_position_embeddings": 1024,
52
+ "type": "yarn"
53
+ },
54
+ "rope_theta": 1000000,
55
+ "sliding_window": null,
56
+ "tie_word_embeddings": true,
57
+ "transformers_version": "4.57.0",
58
+ "use_cache": true,
59
+ "use_sliding_window": false,
60
+ "vocab_size": 84
61
+ }
C6p5e18_200m_alpha0.200_beta0.100/generation_config.json ADDED
@@ -0,0 +1,7 @@
 
 
 
 
 
 
 
 
1
+ {
2
+ "bos_token_id": 0,
3
+ "do_sample": true,
4
+ "eos_token_id": 1,
5
+ "max_new_tokens": 1024,
6
+ "transformers_version": "4.57.0"
7
+ }
C6p5e18_200m_alpha0.200_beta0.100/model.safetensors ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:6f936fd76d26adac244e4a5dc00501ebc38da25515f68d108597ca9b4d99bcb2
3
+ size 812060488
C6p5e18_200m_alpha0.200_beta0.100/optimizer_states/model.safetensors ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:6f936fd76d26adac244e4a5dc00501ebc38da25515f68d108597ca9b4d99bcb2
3
+ size 812060488
C6p5e18_200m_alpha0.200_beta0.100/optimizer_states/optimizer.bin ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:55a552b5dfa01dc182a847802c16d60d0cb278d335cc5780f91dfa6243b31bee
3
+ size 1624285707
C6p5e18_200m_alpha0.200_beta0.100/optimizer_states/random_states_0.pkl ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:5297d721e1b62cdfd52a4e2b71a431e805830499d92283cfe5d8317dc3e80f50
3
+ size 15017
C6p5e18_200m_alpha0.200_beta0.100/optimizer_states/scheduler.bin ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:1e70999e57e5aa5c5681f05b503f3d5671e0cdba641bc4a5b8b1bcc7a8cecde6
3
+ size 1465
C6p5e18_200m_alpha0.200_beta0.100/optimizer_states/training_state.json ADDED
@@ -0,0 +1 @@
 
 
1
+ {"step": 385, "epoch": 2}
C6p5e18_200m_alpha0.200_beta0.100/special_tokens_map.json ADDED
@@ -0,0 +1,7 @@
 
 
 
 
 
 
 
 
1
+ {
2
+ "bos_token": "<bos>",
3
+ "eos_token": "<eos>",
4
+ "pad_token": "<bos>",
5
+ "unk_token": "<unk>",
6
+ "env_token": null
7
+ }
C6p5e18_200m_alpha0.200_beta0.100/tokenizer.py ADDED
@@ -0,0 +1,818 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Auto-generated self-contained HF tokenizer.
3
+ Do NOT edit manually -- regenerate via training.hf_tokenizer_utils.save_hf_tokenizer().
4
+ """
5
+ from __future__ import annotations
6
+
7
+ # --- BaseTokenizer (inlined) ---
8
+ # base_tokenizer.py
9
+ from abc import ABC, abstractmethod
10
+ from typing import List, Dict, Optional
11
+
12
+ class BaseTokenizer(ABC):
13
+ """Minimal interface for tokenizers used in pretraining."""
14
+
15
+ # ---- required ----
16
+ @abstractmethod
17
+ def encode(self, text: str) -> List[int]:
18
+ """Convert text/PGN to token IDs."""
19
+ raise NotImplementedError
20
+
21
+ @abstractmethod
22
+ def decode(self, ids: List[int]) -> str:
23
+ """Convert token IDs back to text/PGN."""
24
+ raise NotImplementedError
25
+
26
+ @abstractmethod
27
+ def get_vocab(self) -> Dict[str, int]:
28
+ """Return token -> id mapping (if available)."""
29
+ raise NotImplementedError
30
+
31
+ def bos_id(self) -> Optional[int]: return None
32
+ def eos_id(self) -> Optional[int]: return None
33
+ def pad_id(self) -> Optional[int]: return None
34
+ def get_vocab_size(self) -> int: return len(self.get_vocab())
35
+
36
+ def __call__(self, text: str) -> List[int]:
37
+ """Alias for encode()."""
38
+ return self.encode(text)
39
+
40
+ # --- Concrete tokenizer (inlined) ---
41
+ # lan_tokenizer_sft.py
42
+ """
43
+ LAN Tokenizer with SFT support (CoT format with <T> and <sep> tokens).
44
+
45
+ This extends the base LAN tokenizer with SFT-specific functionality:
46
+ - <T> token for marking thinking/CoT content
47
+ - <sep> token for separating prompt from response
48
+ """
49
+ from typing import List, Dict, Optional, Tuple
50
+ import io
51
+ import chess, chess.pgn
52
+ from tokenizers import Tokenizer
53
+ from tokenizers.models import WordLevel
54
+ from tokenizers.pre_tokenizers import WhitespaceSplit
55
+ _RESULT = {"1-0", "0-1", "1/2-1/2", "*"}
56
+ FILES = "abcdefgh"
57
+ RANKS = "12345678"
58
+ SQUARES = [f+r for f in FILES for r in RANKS]
59
+ PROMOS = "QRBN"
60
+ DIGITS = set("0123456789")
61
+
62
+ # SFT special tokens for CoT format
63
+ T_TOKEN = "<T>"
64
+ T_END_TOKEN = "</T>"
65
+ SEP_TOKEN = "<sep>"
66
+
67
+ # Environment interaction / reward special tokens
68
+ CALL_ENV_TOKEN = "<call_env>"
69
+ VERIFY_TOKEN = "<verify>"
70
+ REWARD_POS_TOKEN = "<+1>"
71
+ REWARD_NEG_TOKEN = "<-1>"
72
+ REWARD_ZERO_TOKEN = "<0>"
73
+ ENV_TOKENS = [CALL_ENV_TOKEN]
74
+ REWARD_TOKENS = [VERIFY_TOKEN, REWARD_POS_TOKEN, REWARD_NEG_TOKEN, REWARD_ZERO_TOKEN]
75
+
76
+ def _vocab_with_sft(
77
+ include_move_numbers: bool,
78
+ keep_result: bool,
79
+ bos: str,
80
+ eos: str,
81
+ unk: str,
82
+ include_env_tokens: bool = False,
83
+ include_reward_tokens: bool = False,
84
+ ) -> Dict[str, int]:
85
+ """Create vocabulary including SFT special tokens."""
86
+ base = [bos, eos, unk]
87
+ ops = ["x", "=", "+", "#", "O-O", "O-O-O", ".", "..."]
88
+ toks = base + list("KQRBNP") + SQUARES + list(PROMOS) + ops
89
+ if include_move_numbers:
90
+ toks += list("0123456789")
91
+ if keep_result:
92
+ toks += sorted(_RESULT)
93
+
94
+ # Add SFT special tokens for CoT format
95
+ sft_tokens = [T_TOKEN, T_END_TOKEN, SEP_TOKEN]
96
+ toks += sft_tokens
97
+
98
+ # Add environment / reward tokens when requested
99
+ if include_env_tokens:
100
+ toks += ENV_TOKENS
101
+ if include_reward_tokens:
102
+ toks += REWARD_TOKENS
103
+
104
+ return {t: i for i, t in enumerate(dict.fromkeys(toks))}
105
+
106
+
107
+ class LanTokenizerSFT(BaseTokenizer):
108
+ """
109
+ LAN Tokenizer with SFT capabilities.
110
+
111
+ This tokenizer extends the base LAN tokenizer with:
112
+ - <T> token for marking thinking/CoT boundaries
113
+ - <sep> token for separating candidate trajectories
114
+
115
+ CoT Format: {prompt} <T> <sep> {traj1} <sep> {traj2} <sep> ... <sep> {trajN} <sep> <T> {answer}
116
+
117
+ Where:
118
+ - {prompt}: The game history/board state (PGN moves)
119
+ - {trajN}: Candidate reasoning trajectories
120
+ - {answer}: The final best move
121
+ """
122
+
123
+ # Special tokens for CoT format
124
+ T = T_TOKEN
125
+ T_END = T_END_TOKEN
126
+ SEP = SEP_TOKEN
127
+
128
+ # Environment / reward tokens (class-level constants for easy access)
129
+ CALL_ENV = CALL_ENV_TOKEN # "<call_env>"
130
+ VERIFY = VERIFY_TOKEN # "<verify>"
131
+ REWARD_POS = REWARD_POS_TOKEN # "<+1>"
132
+ REWARD_NEG = REWARD_NEG_TOKEN # "<-1>"
133
+ REWARD_ZERO = REWARD_ZERO_TOKEN # "<0>"
134
+ ENV_TOKENS = ENV_TOKENS # full list
135
+
136
+ def __init__(self, config: Optional[dict] = None):
137
+ """
138
+ Args:
139
+ config: Configuration dict with tokenizer settings.
140
+ include_env_tokens (bool): add <call_env>, <verify>, <+1>, <-1>, <0>
141
+ to the vocabulary. Default: False.
142
+ """
143
+ config = config or {}
144
+
145
+ include_move_numbers = config.get("include_move_numbers", False)
146
+ include_black_tripledots = config.get("include_black_tripledots", False)
147
+ bos = config.get("bos", "<bos>")
148
+ eos = config.get("eos", "<eos>")
149
+ unk = config.get("unk", "<unk>")
150
+ keep_result = config.get("keep_result", False)
151
+ include_env_tokens = config.get("include_env_tokens", False)
152
+ include_reward_tokens = config.get("include_reward_tokens", False)
153
+
154
+ self._bos = bos
155
+ self._eos = eos
156
+ self._unk = unk
157
+ self._keep_res = keep_result
158
+ self._include_nums = include_move_numbers
159
+ self._include_black_ellipses = include_black_tripledots
160
+ self._include_env_tokens = include_env_tokens
161
+ self._include_reward_tokens = include_reward_tokens
162
+
163
+ # Create vocabulary with SFT tokens
164
+ tok2id = _vocab_with_sft(
165
+ include_move_numbers, keep_result, bos, eos, unk,
166
+ include_env_tokens=include_env_tokens,
167
+ include_reward_tokens=include_reward_tokens,
168
+ )
169
+ self._tok2id = tok2id
170
+
171
+ # Initialize tokenizer
172
+ self.tk = Tokenizer(WordLevel(vocab=tok2id, unk_token=self._unk))
173
+ self.tk.pre_tokenizer = WhitespaceSplit()
174
+
175
+ def _pgn_to_tokens(self, text: str) -> Optional[List[str]]:
176
+ """Convert PGN text to tokens."""
177
+ import os, contextlib
178
+ with open(os.devnull, "w") as devnull, contextlib.redirect_stderr(devnull):
179
+ g = chess.pgn.read_game(io.StringIO(text))
180
+ if g is None:
181
+ return None
182
+
183
+ b, out, n = g.board(), [], 1
184
+ for mv in g.mainline_moves():
185
+ if b.turn == chess.WHITE and self._include_nums:
186
+ out += list(str(n)) + (
187
+ ["..."] if self._include_black_ellipses and b.fullmove_number < n else ["."]
188
+ )
189
+
190
+ if b.is_castling(mv):
191
+ b.push(mv)
192
+ suf = "#" if b.is_checkmate() else ("+" if b.is_check() else "")
193
+ b.pop()
194
+ out.append("O-O" if chess.square_file(mv.to_square) == 6 else "O-O-O")
195
+ if suf:
196
+ out.append(suf)
197
+ b.push(mv)
198
+ else:
199
+ piece = b.piece_at(mv.from_square).symbol().upper()
200
+ frm = chess.square_name(mv.from_square)
201
+ to = chess.square_name(mv.to_square)
202
+ is_cap = b.is_capture(mv)
203
+ promo = mv.promotion
204
+
205
+ b.push(mv)
206
+ suf = "#" if b.is_checkmate() else ("+" if b.is_check() else "")
207
+
208
+ # Emit LAN tokens
209
+ out.append(piece)
210
+ out.append(frm)
211
+ if is_cap:
212
+ out.append("x")
213
+ out.append(to)
214
+ if promo:
215
+ out += ["=", chess.piece_symbol(promo).upper()]
216
+ if suf:
217
+ out.append(suf)
218
+
219
+ if b.turn == chess.WHITE:
220
+ n += 1
221
+
222
+ res = g.headers.get("Result")
223
+ if self._keep_res and res in _RESULT:
224
+ out.append(res)
225
+
226
+ return out
227
+
228
+ def _lan_move_to_tokens(self, move: str) -> List[str]:
229
+ """
230
+ Convert a single LAN move to tokens.
231
+
232
+ LAN format: [Piece][from_square][x]?[to_square][=Promo]?[+#]?
233
+
234
+ Examples:
235
+ "Ng1f3" -> ["N", "g1", "f3"]
236
+ "Nd4xe6" -> ["N", "d4", "x", "e6"]
237
+ "Pe2e4" -> ["P", "e2", "e4"]
238
+ "Pe4xd5" -> ["P", "e4", "x", "d5"]
239
+ "O-O" -> ["O-O"]
240
+ "O-O-O" -> ["O-O-O"]
241
+ "Pe7e8=Q" -> ["P", "e7", "e8", "=", "Q"]
242
+ "Ng1f3+" -> ["N", "g1", "f3", "+"]
243
+ """
244
+ # Handle castling
245
+ if move in {"O-O", "O-O-O"}:
246
+ return [move]
247
+ if move.rstrip("+#") in {"O-O", "O-O-O"}:
248
+ base = move.rstrip("+#")
249
+ suffix = move[len(base):]
250
+ return [base] + ([suffix] if suffix else [])
251
+
252
+ out = []
253
+ i = 0
254
+ n = len(move)
255
+
256
+ # Get piece letter (required in LAN format)
257
+ if i < n and move[i] in "KQRBNP":
258
+ out.append(move[i])
259
+ i += 1
260
+ else:
261
+ # No piece letter - might be malformed, return as-is
262
+ return [move]
263
+
264
+ # Get from square (required in LAN format)
265
+ if i + 1 < n and move[i] in FILES and move[i + 1] in RANKS:
266
+ out.append(move[i:i+2])
267
+ i += 2
268
+
269
+ # Handle capture
270
+ if i < n and move[i] == "x":
271
+ out.append("x")
272
+ i += 1
273
+
274
+ # Get to square (required in LAN format)
275
+ if i + 1 < n and move[i] in FILES and move[i + 1] in RANKS:
276
+ out.append(move[i:i+2])
277
+ i += 2
278
+
279
+ # Handle promotion
280
+ if i < n and move[i] == "=":
281
+ out.append("=")
282
+ i += 1
283
+ if i < n and move[i] in PROMOS:
284
+ out.append(move[i])
285
+ i += 1
286
+
287
+ # Handle check/checkmate
288
+ if i < n and move[i] in "+#":
289
+ out.append(move[i])
290
+ i += 1
291
+
292
+ return out
293
+
294
+ def _active_env_tokens(self) -> set:
295
+ """Return the set of env tokens that are active for this instance."""
296
+ return set(ENV_TOKENS) if self._include_env_tokens else set()
297
+
298
+ def _cot_to_tokens(self, text: str) -> List[str]:
299
+ """
300
+ Convert CoT formatted text to tokens.
301
+ Handles special tokens and LAN moves.
302
+ """
303
+ env_toks = self._active_env_tokens()
304
+ out = []
305
+ for token in text.split():
306
+ if token in {self.T, self.T_END, self.SEP} or token in env_toks:
307
+ # Keep special tokens as-is
308
+ out.append(token)
309
+ elif token in _RESULT:
310
+ # Game result
311
+ out.append(token)
312
+ elif token and token[0].isdigit() and "." in token:
313
+ # Move number like "1." or "15..."
314
+ # Split into digits and dots
315
+ num_part = token.rstrip(".")
316
+ dot_part = token[len(num_part):]
317
+ out.extend(list(num_part))
318
+ if dot_part:
319
+ out.append("..." if len(dot_part) > 1 else ".")
320
+ elif token and all(c.isdigit() for c in token):
321
+ # Pure number - tokenize each digit
322
+ out.extend(list(token))
323
+ else:
324
+ # LAN move - tokenize it
325
+ out.extend(self._lan_move_to_tokens(token))
326
+ return out
327
+
328
+ def encode(self, text: str) -> List[int]:
329
+ """
330
+ Encode text to token IDs.
331
+
332
+ Args:
333
+ text: Text to encode (can be PGN or CoT formatted)
334
+
335
+ Returns:
336
+ List of token IDs
337
+ """
338
+ # Check if this is CoT-formatted text (contains special tokens)
339
+ sft_special = (
340
+ [self.T, self.T_END, self.SEP]
341
+ + (ENV_TOKENS if self._include_env_tokens else [])
342
+ )
343
+ is_cot_format = any(token in text for token in sft_special)
344
+
345
+ if is_cot_format:
346
+ t_idx = text.index(self.T)
347
+ prompt_part = text[:t_idx].strip()
348
+ rest_part = text[t_idx:] # starts with <T>
349
+
350
+ pgn_tokens = self._pgn_to_tokens(prompt_part) if prompt_part else None
351
+ if pgn_tokens is None:
352
+ pgn_tokens = self._cot_to_tokens(prompt_part) if prompt_part else []
353
+ rest_tokens = self._cot_to_tokens(rest_part)
354
+ tokens = [self._bos] + pgn_tokens + rest_tokens + [self._eos]
355
+ else:
356
+ pgn_tokens = self._pgn_to_tokens(text)
357
+ if pgn_tokens is not None and len(pgn_tokens) > 0:
358
+ tokens = [self._bos] + pgn_tokens + [self._eos]
359
+ else:
360
+ # Not valid PGN — treat each word as a LAN move
361
+ lan_tokens = []
362
+ for word in text.split():
363
+ lan_tokens.extend(self._lan_move_to_tokens(word))
364
+ tokens = [self._bos] + lan_tokens + [self._eos]
365
+
366
+ return self.tk.encode(" ".join(tokens)).ids
367
+
368
+ def decode(self, ids: List[int]) -> str:
369
+ """
370
+ Decode token IDs to text.
371
+
372
+ Args:
373
+ ids: List of token IDs
374
+
375
+ Returns:
376
+ Decoded text
377
+ """
378
+ toks = [t for t in self.tk.decode(ids).split() if t not in {self._bos, self._eos}]
379
+
380
+ # Otherwise, use LAN decoding logic
381
+ out: List[str] = []
382
+ i, n = 0, len(toks)
383
+
384
+ while i < n:
385
+ t = toks[i]
386
+
387
+ if t in {self.T, self.T_END, self.SEP} or t in _RESULT or t in self._active_env_tokens():
388
+ out.append(t)
389
+ i += 1
390
+ continue
391
+
392
+ if t and all(ch in DIGITS for ch in t):
393
+ j = i
394
+ num = []
395
+ while j < n and all(ch in DIGITS for ch in toks[j]):
396
+ num.append(toks[j])
397
+ j += 1
398
+ dots = ""
399
+ if j < n and toks[j] in {".", "..."}:
400
+ dots = toks[j]
401
+ j += 1
402
+ out.append("".join(num) + dots)
403
+ i = j
404
+ continue
405
+
406
+ if t in {"O-O", "O-O-O"}:
407
+ j = i + 1
408
+ suf = toks[j] if j < n and toks[j] in {"+", "#"} else ""
409
+ if suf:
410
+ j += 1
411
+ out.append(t + suf)
412
+ i = j
413
+ continue
414
+
415
+ if t in set("KQRBNP"):
416
+ piece = t
417
+ j = i + 1
418
+ frm = toks[j] if j < n else ""
419
+ j += 1
420
+ cap = ""
421
+ if j < n and toks[j] == "x":
422
+ cap = "x"
423
+ j += 1
424
+ to = toks[j] if j < n else ""
425
+ j += 1
426
+ promo = ""
427
+ if j + 1 <= n - 1 and toks[j] == "=" and toks[j + 1] in set(PROMOS):
428
+ promo = "=" + toks[j + 1]
429
+ j += 2
430
+ suf = ""
431
+ if j < n and toks[j] in {"+", "#"}:
432
+ suf = toks[j]
433
+ j += 1
434
+ lan = f"{piece}{frm}{cap}{to}{promo}{suf}"
435
+ out.append(lan)
436
+ i = j
437
+ continue
438
+
439
+ out.append(t)
440
+ i += 1
441
+
442
+ return " ".join(out)
443
+
444
+ def get_vocab(self) -> Dict[str, int]:
445
+ """Get token-to-id vocabulary mapping."""
446
+ return self._tok2id
447
+
448
+ def bos_id(self) -> Optional[int]:
449
+ """Get BOS token ID."""
450
+ return self._tok2id[self._bos]
451
+
452
+ def eos_id(self) -> Optional[int]:
453
+ """Get EOS token ID."""
454
+ return self._tok2id[self._eos]
455
+
456
+ def pad_id(self) -> Optional[int]:
457
+ """Get PAD token ID (uses BOS as pad by default)."""
458
+ return self._tok2id.get("<pad>", self.bos_id())
459
+
460
+ def get_vocab_size(self) -> int:
461
+ """Get vocabulary size."""
462
+ return len(self._tok2id)
463
+
464
+ def t_id(self) -> int:
465
+ """Get <T> token ID."""
466
+ return self._tok2id[self.T]
467
+
468
+ def sep_id(self) -> int:
469
+ """Get <sep> token ID."""
470
+ return self._tok2id[self.SEP]
471
+
472
+ def t_end_id(self) -> int:
473
+ """Get </T> token ID."""
474
+ return self._tok2id[self.T_END]
475
+
476
+ # ------------------------------------------------------------------
477
+ # Environment / reward token accessors
478
+ # ------------------------------------------------------------------
479
+
480
+ def _require_env_tokens(self) -> None:
481
+ if not self._include_env_tokens:
482
+ raise ValueError(
483
+ "Environment tokens are not enabled. "
484
+ "Pass include_env_tokens=True in the config."
485
+ )
486
+
487
+ def call_env_id(self) -> int:
488
+ """Get <call_env> token ID."""
489
+ self._require_env_tokens()
490
+ return self._tok2id[CALL_ENV_TOKEN]
491
+
492
+ def verify_id(self) -> int:
493
+ """Get <verify> token ID."""
494
+ self._require_env_tokens()
495
+ return self._tok2id[VERIFY_TOKEN]
496
+
497
+ def reward_pos_id(self) -> int:
498
+ """Get <+1> (positive reward) token ID."""
499
+ self._require_env_tokens()
500
+ return self._tok2id[REWARD_POS_TOKEN]
501
+
502
+ def reward_neg_id(self) -> int:
503
+ """Get <-1> (negative reward) token ID."""
504
+ self._require_env_tokens()
505
+ return self._tok2id[REWARD_NEG_TOKEN]
506
+
507
+ def reward_zero_id(self) -> int:
508
+ """Get <0> (zero reward) token ID."""
509
+ self._require_env_tokens()
510
+ return self._tok2id[REWARD_ZERO_TOKEN]
511
+
512
+ def reward_id(self, value) -> int:
513
+ """
514
+ Get reward token ID by numeric value.
515
+
516
+ Args:
517
+ value: 1, -1, or 0 (or the strings "+1", "-1", "0")
518
+
519
+ Returns:
520
+ Token ID for the corresponding reward token.
521
+ """
522
+ self._require_env_tokens()
523
+ mapping = {1: REWARD_POS_TOKEN, -1: REWARD_NEG_TOKEN, 0: REWARD_ZERO_TOKEN,
524
+ "+1": REWARD_POS_TOKEN, "-1": REWARD_NEG_TOKEN, "0": REWARD_ZERO_TOKEN}
525
+ if value not in mapping:
526
+ raise ValueError(f"reward value must be one of 1, -1, 0 (or '+1', '-1', '0'), got {value!r}")
527
+ return self._tok2id[mapping[value]]
528
+
529
+ def env_token_ids(self) -> Dict[str, int]:
530
+ """Get mapping of all env/reward special tokens to their IDs."""
531
+ self._require_env_tokens()
532
+ return {tok: self._tok2id[tok] for tok in ENV_TOKENS}
533
+
534
+ def extract_parts(self, text: str) -> Tuple[Optional[str], Optional[List[str]], str]:
535
+ """
536
+ Extract prompt, trajectories and answer from BoN CoT formatted text.
537
+
538
+ Args:
539
+ text: Text in format: {prompt} <T> <sep> {traj1} <sep> ... <sep> <T> {answer}
540
+
541
+ Returns:
542
+ prompt: The prompt/context (or None if not present)
543
+ trajectories: List of trajectory strings (or None if not present)
544
+ answer: The final answer
545
+ """
546
+ if self.T not in text:
547
+ return None, None, text
548
+
549
+ try:
550
+ # Split by <T> to get prompt, thinking section, and answer
551
+ t_parts = text.split(self.T)
552
+ if len(t_parts) < 3:
553
+ return None, None, text
554
+
555
+ # t_parts[0] is prompt (before first <T>)
556
+ # t_parts[1] is the thinking section with trajectories
557
+ # t_parts[2] is the answer
558
+ prompt = t_parts[0].strip() if t_parts[0].strip() else None
559
+ thinking_section = t_parts[1].strip()
560
+ answer = t_parts[2].strip()
561
+
562
+ # Split thinking section by <sep> to get trajectories
563
+ trajectories = [t.strip() for t in thinking_section.split(self.SEP) if t.strip()]
564
+
565
+ return prompt, trajectories, answer
566
+ except (ValueError, IndexError):
567
+ return None, None, text
568
+
569
+ def extract_thinking_and_answer(self, text: str) -> Tuple[Optional[List[str]], str]:
570
+ """
571
+ Extract trajectories and answer from BoN CoT formatted text (ignores prompt).
572
+
573
+ Args:
574
+ text: Text in format: {prompt} <T> <sep> {traj1} <sep> ... <sep> <T> {answer}
575
+
576
+ Returns:
577
+ trajectories: List of trajectory strings (or None if not present)
578
+ answer: The final answer
579
+ """
580
+ _, trajectories, answer = self.extract_parts(text)
581
+ return trajectories, answer
582
+
583
+ def get_sft_special_tokens(self) -> List[str]:
584
+ """Get list of SFT special tokens (including env/reward tokens if enabled)."""
585
+ toks = [self.T, self.T_END, self.SEP]
586
+ if self._include_env_tokens:
587
+ toks += ENV_TOKENS
588
+ return toks
589
+
590
+ def get_sft_token_ids(self) -> Dict[str, int]:
591
+ """Get mapping of SFT special tokens to their IDs."""
592
+ result = {
593
+ self.T: self._tok2id[self.T],
594
+ self.T_END: self._tok2id[self.T_END],
595
+ self.SEP: self._tok2id[self.SEP],
596
+ }
597
+ if self._include_env_tokens:
598
+ for tok in ENV_TOKENS:
599
+ result[tok] = self._tok2id[tok]
600
+ return result
601
+
602
+ def parse_cot_line(self, line: str) -> Tuple[Optional[List[str]], Optional[str]]:
603
+ """
604
+ Parse a CoT data line in format: <T> <sep> ... <sep> <T> {answer}
605
+
606
+ Args:
607
+ line: A line from the CoT data file
608
+
609
+ Returns:
610
+ trajectories: List of trajectory strings
611
+ answer: The final answer/move
612
+ """
613
+ line = line.strip()
614
+ if not line or not line.startswith(self.T):
615
+ return None, None
616
+
617
+ return self.extract_thinking_and_answer(line)
618
+
619
+ # ============================================================
620
+ # HuggingFace-compatible wrapper (auto-generated)
621
+ # ============================================================
622
+ import json as _json
623
+ from pathlib import Path as _Path
624
+ from transformers import PreTrainedTokenizer
625
+ import torch
626
+ from transformers.tokenization_utils_base import BatchEncoding
627
+
628
+ from huggingface_hub import hf_hub_download
629
+
630
+ class HFTokenizerWrapper(PreTrainedTokenizer):
631
+ def __init__(self, model_max_length=2048, **kwargs):
632
+ # These are usually provided by from_pretrained
633
+ repo_id = kwargs.get("name_or_path") or kwargs.get("_name_or_path")
634
+ revision = kwargs.get("revision", None)
635
+
636
+ if not repo_id or "/" not in str(repo_id):
637
+ # Fallback: user may pass repo_id explicitly
638
+ repo_id = kwargs.get("repo_id", None)
639
+ if not repo_id:
640
+ raise ValueError("Cannot infer repo_id; pass repo_id=... or ensure name_or_path is set.")
641
+
642
+ import os
643
+ if os.path.isdir(repo_id):
644
+ vocab_path = os.path.join(repo_id, "vocab.json")
645
+ cfg_path = os.path.join(repo_id, "tokenizer_config.json")
646
+ else:
647
+ vocab_path = hf_hub_download(repo_id=repo_id, filename="vocab.json", revision=revision)
648
+ cfg_path = hf_hub_download(repo_id=repo_id, filename="tokenizer_config.json", revision=revision)
649
+
650
+ with open(vocab_path, "r", encoding="utf-8") as _f:
651
+ saved_vocab = _json.load(_f)
652
+ with open(cfg_path, "r", encoding="utf-8") as _f:
653
+ _tok_cfg = _json.load(_f)
654
+
655
+ lan_config = _tok_cfg.get("lan_config", {})
656
+ lan_class_name = _tok_cfg.get("lan_tokenizer_class", "LanTokenizerSFT")
657
+
658
+ _cls = globals()[lan_class_name]
659
+ custom_tokenizer = _cls(config=lan_config)
660
+
661
+ # Override vocab with the saved vocab
662
+ custom_tokenizer._tok2id = saved_vocab
663
+ from tokenizers import Tokenizer as _TkTokenizer
664
+ from tokenizers.models import WordLevel as _WordLevel
665
+ from tokenizers.pre_tokenizers import WhitespaceSplit as _WhitespaceSplit
666
+ custom_tokenizer.tk = _TkTokenizer(_WordLevel(vocab=saved_vocab, unk_token=custom_tokenizer._unk))
667
+ custom_tokenizer.tk.pre_tokenizer = _WhitespaceSplit()
668
+
669
+ self.custom_tokenizer = custom_tokenizer
670
+ self._vocab = dict(saved_vocab)
671
+ self._id_to_token = {i: t for t, i in self._vocab.items()}
672
+
673
+ bos_token = _tok_cfg.get("bos_token")
674
+ eos_token = _tok_cfg.get("eos_token")
675
+ pad_token = _tok_cfg.get("pad_token")
676
+ unk_token = _tok_cfg.get("unk_token")
677
+ env_token = _tok_cfg.get("env_token")
678
+ if "env_id" in _tok_cfg:
679
+ env_token = self._id_to_token[_tok_cfg.get("env_id")]
680
+ else:
681
+ env_token = _tok_cfg.get("env_token")
682
+ self.env_token = env_token
683
+
684
+ for _key in ("bos_token","eos_token","pad_token","unk_token","env_token",
685
+ "model_max_length","name_or_path","lan_config",
686
+ "lan_tokenizer_class","tokenizer_class","auto_map","use_fast",
687
+ "revision","repo_id"):
688
+ kwargs.pop(_key, None)
689
+
690
+ super().__init__(
691
+ bos_token=bos_token,
692
+ eos_token=eos_token,
693
+ pad_token=pad_token,
694
+ unk_token=unk_token,
695
+ model_max_length=model_max_length,
696
+ **kwargs,
697
+ )
698
+
699
+ # ---- PreTrainedTokenizer interface ----
700
+
701
+ @property
702
+ def vocab_size(self):
703
+ return len(self._vocab)
704
+
705
+ def get_vocab(self):
706
+ return dict(self._vocab)
707
+
708
+ def _tokenize(self, text):
709
+ return [] # we override encode/decode directly
710
+
711
+ def _convert_token_to_id(self, token):
712
+ return self._vocab.get(token, self._vocab.get(self.unk_token, 0))
713
+
714
+ def _convert_id_to_token(self, index):
715
+ return self._id_to_token.get(index, self.unk_token or "")
716
+
717
+ def convert_tokens_to_string(self, tokens):
718
+ ids = [self._convert_token_to_id(t) for t in tokens]
719
+ return self.custom_tokenizer.decode(ids)
720
+
721
+ def build_inputs_with_special_tokens(self, token_ids_0, token_ids_1=None):
722
+ if token_ids_1 is None:
723
+ return token_ids_0
724
+ return token_ids_0 + token_ids_1
725
+
726
+ def encode(self, text, add_special_tokens=True, **kwargs):
727
+ ids = self.custom_tokenizer.encode(text)
728
+ if add_special_tokens:
729
+ return ids[:-1] # strip trailing EOS; vLLM adds its own
730
+ if (len(ids) >= 2
731
+ and self.bos_token_id is not None
732
+ and self.eos_token_id is not None
733
+ and ids[0] == self.bos_token_id
734
+ and ids[-1] == self.eos_token_id):
735
+ return ids[1:-1]
736
+ return ids
737
+
738
+ def decode(self, token_ids, skip_special_tokens=True, **kwargs):
739
+ import numpy as np
740
+ if isinstance(token_ids, torch.Tensor):
741
+ token_ids = token_ids.detach().cpu().tolist()
742
+ elif isinstance(token_ids, np.ndarray):
743
+ token_ids = token_ids.tolist()
744
+ return self.custom_tokenizer.decode(token_ids)
745
+
746
+ def save_vocabulary(self, save_directory, filename_prefix=None):
747
+ save_directory = _Path(save_directory)
748
+ save_directory.mkdir(parents=True, exist_ok=True)
749
+ vocab_file = save_directory / (
750
+ (filename_prefix + "-" if filename_prefix else "") + "vocab.json"
751
+ )
752
+ with open(vocab_file, "w", encoding="utf-8") as f:
753
+ _json.dump(self._vocab, f, ensure_ascii=False, indent=2)
754
+ return (str(vocab_file),)
755
+
756
+ def __call__(
757
+ self,
758
+ text,
759
+ text_pair=None,
760
+ add_special_tokens=True,
761
+ truncation=False,
762
+ max_length=None,
763
+ padding=False,
764
+ return_tensors=None,
765
+ **kwargs,
766
+ ):
767
+ if text_pair is not None:
768
+ raise ValueError("text_pair not supported for this tokenizer.")
769
+
770
+ # Normalize to batch
771
+ is_batched = isinstance(text, (list, tuple))
772
+ texts = list(text) if is_batched else [text]
773
+
774
+ input_ids = [self.encode(t, add_special_tokens=add_special_tokens) for t in texts]
775
+
776
+ # Truncation
777
+ if truncation and max_length is not None:
778
+ if self.truncation_side == "left":
779
+ input_ids = [ids[-max_length:] for ids in input_ids]
780
+ else:
781
+ input_ids = [ids[:max_length] for ids in input_ids]
782
+
783
+ # Attention masks (pre-padding)
784
+ attention_mask = [[1] * len(ids) for ids in input_ids]
785
+
786
+ # Padding
787
+ if padding:
788
+ if padding == "max_length":
789
+ if max_length is None:
790
+ raise ValueError("padding='max_length' requires max_length.")
791
+ pad_to = max_length
792
+ else:
793
+ pad_to = max(len(ids) for ids in input_ids) if input_ids else 0
794
+
795
+ pad_id = self.pad_token_id
796
+ if pad_id is None:
797
+ pad_id = self.bos_token_id if self.bos_token_id is not None else 0
798
+
799
+ for i, ids in enumerate(input_ids):
800
+ pad_len = pad_to - len(ids)
801
+ if pad_len > 0:
802
+ input_ids[i] = ids + [pad_id] * pad_len
803
+ attention_mask[i] = attention_mask[i] + [0] * pad_len
804
+
805
+ data = {"input_ids": input_ids, "attention_mask": attention_mask}
806
+
807
+ # Unbatch if single example and no tensor return
808
+ if not is_batched and return_tensors is None:
809
+ data = {"input_ids": data["input_ids"][0], "attention_mask": data["attention_mask"][0]}
810
+
811
+ # Tensors
812
+ if return_tensors == "pt":
813
+ data = {k: torch.tensor(v, dtype=torch.long) for k, v in data.items()}
814
+
815
+ return BatchEncoding(data, tensor_type=None)
816
+
817
+
818
+ __all__ = ["HFTokenizerWrapper"]
C6p5e18_200m_alpha0.200_beta0.100/tokenizer_config.json ADDED
@@ -0,0 +1,26 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "tokenizer_class": "HFTokenizerWrapper",
3
+ "auto_map": {
4
+ "AutoTokenizer": [
5
+ "tokenizer.HFTokenizerWrapper",
6
+ null
7
+ ]
8
+ },
9
+ "model_max_length": 2048,
10
+ "bos_token": "<bos>",
11
+ "eos_token": "<eos>",
12
+ "pad_token": "<bos>",
13
+ "unk_token": "<unk>",
14
+ "env_token": null,
15
+ "use_fast": false,
16
+ "lan_config": {
17
+ "name": "LanTokenizerSFT",
18
+ "include_move_numbers": false,
19
+ "include_black_tripledots": false,
20
+ "bos": "<bos>",
21
+ "eos": "<eos>",
22
+ "unk": "<unk>",
23
+ "keep_result": false
24
+ },
25
+ "lan_tokenizer_class": "LanTokenizerSFT"
26
+ }
C6p5e18_200m_alpha0.200_beta0.100/training_state.json ADDED
@@ -0,0 +1 @@
 
 
1
+ {"step": 385, "epoch": 2}
C6p5e18_200m_alpha0.200_beta0.100/vocab.json ADDED
@@ -0,0 +1,86 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "<bos>": 0,
3
+ "<eos>": 1,
4
+ "<unk>": 2,
5
+ "K": 3,
6
+ "Q": 4,
7
+ "R": 5,
8
+ "B": 6,
9
+ "N": 7,
10
+ "P": 8,
11
+ "a1": 9,
12
+ "a2": 10,
13
+ "a3": 11,
14
+ "a4": 12,
15
+ "a5": 13,
16
+ "a6": 14,
17
+ "a7": 15,
18
+ "a8": 16,
19
+ "b1": 17,
20
+ "b2": 18,
21
+ "b3": 19,
22
+ "b4": 20,
23
+ "b5": 21,
24
+ "b6": 22,
25
+ "b7": 23,
26
+ "b8": 24,
27
+ "c1": 25,
28
+ "c2": 26,
29
+ "c3": 27,
30
+ "c4": 28,
31
+ "c5": 29,
32
+ "c6": 30,
33
+ "c7": 31,
34
+ "c8": 32,
35
+ "d1": 33,
36
+ "d2": 34,
37
+ "d3": 35,
38
+ "d4": 36,
39
+ "d5": 37,
40
+ "d6": 38,
41
+ "d7": 39,
42
+ "d8": 40,
43
+ "e1": 41,
44
+ "e2": 42,
45
+ "e3": 43,
46
+ "e4": 44,
47
+ "e5": 45,
48
+ "e6": 46,
49
+ "e7": 47,
50
+ "e8": 48,
51
+ "f1": 49,
52
+ "f2": 50,
53
+ "f3": 51,
54
+ "f4": 52,
55
+ "f5": 53,
56
+ "f6": 54,
57
+ "f7": 55,
58
+ "f8": 56,
59
+ "g1": 57,
60
+ "g2": 58,
61
+ "g3": 59,
62
+ "g4": 60,
63
+ "g5": 61,
64
+ "g6": 62,
65
+ "g7": 63,
66
+ "g8": 64,
67
+ "h1": 65,
68
+ "h2": 66,
69
+ "h3": 67,
70
+ "h4": 68,
71
+ "h5": 69,
72
+ "h6": 70,
73
+ "h7": 71,
74
+ "h8": 72,
75
+ "x": 73,
76
+ "=": 74,
77
+ "+": 75,
78
+ "#": 76,
79
+ "O-O": 77,
80
+ "O-O-O": 78,
81
+ ".": 79,
82
+ "...": 80,
83
+ "<T>": 81,
84
+ "</T>": 82,
85
+ "<sep>": 83
86
+ }