| |
| |
| |
| |
| |
|
|
| |
|
|
| from typing import Sequence, Tuple, Union |
|
|
| import torch |
|
|
|
|
| """ |
| Some functions which depend on PyTorch or Python versions. |
| """ |
|
|
|
|
| def meshgrid_ij( |
| *A: Union[torch.Tensor, Sequence[torch.Tensor]], |
| ) -> Tuple[torch.Tensor, ...]: |
| """ |
| Like torch.meshgrid was before PyTorch 1.10.0, i.e. with indexing set to ij |
| """ |
| if ( |
| |
| torch.meshgrid.__kwdefaults__ is not None |
| and "indexing" in torch.meshgrid.__kwdefaults__ |
| ): |
| |
| |
| |
| return torch.meshgrid(*A, indexing="ij") |
| |
| |
| return torch.meshgrid(*A) |
|
|
|
|
| def prod(iterable, *, start=1): |
| """ |
| Like math.prod in Python 3.8 and later. |
| """ |
| for i in iterable: |
| start *= i |
| return start |
|
|