File size: 2,832 Bytes
76de008 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 | from __future__ import annotations
import json
import re
from dataclasses import dataclass
from typing import Any, Dict, List, Tuple
import numpy as np
_INT_RE = re.compile(r"-?\d+")
_TUPLE_PROMPT_RE = re.compile(r"\((\d+),(\d+),(\d+)\)")
@dataclass(frozen=True)
class CellExample:
grid: np.ndarray
target_cell: tuple[int, int]
target_value: int
turn_idx: int
total_turns: int
def parse_n_value_prediction(text: str, n: int) -> Tuple[List[int] | None, bool]:
raw = str(text or '').strip()
if not raw:
return None, False
try:
parsed = json.loads(raw)
if isinstance(parsed, dict) and isinstance(parsed.get('values'), list):
values = [int(v) for v in parsed['values']]
if len(values) == int(n):
return values, True
if isinstance(parsed, list):
values = [int(v) for v in parsed]
if len(values) == int(n):
return values, True
except Exception:
pass
values = [int(match.group(0)) for match in _INT_RE.finditer(raw)]
if len(values) == int(n):
return values, True
return None, False
def parse_grid_from_tuple_prompt(prompt_text: str) -> np.ndarray:
triples = _TUPLE_PROMPT_RE.findall(str(prompt_text))
if len(triples) < 16:
raise ValueError('Could not recover 16 (row,col,value) tuples from prompt.')
grid = np.zeros((4, 4), dtype=int)
for rr, cc, vv in triples[:16]:
grid[int(rr) - 1, int(cc) - 1] = int(vv)
return grid
def build_cell_examples_from_row(row: Dict[str, Any]) -> List[CellExample]:
prompt = str(row['prompt'])
grid = parse_grid_from_tuple_prompt(prompt)
metadata = dict(row.get('metadata', {}))
empty_locs = metadata.get('empty_locs_1based')
target_triples = metadata.get('target_triples_1based')
if not empty_locs or not target_triples:
completion = str(row.get('completion', ''))
parsed, _ = parse_n_value_prediction(completion, int(metadata.get('empties', 0) or 0))
if parsed is None:
raise ValueError('Row is missing metadata and completion could not be parsed.')
empty_locs = [(r + 1, c + 1) for r, c in np.argwhere(grid == 0).tolist()]
target_triples = [(int(r), int(c), int(v)) for (r, c), v in zip(empty_locs, parsed)]
total_turns = len(target_triples)
out: List[CellExample] = []
for idx, triple in enumerate(target_triples, start=1):
rr, cc, value = int(triple[0]) - 1, int(triple[1]) - 1, int(triple[2])
out.append(
CellExample(
grid=np.asarray(grid, dtype=int).copy(),
target_cell=(rr, cc),
target_value=value,
turn_idx=idx,
total_turns=total_turns,
)
)
return out
|