| | from typing import Union, Dict |
| |
|
| | import unittest |
| | import zarr |
| | import numpy as np |
| | import torch |
| | import torch.nn as nn |
| | from diffusion_policy.common.pytorch_util import dict_apply |
| | from diffusion_policy.model.common.dict_of_tensor_mixin import DictOfTensorMixin |
| |
|
| |
|
| | class LinearNormalizer(DictOfTensorMixin): |
| | avaliable_modes = ['limits', 'gaussian'] |
| | |
| | @torch.no_grad() |
| | def fit(self, |
| | data: Union[Dict, torch.Tensor, np.ndarray, zarr.Array], |
| | last_n_dims=1, |
| | dtype=torch.float32, |
| | mode='limits', |
| | output_max=1., |
| | output_min=-1., |
| | range_eps=1e-4, |
| | fit_offset=True): |
| | if isinstance(data, dict): |
| | for key, value in data.items(): |
| | self.params_dict[key] = _fit(value, |
| | last_n_dims=last_n_dims, |
| | dtype=dtype, |
| | mode=mode, |
| | output_max=output_max, |
| | output_min=output_min, |
| | range_eps=range_eps, |
| | fit_offset=fit_offset) |
| | else: |
| | self.params_dict['_default'] = _fit(data, |
| | last_n_dims=last_n_dims, |
| | dtype=dtype, |
| | mode=mode, |
| | output_max=output_max, |
| | output_min=output_min, |
| | range_eps=range_eps, |
| | fit_offset=fit_offset) |
| | |
| | def __call__(self, x: Union[Dict, torch.Tensor, np.ndarray]) -> torch.Tensor: |
| | return self.normalize(x) |
| | |
| | def __getitem__(self, key: str): |
| | return SingleFieldLinearNormalizer(self.params_dict[key]) |
| |
|
| | def __setitem__(self, key: str , value: 'SingleFieldLinearNormalizer'): |
| | self.params_dict[key] = value.params_dict |
| |
|
| | def _normalize_impl(self, x, forward=True): |
| | if isinstance(x, dict): |
| | result = dict() |
| | for key, value in x.items(): |
| | params = self.params_dict[key] |
| | result[key] = _normalize(value, params, forward=forward) |
| | return result |
| | else: |
| | if '_default' not in self.params_dict: |
| | raise RuntimeError("Not initialized") |
| | params = self.params_dict['_default'] |
| | return _normalize(x, params, forward=forward) |
| |
|
| | def normalize(self, x: Union[Dict, torch.Tensor, np.ndarray]) -> torch.Tensor: |
| | return self._normalize_impl(x, forward=True) |
| |
|
| | def unnormalize(self, x: Union[Dict, torch.Tensor, np.ndarray]) -> torch.Tensor: |
| | return self._normalize_impl(x, forward=False) |
| |
|
| | def get_input_stats(self) -> Dict: |
| | if len(self.params_dict) == 0: |
| | raise RuntimeError("Not initialized") |
| | if len(self.params_dict) == 1 and '_default' in self.params_dict: |
| | return self.params_dict['_default']['input_stats'] |
| | |
| | result = dict() |
| | for key, value in self.params_dict.items(): |
| | if key != '_default': |
| | result[key] = value['input_stats'] |
| | return result |
| |
|
| |
|
| | def get_output_stats(self, key='_default'): |
| | input_stats = self.get_input_stats() |
| | if 'min' in input_stats: |
| | |
| | return dict_apply(input_stats, self.normalize) |
| | |
| | result = dict() |
| | for key, group in input_stats.items(): |
| | this_dict = dict() |
| | for name, value in group.items(): |
| | this_dict[name] = self.normalize({key:value})[key] |
| | result[key] = this_dict |
| | return result |
| |
|
| |
|
| | class SingleFieldLinearNormalizer(DictOfTensorMixin): |
| | avaliable_modes = ['limits', 'gaussian'] |
| | |
| | @torch.no_grad() |
| | def fit(self, |
| | data: Union[torch.Tensor, np.ndarray, zarr.Array], |
| | last_n_dims=1, |
| | dtype=torch.float32, |
| | mode='limits', |
| | output_max=1., |
| | output_min=-1., |
| | range_eps=1e-4, |
| | fit_offset=True): |
| | self.params_dict = _fit(data, |
| | last_n_dims=last_n_dims, |
| | dtype=dtype, |
| | mode=mode, |
| | output_max=output_max, |
| | output_min=output_min, |
| | range_eps=range_eps, |
| | fit_offset=fit_offset) |
| | |
| | @classmethod |
| | def create_fit(cls, data: Union[torch.Tensor, np.ndarray, zarr.Array], **kwargs): |
| | obj = cls() |
| | obj.fit(data, **kwargs) |
| | return obj |
| | |
| | @classmethod |
| | def create_manual(cls, |
| | scale: Union[torch.Tensor, np.ndarray], |
| | offset: Union[torch.Tensor, np.ndarray], |
| | input_stats_dict: Dict[str, Union[torch.Tensor, np.ndarray]]): |
| | def to_tensor(x): |
| | if not isinstance(x, torch.Tensor): |
| | x = torch.from_numpy(x) |
| | x = x.flatten() |
| | return x |
| | |
| | |
| | for x in [offset] + list(input_stats_dict.values()): |
| | assert x.shape == scale.shape |
| | assert x.dtype == scale.dtype |
| | |
| | params_dict = nn.ParameterDict({ |
| | 'scale': to_tensor(scale), |
| | 'offset': to_tensor(offset), |
| | 'input_stats': nn.ParameterDict( |
| | dict_apply(input_stats_dict, to_tensor)) |
| | }) |
| | return cls(params_dict) |
| |
|
| | @classmethod |
| | def create_identity(cls, dtype=torch.float32): |
| | scale = torch.tensor([1], dtype=dtype) |
| | offset = torch.tensor([0], dtype=dtype) |
| | input_stats_dict = { |
| | 'min': torch.tensor([-1], dtype=dtype), |
| | 'max': torch.tensor([1], dtype=dtype), |
| | 'mean': torch.tensor([0], dtype=dtype), |
| | 'std': torch.tensor([1], dtype=dtype) |
| | } |
| | return cls.create_manual(scale, offset, input_stats_dict) |
| |
|
| | def normalize(self, x: Union[torch.Tensor, np.ndarray]) -> torch.Tensor: |
| | return _normalize(x, self.params_dict, forward=True) |
| |
|
| | def unnormalize(self, x: Union[torch.Tensor, np.ndarray]) -> torch.Tensor: |
| | return _normalize(x, self.params_dict, forward=False) |
| |
|
| | def get_input_stats(self): |
| | return self.params_dict['input_stats'] |
| |
|
| | def get_output_stats(self): |
| | return dict_apply(self.params_dict['input_stats'], self.normalize) |
| |
|
| | def __call__(self, x: Union[torch.Tensor, np.ndarray]) -> torch.Tensor: |
| | return self.normalize(x) |
| |
|
| |
|
| |
|
| | def _fit(data: Union[torch.Tensor, np.ndarray, zarr.Array], |
| | last_n_dims=1, |
| | dtype=torch.float32, |
| | mode='limits', |
| | output_max=1., |
| | output_min=-1., |
| | range_eps=1e-4, |
| | fit_offset=True): |
| | assert mode in ['limits', 'gaussian'] |
| | assert last_n_dims >= 0 |
| | assert output_max > output_min |
| |
|
| | |
| | if isinstance(data, zarr.Array): |
| | data = data[:] |
| | if isinstance(data, np.ndarray): |
| | data = torch.from_numpy(data) |
| | if dtype is not None: |
| | data = data.type(dtype) |
| |
|
| | |
| | dim = 1 |
| | if last_n_dims > 0: |
| | dim = np.prod(data.shape[-last_n_dims:]) |
| | data = data.reshape(-1,dim) |
| |
|
| | |
| | input_min, _ = data.min(axis=0) |
| | input_max, _ = data.max(axis=0) |
| | input_mean = data.mean(axis=0) |
| | input_std = data.std(axis=0) |
| |
|
| | |
| | if mode == 'limits': |
| | if fit_offset: |
| | |
| | input_range = input_max - input_min |
| | ignore_dim = input_range < range_eps |
| | input_range[ignore_dim] = output_max - output_min |
| | scale = (output_max - output_min) / input_range |
| | offset = output_min - scale * input_min |
| | offset[ignore_dim] = (output_max + output_min) / 2 - input_min[ignore_dim] |
| | |
| | else: |
| | |
| | assert output_max > 0 |
| | assert output_min < 0 |
| | |
| | output_abs = min(abs(output_min), abs(output_max)) |
| | input_abs = torch.maximum(torch.abs(input_min), torch.abs(input_max)) |
| | ignore_dim = input_abs < range_eps |
| | input_abs[ignore_dim] = output_abs |
| | |
| | scale = output_abs / input_abs |
| | offset = torch.zeros_like(input_mean) |
| | elif mode == 'gaussian': |
| | ignore_dim = input_std < range_eps |
| | scale = input_std.clone() |
| | scale[ignore_dim] = 1 |
| | scale = 1 / scale |
| |
|
| | if fit_offset: |
| | offset = - input_mean * scale |
| | else: |
| | offset = torch.zeros_like(input_mean) |
| | |
| | |
| | this_params = nn.ParameterDict({ |
| | 'scale': scale, |
| | 'offset': offset, |
| | 'input_stats': nn.ParameterDict({ |
| | 'min': input_min, |
| | 'max': input_max, |
| | 'mean': input_mean, |
| | 'std': input_std |
| | }) |
| | }) |
| | for p in this_params.parameters(): |
| | p.requires_grad_(False) |
| | return this_params |
| |
|
| |
|
| | def _normalize(x, params, forward=True): |
| | assert 'scale' in params |
| | if isinstance(x, np.ndarray): |
| | x = torch.from_numpy(x) |
| | scale = params['scale'] |
| | offset = params['offset'] |
| | x = x.to(device=scale.device, dtype=scale.dtype) |
| | src_shape = x.shape |
| | x = x.reshape(-1, scale.shape[0]) |
| | if forward: |
| | x = x * scale + offset |
| | else: |
| | x = (x - offset) / scale |
| | x = x.reshape(src_shape) |
| | return x |
| |
|
| |
|
| | def test(): |
| | data = torch.zeros((100,10,9,2)).uniform_() |
| | data[...,0,0] = 0 |
| |
|
| | normalizer = SingleFieldLinearNormalizer() |
| | normalizer.fit(data, mode='limits', last_n_dims=2) |
| | datan = normalizer.normalize(data) |
| | assert datan.shape == data.shape |
| | assert np.allclose(datan.max(), 1.) |
| | assert np.allclose(datan.min(), -1.) |
| | dataun = normalizer.unnormalize(datan) |
| | assert torch.allclose(data, dataun, atol=1e-7) |
| |
|
| | input_stats = normalizer.get_input_stats() |
| | output_stats = normalizer.get_output_stats() |
| |
|
| | normalizer = SingleFieldLinearNormalizer() |
| | normalizer.fit(data, mode='limits', last_n_dims=1, fit_offset=False) |
| | datan = normalizer.normalize(data) |
| | assert datan.shape == data.shape |
| | assert np.allclose(datan.max(), 1., atol=1e-3) |
| | assert np.allclose(datan.min(), 0., atol=1e-3) |
| | dataun = normalizer.unnormalize(datan) |
| | assert torch.allclose(data, dataun, atol=1e-7) |
| |
|
| | data = torch.zeros((100,10,9,2)).uniform_() |
| | normalizer = SingleFieldLinearNormalizer() |
| | normalizer.fit(data, mode='gaussian', last_n_dims=0) |
| | datan = normalizer.normalize(data) |
| | assert datan.shape == data.shape |
| | assert np.allclose(datan.mean(), 0., atol=1e-3) |
| | assert np.allclose(datan.std(), 1., atol=1e-3) |
| | dataun = normalizer.unnormalize(datan) |
| | assert torch.allclose(data, dataun, atol=1e-7) |
| |
|
| |
|
| | |
| | data = torch.zeros((100,10,9,2)).uniform_() |
| | data[...,0,0] = 0 |
| |
|
| | normalizer = LinearNormalizer() |
| | normalizer.fit(data, mode='limits', last_n_dims=2) |
| | datan = normalizer.normalize(data) |
| | assert datan.shape == data.shape |
| | assert np.allclose(datan.max(), 1.) |
| | assert np.allclose(datan.min(), -1.) |
| | dataun = normalizer.unnormalize(datan) |
| | assert torch.allclose(data, dataun, atol=1e-7) |
| |
|
| | input_stats = normalizer.get_input_stats() |
| | output_stats = normalizer.get_output_stats() |
| |
|
| | data = { |
| | 'obs': torch.zeros((1000,128,9,2)).uniform_() * 512, |
| | 'action': torch.zeros((1000,128,2)).uniform_() * 512 |
| | } |
| | normalizer = LinearNormalizer() |
| | normalizer.fit(data) |
| | datan = normalizer.normalize(data) |
| | dataun = normalizer.unnormalize(datan) |
| | for key in data: |
| | assert torch.allclose(data[key], dataun[key], atol=1e-4) |
| | |
| | input_stats = normalizer.get_input_stats() |
| | output_stats = normalizer.get_output_stats() |
| |
|
| | state_dict = normalizer.state_dict() |
| | n = LinearNormalizer() |
| | n.load_state_dict(state_dict) |
| | datan = n.normalize(data) |
| | dataun = n.unnormalize(datan) |
| | for key in data: |
| | assert torch.allclose(data[key], dataun[key], atol=1e-4) |
| |
|