"""Shared data utilities for PAWN training pipelines.""" import torch _unpack_bits_cache: dict[torch.device, torch.Tensor] = {} def unpack_grid(packed_grid: torch.Tensor) -> torch.Tensor: """Unpack bit-packed legal move grid to dense float targets. packed_grid: (..., 64) int64 — each value is a 64-bit destination mask Returns: (..., 64, 64) float32 — binary targets """ device = packed_grid.device bits = _unpack_bits_cache.get(device) if bits is None: bits = torch.arange(64, device=device, dtype=torch.long) _unpack_bits_cache[device] = bits return ((packed_grid.unsqueeze(-1) >> bits) & 1).float()