""" Implementation of Diffusion Policy https://diffusion-policy.cs.columbia.edu/ by Cheng Chi """ from typing import Callable, Union import math from collections import OrderedDict, deque from packaging.version import parse as parse_version import copy import torch import torch.nn as nn import torch.nn.functional as F # requires diffusers==0.11.1 from diffusers.schedulers.scheduling_ddpm import DDPMScheduler from diffusers.schedulers.scheduling_ddim import DDIMScheduler from diffusers.training_utils import EMAModel import robomimic.models.obs_nets as ObsNets import robomimic.utils.tensor_utils as TensorUtils import robomimic.utils.torch_utils as TorchUtils import robomimic.utils.obs_utils as ObsUtils from robomimic.algo import register_algo_factory_func, PolicyAlgo @register_algo_factory_func("diffusion_policy") def algo_config_to_class(algo_config): """ Maps algo config to the BC algo class to instantiate, along with additional algo kwargs. Args: algo_config (Config instance): algo config Returns: algo_class: subclass of Algo algo_kwargs (dict): dictionary of additional kwargs to pass to algorithm """ if algo_config.unet.enabled: return DiffusionPolicyUNet, {} elif algo_config.transformer.enabled: raise NotImplementedError() else: raise RuntimeError() class DiffusionPolicyUNet(PolicyAlgo): def _create_networks(self): """ Creates networks and places them into @self.nets. """ # set up different observation groups for @MIMO_MLP observation_group_shapes = OrderedDict() observation_group_shapes["obs"] = OrderedDict(self.obs_shapes) encoder_kwargs = ObsUtils.obs_encoder_kwargs_from_config(self.obs_config.encoder) obs_encoder = ObsNets.ObservationGroupEncoder( observation_group_shapes=observation_group_shapes, encoder_kwargs=encoder_kwargs, ) # IMPORTANT! # replace all BatchNorm with GroupNorm to work with EMA # performance will tank if you forget to do this! obs_encoder = replace_bn_with_gn(obs_encoder) obs_dim = obs_encoder.output_shape()[0] # create network object noise_pred_net = ConditionalUnet1D( input_dim=self.ac_dim, global_cond_dim=obs_dim*self.algo_config.horizon.observation_horizon ) # the final arch has 2 parts nets = nn.ModuleDict({ 'policy': nn.ModuleDict({ 'obs_encoder': obs_encoder, 'noise_pred_net': noise_pred_net }) }) nets = nets.float().to(self.device) # setup noise scheduler noise_scheduler = None if self.algo_config.ddpm.enabled: noise_scheduler = DDPMScheduler( num_train_timesteps=self.algo_config.ddpm.num_train_timesteps, beta_schedule=self.algo_config.ddpm.beta_schedule, clip_sample=self.algo_config.ddpm.clip_sample, prediction_type=self.algo_config.ddpm.prediction_type ) elif self.algo_config.ddim.enabled: noise_scheduler = DDIMScheduler( num_train_timesteps=self.algo_config.ddim.num_train_timesteps, beta_schedule=self.algo_config.ddim.beta_schedule, clip_sample=self.algo_config.ddim.clip_sample, set_alpha_to_one=self.algo_config.ddim.set_alpha_to_one, steps_offset=self.algo_config.ddim.steps_offset, prediction_type=self.algo_config.ddim.prediction_type ) else: raise RuntimeError() # setup EMA ema = None if self.algo_config.ema.enabled: ema = EMAModel(parameters=nets.parameters(), power=self.algo_config.ema.power) # set attrs self.nets = nets self._shadow_nets = copy.deepcopy(self.nets).eval() self._shadow_nets.requires_grad_(False) self.noise_scheduler = noise_scheduler self.ema = ema self.action_check_done = False self.obs_queue = None self.action_queue = None def process_batch_for_training(self, batch): """ Processes input batch from a data loader to filter out relevant information and prepare the batch for training. Args: batch (dict): dictionary with torch.Tensors sampled from a data loader Returns: input_batch (dict): processed and filtered batch that will be used for training """ To = self.algo_config.horizon.observation_horizon Ta = self.algo_config.horizon.action_horizon Tp = self.algo_config.horizon.prediction_horizon input_batch = dict() input_batch["obs"] = {k: batch["obs"][k][:, :To, :] for k in batch["obs"]} input_batch["goal_obs"] = batch.get("goal_obs", None) # goals may not be present input_batch["actions"] = batch["actions"][:, :Tp, :] # check if actions are normalized to [-1,1] if not self.action_check_done: actions = input_batch["actions"] in_range = (-1 <= actions) & (actions <= 1) all_in_range = torch.all(in_range).item() if not all_in_range: raise ValueError('"actions" must be in range [-1,1] for Diffusion Policy! Check if hdf5_normalize_action is enabled.') self.action_check_done = True return TensorUtils.to_device(TensorUtils.to_float(input_batch), self.device) def train_on_batch(self, batch, epoch, validate=False): """ Training on a single batch of data. Args: batch (dict): dictionary with torch.Tensors sampled from a data loader and filtered by @process_batch_for_training epoch (int): epoch number - required by some Algos that need to perform staged training and early stopping validate (bool): if True, don't perform any learning updates. Returns: info (dict): dictionary of relevant inputs, outputs, and losses that might be relevant for logging """ To = self.algo_config.horizon.observation_horizon Ta = self.algo_config.horizon.action_horizon Tp = self.algo_config.horizon.prediction_horizon action_dim = self.ac_dim B = batch['actions'].shape[0] with TorchUtils.maybe_no_grad(no_grad=validate): info = super(DiffusionPolicyUNet, self).train_on_batch(batch, epoch, validate=validate) actions = batch['actions'] # encode obs inputs = { 'obs': batch["obs"], 'goal': batch["goal_obs"] } for k in self.obs_shapes: # first two dimensions should be [B, T] for inputs assert inputs['obs'][k].ndim - 2 == len(self.obs_shapes[k]) obs_features = TensorUtils.time_distributed(inputs, self.nets['policy']['obs_encoder'], inputs_as_kwargs=True) assert obs_features.ndim == 3 # [B, T, D] obs_cond = obs_features.flatten(start_dim=1) # sample noise to add to actions noise = torch.randn(actions.shape, device=self.device) # sample a diffusion iteration for each data point timesteps = torch.randint( 0, self.noise_scheduler.config.num_train_timesteps, (B,), device=self.device ).long() # add noise to the clean actions according to the noise magnitude at each diffusion iteration # (this is the forward diffusion process) noisy_actions = self.noise_scheduler.add_noise( actions, noise, timesteps) # predict the noise residual noise_pred = self.nets['policy']['noise_pred_net']( noisy_actions, timesteps, global_cond=obs_cond) # L2 loss loss = F.mse_loss(noise_pred, noise) # logging losses = { 'l2_loss': loss } info["losses"] = TensorUtils.detach(losses) if not validate: # gradient step policy_grad_norms = TorchUtils.backprop_for_loss( net=self.nets, optim=self.optimizers["policy"], loss=loss, ) # update Exponential Moving Average of the model weights if self.ema is not None: self.ema.step(self.nets.parameters()) step_info = { 'policy_grad_norms': policy_grad_norms } info.update(step_info) return info def log_info(self, info): """ Process info dictionary from @train_on_batch to summarize information to pass to tensorboard for logging. Args: info (dict): dictionary of info Returns: loss_log (dict): name -> summary statistic """ log = super(DiffusionPolicyUNet, self).log_info(info) log["Loss"] = info["losses"]["l2_loss"].item() if "policy_grad_norms" in info: log["Policy_Grad_Norms"] = info["policy_grad_norms"] return log def reset(self): """ Reset algo state to prepare for environment rollouts. """ # setup inference queues To = self.algo_config.horizon.observation_horizon Ta = self.algo_config.horizon.action_horizon obs_queue = deque(maxlen=To) action_queue = deque(maxlen=Ta) self.obs_queue = obs_queue self.action_queue = action_queue def get_action(self, obs_dict, goal_dict=None): """ Get policy action outputs. Args: obs_dict (dict): current observation [1, Do] goal_dict (dict): (optional) goal Returns: action (torch.Tensor): action tensor [1, Da] """ # obs_dict: key: [1,D] To = self.algo_config.horizon.observation_horizon Ta = self.algo_config.horizon.action_horizon # TODO: obs_queue already handled by frame_stack # make sure we have at least To observations in obs_queue # if not enough, repeat # if already full, append one to the obs_queue # n_repeats = max(To - len(self.obs_queue), 1) # self.obs_queue.extend([obs_dict] * n_repeats) if len(self.action_queue) == 0: # no actions left, run inference # turn obs_queue into dict of tensors (concat at T dim) # import pdb; pdb.set_trace() # obs_dict_list = TensorUtils.list_of_flat_dict_to_dict_of_list(list(self.obs_queue)) # obs_dict_tensor = dict((k, torch.cat(v, dim=0).unsqueeze(0)) for k,v in obs_dict_list.items()) # run inference # [1,T,Da] action_sequence = self._get_action_trajectory(obs_dict=obs_dict) # put actions into the queue self.action_queue.extend(action_sequence[0]) # has action, execute from left to right # [Da] action = self.action_queue.popleft() # [1,Da] action = action.unsqueeze(0) return action def _get_action_trajectory(self, obs_dict, goal_dict=None): assert not self.nets.training To = self.algo_config.horizon.observation_horizon Ta = self.algo_config.horizon.action_horizon Tp = self.algo_config.horizon.prediction_horizon action_dim = self.ac_dim if self.algo_config.ddpm.enabled is True: num_inference_timesteps = self.algo_config.ddpm.num_inference_timesteps elif self.algo_config.ddim.enabled is True: num_inference_timesteps = self.algo_config.ddim.num_inference_timesteps else: raise ValueError # select network nets = self.nets if self.ema is not None: self.ema.copy_to(parameters=self._shadow_nets.parameters()) nets = self._shadow_nets # encode obs inputs = { 'obs': obs_dict, 'goal': goal_dict } for k in self.obs_shapes: # first two dimensions should be [B, T] for inputs assert inputs['obs'][k].ndim - 2 == len(self.obs_shapes[k]) obs_features = TensorUtils.time_distributed(inputs, self.nets['policy']['obs_encoder'], inputs_as_kwargs=True) assert obs_features.ndim == 3 # [B, T, D] B = obs_features.shape[0] # reshape observation to (B,obs_horizon*obs_dim) obs_cond = obs_features.flatten(start_dim=1) # initialize action from Guassian noise noisy_action = torch.randn( (B, Tp, action_dim), device=self.device) naction = noisy_action # init scheduler self.noise_scheduler.set_timesteps(num_inference_timesteps) for k in self.noise_scheduler.timesteps: # predict noise noise_pred = nets['policy']['noise_pred_net']( sample=naction, timestep=k, global_cond=obs_cond ) # inverse diffusion step (remove noise) naction = self.noise_scheduler.step( model_output=noise_pred, timestep=k, sample=naction ).prev_sample # process action using Ta start = To - 1 end = start + Ta action = naction[:,start:end] return action def serialize(self): """ Get dictionary of current model parameters. """ return { "nets": self.nets.state_dict(), "ema": self.ema.state_dict() if self.ema is not None else None, } def deserialize(self, model_dict): """ Load model from a checkpoint. Args: model_dict (dict): a dictionary saved by self.serialize() that contains the same keys as @self.network_classes """ self.nets.load_state_dict(model_dict["nets"]) if model_dict.get("ema", None) is not None: self.ema.load_state_dict(model_dict["ema"]) # =================== Vision Encoder Utils ===================== def replace_submodules( root_module: nn.Module, predicate: Callable[[nn.Module], bool], func: Callable[[nn.Module], nn.Module]) -> nn.Module: """ Replace all submodules selected by the predicate with the output of func. predicate: Return true if the module is to be replaced. func: Return new module to use. """ if predicate(root_module): return func(root_module) if parse_version(torch.__version__) < parse_version('1.9.0'): raise ImportError('This function requires pytorch >= 1.9.0') bn_list = [k.split('.') for k, m in root_module.named_modules(remove_duplicate=True) if predicate(m)] for *parent, k in bn_list: parent_module = root_module if len(parent) > 0: parent_module = root_module.get_submodule('.'.join(parent)) if isinstance(parent_module, nn.Sequential): src_module = parent_module[int(k)] else: src_module = getattr(parent_module, k) tgt_module = func(src_module) if isinstance(parent_module, nn.Sequential): parent_module[int(k)] = tgt_module else: setattr(parent_module, k, tgt_module) # verify that all modules are replaced bn_list = [k.split('.') for k, m in root_module.named_modules(remove_duplicate=True) if predicate(m)] assert len(bn_list) == 0 return root_module def replace_bn_with_gn( root_module: nn.Module, features_per_group: int=16) -> nn.Module: """ Relace all BatchNorm layers with GroupNorm. """ replace_submodules( root_module=root_module, predicate=lambda x: isinstance(x, nn.BatchNorm2d), func=lambda x: nn.GroupNorm( num_groups=x.num_features//features_per_group, num_channels=x.num_features) ) return root_module # =================== UNet for Diffusion ============== class SinusoidalPosEmb(nn.Module): def __init__(self, dim): super().__init__() self.dim = dim def forward(self, x): device = x.device half_dim = self.dim // 2 emb = math.log(10000) / (half_dim - 1) emb = torch.exp(torch.arange(half_dim, device=device) * -emb) emb = x[:, None] * emb[None, :] emb = torch.cat((emb.sin(), emb.cos()), dim=-1) return emb class Downsample1d(nn.Module): def __init__(self, dim): super().__init__() self.conv = nn.Conv1d(dim, dim, 3, 2, 1) def forward(self, x): return self.conv(x) class Upsample1d(nn.Module): def __init__(self, dim): super().__init__() self.conv = nn.ConvTranspose1d(dim, dim, 4, 2, 1) def forward(self, x): return self.conv(x) class Conv1dBlock(nn.Module): ''' Conv1d --> GroupNorm --> Mish ''' def __init__(self, inp_channels, out_channels, kernel_size, n_groups=8): super().__init__() self.block = nn.Sequential( nn.Conv1d(inp_channels, out_channels, kernel_size, padding=kernel_size // 2), nn.GroupNorm(n_groups, out_channels), nn.Mish(), ) def forward(self, x): return self.block(x) class ConditionalResidualBlock1D(nn.Module): def __init__(self, in_channels, out_channels, cond_dim, kernel_size=3, n_groups=8): super().__init__() self.blocks = nn.ModuleList([ Conv1dBlock(in_channels, out_channels, kernel_size, n_groups=n_groups), Conv1dBlock(out_channels, out_channels, kernel_size, n_groups=n_groups), ]) # FiLM modulation https://arxiv.org/abs/1709.07871 # predicts per-channel scale and bias cond_channels = out_channels * 2 self.out_channels = out_channels self.cond_encoder = nn.Sequential( nn.Mish(), nn.Linear(cond_dim, cond_channels), nn.Unflatten(-1, (-1, 1)) ) # make sure dimensions compatible self.residual_conv = nn.Conv1d(in_channels, out_channels, 1) \ if in_channels != out_channels else nn.Identity() def forward(self, x, cond): ''' x : [ batch_size x in_channels x horizon ] cond : [ batch_size x cond_dim] returns: out : [ batch_size x out_channels x horizon ] ''' out = self.blocks[0](x) embed = self.cond_encoder(cond) embed = embed.reshape( embed.shape[0], 2, self.out_channels, 1) scale = embed[:,0,...] bias = embed[:,1,...] out = scale * out + bias out = self.blocks[1](out) out = out + self.residual_conv(x) return out class ConditionalUnet1D(nn.Module): def __init__(self, input_dim, global_cond_dim, diffusion_step_embed_dim=256, down_dims=[256,512,1024], kernel_size=5, n_groups=8 ): """ input_dim: Dim of actions. global_cond_dim: Dim of global conditioning applied with FiLM in addition to diffusion step embedding. This is usually obs_horizon * obs_dim diffusion_step_embed_dim: Size of positional encoding for diffusion iteration k down_dims: Channel size for each UNet level. The length of this array determines numebr of levels. kernel_size: Conv kernel size n_groups: Number of groups for GroupNorm """ super().__init__() all_dims = [input_dim] + list(down_dims) start_dim = down_dims[0] dsed = diffusion_step_embed_dim diffusion_step_encoder = nn.Sequential( SinusoidalPosEmb(dsed), nn.Linear(dsed, dsed * 4), nn.Mish(), nn.Linear(dsed * 4, dsed), ) cond_dim = dsed + global_cond_dim in_out = list(zip(all_dims[:-1], all_dims[1:])) mid_dim = all_dims[-1] self.mid_modules = nn.ModuleList([ ConditionalResidualBlock1D( mid_dim, mid_dim, cond_dim=cond_dim, kernel_size=kernel_size, n_groups=n_groups ), ConditionalResidualBlock1D( mid_dim, mid_dim, cond_dim=cond_dim, kernel_size=kernel_size, n_groups=n_groups ), ]) down_modules = nn.ModuleList([]) for ind, (dim_in, dim_out) in enumerate(in_out): is_last = ind >= (len(in_out) - 1) down_modules.append(nn.ModuleList([ ConditionalResidualBlock1D( dim_in, dim_out, cond_dim=cond_dim, kernel_size=kernel_size, n_groups=n_groups), ConditionalResidualBlock1D( dim_out, dim_out, cond_dim=cond_dim, kernel_size=kernel_size, n_groups=n_groups), Downsample1d(dim_out) if not is_last else nn.Identity() ])) up_modules = nn.ModuleList([]) for ind, (dim_in, dim_out) in enumerate(reversed(in_out[1:])): is_last = ind >= (len(in_out) - 1) up_modules.append(nn.ModuleList([ ConditionalResidualBlock1D( dim_out*2, dim_in, cond_dim=cond_dim, kernel_size=kernel_size, n_groups=n_groups), ConditionalResidualBlock1D( dim_in, dim_in, cond_dim=cond_dim, kernel_size=kernel_size, n_groups=n_groups), Upsample1d(dim_in) if not is_last else nn.Identity() ])) final_conv = nn.Sequential( Conv1dBlock(start_dim, start_dim, kernel_size=kernel_size), nn.Conv1d(start_dim, input_dim, 1), ) self.diffusion_step_encoder = diffusion_step_encoder self.up_modules = up_modules self.down_modules = down_modules self.final_conv = final_conv print("number of parameters: {:e}".format( sum(p.numel() for p in self.parameters())) ) def forward(self, sample: torch.Tensor, timestep: Union[torch.Tensor, float, int], global_cond=None): """ x: (B,T,input_dim) timestep: (B,) or int, diffusion step global_cond: (B,global_cond_dim) output: (B,T,input_dim) """ # (B,T,C) sample = sample.moveaxis(-1,-2) # (B,C,T) # 1. time timesteps = timestep if not torch.is_tensor(timesteps): timesteps = torch.tensor([timesteps], dtype=torch.long, device=sample.device) elif torch.is_tensor(timesteps) and len(timesteps.shape) == 0: timesteps = timesteps[None].to(sample.device) # broadcast to batch dimension in a way that's compatible with ONNX/Core ML timesteps = timesteps.expand(sample.shape[0]) global_feature = self.diffusion_step_encoder(timesteps) if global_cond is not None: global_feature = torch.cat([ global_feature, global_cond ], axis=-1) x = sample h = [] for idx, (resnet, resnet2, downsample) in enumerate(self.down_modules): x = resnet(x, global_feature) x = resnet2(x, global_feature) h.append(x) x = downsample(x) for mid_module in self.mid_modules: x = mid_module(x, global_feature) for idx, (resnet, resnet2, upsample) in enumerate(self.up_modules): x = torch.cat((x, h.pop()), dim=1) x = resnet(x, global_feature) x = resnet2(x, global_feature) x = upsample(x) x = self.final_conv(x) # (B,C,T) x = x.moveaxis(-1,-2) # (B,T,C) return x