| """ |
| One-time data preparation for Rubik's-cube autoresearch experiments. |
| |
| This version replaces the generic text corpus with a synthetic cube-policy |
| dataset built from canonicalized NxN cube states and structured move tokens. |
| |
| Usage: |
| python prepare.py |
| python prepare.py --force |
| |
| Artifacts are stored in ~/.cache/autoresearch-rubiks/. |
| """ |
|
|
| from __future__ import annotations |
|
|
| import argparse |
| import hashlib |
| import json |
| import os |
| import pickle |
| import random |
| from collections import defaultdict |
| from contextlib import nullcontext |
|
|
| import torch |
|
|
| from rubiks import ( |
| Cube, |
| Episode, |
| Move, |
| build_answer_tokens, |
| build_prompt_tokens, |
| build_training_examples_from_solution, |
| parse_answer_tokens, |
| random_scramble, |
| scramble_length_for_size, |
| ) |
| from teacher_dwalton import solve_cube_222, solve_cube_333 |
|
|
| |
| |
| |
|
|
| MAX_SEQ_LEN = 72 |
| TIME_BUDGET = 10800 |
| TEACHER_BACKEND = "dwalton76/rubiks-cube-NxNxN-solver" |
| PROMPT_FORMAT_VERSION = "flat24-history3-jointmove-v1" |
|
|
| TRAIN_SIZES = (2, 3) |
| ID_VAL_SIZES = TRAIN_SIZES |
| OOD_DEV_SIZES = () |
| OOD_TEST_SIZES = () |
|
|
| TRAIN_EPISODES_PER_SIZE = 65536 |
| _TRAIN_EPISODES_OVERRIDE = {} |
| ID_VAL_EPISODES_PER_SIZE = 256 |
| OOD_DEV_EPISODES_PER_SIZE = 0 |
| OOD_TEST_EPISODES_PER_SIZE = 0 |
|
|
| MAX_MOVE_DEPTH = 2 |
| MAX_MOVE_WIDTH = 2 |
| MAX_GENERATION_TOKENS = 20 |
|
|
| TRAIN_RNG_SEED = 42 |
| VAL_RNG_SEED = 99 |
| TRAIN_USE_CURRICULUM = False |
| TRAIN_CURRICULUM_SCRAMBLE_LENGTHS = (2, 4, 6, 8, 10, 14, 18) |
|
|
| ROLLOUT_MIN_STEPS = 200 |
| SEARCH_SELECTOR = "hybrid_greedy_v1" |
| SEARCH_RESIDUAL_DELTA = 2 |
| SEARCH_LOOKAHEAD_TOP_K = 3 |
| SEARCH_SECOND_SCORE_DISCOUNT = 0.5 |
|
|
|
|
| def _stable_version(prefix: str, payload: dict[str, object]) -> str: |
| digest = hashlib.sha1(json.dumps(payload, sort_keys=True).encode("utf-8")).hexdigest()[:12] |
| return f"{prefix}-{digest}" |
|
|
|
|
| DATA_CONFIG = { |
| "prompt_format_version": PROMPT_FORMAT_VERSION, |
| "teacher_backend": TEACHER_BACKEND, |
| "train_sizes": TRAIN_SIZES, |
| "id_val_sizes": ID_VAL_SIZES, |
| "ood_dev_sizes": OOD_DEV_SIZES, |
| "ood_test_sizes": OOD_TEST_SIZES, |
| "train_episodes_per_size": TRAIN_EPISODES_PER_SIZE, |
| "id_val_episodes_per_size": ID_VAL_EPISODES_PER_SIZE, |
| "ood_dev_episodes_per_size": OOD_DEV_EPISODES_PER_SIZE, |
| "ood_test_episodes_per_size": OOD_TEST_EPISODES_PER_SIZE, |
| "max_move_depth": MAX_MOVE_DEPTH, |
| "max_move_width": MAX_MOVE_WIDTH, |
| "train_rng_seed": TRAIN_RNG_SEED, |
| "val_rng_seed": VAL_RNG_SEED, |
| "train_use_curriculum": TRAIN_USE_CURRICULUM, |
| "train_curriculum_scramble_lengths": TRAIN_CURRICULUM_SCRAMBLE_LENGTHS, |
| } |
| DATA_VERSION = _stable_version("rubiks-v8", DATA_CONFIG) |
|
|
|
|
| def get_experiment_manifest() -> dict[str, object]: |
| return { |
| "data": { |
| "data_version": DATA_VERSION, |
| "dataset_path": DATASET_PATH, |
| **DATA_CONFIG, |
| }, |
| "eval": { |
| "rollout_min_steps": ROLLOUT_MIN_STEPS, |
| "search_selector": SEARCH_SELECTOR, |
| "search_residual_delta": SEARCH_RESIDUAL_DELTA, |
| "search_lookahead_top_k": SEARCH_LOOKAHEAD_TOP_K, |
| "search_second_score_discount": SEARCH_SECOND_SCORE_DISCOUNT, |
| "no_inverse_rule": True, |
| "state_avoidance": True, |
| }, |
| } |
|
|
| |
| |
| |
|
|
| CACHE_DIR = os.path.join(os.path.expanduser("~"), ".cache", "autoresearch-rubiks") |
| DATA_DIR = os.path.join(CACHE_DIR, "data") |
| TOKENIZER_DIR = os.path.join(CACHE_DIR, "tokenizer") |
| DATASET_PATH = os.path.join(DATA_DIR, "datasets.pkl") |
| TOKENIZER_PATH = os.path.join(TOKENIZER_DIR, "tokenizer.json") |
|
|
| |
| |
| |
|
|
|
|
| def report_environment(): |
| if torch.cuda.is_available(): |
| device = "cuda" |
| elif torch.backends.mps.is_available(): |
| device = "mps" |
| else: |
| device = "cpu" |
| print(f"Environment detected: {device}") |
| print() |
|
|
|
|
| def get_runtime_device() -> str: |
| if torch.cuda.is_available(): |
| return "cuda" |
| if torch.backends.mps.is_available(): |
| return "mps" |
| return "cpu" |
|
|
|
|
| |
| |
| |
|
|
|
|
| def build_vocab() -> list[str]: |
| tokens = [ |
| "<|pad|>", |
| "<|bos|>", |
| "<|eos|>", |
| "<TASK_POLICY>", |
| "<SIZE>", |
| "</SIZE>", |
| "<STATE>", |
| "</STATE>", |
| "<TARGET>", |
| "<DONE>", |
| "<MOVE>", |
| "</MOVE>", |
| "<FACE>", |
| "</FACE>", |
| "<DEPTH>", |
| "</DEPTH>", |
| "<WIDTH>", |
| "</WIDTH>", |
| "<TURN>", |
| "</TURN>", |
| "<ROW>", |
| "</ROW>", |
| ] |
| for face in ("U", "R", "F", "D", "L", "B"): |
| tokens.append(f"<GRID_{face}>") |
| tokens.append(f"</GRID_{face}>") |
| tokens.append(f"FACE_{face}") |
| for color in ("W", "Y", "G", "B", "R", "O"): |
| tokens.append(f"COL_{color}") |
| for turn in ("CW", "CCW", "HALF"): |
| tokens.append(f"TURN_{turn}") |
| for digit in range(10): |
| tokens.append(f"DIGIT_{digit}") |
| |
| for face in ("U", "R", "F", "D", "L", "B"): |
| for turn in ("CW", "CCW", "HALF"): |
| tokens.append(f"MOVE_{face}_{turn}") |
| return tokens |
|
|
|
|
| class Tokenizer: |
| def __init__(self, token_to_id: dict[str, int], id_to_token: list[str]): |
| self.token_to_id = token_to_id |
| self.id_to_token = id_to_token |
|
|
| @classmethod |
| def from_directory(cls, tokenizer_dir=TOKENIZER_DIR) -> "Tokenizer": |
| with open(os.path.join(tokenizer_dir, "tokenizer.json"), "r", encoding="utf-8") as f: |
| payload = json.load(f) |
| return cls(payload["token_to_id"], payload["id_to_token"]) |
|
|
| def save(self, tokenizer_dir=TOKENIZER_DIR): |
| os.makedirs(tokenizer_dir, exist_ok=True) |
| payload = { |
| "token_to_id": self.token_to_id, |
| "id_to_token": self.id_to_token, |
| } |
| with open(os.path.join(tokenizer_dir, "tokenizer.json"), "w", encoding="utf-8") as f: |
| json.dump(payload, f, indent=2) |
|
|
| def get_vocab_size(self) -> int: |
| return len(self.id_to_token) |
|
|
| def get_pad_token_id(self) -> int: |
| return self.token_to_id["<|pad|>"] |
|
|
| def get_bos_token_id(self) -> int: |
| return self.token_to_id["<|bos|>"] |
|
|
| def get_eos_token_id(self) -> int: |
| return self.token_to_id["<|eos|>"] |
|
|
| def encode_tokens(self, tokens: list[str]) -> list[int]: |
| try: |
| return [self.token_to_id[token] for token in tokens] |
| except KeyError as exc: |
| raise ValueError(f"Unknown token: {exc.args[0]}") from exc |
|
|
| def decode_ids(self, ids: list[int]) -> list[str]: |
| return [self.id_to_token[idx] for idx in ids] |
|
|
| def encode(self, text, prepend=None, num_threads=8): |
| del num_threads |
| if isinstance(text, str): |
| tokens = text.strip().split() |
| ids = self.encode_tokens(tokens) |
| if prepend is not None: |
| prepend_id = prepend if isinstance(prepend, int) else self.token_to_id[prepend] |
| ids.insert(0, prepend_id) |
| return ids |
| if isinstance(text, list): |
| encoded = [] |
| for item in text: |
| ids = self.encode(item, prepend=prepend) |
| encoded.append(ids) |
| return encoded |
| raise ValueError(f"Invalid input type: {type(text)}") |
|
|
| def decode(self, ids: list[int]) -> str: |
| return " ".join(self.decode_ids(ids)) |
|
|
|
|
| def ensure_tokenizer(force: bool = False): |
| if not force and os.path.exists(TOKENIZER_PATH): |
| print(f"Tokenizer: already present at {TOKENIZER_DIR}") |
| return |
|
|
| tokenizer = Tokenizer( |
| token_to_id={token: idx for idx, token in enumerate(build_vocab())}, |
| id_to_token=build_vocab(), |
| ) |
| tokenizer.save(TOKENIZER_DIR) |
| print(f"Tokenizer: wrote fixed vocabulary with {tokenizer.get_vocab_size()} tokens") |
|
|
|
|
| |
| |
| |
|
|
|
|
| def encode_supervised_example(tokenizer: Tokenizer, prompt_tokens: list[str], answer_tokens: list[str]) -> dict[str, object]: |
| prompt_ids = tokenizer.encode_tokens(prompt_tokens) |
| answer_ids = tokenizer.encode_tokens(answer_tokens) |
| full_ids = [tokenizer.get_bos_token_id(), *prompt_ids, *answer_ids] |
| input_ids = full_ids[:-1] |
| targets = [-1] * len(prompt_ids) + full_ids[len(prompt_ids) + 1:] |
| if len(input_ids) != len(targets): |
| raise AssertionError("input/target length mismatch") |
| if len(input_ids) > MAX_SEQ_LEN: |
| raise ValueError(f"Example length {len(input_ids)} exceeds MAX_SEQ_LEN={MAX_SEQ_LEN}") |
| return { |
| "input_ids": input_ids, |
| "targets": targets, |
| "prompt_len": len(prompt_ids), |
| "answer_len": len(answer_ids), |
| } |
|
|
|
|
| def episode_to_examples(tokenizer: Tokenizer, episode: Episode) -> list[dict[str, object]]: |
| examples = [] |
| for prompt_tokens, answer_tokens, distance in build_training_examples_from_solution( |
| episode.size, |
| episode.scramble, |
| episode.solution, |
| ): |
| encoded = encode_supervised_example(tokenizer, prompt_tokens, answer_tokens) |
| encoded["size"] = episode.size |
| encoded["distance_to_goal"] = distance |
| examples.append(encoded) |
| return examples |
|
|
|
|
| def _generate_one_worker(args): |
| """Module-level worker for parallel episode generation.""" |
| import sys |
| solver_root = os.path.join(os.path.dirname(os.path.abspath(__file__)), "..", "rubiks-cube-NxNxN-solver") |
| if os.path.exists(solver_root) and solver_root not in sys.path: |
| sys.path.insert(0, solver_root) |
| size, seed, scramble_length = args |
| worker_rng = random.Random(seed) |
| return generate_teacher_episode(size=size, rng=worker_rng, scramble_length=scramble_length) |
|
|
|
|
| def generate_teacher_episode(size: int, rng: random.Random, scramble_length: int | None = None) -> Episode: |
| if scramble_length is None: |
| scramble_length = scramble_length_for_size(size) |
| scramble = random_scramble( |
| size=size, |
| length=scramble_length, |
| rng=rng, |
| max_depth=MAX_MOVE_DEPTH, |
| max_width=MAX_MOVE_WIDTH, |
| ) |
| cube = Cube(size) |
| cube.apply_moves(scramble) |
|
|
| if size == 2: |
| solution = solve_cube_222(cube) |
| elif size == 3: |
| solution = solve_cube_333(cube) |
| else: |
| raise NotImplementedError(f"Teacher backend is not integrated for size {size} yet") |
|
|
| return Episode( |
| size=size, |
| scramble=scramble, |
| solution=solution, |
| max_rollout_steps=max(8, len(solution) * 2), |
| ) |
|
|
|
|
| def _generate_perturbed_examples( |
| tokenizer: Tokenizer, episode: Episode, rng: random.Random, num_perturbations: int = 2, |
| ) -> list[dict[str, object]]: |
| """DAgger-style data augmentation: inject random wrong moves into teacher |
| solution paths, then query the teacher for corrections from the resulting |
| off-policy states. This teaches the model to recover from mistakes.""" |
| if len(episode.solution) == 0: |
| return [] |
| examples = [] |
| for _ in range(num_perturbations): |
| cube = Cube(episode.size) |
| cube.apply_moves(episode.scramble) |
| history: list[Move] = [] |
|
|
| |
| inject_point = rng.randint(0, len(episode.solution) - 1) |
| for i in range(inject_point): |
| cube.apply_move(episode.solution[i]) |
| history.append(episode.solution[i]) |
|
|
| |
| num_random = rng.randint(1, 2) |
| for _ in range(num_random): |
| face = rng.choice(("U", "R", "F", "D", "L", "B")) |
| turns = rng.choice((1, -1, 2)) |
| wrong_move = Move(face=face, depth=1, width=1, turns=turns) |
| cube.apply_move(wrong_move) |
| history.append(wrong_move) |
|
|
| if cube.has_uniform_faces(): |
| continue |
|
|
| |
| try: |
| correction = solve_cube_222(cube) |
| except Exception: |
| continue |
|
|
| if not correction: |
| continue |
|
|
| |
| prompt_tokens = build_prompt_tokens(episode.size, cube, history=history) |
| answer_tokens = build_answer_tokens(correction[0]) |
| try: |
| encoded = encode_supervised_example(tokenizer, prompt_tokens, answer_tokens) |
| encoded["size"] = episode.size |
| examples.append(encoded) |
| except (ValueError, AssertionError): |
| continue |
|
|
| |
| for step_idx in range(1, min(3, len(correction))): |
| cube.apply_move(correction[step_idx - 1]) |
| history.append(correction[step_idx - 1]) |
| if cube.has_uniform_faces(): |
| break |
| prompt_tokens = build_prompt_tokens(episode.size, cube, history=history) |
| answer_tokens = build_answer_tokens(correction[step_idx]) |
| try: |
| encoded = encode_supervised_example(tokenizer, prompt_tokens, answer_tokens) |
| encoded["size"] = episode.size |
| examples.append(encoded) |
| except (ValueError, AssertionError): |
| break |
|
|
| return examples |
|
|
|
|
| def build_dataset_payload(force: bool = False): |
| if not force and os.path.exists(DATASET_PATH): |
| with open(DATASET_PATH, "rb") as f: |
| payload = pickle.load(f) |
| if payload.get("data_version") == DATA_VERSION: |
| print(f"Data: already prepared at {DATASET_PATH}") |
| return |
|
|
| os.makedirs(DATA_DIR, exist_ok=True) |
| tokenizer = Tokenizer.from_directory() |
| train_rng = random.Random(TRAIN_RNG_SEED) |
| val_rng = random.Random(VAL_RNG_SEED) |
|
|
| train_examples: list[dict[str, object]] = [] |
| id_val_examples: list[dict[str, object]] = [] |
| ood_dev_examples: list[dict[str, object]] = [] |
| eval_episodes = { |
| "id": [], |
| "ood_dev": [], |
| "ood_test": [], |
| } |
|
|
| def generate_episode_batch(size: int, count: int, rng: random.Random, curriculum: bool = False) -> list[Episode]: |
| from multiprocessing import Pool, cpu_count |
| args = [] |
| for _ in range(count): |
| sl = rng.choice(TRAIN_CURRICULUM_SCRAMBLE_LENGTHS) if curriculum else None |
| seed = rng.randint(0, 2**31) |
| args.append((size, seed, sl)) |
| n_workers = min(cpu_count(), 16) |
| print(f" Generating {count} episodes for size {size} with {n_workers} workers...") |
| with Pool(n_workers) as pool: |
| episodes = pool.map(_generate_one_worker, args, chunksize=max(1, count // (n_workers * 4))) |
| return episodes |
|
|
| for size in TRAIN_SIZES: |
| n_episodes = _TRAIN_EPISODES_OVERRIDE.get(size, TRAIN_EPISODES_PER_SIZE) |
| episodes = generate_episode_batch( |
| size, |
| n_episodes, |
| rng=train_rng, |
| curriculum=TRAIN_USE_CURRICULUM, |
| ) |
| for episode in episodes: |
| train_examples.extend(episode_to_examples(tokenizer, episode)) |
|
|
| for size in ID_VAL_SIZES: |
| episodes = generate_episode_batch(size, ID_VAL_EPISODES_PER_SIZE, rng=val_rng) |
| eval_episodes["id"].extend(episode.to_dict() for episode in episodes) |
| for episode in episodes: |
| id_val_examples.extend(episode_to_examples(tokenizer, episode)) |
|
|
| for size in OOD_DEV_SIZES: |
| episodes = generate_episode_batch(size, OOD_DEV_EPISODES_PER_SIZE, rng=val_rng) |
| eval_episodes["ood_dev"].extend(episode.to_dict() for episode in episodes) |
| for episode in episodes: |
| ood_dev_examples.extend(episode_to_examples(tokenizer, episode)) |
|
|
| for size in OOD_TEST_SIZES: |
| episodes = generate_episode_batch(size, OOD_TEST_EPISODES_PER_SIZE, rng=val_rng) |
| eval_episodes["ood_test"].extend(episode.to_dict() for episode in episodes) |
|
|
| payload = { |
| "data_version": DATA_VERSION, |
| "config": { |
| "max_seq_len": MAX_SEQ_LEN, |
| **DATA_CONFIG, |
| }, |
| "train_examples": train_examples, |
| "id_val_examples": id_val_examples, |
| "ood_dev_examples": ood_dev_examples, |
| "eval_episodes": eval_episodes, |
| } |
| with open(DATASET_PATH, "wb") as f: |
| pickle.dump(payload, f) |
|
|
| print( |
| "Data: wrote " |
| f"{len(train_examples):,} train examples, " |
| f"{len(id_val_examples):,} ID val examples, " |
| f"{len(ood_dev_examples):,} OOD-dev val examples" |
| ) |
| print(f"Data: evaluation suite stored at {DATASET_PATH}") |
|
|
|
|
| def load_dataset() -> dict[str, object]: |
| with open(DATASET_PATH, "rb") as f: |
| payload = pickle.load(f) |
| if payload.get("data_version") != DATA_VERSION: |
| raise RuntimeError("Dataset version mismatch. Re-run prepare.py --force") |
| return payload |
|
|
|
|
| |
| |
| |
|
|
|
|
| def _example_stream(examples: list[dict[str, object]], shuffle: bool): |
| order = list(range(len(examples))) |
| rng = random.Random(1234 if shuffle else 0) |
| epoch = 1 |
| while True: |
| if shuffle: |
| rng.shuffle(order) |
| for idx in order: |
| yield examples[idx], epoch |
| epoch += 1 |
|
|
|
|
| def make_dataloader(tokenizer: Tokenizer, B: int, T: int, split: str): |
| if split not in {"train", "val", "ood_val"}: |
| raise ValueError(f"Unsupported split: {split}") |
| payload = load_dataset() |
| if split == "train": |
| examples = payload["train_examples"] |
| shuffle = True |
| elif split == "val": |
| examples = payload["id_val_examples"] |
| shuffle = False |
| else: |
| examples = payload["ood_dev_examples"] |
| shuffle = False |
|
|
| stream = _example_stream(examples, shuffle=shuffle) |
| device = get_runtime_device() |
| pad_id = tokenizer.get_pad_token_id() |
|
|
| cpu_inputs = torch.full((B, T), pad_id, dtype=torch.long, pin_memory=(device == "cuda")) |
| cpu_targets = torch.full((B, T), -1, dtype=torch.long, pin_memory=(device == "cuda")) |
| cpu_distances = torch.zeros(B, dtype=torch.float32, pin_memory=(device == "cuda")) |
| inputs = torch.full((B, T), pad_id, dtype=torch.long, device=device) |
| targets = torch.full((B, T), -1, dtype=torch.long, device=device) |
| distances = torch.zeros(B, dtype=torch.float32, device=device) |
|
|
| while True: |
| for row_idx in range(B): |
| example, epoch = next(stream) |
| input_ids = example["input_ids"] |
| target_ids = example["targets"] |
| seq_len = min(len(input_ids), T) |
| cpu_inputs[row_idx].fill_(pad_id) |
| cpu_targets[row_idx].fill_(-1) |
| cpu_inputs[row_idx, :seq_len] = torch.tensor(input_ids[:seq_len], dtype=torch.long) |
| cpu_targets[row_idx, :seq_len] = torch.tensor(target_ids[:seq_len], dtype=torch.long) |
| cpu_distances[row_idx] = float(example.get("distance_to_goal", 0)) |
|
|
| inputs.copy_(cpu_inputs, non_blocking=(device == "cuda")) |
| targets.copy_(cpu_targets, non_blocking=(device == "cuda")) |
| distances.copy_(cpu_distances, non_blocking=(device == "cuda")) |
| yield inputs, targets, distances, epoch |
|
|
|
|
| |
| |
| |
|
|
|
|
| def _autocast_context(device_type: str): |
| if device_type == "cuda": |
| return torch.amp.autocast(device_type="cuda", dtype=torch.bfloat16) |
| if device_type == "cpu": |
| return torch.amp.autocast(device_type="cpu", dtype=torch.bfloat16) |
| return nullcontext() |
|
|
|
|
| def _tensorize_batch(examples: list[dict[str, object]], tokenizer: Tokenizer, T: int, device: str): |
| pad_id = tokenizer.get_pad_token_id() |
| inputs = torch.full((len(examples), T), pad_id, dtype=torch.long, device=device) |
| targets = torch.full((len(examples), T), -1, dtype=torch.long, device=device) |
| for row_idx, example in enumerate(examples): |
| seq_len = min(len(example["input_ids"]), T) |
| inputs[row_idx, :seq_len] = torch.tensor(example["input_ids"][:seq_len], dtype=torch.long, device=device) |
| targets[row_idx, :seq_len] = torch.tensor(example["targets"][:seq_len], dtype=torch.long, device=device) |
| return inputs, targets |
|
|
|
|
| @torch.no_grad() |
| def evaluate_move_accuracy(model, tokenizer: Tokenizer, examples: list[dict[str, object]], batch_size: int) -> float: |
| if not examples: |
| return 0.0 |
| device = next(model.parameters()).device |
| total = 0 |
| correct = 0 |
| autocast_ctx = _autocast_context(device.type) |
| for start in range(0, len(examples), batch_size): |
| batch = examples[start:start + batch_size] |
| inputs, targets = _tensorize_batch(batch, tokenizer, MAX_SEQ_LEN, device) |
| with autocast_ctx: |
| logits = model(inputs) |
| preds = logits.argmax(dim=-1) |
| for row_idx in range(len(batch)): |
| mask = targets[row_idx] != -1 |
| if mask.sum().item() == 0: |
| continue |
| total += 1 |
| if torch.equal(preds[row_idx][mask], targets[row_idx][mask]): |
| correct += 1 |
| return correct / total if total else 0.0 |
|
|
|
|
| def _generate_answer_ids(model, prompt_ids: list[int], tokenizer: Tokenizer, |
| last_move=None) -> list[int]: |
| """Generate answer with single-token move prediction. |
| |
| Single forward pass to choose among 18 MOVE_face_turn tokens + DONE. |
| Supports no-inverse rule: if last_move is given, masks the inverse token. |
| """ |
| device = next(model.parameters()).device |
| t2i = tokenizer.token_to_id |
| autocast_ctx = _autocast_context(device.type) |
|
|
| input_ids = torch.tensor(prompt_ids, dtype=torch.long, device=device).unsqueeze(0) |
| face_names = ("U", "R", "F", "D", "L", "B") |
| turn_names = ("CW", "CCW", "HALF") |
| done_id = t2i["<DONE>"] |
|
|
| |
| valid_ids = [done_id] |
| for face in face_names: |
| for turn in turn_names: |
| if last_move and face == last_move.face: |
| inv = {1: -1, -1: 1, 2: 2}[last_move.turns] |
| inv_name = {1: "CW", -1: "CCW", 2: "HALF"}[inv] |
| if turn == inv_name: |
| continue |
| valid_ids.append(t2i[f"MOVE_{face}_{turn}"]) |
|
|
| with autocast_ctx: |
| logits = model(input_ids) |
| last_logits = logits[0, -1].float() |
| mask = torch.full_like(last_logits, float('-inf')) |
| for vid in valid_ids: |
| mask[vid] = 0.0 |
| chosen = int((last_logits + mask).argmax().item()) |
|
|
| return [chosen] |
|
|
|
|
| def _build_prompt_ids(tokenizer: Tokenizer, cube: Cube, history: list[Move]) -> list[int]: |
| prompt_tokens = build_prompt_tokens(cube.size, cube, history=history) |
| return [tokenizer.get_bos_token_id(), *tokenizer.encode_tokens(prompt_tokens)] |
|
|
|
|
| def _enumerate_move_candidates( |
| model, |
| tokenizer: Tokenizer, |
| cube: Cube, |
| history: list[Move], |
| visited_states: set[str], |
| ) -> list[dict[str, object]]: |
| device = next(model.parameters()).device |
| t2i = tokenizer.token_to_id |
| autocast_ctx = _autocast_context(device.type) |
|
|
| face_names = ("U", "R", "F", "D", "L", "B") |
| turn_names = ("CW", "CCW", "HALF") |
| turn_to_val = {"CW": 1, "CCW": -1, "HALF": 2} |
| last_move = history[-1] if history else None |
| prompt_ids = _build_prompt_ids(tokenizer, cube, history) |
| input_ids = torch.tensor(prompt_ids, dtype=torch.long, device=device).unsqueeze(0) |
|
|
| with autocast_ctx: |
| logits = model(input_ids) |
| last_logits = logits[0, -1].float() |
|
|
| candidates: list[dict[str, object]] = [] |
| for face in face_names: |
| for turn_name in turn_names: |
| turns = turn_to_val[turn_name] |
| if last_move and face == last_move.face: |
| inv = {1: -1, -1: 1, 2: 2}[last_move.turns] |
| if turns == inv: |
| continue |
|
|
| tid = t2i[f"MOVE_{face}_{turn_name}"] |
| score = last_logits[tid].item() |
| move = Move(face=face, depth=1, width=1, turns=turns) |
| next_cube = cube.copy() |
| next_cube.apply_move(move) |
| next_state_str = next_cube.to_kociemba_string() |
| if next_state_str in visited_states: |
| continue |
|
|
| is_goal = next_cube.has_uniform_faces() if cube.size == 2 else next_cube.is_solved() |
| candidates.append( |
| { |
| "move": move, |
| "score": score, |
| "residual": _cube_residual_error(next_cube), |
| "next_cube": next_cube, |
| "next_state_str": next_state_str, |
| "is_goal": is_goal, |
| } |
| ) |
| return candidates |
|
|
|
|
| def _shortlist_candidates(candidates: list[dict[str, object]], current_residual: int) -> list[dict[str, object]]: |
| acceptable = [ |
| candidate |
| for candidate in candidates |
| if candidate["residual"] <= current_residual + SEARCH_RESIDUAL_DELTA |
| ] |
| pool = acceptable if acceptable else candidates |
| return sorted(pool, key=lambda candidate: candidate["score"], reverse=True)[:SEARCH_LOOKAHEAD_TOP_K] |
|
|
|
|
| @torch.no_grad() |
| def _select_move_greedy( |
| model, |
| tokenizer: Tokenizer, |
| cube: Cube, |
| history: list[Move], |
| visited_states: set[str], |
| ) -> Move | None: |
| current_residual = _cube_residual_error(cube) |
| candidates = _enumerate_move_candidates(model, tokenizer, cube, history, visited_states) |
| if candidates: |
| shortlist = _shortlist_candidates(candidates, current_residual) |
| return shortlist[0]["move"] |
|
|
| prompt_ids = _build_prompt_ids(tokenizer, cube, history) |
| last_move = history[-1] if history else None |
| answer_ids = _generate_answer_ids(model, prompt_ids, tokenizer, last_move=last_move) |
| answer_tokens = tokenizer.decode_ids(answer_ids) |
| try: |
| return parse_answer_tokens(answer_tokens) |
| except Exception: |
| return None |
|
|
|
|
| @torch.no_grad() |
| def _select_move_two_ply( |
| model, |
| tokenizer: Tokenizer, |
| cube: Cube, |
| history: list[Move], |
| visited_states: set[str], |
| ) -> Move | None: |
| """Select a move with deterministic 2-ply lookahead.""" |
| current_residual = _cube_residual_error(cube) |
| candidates = _enumerate_move_candidates(model, tokenizer, cube, history, visited_states) |
|
|
| if candidates: |
| shortlist = _shortlist_candidates(candidates, current_residual) |
| ranked: list[tuple[tuple[float, ...], Move]] = [] |
| for candidate in shortlist: |
| if candidate["is_goal"]: |
| ranked.append( |
| ( |
| (0.0, 0.0, float(candidate["residual"]), -float(candidate["score"])), |
| candidate["move"], |
| ) |
| ) |
| continue |
|
|
| next_history = [*history, candidate["move"]] |
| next_visited = set(visited_states) |
| next_visited.add(candidate["next_state_str"]) |
| second_candidates = _enumerate_move_candidates( |
| model, |
| tokenizer, |
| candidate["next_cube"], |
| next_history, |
| next_visited, |
| ) |
|
|
| if second_candidates: |
| second_best = min( |
| second_candidates, |
| key=lambda item: ( |
| 0 if item["is_goal"] else 1, |
| item["residual"], |
| -item["score"], |
| ), |
| ) |
| final_residual = second_best["residual"] |
| solve_depth = 1.0 if second_best["is_goal"] else 2.0 |
| combined_score = candidate["score"] + SEARCH_SECOND_SCORE_DISCOUNT * second_best["score"] |
| else: |
| final_residual = candidate["residual"] |
| solve_depth = 2.0 |
| combined_score = candidate["score"] |
|
|
| ranked.append( |
| ( |
| ( |
| solve_depth, |
| float(final_residual), |
| float(candidate["residual"]), |
| -float(combined_score), |
| ), |
| candidate["move"], |
| ) |
| ) |
|
|
| ranked.sort(key=lambda item: item[0]) |
| return ranked[0][1] |
|
|
| prompt_ids = _build_prompt_ids(tokenizer, cube, history) |
| last_move = history[-1] if history else None |
| answer_ids = _generate_answer_ids(model, prompt_ids, tokenizer, last_move=last_move) |
| answer_tokens = tokenizer.decode_ids(answer_ids) |
| try: |
| return parse_answer_tokens(answer_tokens) |
| except Exception: |
| return None |
|
|
|
|
| @torch.no_grad() |
| def _select_move_value_guided( |
| model, |
| tokenizer: Tokenizer, |
| cube: Cube, |
| history: list[Move], |
| visited_states: set[str], |
| ) -> Move | None: |
| """Select move using value head to evaluate candidate next-states. |
| Combines residual filtering (for 2x2) with value-guided ranking.""" |
| candidates = _enumerate_move_candidates(model, tokenizer, cube, history, visited_states) |
| if not candidates: |
| return None |
|
|
| |
| for c in candidates: |
| if c["is_goal"]: |
| return c["move"] |
|
|
| |
| if cube.size == 2: |
| current_residual = _cube_residual_error(cube) |
| acceptable = [c for c in candidates if c["residual"] <= current_residual + SEARCH_RESIDUAL_DELTA] |
| pool = acceptable if acceptable else candidates |
| else: |
| pool = candidates |
|
|
| |
| if len(pool) == 1: |
| return pool[0]["move"] |
|
|
| device = next(model.parameters()).device |
| autocast_ctx = _autocast_context(device.type) |
|
|
| |
| prompt_ids_list = [] |
| for c in pool: |
| next_history = [*history, c["move"]] |
| ids = _build_prompt_ids(tokenizer, c["next_cube"], next_history) |
| prompt_ids_list.append(ids) |
|
|
| max_len = max(len(ids) for ids in prompt_ids_list) |
| pad_id = tokenizer.get_pad_token_id() |
| batch = torch.full((len(prompt_ids_list), max_len), pad_id, dtype=torch.long, device=device) |
| for i, ids in enumerate(prompt_ids_list): |
| batch[i, :len(ids)] = torch.tensor(ids, dtype=torch.long, device=device) |
|
|
| with autocast_ctx: |
| values = model.predict_value(batch).float() |
|
|
| |
| policy_scores = torch.tensor([c["score"] for c in pool], device=device) |
| policy_probs = torch.softmax(policy_scores, dim=0) |
|
|
| |
| alpha = 0.3 |
| value_scores = -values |
| if value_scores.std() > 1e-6: |
| value_scores = (value_scores - value_scores.mean()) / value_scores.std() |
| if policy_probs.std() > 1e-6: |
| policy_norm = (policy_probs - policy_probs.mean()) / policy_probs.std() |
| else: |
| policy_norm = policy_probs |
| combined = alpha * policy_norm + (1 - alpha) * value_scores |
|
|
| best_idx = combined.argmax().item() |
| return pool[best_idx]["move"] |
|
|
|
|
| @torch.no_grad() |
| def _select_move_with_search( |
| model, |
| tokenizer: Tokenizer, |
| cube: Cube, |
| history: list[Move], |
| visited_states: set[str], |
| ) -> Move | None: |
| if SEARCH_SELECTOR == "hybrid_greedy_v1": |
| return _select_move_greedy(model, tokenizer, cube, history, visited_states) |
| if SEARCH_SELECTOR == "two_ply_v1": |
| return _select_move_two_ply(model, tokenizer, cube, history, visited_states) |
| if SEARCH_SELECTOR == "value_guided_v1": |
| return _select_move_value_guided(model, tokenizer, cube, history, visited_states) |
| if SEARCH_SELECTOR == "hybrid_auto_v1": |
| |
| if cube.size == 2: |
| return _select_move_greedy(model, tokenizer, cube, history, visited_states) |
| return _select_move_value_guided(model, tokenizer, cube, history, visited_states) |
| raise ValueError(f"Unsupported search selector: {SEARCH_SELECTOR}") |
|
|
|
|
| def _cube_residual_error(cube: Cube) -> int: |
| """Count stickers that don't match their face's majority color. |
| |
| For 2x2 cubes where 'solved' means uniform faces (any global orientation), |
| this is a better heuristic than checking against canonical colors. |
| """ |
| from collections import Counter |
| error = 0 |
| for face in ("U", "R", "F", "D", "L", "B"): |
| colors = [c for row in cube.face_grid(face) for c in row] |
| most_common_count = Counter(colors).most_common(1)[0][1] |
| error += len(colors) - most_common_count |
| return error |
|
|
|
|
| @torch.no_grad() |
| def _beam_search_solve(model, tokenizer: Tokenizer, cube: Cube, beam_width: int = 8, max_steps: int = 200) -> bool: |
| """Beam search rollout: keep multiple partial solutions, expand the most promising.""" |
| device = next(model.parameters()).device |
| autocast_ctx = _autocast_context(device.type) |
|
|
| def is_goal(c): |
| return c.has_uniform_faces() if c.size == 2 else c.is_solved() |
|
|
| if is_goal(cube): |
| return True |
|
|
| |
| initial_state = cube.to_kociemba_string() |
| beams = [(cube.copy(), [], {initial_state}, 0.0)] |
|
|
| for step in range(max_steps): |
| if not beams: |
| break |
|
|
| all_candidates = [] |
| for beam_idx, (b_cube, b_history, b_visited, b_score) in enumerate(beams): |
| if is_goal(b_cube): |
| return True |
|
|
| candidates = _enumerate_move_candidates(model, tokenizer, b_cube, b_history, b_visited) |
| for c in candidates: |
| if c["is_goal"]: |
| return True |
| all_candidates.append((beam_idx, c)) |
|
|
| if not all_candidates: |
| break |
|
|
| |
| prompt_ids_list = [] |
| for beam_idx, c in all_candidates: |
| b_history = beams[beam_idx][1] |
| next_history = [*b_history, c["move"]] |
| ids = _build_prompt_ids(tokenizer, c["next_cube"], next_history) |
| prompt_ids_list.append(ids) |
|
|
| max_len = max(len(ids) for ids in prompt_ids_list) |
| pad_id = tokenizer.get_pad_token_id() |
| batch = torch.full((len(prompt_ids_list), max_len), pad_id, dtype=torch.long, device=device) |
| for i, ids in enumerate(prompt_ids_list): |
| batch[i, :len(ids)] = torch.tensor(ids, dtype=torch.long, device=device) |
|
|
| with autocast_ctx: |
| values = model.predict_value(batch).float() |
|
|
| |
| scored = [] |
| for i, (beam_idx, c) in enumerate(all_candidates): |
| parent_score = beams[beam_idx][3] |
| |
| candidate_score = parent_score + c["score"] - 0.5 * values[i].item() |
| scored.append((candidate_score, beam_idx, c, values[i].item())) |
|
|
| |
| scored.sort(key=lambda x: -x[0]) |
| new_beams = [] |
| seen_states = set() |
| for score, beam_idx, c, val in scored: |
| if len(new_beams) >= beam_width: |
| break |
| state_str = c["next_state_str"] |
| if state_str in seen_states: |
| continue |
| seen_states.add(state_str) |
|
|
| parent = beams[beam_idx] |
| new_history = [*parent[1], c["move"]] |
| new_visited = set(parent[2]) |
| new_visited.add(state_str) |
| new_beams.append((c["next_cube"], new_history, new_visited, score)) |
|
|
| beams = new_beams |
|
|
| |
| for b_cube, _, _, _ in beams: |
| if is_goal(b_cube): |
| return True |
| return False |
|
|
|
|
| @torch.no_grad() |
| def evaluate_rollouts(model, tokenizer: Tokenizer, episodes: list[Episode]) -> tuple[float, dict[int, dict[str, float]]]: |
| if not episodes: |
| return 0.0, {} |
|
|
| solved = 0 |
| per_size = defaultdict(lambda: {"count": 0, "solved": 0, "invalid": 0, "residual_sum": 0.0}) |
|
|
| def is_goal(cube: Cube, size: int) -> bool: |
| if size == 2: |
| return cube.has_uniform_faces() |
| return cube.is_solved() |
|
|
| for episode in episodes: |
| cube = Cube(episode.size) |
| cube.apply_moves(episode.scramble) |
| invalid = False |
| moves_taken: list[Move] = [] |
| visited_states = {cube.to_kociemba_string()} |
| for _ in range(max(episode.max_rollout_steps, ROLLOUT_MIN_STEPS)): |
| if is_goal(cube, episode.size): |
| break |
| move = _select_move_with_search( |
| model, |
| tokenizer, |
| cube, |
| moves_taken, |
| visited_states, |
| ) |
| if move is None: |
| break |
| try: |
| cube.apply_move(move) |
| except Exception: |
| invalid = True |
| break |
| moves_taken.append(move) |
| visited_states.add(cube.to_kociemba_string()) |
|
|
| is_solved = is_goal(cube, episode.size) |
| solved += int(is_solved) |
| bucket = per_size[episode.size] |
| bucket["count"] += 1 |
| bucket["solved"] += int(is_solved) |
| bucket["invalid"] += int(invalid) |
| bucket["residual_sum"] += 0.0 if is_solved else _cube_residual_error(cube) |
|
|
| size_metrics: dict[int, dict[str, float]] = {} |
| for size, stats in per_size.items(): |
| size_metrics[size] = { |
| "count": stats["count"], |
| "solve_rate": stats["solved"] / stats["count"], |
| "invalid_rate": stats["invalid"] / stats["count"], |
| "mean_residual": stats["residual_sum"] / stats["count"], |
| } |
| return solved / len(episodes), size_metrics |
|
|
|
|
| @torch.no_grad() |
| def evaluate_policy(model, tokenizer: Tokenizer, batch_size: int) -> dict[str, object]: |
| payload = load_dataset() |
| episodes = payload["eval_episodes"] |
| id_episodes = [Episode.from_dict(item) for item in episodes["id"]] |
| ood_dev_episodes = [Episode.from_dict(item) for item in episodes["ood_dev"]] |
| ood_test_episodes = [Episode.from_dict(item) for item in episodes["ood_test"]] |
|
|
| id_move_accuracy = evaluate_move_accuracy(model, tokenizer, payload["id_val_examples"], batch_size) |
| ood_dev_move_accuracy = evaluate_move_accuracy(model, tokenizer, payload["ood_dev_examples"], batch_size) |
|
|
| id_solve_rate, id_size_metrics = evaluate_rollouts(model, tokenizer, id_episodes) |
| ood_dev_solve_rate, ood_dev_size_metrics = evaluate_rollouts(model, tokenizer, ood_dev_episodes) |
| ood_test_solve_rate, ood_test_size_metrics = evaluate_rollouts(model, tokenizer, ood_test_episodes) |
| primary_metric = ood_dev_solve_rate if ood_dev_episodes else id_solve_rate |
|
|
| return { |
| "primary_metric": primary_metric, |
| "id_move_accuracy": id_move_accuracy, |
| "ood_dev_move_accuracy": ood_dev_move_accuracy, |
| "id_solve_rate": id_solve_rate, |
| "ood_dev_solve_rate": ood_dev_solve_rate, |
| "ood_test_solve_rate": ood_test_solve_rate, |
| "size_metrics": { |
| "id": id_size_metrics, |
| "ood_dev": ood_dev_size_metrics, |
| "ood_test": ood_test_size_metrics, |
| }, |
| } |
|
|
|
|
| |
| |
| |
|
|
|
|
| def main(): |
| parser = argparse.ArgumentParser(description="Prepare Rubik's-cube policy data") |
| parser.add_argument("--force", action="store_true", help="Regenerate tokenizer and dataset") |
| args = parser.parse_args() |
|
|
| report_environment() |
| print(f"Cache directory: {CACHE_DIR}") |
| print() |
|
|
| ensure_tokenizer(force=args.force) |
| build_dataset_payload(force=args.force) |
| print() |
| print("Done! Ready to train.") |
|
|
|
|
| if __name__ == "__main__": |
| main() |
|
|