Spaces:
Runtime error
Runtime error
| import torch | |
| from einops import einsum, reduce, repeat | |
| from jaxtyping import Float | |
| from torch import Tensor | |
| from ..types import BatchedExample | |
| def inverse_normalize_image(tensor, mean=(0.5, 0.5, 0.5), std=(0.5, 0.5, 0.5)): | |
| mean = torch.as_tensor(mean, dtype=tensor.dtype, device=tensor.device).view(-1, 1, 1) | |
| std = torch.as_tensor(std, dtype=tensor.dtype, device=tensor.device).view(-1, 1, 1) | |
| return tensor * std + mean | |
| def normalize_image(tensor, mean=(0.5, 0.5, 0.5), std=(0.5, 0.5, 0.5)): | |
| mean = torch.as_tensor(mean, dtype=tensor.dtype, device=tensor.device).view(-1, 1, 1) | |
| std = torch.as_tensor(std, dtype=tensor.dtype, device=tensor.device).view(-1, 1, 1) | |
| return (tensor - mean) / std | |
| def apply_normalize_shim( | |
| batch: BatchedExample, | |
| mean: tuple[float, float, float] = (0.5, 0.5, 0.5), | |
| std: tuple[float, float, float] = (0.5, 0.5, 0.5), | |
| ) -> BatchedExample: | |
| batch["context"]["image"] = normalize_image(batch["context"]["image"], mean, std) | |
| return batch | |