| import torch |
| from einops import einsum, reduce, repeat |
| from jaxtyping import Float |
| from torch import Tensor |
|
|
| from ..types import BatchedExample |
|
|
|
|
| def inverse_normalize_image(tensor, mean=(0.5, 0.5, 0.5), std=(0.5, 0.5, 0.5)): |
| mean = torch.as_tensor(mean, dtype=tensor.dtype, device=tensor.device).view(-1, 1, 1) |
| std = torch.as_tensor(std, dtype=tensor.dtype, device=tensor.device).view(-1, 1, 1) |
| return tensor * std + mean |
|
|
|
|
| def normalize_image(tensor, mean=(0.5, 0.5, 0.5), std=(0.5, 0.5, 0.5)): |
| mean = torch.as_tensor(mean, dtype=tensor.dtype, device=tensor.device).view(-1, 1, 1) |
| std = torch.as_tensor(std, dtype=tensor.dtype, device=tensor.device).view(-1, 1, 1) |
| return (tensor - mean) / std |
|
|
|
|
| def apply_normalize_shim( |
| batch: BatchedExample, |
| mean: tuple[float, float, float] = (0.5, 0.5, 0.5), |
| std: tuple[float, float, float] = (0.5, 0.5, 0.5), |
| ) -> BatchedExample: |
| batch["context"]["image"] = normalize_image(batch["context"]["image"], mean, std) |
| return batch |
|
|