| # Copyright (c) 2023-2024, NVIDIA CORPORATION. All rights reserved. | |
| # | |
| # NVIDIA CORPORATION and its licensors retain all intellectual property | |
| # and proprietary rights in and to this software, related documentation | |
| # and any modifications thereto. Any use, reproduction, disclosure or | |
| # distribution of this software and related documentation without an express | |
| # license agreement from NVIDIA CORPORATION is strictly prohibited. | |
| from typing import Union, Tuple | |
| import torch | |
| from torch import nn | |
| norm_t = Union[Tuple[float, float, float], torch.Tensor] | |
| class InputConditioner(nn.Module): | |
| def __init__(self, | |
| input_scale: float, | |
| norm_mean: norm_t, | |
| norm_std: norm_t, | |
| dtype: torch.dtype = None, | |
| ): | |
| super().__init__() | |
| self.dtype = dtype | |
| self.register_buffer("norm_mean", _to_tensor(norm_mean) / input_scale) | |
| self.register_buffer("norm_std", _to_tensor(norm_std) / input_scale) | |
| def forward(self, x: torch.Tensor): | |
| y = (x - self.norm_mean) / self.norm_std | |
| # if self.dtype is not None: | |
| # y = y.to(self.dtype) | |
| return y | |
| def get_default_conditioner(): | |
| from timm.data.constants import OPENAI_CLIP_MEAN, OPENAI_CLIP_STD | |
| return InputConditioner( | |
| input_scale=1.0, | |
| norm_mean=OPENAI_CLIP_MEAN, | |
| norm_std=OPENAI_CLIP_STD, | |
| ) | |
| def _to_tensor(v: norm_t): | |
| return torch.as_tensor(v, dtype=torch.float32).view(-1, 1, 1) | |