xfu314's picture
Add phantom project with submodules and dependencies
96da58e
"""
Implementation of Diffusion Policy https://diffusion-policy.cs.columbia.edu/ by Cheng Chi
"""
from typing import Callable, Union
import math
from collections import OrderedDict, deque
from packaging.version import parse as parse_version
import copy
import torch
import torch.nn as nn
import torch.nn.functional as F
# requires diffusers==0.11.1
from diffusers.schedulers.scheduling_ddpm import DDPMScheduler
from diffusers.schedulers.scheduling_ddim import DDIMScheduler
from diffusers.training_utils import EMAModel
import robomimic.models.obs_nets as ObsNets
import robomimic.utils.tensor_utils as TensorUtils
import robomimic.utils.torch_utils as TorchUtils
import robomimic.utils.obs_utils as ObsUtils
from robomimic.algo import register_algo_factory_func, PolicyAlgo
@register_algo_factory_func("diffusion_policy")
def algo_config_to_class(algo_config):
"""
Maps algo config to the BC algo class to instantiate, along with additional algo kwargs.
Args:
algo_config (Config instance): algo config
Returns:
algo_class: subclass of Algo
algo_kwargs (dict): dictionary of additional kwargs to pass to algorithm
"""
if algo_config.unet.enabled:
return DiffusionPolicyUNet, {}
elif algo_config.transformer.enabled:
raise NotImplementedError()
else:
raise RuntimeError()
class DiffusionPolicyUNet(PolicyAlgo):
def _create_networks(self):
"""
Creates networks and places them into @self.nets.
"""
# set up different observation groups for @MIMO_MLP
observation_group_shapes = OrderedDict()
observation_group_shapes["obs"] = OrderedDict(self.obs_shapes)
encoder_kwargs = ObsUtils.obs_encoder_kwargs_from_config(self.obs_config.encoder)
obs_encoder = ObsNets.ObservationGroupEncoder(
observation_group_shapes=observation_group_shapes,
encoder_kwargs=encoder_kwargs,
)
# IMPORTANT!
# replace all BatchNorm with GroupNorm to work with EMA
# performance will tank if you forget to do this!
obs_encoder = replace_bn_with_gn(obs_encoder)
obs_dim = obs_encoder.output_shape()[0]
# create network object
noise_pred_net = ConditionalUnet1D(
input_dim=self.ac_dim,
global_cond_dim=obs_dim*self.algo_config.horizon.observation_horizon
)
# the final arch has 2 parts
nets = nn.ModuleDict({
'policy': nn.ModuleDict({
'obs_encoder': obs_encoder,
'noise_pred_net': noise_pred_net
})
})
nets = nets.float().to(self.device)
# setup noise scheduler
noise_scheduler = None
if self.algo_config.ddpm.enabled:
noise_scheduler = DDPMScheduler(
num_train_timesteps=self.algo_config.ddpm.num_train_timesteps,
beta_schedule=self.algo_config.ddpm.beta_schedule,
clip_sample=self.algo_config.ddpm.clip_sample,
prediction_type=self.algo_config.ddpm.prediction_type
)
elif self.algo_config.ddim.enabled:
noise_scheduler = DDIMScheduler(
num_train_timesteps=self.algo_config.ddim.num_train_timesteps,
beta_schedule=self.algo_config.ddim.beta_schedule,
clip_sample=self.algo_config.ddim.clip_sample,
set_alpha_to_one=self.algo_config.ddim.set_alpha_to_one,
steps_offset=self.algo_config.ddim.steps_offset,
prediction_type=self.algo_config.ddim.prediction_type
)
else:
raise RuntimeError()
# setup EMA
ema = None
if self.algo_config.ema.enabled:
ema = EMAModel(parameters=nets.parameters(), power=self.algo_config.ema.power)
# set attrs
self.nets = nets
self._shadow_nets = copy.deepcopy(self.nets).eval()
self._shadow_nets.requires_grad_(False)
self.noise_scheduler = noise_scheduler
self.ema = ema
self.action_check_done = False
self.obs_queue = None
self.action_queue = None
def process_batch_for_training(self, batch):
"""
Processes input batch from a data loader to filter out
relevant information and prepare the batch for training.
Args:
batch (dict): dictionary with torch.Tensors sampled
from a data loader
Returns:
input_batch (dict): processed and filtered batch that
will be used for training
"""
To = self.algo_config.horizon.observation_horizon
Ta = self.algo_config.horizon.action_horizon
Tp = self.algo_config.horizon.prediction_horizon
input_batch = dict()
input_batch["obs"] = {k: batch["obs"][k][:, :To, :] for k in batch["obs"]}
input_batch["goal_obs"] = batch.get("goal_obs", None) # goals may not be present
input_batch["actions"] = batch["actions"][:, :Tp, :]
# check if actions are normalized to [-1,1]
if not self.action_check_done:
actions = input_batch["actions"]
in_range = (-1 <= actions) & (actions <= 1)
all_in_range = torch.all(in_range).item()
if not all_in_range:
raise ValueError('"actions" must be in range [-1,1] for Diffusion Policy! Check if hdf5_normalize_action is enabled.')
self.action_check_done = True
return TensorUtils.to_device(TensorUtils.to_float(input_batch), self.device)
def train_on_batch(self, batch, epoch, validate=False):
"""
Training on a single batch of data.
Args:
batch (dict): dictionary with torch.Tensors sampled
from a data loader and filtered by @process_batch_for_training
epoch (int): epoch number - required by some Algos that need
to perform staged training and early stopping
validate (bool): if True, don't perform any learning updates.
Returns:
info (dict): dictionary of relevant inputs, outputs, and losses
that might be relevant for logging
"""
To = self.algo_config.horizon.observation_horizon
Ta = self.algo_config.horizon.action_horizon
Tp = self.algo_config.horizon.prediction_horizon
action_dim = self.ac_dim
B = batch['actions'].shape[0]
with TorchUtils.maybe_no_grad(no_grad=validate):
info = super(DiffusionPolicyUNet, self).train_on_batch(batch, epoch, validate=validate)
actions = batch['actions']
# encode obs
inputs = {
'obs': batch["obs"],
'goal': batch["goal_obs"]
}
for k in self.obs_shapes:
# first two dimensions should be [B, T] for inputs
assert inputs['obs'][k].ndim - 2 == len(self.obs_shapes[k])
obs_features = TensorUtils.time_distributed(inputs, self.nets['policy']['obs_encoder'], inputs_as_kwargs=True)
assert obs_features.ndim == 3 # [B, T, D]
obs_cond = obs_features.flatten(start_dim=1)
# sample noise to add to actions
noise = torch.randn(actions.shape, device=self.device)
# sample a diffusion iteration for each data point
timesteps = torch.randint(
0, self.noise_scheduler.config.num_train_timesteps,
(B,), device=self.device
).long()
# add noise to the clean actions according to the noise magnitude at each diffusion iteration
# (this is the forward diffusion process)
noisy_actions = self.noise_scheduler.add_noise(
actions, noise, timesteps)
# predict the noise residual
noise_pred = self.nets['policy']['noise_pred_net'](
noisy_actions, timesteps, global_cond=obs_cond)
# L2 loss
loss = F.mse_loss(noise_pred, noise)
# logging
losses = {
'l2_loss': loss
}
info["losses"] = TensorUtils.detach(losses)
if not validate:
# gradient step
policy_grad_norms = TorchUtils.backprop_for_loss(
net=self.nets,
optim=self.optimizers["policy"],
loss=loss,
)
# update Exponential Moving Average of the model weights
if self.ema is not None:
self.ema.step(self.nets.parameters())
step_info = {
'policy_grad_norms': policy_grad_norms
}
info.update(step_info)
return info
def log_info(self, info):
"""
Process info dictionary from @train_on_batch to summarize
information to pass to tensorboard for logging.
Args:
info (dict): dictionary of info
Returns:
loss_log (dict): name -> summary statistic
"""
log = super(DiffusionPolicyUNet, self).log_info(info)
log["Loss"] = info["losses"]["l2_loss"].item()
if "policy_grad_norms" in info:
log["Policy_Grad_Norms"] = info["policy_grad_norms"]
return log
def reset(self):
"""
Reset algo state to prepare for environment rollouts.
"""
# setup inference queues
To = self.algo_config.horizon.observation_horizon
Ta = self.algo_config.horizon.action_horizon
obs_queue = deque(maxlen=To)
action_queue = deque(maxlen=Ta)
self.obs_queue = obs_queue
self.action_queue = action_queue
def get_action(self, obs_dict, goal_dict=None):
"""
Get policy action outputs.
Args:
obs_dict (dict): current observation [1, Do]
goal_dict (dict): (optional) goal
Returns:
action (torch.Tensor): action tensor [1, Da]
"""
# obs_dict: key: [1,D]
To = self.algo_config.horizon.observation_horizon
Ta = self.algo_config.horizon.action_horizon
# TODO: obs_queue already handled by frame_stack
# make sure we have at least To observations in obs_queue
# if not enough, repeat
# if already full, append one to the obs_queue
# n_repeats = max(To - len(self.obs_queue), 1)
# self.obs_queue.extend([obs_dict] * n_repeats)
if len(self.action_queue) == 0:
# no actions left, run inference
# turn obs_queue into dict of tensors (concat at T dim)
# import pdb; pdb.set_trace()
# obs_dict_list = TensorUtils.list_of_flat_dict_to_dict_of_list(list(self.obs_queue))
# obs_dict_tensor = dict((k, torch.cat(v, dim=0).unsqueeze(0)) for k,v in obs_dict_list.items())
# run inference
# [1,T,Da]
action_sequence = self._get_action_trajectory(obs_dict=obs_dict)
# put actions into the queue
self.action_queue.extend(action_sequence[0])
# has action, execute from left to right
# [Da]
action = self.action_queue.popleft()
# [1,Da]
action = action.unsqueeze(0)
return action
def _get_action_trajectory(self, obs_dict, goal_dict=None):
assert not self.nets.training
To = self.algo_config.horizon.observation_horizon
Ta = self.algo_config.horizon.action_horizon
Tp = self.algo_config.horizon.prediction_horizon
action_dim = self.ac_dim
if self.algo_config.ddpm.enabled is True:
num_inference_timesteps = self.algo_config.ddpm.num_inference_timesteps
elif self.algo_config.ddim.enabled is True:
num_inference_timesteps = self.algo_config.ddim.num_inference_timesteps
else:
raise ValueError
# select network
nets = self.nets
if self.ema is not None:
self.ema.copy_to(parameters=self._shadow_nets.parameters())
nets = self._shadow_nets
# encode obs
inputs = {
'obs': obs_dict,
'goal': goal_dict
}
for k in self.obs_shapes:
# first two dimensions should be [B, T] for inputs
assert inputs['obs'][k].ndim - 2 == len(self.obs_shapes[k])
obs_features = TensorUtils.time_distributed(inputs, self.nets['policy']['obs_encoder'], inputs_as_kwargs=True)
assert obs_features.ndim == 3 # [B, T, D]
B = obs_features.shape[0]
# reshape observation to (B,obs_horizon*obs_dim)
obs_cond = obs_features.flatten(start_dim=1)
# initialize action from Guassian noise
noisy_action = torch.randn(
(B, Tp, action_dim), device=self.device)
naction = noisy_action
# init scheduler
self.noise_scheduler.set_timesteps(num_inference_timesteps)
for k in self.noise_scheduler.timesteps:
# predict noise
noise_pred = nets['policy']['noise_pred_net'](
sample=naction,
timestep=k,
global_cond=obs_cond
)
# inverse diffusion step (remove noise)
naction = self.noise_scheduler.step(
model_output=noise_pred,
timestep=k,
sample=naction
).prev_sample
# process action using Ta
start = To - 1
end = start + Ta
action = naction[:,start:end]
return action
def serialize(self):
"""
Get dictionary of current model parameters.
"""
return {
"nets": self.nets.state_dict(),
"ema": self.ema.state_dict() if self.ema is not None else None,
}
def deserialize(self, model_dict):
"""
Load model from a checkpoint.
Args:
model_dict (dict): a dictionary saved by self.serialize() that contains
the same keys as @self.network_classes
"""
self.nets.load_state_dict(model_dict["nets"])
if model_dict.get("ema", None) is not None:
self.ema.load_state_dict(model_dict["ema"])
# =================== Vision Encoder Utils =====================
def replace_submodules(
root_module: nn.Module,
predicate: Callable[[nn.Module], bool],
func: Callable[[nn.Module], nn.Module]) -> nn.Module:
"""
Replace all submodules selected by the predicate with
the output of func.
predicate: Return true if the module is to be replaced.
func: Return new module to use.
"""
if predicate(root_module):
return func(root_module)
if parse_version(torch.__version__) < parse_version('1.9.0'):
raise ImportError('This function requires pytorch >= 1.9.0')
bn_list = [k.split('.') for k, m
in root_module.named_modules(remove_duplicate=True)
if predicate(m)]
for *parent, k in bn_list:
parent_module = root_module
if len(parent) > 0:
parent_module = root_module.get_submodule('.'.join(parent))
if isinstance(parent_module, nn.Sequential):
src_module = parent_module[int(k)]
else:
src_module = getattr(parent_module, k)
tgt_module = func(src_module)
if isinstance(parent_module, nn.Sequential):
parent_module[int(k)] = tgt_module
else:
setattr(parent_module, k, tgt_module)
# verify that all modules are replaced
bn_list = [k.split('.') for k, m
in root_module.named_modules(remove_duplicate=True)
if predicate(m)]
assert len(bn_list) == 0
return root_module
def replace_bn_with_gn(
root_module: nn.Module,
features_per_group: int=16) -> nn.Module:
"""
Relace all BatchNorm layers with GroupNorm.
"""
replace_submodules(
root_module=root_module,
predicate=lambda x: isinstance(x, nn.BatchNorm2d),
func=lambda x: nn.GroupNorm(
num_groups=x.num_features//features_per_group,
num_channels=x.num_features)
)
return root_module
# =================== UNet for Diffusion ==============
class SinusoidalPosEmb(nn.Module):
def __init__(self, dim):
super().__init__()
self.dim = dim
def forward(self, x):
device = x.device
half_dim = self.dim // 2
emb = math.log(10000) / (half_dim - 1)
emb = torch.exp(torch.arange(half_dim, device=device) * -emb)
emb = x[:, None] * emb[None, :]
emb = torch.cat((emb.sin(), emb.cos()), dim=-1)
return emb
class Downsample1d(nn.Module):
def __init__(self, dim):
super().__init__()
self.conv = nn.Conv1d(dim, dim, 3, 2, 1)
def forward(self, x):
return self.conv(x)
class Upsample1d(nn.Module):
def __init__(self, dim):
super().__init__()
self.conv = nn.ConvTranspose1d(dim, dim, 4, 2, 1)
def forward(self, x):
return self.conv(x)
class Conv1dBlock(nn.Module):
'''
Conv1d --> GroupNorm --> Mish
'''
def __init__(self, inp_channels, out_channels, kernel_size, n_groups=8):
super().__init__()
self.block = nn.Sequential(
nn.Conv1d(inp_channels, out_channels, kernel_size, padding=kernel_size // 2),
nn.GroupNorm(n_groups, out_channels),
nn.Mish(),
)
def forward(self, x):
return self.block(x)
class ConditionalResidualBlock1D(nn.Module):
def __init__(self,
in_channels,
out_channels,
cond_dim,
kernel_size=3,
n_groups=8):
super().__init__()
self.blocks = nn.ModuleList([
Conv1dBlock(in_channels, out_channels, kernel_size, n_groups=n_groups),
Conv1dBlock(out_channels, out_channels, kernel_size, n_groups=n_groups),
])
# FiLM modulation https://arxiv.org/abs/1709.07871
# predicts per-channel scale and bias
cond_channels = out_channels * 2
self.out_channels = out_channels
self.cond_encoder = nn.Sequential(
nn.Mish(),
nn.Linear(cond_dim, cond_channels),
nn.Unflatten(-1, (-1, 1))
)
# make sure dimensions compatible
self.residual_conv = nn.Conv1d(in_channels, out_channels, 1) \
if in_channels != out_channels else nn.Identity()
def forward(self, x, cond):
'''
x : [ batch_size x in_channels x horizon ]
cond : [ batch_size x cond_dim]
returns:
out : [ batch_size x out_channels x horizon ]
'''
out = self.blocks[0](x)
embed = self.cond_encoder(cond)
embed = embed.reshape(
embed.shape[0], 2, self.out_channels, 1)
scale = embed[:,0,...]
bias = embed[:,1,...]
out = scale * out + bias
out = self.blocks[1](out)
out = out + self.residual_conv(x)
return out
class ConditionalUnet1D(nn.Module):
def __init__(self,
input_dim,
global_cond_dim,
diffusion_step_embed_dim=256,
down_dims=[256,512,1024],
kernel_size=5,
n_groups=8
):
"""
input_dim: Dim of actions.
global_cond_dim: Dim of global conditioning applied with FiLM
in addition to diffusion step embedding. This is usually obs_horizon * obs_dim
diffusion_step_embed_dim: Size of positional encoding for diffusion iteration k
down_dims: Channel size for each UNet level.
The length of this array determines numebr of levels.
kernel_size: Conv kernel size
n_groups: Number of groups for GroupNorm
"""
super().__init__()
all_dims = [input_dim] + list(down_dims)
start_dim = down_dims[0]
dsed = diffusion_step_embed_dim
diffusion_step_encoder = nn.Sequential(
SinusoidalPosEmb(dsed),
nn.Linear(dsed, dsed * 4),
nn.Mish(),
nn.Linear(dsed * 4, dsed),
)
cond_dim = dsed + global_cond_dim
in_out = list(zip(all_dims[:-1], all_dims[1:]))
mid_dim = all_dims[-1]
self.mid_modules = nn.ModuleList([
ConditionalResidualBlock1D(
mid_dim, mid_dim, cond_dim=cond_dim,
kernel_size=kernel_size, n_groups=n_groups
),
ConditionalResidualBlock1D(
mid_dim, mid_dim, cond_dim=cond_dim,
kernel_size=kernel_size, n_groups=n_groups
),
])
down_modules = nn.ModuleList([])
for ind, (dim_in, dim_out) in enumerate(in_out):
is_last = ind >= (len(in_out) - 1)
down_modules.append(nn.ModuleList([
ConditionalResidualBlock1D(
dim_in, dim_out, cond_dim=cond_dim,
kernel_size=kernel_size, n_groups=n_groups),
ConditionalResidualBlock1D(
dim_out, dim_out, cond_dim=cond_dim,
kernel_size=kernel_size, n_groups=n_groups),
Downsample1d(dim_out) if not is_last else nn.Identity()
]))
up_modules = nn.ModuleList([])
for ind, (dim_in, dim_out) in enumerate(reversed(in_out[1:])):
is_last = ind >= (len(in_out) - 1)
up_modules.append(nn.ModuleList([
ConditionalResidualBlock1D(
dim_out*2, dim_in, cond_dim=cond_dim,
kernel_size=kernel_size, n_groups=n_groups),
ConditionalResidualBlock1D(
dim_in, dim_in, cond_dim=cond_dim,
kernel_size=kernel_size, n_groups=n_groups),
Upsample1d(dim_in) if not is_last else nn.Identity()
]))
final_conv = nn.Sequential(
Conv1dBlock(start_dim, start_dim, kernel_size=kernel_size),
nn.Conv1d(start_dim, input_dim, 1),
)
self.diffusion_step_encoder = diffusion_step_encoder
self.up_modules = up_modules
self.down_modules = down_modules
self.final_conv = final_conv
print("number of parameters: {:e}".format(
sum(p.numel() for p in self.parameters()))
)
def forward(self,
sample: torch.Tensor,
timestep: Union[torch.Tensor, float, int],
global_cond=None):
"""
x: (B,T,input_dim)
timestep: (B,) or int, diffusion step
global_cond: (B,global_cond_dim)
output: (B,T,input_dim)
"""
# (B,T,C)
sample = sample.moveaxis(-1,-2)
# (B,C,T)
# 1. time
timesteps = timestep
if not torch.is_tensor(timesteps):
timesteps = torch.tensor([timesteps], dtype=torch.long, device=sample.device)
elif torch.is_tensor(timesteps) and len(timesteps.shape) == 0:
timesteps = timesteps[None].to(sample.device)
# broadcast to batch dimension in a way that's compatible with ONNX/Core ML
timesteps = timesteps.expand(sample.shape[0])
global_feature = self.diffusion_step_encoder(timesteps)
if global_cond is not None:
global_feature = torch.cat([
global_feature, global_cond
], axis=-1)
x = sample
h = []
for idx, (resnet, resnet2, downsample) in enumerate(self.down_modules):
x = resnet(x, global_feature)
x = resnet2(x, global_feature)
h.append(x)
x = downsample(x)
for mid_module in self.mid_modules:
x = mid_module(x, global_feature)
for idx, (resnet, resnet2, upsample) in enumerate(self.up_modules):
x = torch.cat((x, h.pop()), dim=1)
x = resnet(x, global_feature)
x = resnet2(x, global_feature)
x = upsample(x)
x = self.final_conv(x)
# (B,C,T)
x = x.moveaxis(-1,-2)
# (B,T,C)
return x