Spaces:
Runtime error
Runtime error
| from typing import Optional, Protocol, runtime_checkable | |
| import torch | |
| from jaxtyping import Float | |
| from torch import Tensor | |
| from .types import Pair, sanitize_pair | |
| class ConversionFunction(Protocol): | |
| def __call__( | |
| self, | |
| xy: Float[Tensor, "*batch 2"], | |
| ) -> Float[Tensor, "*batch 2"]: | |
| pass | |
| def generate_conversions( | |
| shape: tuple[int, int], | |
| device: torch.device, | |
| x_range: Optional[Pair] = None, | |
| y_range: Optional[Pair] = None, | |
| ) -> tuple[ | |
| ConversionFunction, # conversion from world coordinates to pixel coordinates | |
| ConversionFunction, # conversion from pixel coordinates to world coordinates | |
| ]: | |
| h, w = shape | |
| x_range = sanitize_pair((0, w) if x_range is None else x_range, device) | |
| y_range = sanitize_pair((0, h) if y_range is None else y_range, device) | |
| minima, maxima = torch.stack((x_range, y_range), dim=-1) | |
| wh = torch.tensor((w, h), dtype=torch.float32, device=device) | |
| def convert_world_to_pixel( | |
| xy: Float[Tensor, "*batch 2"], | |
| ) -> Float[Tensor, "*batch 2"]: | |
| return (xy - minima) / (maxima - minima) * wh | |
| def convert_pixel_to_world( | |
| xy: Float[Tensor, "*batch 2"], | |
| ) -> Float[Tensor, "*batch 2"]: | |
| return xy / wh * (maxima - minima) + minima | |
| return convert_world_to_pixel, convert_pixel_to_world | |