File size: 661 Bytes
d7ecc62
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
"""Shared data utilities for PAWN training pipelines."""

import torch


_unpack_bits_cache: dict[torch.device, torch.Tensor] = {}


def unpack_grid(packed_grid: torch.Tensor) -> torch.Tensor:
    """Unpack bit-packed legal move grid to dense float targets.

    packed_grid: (..., 64) int64 — each value is a 64-bit destination mask
    Returns: (..., 64, 64) float32 — binary targets
    """
    device = packed_grid.device
    bits = _unpack_bits_cache.get(device)
    if bits is None:
        bits = torch.arange(64, device=device, dtype=torch.long)
        _unpack_bits_cache[device] = bits
    return ((packed_grid.unsqueeze(-1) >> bits) & 1).float()