| from typing import Callable, Literal, TypedDict |
|
|
| from jaxtyping import Float, Int64 |
| from torch import Tensor |
|
|
| Stage = Literal["train", "val", "test"] |
|
|
|
|
| |
| |
| |
| |
| |
|
|
|
|
| class BatchedViews(TypedDict, total=False): |
| extrinsics: Float[Tensor, "batch _ 4 4"] |
| intrinsics: Float[Tensor, "batch _ 3 3"] |
| image: Float[Tensor, "batch _ _ _ _"] |
| near: Float[Tensor, "batch _"] |
| far: Float[Tensor, "batch _"] |
| index: Int64[Tensor, "batch _"] |
| overlap: Float[Tensor, "batch _"] |
|
|
|
|
| class BatchedExample(TypedDict, total=False): |
| target: BatchedViews |
| context: BatchedViews |
| scene: list[str] |
|
|
|
|
| class UnbatchedViews(TypedDict, total=False): |
| extrinsics: Float[Tensor, "_ 4 4"] |
| intrinsics: Float[Tensor, "_ 3 3"] |
| image: Float[Tensor, "_ 3 height width"] |
| near: Float[Tensor, " _"] |
| far: Float[Tensor, " _"] |
| index: Int64[Tensor, " _"] |
|
|
|
|
| class UnbatchedExample(TypedDict, total=False): |
| target: UnbatchedViews |
| context: UnbatchedViews |
| scene: str |
|
|
|
|
| |
| DataShim = Callable[[BatchedExample], BatchedExample] |
|
|
| AnyExample = BatchedExample | UnbatchedExample |
| AnyViews = BatchedViews | UnbatchedViews |
|
|