| import torch |
| from einops import repeat |
| from jaxtyping import Int |
| from torch import Tensor |
|
|
| Index = Int[Tensor, "n n-1"] |
|
|
|
|
| def generate_heterogeneous_index( |
| n: int, |
| device: torch.device = torch.device("cpu"), |
| ) -> tuple[Index, Index]: |
| """Generate indices for all pairs except self-pairs.""" |
| arange = torch.arange(n, device=device) |
|
|
| |
| index_self = repeat(arange, "h -> h w", w=n - 1) |
|
|
| |
| index_other = repeat(arange, "w -> h w", h=n).clone() |
| index_other += torch.ones((n, n), device=device, dtype=torch.int64).triu() |
| index_other = index_other[:, :-1] |
|
|
| return index_self, index_other |
|
|
|
|
| def generate_heterogeneous_index_transpose( |
| n: int, |
| device: torch.device = torch.device("cpu"), |
| ) -> tuple[Index, Index]: |
| """Generate an index that can be used to "transpose" the heterogeneous index. |
| Applying the index a second time inverts the "transpose." |
| """ |
| arange = torch.arange(n, device=device) |
| ones = torch.ones((n, n), device=device, dtype=torch.int64) |
|
|
| index_self = repeat(arange, "w -> h w", h=n).clone() |
| index_self = index_self + ones.triu() |
|
|
| index_other = repeat(arange, "h -> h w", w=n) |
| index_other = index_other - (1 - ones.triu()) |
|
|
| return index_self[:, :-1], index_other[:, :-1] |
|
|