File size: 1,605 Bytes
a6dd040
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
from typing import Iterable, Union

import torch
from einops import repeat
from jaxtyping import Float, Shaped
from torch import Tensor

Real = Union[float, int]

Vector = Union[
    Real,
    Iterable[Real],
    Shaped[Tensor, "3"],
    Shaped[Tensor, "batch 3"],
]


def sanitize_vector(
    vector: Vector,
    dim: int,
    device: torch.device,
) -> Float[Tensor, "*#batch dim"]:
    if isinstance(vector, Tensor):
        vector = vector.type(torch.float32).to(device)
    else:
        vector = torch.tensor(vector, dtype=torch.float32, device=device)
    while vector.ndim < 2:
        vector = vector[None]
    if vector.shape[-1] == 1:
        vector = repeat(vector, "... () -> ... c", c=dim)
    assert vector.shape[-1] == dim
    assert vector.ndim == 2
    return vector


Scalar = Union[
    Real,
    Iterable[Real],
    Shaped[Tensor, ""],
    Shaped[Tensor, " batch"],
]


def sanitize_scalar(scalar: Scalar, device: torch.device) -> Float[Tensor, "*#batch"]:
    if isinstance(scalar, Tensor):
        scalar = scalar.type(torch.float32).to(device)
    else:
        scalar = torch.tensor(scalar, dtype=torch.float32, device=device)
    while scalar.ndim < 1:
        scalar = scalar[None]
    assert scalar.ndim == 1
    return scalar


Pair = Union[
    Iterable[Real],
    Shaped[Tensor, "2"],
]


def sanitize_pair(pair: Pair, device: torch.device) -> Float[Tensor, "2"]:
    if isinstance(pair, Tensor):
        pair = pair.type(torch.float32).to(device)
    else:
        pair = torch.tensor(pair, dtype=torch.float32, device=device)
    assert pair.shape == (2,)
    return pair