from __future__ import annotations import itertools import json import math import os import random import re import sys from dataclasses import dataclass from functools import lru_cache from typing import Any, Dict, Iterable, List, Optional, Sequence import numpy as np import torch CURRENT_DIR = os.path.dirname(os.path.abspath(__file__)) PARENT_DIR = os.path.dirname(CURRENT_DIR) if PARENT_DIR not in sys.path: sys.path.insert(0, PARENT_DIR) from aligned_cell_policy.shared_cell_policy import CellExample from formatting_icon import is_consistent_pair def all_remaining_empties_have_legal_value(grid: np.ndarray) -> bool: g = np.asarray(grid, dtype=int).reshape(9, 9) for r in range(9): for c in range(9): if int(g[r, c]) != 0: continue cell = r * 9 + c has_legal = any(is_consistent_pair(g, cell=cell, value=v, t=3, n=9) for v in range(1, 10)) if not has_legal: return False return True @dataclass(frozen=True) class ParsedValues: values: tuple[int, ...] parse_ok: bool strict_canonical: bool def all_digit_values() -> List[int]: return list(range(1, 10)) def make_solved_grid_from_row(row: Dict[str, Any]) -> np.ndarray: grid = parse_grid_from_tuple_prompt(str(row["prompt"])) solved = np.asarray(grid, dtype=int).copy() triples = row.get("metadata", {}).get("target_triples_1based", []) for rr, cc, value in triples: solved[int(rr) - 1, int(cc) - 1] = int(value) return solved def _grid_state_key(grid: np.ndarray) -> tuple[int, ...]: return tuple(int(v) for v in np.asarray(grid, dtype=int).reshape(-1)) def _legal_values_for_cell(state: tuple[int, ...], cell: int) -> tuple[int, ...]: rr, cc = divmod(int(cell), 9) if int(state[cell]) != 0: return tuple() g = np.asarray(state, dtype=int).reshape(9, 9) return tuple( int(value) for value in all_digit_values() if is_consistent_pair(g, cell=int(cell), value=int(value), t=3, n=9) ) @lru_cache(maxsize=200000) def _stage_i_consistent_values_for_grid(state: tuple[int, ...], stage_i: int) -> tuple[tuple[int, ...], ...]: stage_i = max(1, int(stage_i)) out: List[tuple[int, ...]] = [tuple() for _ in range(81)] for cell in range(81): legal_values = _legal_values_for_cell(state, cell) if not legal_values: continue if stage_i <= 1: out[cell] = legal_values continue consistent_values: List[int] = [] for value in legal_values: child = list(state) child[cell] = int(value) child_state = tuple(child) child_stage_values = _stage_i_consistent_values_for_grid(child_state, stage_i - 1) if all(int(child_state[idx]) != 0 or len(child_stage_values[idx]) > 0 for idx in range(81)): consistent_values.append(int(value)) out[cell] = tuple(consistent_values) return tuple(out) def stage_i_consistent_values( grid: np.ndarray, *, target_cell: tuple[int, int], stage_i: int, ) -> List[int]: g = np.asarray(grid, dtype=int).reshape(9, 9) rr, cc = int(target_cell[0]), int(target_cell[1]) if int(g[rr, cc]) != 0: return [] cell = rr * 9 + cc stage_values = _stage_i_consistent_values_for_grid(_grid_state_key(g), int(stage_i)) return [int(value) for value in stage_values[cell]] def canonicalize_values(values: Iterable[int]) -> List[int]: out: List[int] = [] seen = set() for value in values: if isinstance(value, bool): raise ValueError("Boolean values are not allowed.") vv = int(value) if vv < 1 or vv > 9: raise ValueError(f"Value must be in [1,9], got {vv}.") if vv not in seen: seen.add(vv) out.append(vv) return out def values_json_text(values: Iterable[int]) -> str: return json.dumps({"values": canonicalize_values(values)}, separators=(",", ":")) def parse_values_json(text: str) -> ParsedValues: raw = str(text).strip() if not raw: return ParsedValues(values=tuple(), parse_ok=False, strict_canonical=False) try: obj = json.loads(raw) except Exception: return ParsedValues(values=tuple(), parse_ok=False, strict_canonical=False) if not isinstance(obj, dict): return ParsedValues(values=tuple(), parse_ok=False, strict_canonical=False) if set(obj.keys()) != {"values"}: return ParsedValues(values=tuple(), parse_ok=False, strict_canonical=False) values_obj = obj.get("values") if not isinstance(values_obj, list): return ParsedValues(values=tuple(), parse_ok=False, strict_canonical=False) try: values = canonicalize_values(values_obj) except Exception: return ParsedValues(values=tuple(), parse_ok=False, strict_canonical=False) if len(values) != len(values_obj): return ParsedValues(values=tuple(), parse_ok=False, strict_canonical=False) canonical = values_json_text(values) return ParsedValues(values=tuple(values), parse_ok=True, strict_canonical=(canonical == raw)) def compute_set_precision_recall(pred_values: Sequence[int], target_values: Sequence[int]) -> tuple[float, float]: pred = set(int(v) for v in pred_values) target = set(int(v) for v in target_values) precision = 0.0 if not pred else float(len(pred & target) / max(1, len(pred))) recall = 1.0 if not target else float(len(pred & target) / max(1, len(target))) return precision, recall def completion_ce_loss( model: torch.nn.Module, tokenizer: Any, prompt_text: str, completion_text: str, device: torch.device, ) -> torch.Tensor: prompt_ids = tokenizer(prompt_text, return_tensors="pt", add_special_tokens=False).input_ids.to(device) all_ids = tokenizer(prompt_text + completion_text, return_tensors="pt", add_special_tokens=False).input_ids.to( device ) labels = all_ids.clone() labels[:, : int(prompt_ids.shape[1])] = -100 out = model(input_ids=all_ids, labels=labels) return out.loss def batched_completion_ce_loss( model: torch.nn.Module, tokenizer: Any, prompt_texts: Sequence[str], completion_texts: Sequence[str], device: torch.device, ) -> torch.Tensor: if len(prompt_texts) != len(completion_texts): raise ValueError("prompt_texts and completion_texts must have the same length") if not prompt_texts: raise ValueError("batched_completion_ce_loss requires at least one example") full_texts = [str(p) + str(c) for p, c in zip(prompt_texts, completion_texts, strict=True)] batch = tokenizer(full_texts, return_tensors="pt", add_special_tokens=False, padding=True) prompt_batch = tokenizer(list(prompt_texts), return_tensors="pt", add_special_tokens=False, padding=True) input_ids = batch["input_ids"].to(device) attention_mask = batch["attention_mask"].to(device) prompt_attention = prompt_batch["attention_mask"] prompt_lengths = prompt_attention.sum(dim=1).tolist() labels = input_ids.clone() labels[attention_mask == 0] = -100 for row_idx, prompt_len in enumerate(prompt_lengths): labels[row_idx, : int(prompt_len)] = -100 out = model(input_ids=input_ids, attention_mask=attention_mask, labels=labels) return out.loss def completion_logprob( model: torch.nn.Module, tokenizer: Any, prompt_text: str, completion_text: str, device: torch.device, ) -> torch.Tensor: prompt_ids = tokenizer(prompt_text, return_tensors="pt", add_special_tokens=False).input_ids.to(device) all_ids = tokenizer(prompt_text + completion_text, return_tensors="pt", add_special_tokens=False).input_ids.to( device ) labels = all_ids.clone() labels[:, : int(prompt_ids.shape[1])] = -100 out = model(input_ids=all_ids, labels=labels) num_completion_tokens = int((labels != -100).sum().item()) return -out.loss * max(1, num_completion_tokens) def enumerate_value_permutations( values: Sequence[int], *, max_permutations: int, rng: Optional[random.Random] = None, ) -> List[tuple[int, ...]]: uniq = tuple(canonicalize_values(values)) if len(uniq) <= 1: return [uniq] total = math.factorial(len(uniq)) if total <= max(1, int(max_permutations)): return [tuple(p) for p in itertools.permutations(uniq)] rr = rng or random.Random(0) perms = set() base = list(uniq) max_needed = max(1, int(max_permutations)) while len(perms) < max_needed: shuffled = list(base) rr.shuffle(shuffled) perms.add(tuple(shuffled)) return list(perms) def permutation_invariant_json_ce_loss( model: torch.nn.Module, tokenizer: Any, prompt_text: str, values: Sequence[int], device: torch.device, *, max_permutations: int, rng: Optional[random.Random] = None, ) -> torch.Tensor: permutations = enumerate_value_permutations(values, max_permutations=max_permutations, rng=rng) logps = [ completion_logprob(model, tokenizer, prompt_text, values_json_text(perm), device) for perm in permutations ] stacked = torch.stack(logps, dim=0) return -(torch.logsumexp(stacked, dim=0) - math.log(float(len(permutations)))) def build_supervised_completion( ex: CellExample, *, stage_i: int, rng: Optional[random.Random] = None, randomize_order: bool = False, ) -> str: values = stage_i_consistent_values(ex.grid, target_cell=ex.target_cell, stage_i=stage_i) if randomize_order and len(values) > 1: shuffled = list(values) (rng or random).shuffle(shuffled) values = shuffled return values_json_text(values) def summarize_values(values: Iterable[int]) -> str: return "[" + ", ".join(str(int(v)) for v in values) + "]" _TUPLE_PROMPT_RE = re.compile(r"\((\d+),(\d+),(\d+)\)") def parse_grid_from_tuple_prompt(prompt_text: str) -> np.ndarray: triples = _TUPLE_PROMPT_RE.findall(str(prompt_text)) if len(triples) < 81: raise ValueError("Could not recover 81 (row,col,value) tuples from prompt.") grid = np.zeros((9, 9), dtype=int) for rr, cc, vv in triples[:81]: r = int(rr) - 1 c = int(cc) - 1 grid[r, c] = int(vv) return grid