File size: 1,605 Bytes
a6dd040 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 |
from typing import Iterable, Union
import torch
from einops import repeat
from jaxtyping import Float, Shaped
from torch import Tensor
Real = Union[float, int]
Vector = Union[
Real,
Iterable[Real],
Shaped[Tensor, "3"],
Shaped[Tensor, "batch 3"],
]
def sanitize_vector(
vector: Vector,
dim: int,
device: torch.device,
) -> Float[Tensor, "*#batch dim"]:
if isinstance(vector, Tensor):
vector = vector.type(torch.float32).to(device)
else:
vector = torch.tensor(vector, dtype=torch.float32, device=device)
while vector.ndim < 2:
vector = vector[None]
if vector.shape[-1] == 1:
vector = repeat(vector, "... () -> ... c", c=dim)
assert vector.shape[-1] == dim
assert vector.ndim == 2
return vector
Scalar = Union[
Real,
Iterable[Real],
Shaped[Tensor, ""],
Shaped[Tensor, " batch"],
]
def sanitize_scalar(scalar: Scalar, device: torch.device) -> Float[Tensor, "*#batch"]:
if isinstance(scalar, Tensor):
scalar = scalar.type(torch.float32).to(device)
else:
scalar = torch.tensor(scalar, dtype=torch.float32, device=device)
while scalar.ndim < 1:
scalar = scalar[None]
assert scalar.ndim == 1
return scalar
Pair = Union[
Iterable[Real],
Shaped[Tensor, "2"],
]
def sanitize_pair(pair: Pair, device: torch.device) -> Float[Tensor, "2"]:
if isinstance(pair, Tensor):
pair = pair.type(torch.float32).to(device)
else:
pair = torch.tensor(pair, dtype=torch.float32, device=device)
assert pair.shape == (2,)
return pair
|