Spaces:
Runtime error
Runtime error
| import torch | |
| from einops import repeat | |
| from jaxtyping import Int | |
| from torch import Tensor | |
| Index = Int[Tensor, "n n-1"] | |
| def generate_heterogeneous_index( | |
| n: int, | |
| device: torch.device = torch.device("cpu"), | |
| ) -> tuple[Index, Index]: | |
| """Generate indices for all pairs except self-pairs.""" | |
| arange = torch.arange(n, device=device) | |
| # Generate an index that represents the item itself. | |
| index_self = repeat(arange, "h -> h w", w=n - 1) | |
| # Generate an index that represents the other items. | |
| index_other = repeat(arange, "w -> h w", h=n).clone() | |
| index_other += torch.ones((n, n), device=device, dtype=torch.int64).triu() | |
| index_other = index_other[:, :-1] | |
| return index_self, index_other | |
| def generate_heterogeneous_index_transpose( | |
| n: int, | |
| device: torch.device = torch.device("cpu"), | |
| ) -> tuple[Index, Index]: | |
| """Generate an index that can be used to "transpose" the heterogeneous index. | |
| Applying the index a second time inverts the "transpose." | |
| """ | |
| arange = torch.arange(n, device=device) | |
| ones = torch.ones((n, n), device=device, dtype=torch.int64) | |
| index_self = repeat(arange, "w -> h w", h=n).clone() | |
| index_self = index_self + ones.triu() | |
| index_other = repeat(arange, "h -> h w", w=n) | |
| index_other = index_other - (1 - ones.triu()) | |
| return index_self[:, :-1], index_other[:, :-1] | |