""" PHI-SCAN: Physics-informed multi-directional token scanners. Alternates between scan patterns per layer with zero extra parameters. """ import torch SCAN_PATTERNS = ["row_major", "col_major", "hilbert", "zigzag_diag"] def _hilbert_index(n, x, y): d = 0 s = n // 2 while s > 0: rx = 1 if (x & s) > 0 else 0 ry = 1 if (y & s) > 0 else 0 d += s * s * ((3 * rx) ^ ry) if ry == 0: if rx == 1: x = n - 1 - x y = n - 1 - y x, y = y, x s //= 2 return d def build_hilbert_permutation(h: int, w: int, device='cpu'): n = max(h, w) n = 1 << (n - 1).bit_length() indices = [-1] * (h * w) for y in range(h): for x in range(w): idx = _hilbert_index(n, x, y) indices[y * w + x] = idx sorted_pairs = sorted(enumerate(indices), key=lambda kv: kv[1]) perm = torch.tensor([i for i, _ in sorted_pairs], dtype=torch.long, device=device) inv = torch.empty_like(perm) inv[perm] = torch.arange(h * w, device=device) return perm, inv def build_zigzag_diag_permutation(h: int, w: int, device='cpu'): diag = {} for y in range(h): for x in range(w): s = x + y if s not in diag: diag[s] = [] diag[s].append((y, x)) order = [] flip = False for s in sorted(diag.keys()): cells = diag[s] if flip: cells = cells[::-1] order.extend(cells) flip = not flip perm = torch.tensor([y * w + x for y, x in order], dtype=torch.long, device=device) inv = torch.empty_like(perm) inv[perm] = torch.arange(h * w, device=device) return perm, inv def build_scan_permutations(h: int, w: int, device='cpu'): row_perm = torch.arange(h * w, device=device) row_inv = torch.arange(h * w, device=device) col_perm = torch.arange(h * w, device=device).reshape(h, w).t().reshape(-1) col_inv = torch.empty_like(col_perm) col_inv[col_perm] = torch.arange(h * w, device=device) hil_perm, hil_inv = build_hilbert_permutation(h, w, device) zig_perm, zig_inv = build_zigzag_diag_permutation(h, w, device) return { "row_major": (row_perm, row_inv), "col_major": (col_perm, col_inv), "hilbert": (hil_perm, hil_inv), "zigzag_diag": (zig_perm, zig_inv), } def apply_scan(x: torch.Tensor, perm: torch.Tensor): B, L, C = x.shape return x[:, perm, :] def unscan(x: torch.Tensor, inv: torch.Tensor): B, L, C = x.shape return x[:, inv, :] def get_scan_pattern(layer_idx: int): return SCAN_PATTERNS[layer_idx % len(SCAN_PATTERNS)]