from typing import Iterable, Union import torch from einops import repeat from jaxtyping import Float, Shaped from torch import Tensor Real = Union[float, int] Vector = Union[ Real, Iterable[Real], Shaped[Tensor, "3"], Shaped[Tensor, "batch 3"], ] def sanitize_vector( vector: Vector, dim: int, device: torch.device, ) -> Float[Tensor, "*#batch dim"]: if isinstance(vector, Tensor): vector = vector.type(torch.float32).to(device) else: vector = torch.tensor(vector, dtype=torch.float32, device=device) while vector.ndim < 2: vector = vector[None] if vector.shape[-1] == 1: vector = repeat(vector, "... () -> ... c", c=dim) assert vector.shape[-1] == dim assert vector.ndim == 2 return vector Scalar = Union[ Real, Iterable[Real], Shaped[Tensor, ""], Shaped[Tensor, " batch"], ] def sanitize_scalar(scalar: Scalar, device: torch.device) -> Float[Tensor, "*#batch"]: if isinstance(scalar, Tensor): scalar = scalar.type(torch.float32).to(device) else: scalar = torch.tensor(scalar, dtype=torch.float32, device=device) while scalar.ndim < 1: scalar = scalar[None] assert scalar.ndim == 1 return scalar Pair = Union[ Iterable[Real], Shaped[Tensor, "2"], ] def sanitize_pair(pair: Pair, device: torch.device) -> Float[Tensor, "2"]: if isinstance(pair, Tensor): pair = pair.type(torch.float32).to(device) else: pair = torch.tensor(pair, dtype=torch.float32, device=device) assert pair.shape == (2,) return pair