| from typing import Dict |
| import torch |
| from equi_diffpo.model.common.normalizer import LinearNormalizer |
| from equi_diffpo.policy.base_image_policy import BaseImagePolicy |
| from equi_diffpo.common.pytorch_util import dict_apply |
|
|
| from robomimic.algo import algo_factory |
| from robomimic.algo.algo import PolicyAlgo |
| import robomimic.utils.obs_utils as ObsUtils |
| from equi_diffpo.common.robomimic_config_util import get_robomimic_config |
|
|
| class RobomimicImagePolicy(BaseImagePolicy): |
| def __init__(self, |
| shape_meta: dict, |
| algo_name='bc_rnn', |
| obs_type='image', |
| task_name='square', |
| dataset_type='ph', |
| crop_shape=(76,76) |
| ): |
| super().__init__() |
|
|
| |
| action_shape = shape_meta['action']['shape'] |
| assert len(action_shape) == 1 |
| action_dim = action_shape[0] |
| obs_shape_meta = shape_meta['obs'] |
| obs_config = { |
| 'low_dim': [], |
| 'rgb': [], |
| 'depth': [], |
| 'scan': [] |
| } |
| obs_key_shapes = dict() |
| for key, attr in obs_shape_meta.items(): |
| shape = attr['shape'] |
| obs_key_shapes[key] = list(shape) |
|
|
| type = attr.get('type', 'low_dim') |
| if type == 'rgb': |
| obs_config['rgb'].append(key) |
| elif type == 'low_dim': |
| obs_config['low_dim'].append(key) |
| else: |
| raise RuntimeError(f"Unsupported obs type: {type}") |
|
|
| |
| config = get_robomimic_config( |
| algo_name=algo_name, |
| hdf5_type=obs_type, |
| task_name=task_name, |
| dataset_type=dataset_type) |
|
|
| |
| with config.unlocked(): |
| |
| config.observation.modalities.obs = obs_config |
|
|
| if crop_shape is None: |
| for key, modality in config.observation.encoder.items(): |
| if modality.obs_randomizer_class == 'CropRandomizer': |
| modality['obs_randomizer_class'] = None |
| else: |
| |
| ch, cw = crop_shape |
| for key, modality in config.observation.encoder.items(): |
| if modality.obs_randomizer_class == 'CropRandomizer': |
| modality.obs_randomizer_kwargs.crop_height = ch |
| modality.obs_randomizer_kwargs.crop_width = cw |
|
|
| |
| ObsUtils.initialize_obs_utils_with_config(config) |
|
|
| |
| model: PolicyAlgo = algo_factory( |
| algo_name=config.algo_name, |
| config=config, |
| obs_key_shapes=obs_key_shapes, |
| ac_dim=action_dim, |
| device='cpu', |
| ) |
|
|
| self.model = model |
| self.nets = model.nets |
| self.normalizer = LinearNormalizer() |
| self.config = config |
|
|
| def to(self,*args,**kwargs): |
| device, dtype, non_blocking, convert_to_format = torch._C._nn._parse_to(*args, **kwargs) |
| if device is not None: |
| self.model.device = device |
| super().to(*args,**kwargs) |
| |
| |
| def predict_action(self, obs_dict: Dict[str, torch.Tensor]) -> Dict[str, torch.Tensor]: |
| nobs_dict = self.normalizer(obs_dict) |
| robomimic_obs_dict = dict_apply(nobs_dict, lambda x: x[:,0,...]) |
| naction = self.model.get_action(robomimic_obs_dict) |
| action = self.normalizer['action'].unnormalize(naction) |
| |
| result = { |
| 'action': action[:,None,:] |
| } |
| return result |
|
|
| def reset(self): |
| self.model.reset() |
|
|
| |
| def set_normalizer(self, normalizer: LinearNormalizer): |
| self.normalizer.load_state_dict(normalizer.state_dict()) |
|
|
| def train_on_batch(self, batch, epoch, validate=False): |
| nobs = self.normalizer.normalize(batch['obs']) |
| nactions = self.normalizer['action'].normalize(batch['action']) |
| robomimic_batch = { |
| 'obs': nobs, |
| 'actions': nactions |
| } |
| input_batch = self.model.process_batch_for_training( |
| robomimic_batch) |
| info = self.model.train_on_batch( |
| batch=input_batch, epoch=epoch, validate=validate) |
| |
| return info |
| |
| def on_epoch_end(self, epoch): |
| self.model.on_epoch_end(epoch) |
|
|
| def get_optimizer(self): |
| return self.model.optimizers['policy'] |
|
|
|
|
| def test(): |
| import os |
| from omegaconf import OmegaConf |
| cfg_path = os.path.expanduser('~/dev/diffusion_policy/diffusion_policy/config/task/lift_image.yaml') |
| cfg = OmegaConf.load(cfg_path) |
| shape_meta = cfg.shape_meta |
|
|
| policy = RobomimicImagePolicy(shape_meta=shape_meta) |
|
|
|
|