| | from typing import Union, Optional, Tuple |
| | import logging |
| | import torch |
| | import torch.nn as nn |
| | from diffusion_policy.model.diffusion.positional_embedding import SinusoidalPosEmb |
| | from diffusion_policy.model.common.module_attr_mixin import ModuleAttrMixin |
| |
|
| | logger = logging.getLogger(__name__) |
| |
|
| | class TransformerForDiffusion(ModuleAttrMixin): |
| | def __init__(self, |
| | input_dim: int, |
| | output_dim: int, |
| | horizon: int, |
| | n_obs_steps: int = None, |
| | cond_dim: int = 0, |
| | n_layer: int = 12, |
| | n_head: int = 12, |
| | n_emb: int = 768, |
| | p_drop_emb: float = 0.1, |
| | p_drop_attn: float = 0.1, |
| | causal_attn: bool=False, |
| | time_as_cond: bool=True, |
| | obs_as_cond: bool=False, |
| | n_cond_layers: int = 0 |
| | ) -> None: |
| | super().__init__() |
| |
|
| | |
| | if n_obs_steps is None: |
| | n_obs_steps = horizon |
| | |
| | T = horizon |
| | T_cond = 1 |
| | if not time_as_cond: |
| | T += 1 |
| | T_cond -= 1 |
| | obs_as_cond = cond_dim > 0 |
| | if obs_as_cond: |
| | assert time_as_cond |
| | T_cond += n_obs_steps |
| |
|
| | |
| | self.input_emb = nn.Linear(input_dim, n_emb) |
| | self.pos_emb = nn.Parameter(torch.zeros(1, T, n_emb)) |
| | self.drop = nn.Dropout(p_drop_emb) |
| |
|
| | |
| | self.time_emb = SinusoidalPosEmb(n_emb) |
| | self.cond_obs_emb = None |
| | |
| | if obs_as_cond: |
| | self.cond_obs_emb = nn.Linear(cond_dim, n_emb) |
| |
|
| | self.cond_pos_emb = None |
| | self.encoder = None |
| | self.decoder = None |
| | encoder_only = False |
| | if T_cond > 0: |
| | self.cond_pos_emb = nn.Parameter(torch.zeros(1, T_cond, n_emb)) |
| | if n_cond_layers > 0: |
| | encoder_layer = nn.TransformerEncoderLayer( |
| | d_model=n_emb, |
| | nhead=n_head, |
| | dim_feedforward=4*n_emb, |
| | dropout=p_drop_attn, |
| | activation='gelu', |
| | batch_first=True, |
| | norm_first=True |
| | ) |
| | self.encoder = nn.TransformerEncoder( |
| | encoder_layer=encoder_layer, |
| | num_layers=n_cond_layers |
| | ) |
| | else: |
| | self.encoder = nn.Sequential( |
| | nn.Linear(n_emb, 4 * n_emb), |
| | nn.Mish(), |
| | nn.Linear(4 * n_emb, n_emb) |
| | ) |
| | |
| | decoder_layer = nn.TransformerDecoderLayer( |
| | d_model=n_emb, |
| | nhead=n_head, |
| | dim_feedforward=4*n_emb, |
| | dropout=p_drop_attn, |
| | activation='gelu', |
| | batch_first=True, |
| | norm_first=True |
| | ) |
| | self.decoder = nn.TransformerDecoder( |
| | decoder_layer=decoder_layer, |
| | num_layers=n_layer |
| | ) |
| | else: |
| | |
| | encoder_only = True |
| |
|
| | encoder_layer = nn.TransformerEncoderLayer( |
| | d_model=n_emb, |
| | nhead=n_head, |
| | dim_feedforward=4*n_emb, |
| | dropout=p_drop_attn, |
| | activation='gelu', |
| | batch_first=True, |
| | norm_first=True |
| | ) |
| | self.encoder = nn.TransformerEncoder( |
| | encoder_layer=encoder_layer, |
| | num_layers=n_layer |
| | ) |
| |
|
| | |
| | if causal_attn: |
| | |
| | |
| | |
| | sz = T |
| | mask = (torch.triu(torch.ones(sz, sz)) == 1).transpose(0, 1) |
| | mask = mask.float().masked_fill(mask == 0, float('-inf')).masked_fill(mask == 1, float(0.0)) |
| | self.register_buffer("mask", mask) |
| | |
| | if time_as_cond and obs_as_cond: |
| | S = T_cond |
| | t, s = torch.meshgrid( |
| | torch.arange(T), |
| | torch.arange(S), |
| | indexing='ij' |
| | ) |
| | mask = t >= (s-1) |
| | mask = mask.float().masked_fill(mask == 0, float('-inf')).masked_fill(mask == 1, float(0.0)) |
| | self.register_buffer('memory_mask', mask) |
| | else: |
| | self.memory_mask = None |
| | else: |
| | self.mask = None |
| | self.memory_mask = None |
| |
|
| | |
| | self.ln_f = nn.LayerNorm(n_emb) |
| | self.head = nn.Linear(n_emb, output_dim) |
| | |
| | |
| | self.T = T |
| | self.T_cond = T_cond |
| | self.horizon = horizon |
| | self.time_as_cond = time_as_cond |
| | self.obs_as_cond = obs_as_cond |
| | self.encoder_only = encoder_only |
| |
|
| | |
| | self.apply(self._init_weights) |
| | logger.info( |
| | "number of parameters: %e", sum(p.numel() for p in self.parameters()) |
| | ) |
| |
|
| | def _init_weights(self, module): |
| | ignore_types = (nn.Dropout, |
| | SinusoidalPosEmb, |
| | nn.TransformerEncoderLayer, |
| | nn.TransformerDecoderLayer, |
| | nn.TransformerEncoder, |
| | nn.TransformerDecoder, |
| | nn.ModuleList, |
| | nn.Mish, |
| | nn.Sequential) |
| | if isinstance(module, (nn.Linear, nn.Embedding)): |
| | torch.nn.init.normal_(module.weight, mean=0.0, std=0.02) |
| | if isinstance(module, nn.Linear) and module.bias is not None: |
| | torch.nn.init.zeros_(module.bias) |
| | elif isinstance(module, nn.MultiheadAttention): |
| | weight_names = [ |
| | 'in_proj_weight', 'q_proj_weight', 'k_proj_weight', 'v_proj_weight'] |
| | for name in weight_names: |
| | weight = getattr(module, name) |
| | if weight is not None: |
| | torch.nn.init.normal_(weight, mean=0.0, std=0.02) |
| | |
| | bias_names = ['in_proj_bias', 'bias_k', 'bias_v'] |
| | for name in bias_names: |
| | bias = getattr(module, name) |
| | if bias is not None: |
| | torch.nn.init.zeros_(bias) |
| | elif isinstance(module, nn.LayerNorm): |
| | torch.nn.init.zeros_(module.bias) |
| | torch.nn.init.ones_(module.weight) |
| | elif isinstance(module, TransformerForDiffusion): |
| | torch.nn.init.normal_(module.pos_emb, mean=0.0, std=0.02) |
| | if module.cond_obs_emb is not None: |
| | torch.nn.init.normal_(module.cond_pos_emb, mean=0.0, std=0.02) |
| | elif isinstance(module, ignore_types): |
| | |
| | pass |
| | else: |
| | raise RuntimeError("Unaccounted module {}".format(module)) |
| | |
| | def get_optim_groups(self, weight_decay: float=1e-3): |
| | """ |
| | This long function is unfortunately doing something very simple and is being very defensive: |
| | We are separating out all parameters of the model into two buckets: those that will experience |
| | weight decay for regularization and those that won't (biases, and layernorm/embedding weights). |
| | We are then returning the PyTorch optimizer object. |
| | """ |
| |
|
| | |
| | decay = set() |
| | no_decay = set() |
| | whitelist_weight_modules = (torch.nn.Linear, torch.nn.MultiheadAttention) |
| | blacklist_weight_modules = (torch.nn.LayerNorm, torch.nn.Embedding) |
| | for mn, m in self.named_modules(): |
| | for pn, p in m.named_parameters(): |
| | fpn = "%s.%s" % (mn, pn) if mn else pn |
| |
|
| | if pn.endswith("bias"): |
| | |
| | no_decay.add(fpn) |
| | elif pn.startswith("bias"): |
| | |
| | no_decay.add(fpn) |
| | elif pn.endswith("weight") and isinstance(m, whitelist_weight_modules): |
| | |
| | decay.add(fpn) |
| | elif pn.endswith("weight") and isinstance(m, blacklist_weight_modules): |
| | |
| | no_decay.add(fpn) |
| |
|
| | |
| | no_decay.add("pos_emb") |
| | no_decay.add("_dummy_variable") |
| | if self.cond_pos_emb is not None: |
| | no_decay.add("cond_pos_emb") |
| |
|
| | |
| | param_dict = {pn: p for pn, p in self.named_parameters()} |
| | inter_params = decay & no_decay |
| | union_params = decay | no_decay |
| | assert ( |
| | len(inter_params) == 0 |
| | ), "parameters %s made it into both decay/no_decay sets!" % (str(inter_params),) |
| | assert ( |
| | len(param_dict.keys() - union_params) == 0 |
| | ), "parameters %s were not separated into either decay/no_decay set!" % ( |
| | str(param_dict.keys() - union_params), |
| | ) |
| |
|
| | |
| | optim_groups = [ |
| | { |
| | "params": [param_dict[pn] for pn in sorted(list(decay))], |
| | "weight_decay": weight_decay, |
| | }, |
| | { |
| | "params": [param_dict[pn] for pn in sorted(list(no_decay))], |
| | "weight_decay": 0.0, |
| | }, |
| | ] |
| | return optim_groups |
| |
|
| |
|
| | def configure_optimizers(self, |
| | learning_rate: float=1e-4, |
| | weight_decay: float=1e-3, |
| | betas: Tuple[float, float]=(0.9,0.95)): |
| | optim_groups = self.get_optim_groups(weight_decay=weight_decay) |
| | optimizer = torch.optim.AdamW( |
| | optim_groups, lr=learning_rate, betas=betas |
| | ) |
| | return optimizer |
| |
|
| | def forward(self, |
| | sample: torch.Tensor, |
| | timestep: Union[torch.Tensor, float, int], |
| | cond: Optional[torch.Tensor]=None, **kwargs): |
| | """ |
| | x: (B,T,input_dim) |
| | timestep: (B,) or int, diffusion step |
| | cond: (B,T',cond_dim) |
| | output: (B,T,input_dim) |
| | """ |
| | |
| | timesteps = timestep |
| | if not torch.is_tensor(timesteps): |
| | |
| | timesteps = torch.tensor([timesteps], dtype=torch.long, device=sample.device) |
| | elif torch.is_tensor(timesteps) and len(timesteps.shape) == 0: |
| | timesteps = timesteps[None].to(sample.device) |
| | |
| | timesteps = timesteps.expand(sample.shape[0]) |
| | time_emb = self.time_emb(timesteps).unsqueeze(1) |
| | |
| |
|
| | |
| | input_emb = self.input_emb(sample) |
| |
|
| | if self.encoder_only: |
| | |
| | token_embeddings = torch.cat([time_emb, input_emb], dim=1) |
| | t = token_embeddings.shape[1] |
| | position_embeddings = self.pos_emb[ |
| | :, :t, : |
| | ] |
| | x = self.drop(token_embeddings + position_embeddings) |
| | |
| | x = self.encoder(src=x, mask=self.mask) |
| | |
| | x = x[:,1:,:] |
| | |
| | else: |
| | |
| | cond_embeddings = time_emb |
| | if self.obs_as_cond: |
| | cond_obs_emb = self.cond_obs_emb(cond) |
| | |
| | cond_embeddings = torch.cat([cond_embeddings, cond_obs_emb], dim=1) |
| | tc = cond_embeddings.shape[1] |
| | position_embeddings = self.cond_pos_emb[ |
| | :, :tc, : |
| | ] |
| | x = self.drop(cond_embeddings + position_embeddings) |
| | x = self.encoder(x) |
| | memory = x |
| | |
| | |
| | |
| | token_embeddings = input_emb |
| | t = token_embeddings.shape[1] |
| | position_embeddings = self.pos_emb[ |
| | :, :t, : |
| | ] |
| | x = self.drop(token_embeddings + position_embeddings) |
| | |
| | x = self.decoder( |
| | tgt=x, |
| | memory=memory, |
| | tgt_mask=self.mask, |
| | memory_mask=self.memory_mask |
| | ) |
| | |
| | |
| | |
| | x = self.ln_f(x) |
| | x = self.head(x) |
| | |
| | return x |
| |
|
| |
|
| | def test(): |
| | |
| | transformer = TransformerForDiffusion( |
| | input_dim=16, |
| | output_dim=16, |
| | horizon=8, |
| | n_obs_steps=4, |
| | |
| | causal_attn=True, |
| | |
| | |
| | ) |
| | opt = transformer.configure_optimizers() |
| |
|
| | timestep = torch.tensor(0) |
| | sample = torch.zeros((4,8,16)) |
| | out = transformer(sample, timestep) |
| | |
| |
|
| | |
| | transformer = TransformerForDiffusion( |
| | input_dim=16, |
| | output_dim=16, |
| | horizon=8, |
| | n_obs_steps=4, |
| | cond_dim=10, |
| | causal_attn=True, |
| | |
| | |
| | ) |
| | opt = transformer.configure_optimizers() |
| | |
| | timestep = torch.tensor(0) |
| | sample = torch.zeros((4,8,16)) |
| | cond = torch.zeros((4,4,10)) |
| | out = transformer(sample, timestep, cond) |
| |
|
| | |
| | transformer = TransformerForDiffusion( |
| | input_dim=16, |
| | output_dim=16, |
| | horizon=8, |
| | n_obs_steps=4, |
| | cond_dim=10, |
| | causal_attn=True, |
| | |
| | n_cond_layers=4 |
| | ) |
| | opt = transformer.configure_optimizers() |
| | |
| | timestep = torch.tensor(0) |
| | sample = torch.zeros((4,8,16)) |
| | cond = torch.zeros((4,4,10)) |
| | out = transformer(sample, timestep, cond) |
| |
|
| | |
| | transformer = TransformerForDiffusion( |
| | input_dim=16, |
| | output_dim=16, |
| | horizon=8, |
| | n_obs_steps=4, |
| | |
| | |
| | time_as_cond=False, |
| | |
| | ) |
| | opt = transformer.configure_optimizers() |
| |
|
| | timestep = torch.tensor(0) |
| | sample = torch.zeros((4,8,16)) |
| | out = transformer(sample, timestep) |
| |
|
| |
|