Spaces:
Running
Running
| from typing import Union, List, Dict | |
| from collections import namedtuple | |
| import numpy as np | |
| import math | |
| import torch | |
| import torch.nn as nn | |
| import torch.nn.functional as F | |
| from ding.utils import list_split, MODEL_REGISTRY, squeeze, SequenceType | |
| def extract(a, t, x_shape): | |
| """ | |
| Overview: | |
| extract output from a through index t. | |
| Arguments: | |
| - a (:obj:`torch.Tensor`): input tensor | |
| - t (:obj:`torch.Tensor`): index tensor | |
| - x_shape (:obj:`torch.Tensor`): shape of x | |
| """ | |
| b, *_ = t.shape | |
| out = a.gather(-1, t) | |
| return out.reshape(b, *((1, ) * (len(x_shape) - 1))) | |
| def cosine_beta_schedule(timesteps: int, s: float = 0.008, dtype=torch.float32): | |
| """ | |
| Overview: | |
| cosine schedule | |
| as proposed in https://openreview.net/forum?id=-NEXDKk8gZ | |
| Arguments: | |
| - timesteps (:obj:`int`): timesteps of diffusion step | |
| - s (:obj:`float`): s | |
| - dtype (:obj:`torch.dtype`): dtype of beta | |
| Return: | |
| Tensor of beta [timesteps,], computing by cosine. | |
| """ | |
| steps = timesteps + 1 | |
| x = np.linspace(0, steps, steps) | |
| alphas_cumprod = np.cos(((x / steps) + s) / (1 + s) * np.pi * 0.5) ** 2 | |
| alphas_cumprod = alphas_cumprod / alphas_cumprod[0] | |
| betas = 1 - (alphas_cumprod[1:] / alphas_cumprod[:-1]) | |
| betas_clipped = np.clip(betas, a_min=0, a_max=0.999) | |
| return torch.tensor(betas_clipped, dtype=dtype) | |
| def apply_conditioning(x, conditions, action_dim): | |
| """ | |
| Overview: | |
| add condition into x | |
| Arguments: | |
| - x (:obj:`torch.Tensor`): input tensor | |
| - conditions (:obj:`dict`): condition dict, key is timestep, value is condition | |
| - action_dim (:obj:`int`): action dim | |
| """ | |
| for t, val in conditions.items(): | |
| x[:, t, action_dim:] = val.clone() | |
| return x | |
| class DiffusionConv1d(nn.Module): | |
| """ | |
| Overview: | |
| Conv1d with activation and normalization for diffusion models. | |
| Interfaces: | |
| ``__init__``, ``forward`` | |
| """ | |
| def __init__( | |
| self, | |
| in_channels: int, | |
| out_channels: int, | |
| kernel_size: int, | |
| padding: int, | |
| activation: nn.Module = None, | |
| n_groups: int = 8 | |
| ) -> None: | |
| """ | |
| Overview: | |
| Create a 1-dim convlution layer with activation and normalization. This Conv1d have GropuNorm. | |
| And need add 1-dim when compute norm | |
| Arguments: | |
| - in_channels (:obj:`int`): Number of channels in the input tensor | |
| - out_channels (:obj:`int`): Number of channels in the output tensor | |
| - kernel_size (:obj:`int`): Size of the convolving kernel | |
| - padding (:obj:`int`): Zero-padding added to both sides of the input | |
| - activation (:obj:`nn.Module`): the optional activation function | |
| """ | |
| super().__init__() | |
| self.conv1 = nn.Conv1d(in_channels, out_channels, kernel_size=kernel_size, padding=padding) | |
| self.norm = nn.GroupNorm(n_groups, out_channels) | |
| self.act = activation | |
| def forward(self, inputs) -> torch.Tensor: | |
| """ | |
| Overview: | |
| compute conv1d for inputs. | |
| Arguments: | |
| - inputs (:obj:`torch.Tensor`): input tensor | |
| Return: | |
| - out (:obj:`torch.Tensor`): output tensor | |
| """ | |
| x = self.conv1(inputs) | |
| # [batch, channels, horizon] -> [batch, channels, 1, horizon] | |
| x = x.unsqueeze(-2) | |
| x = self.norm(x) | |
| # [batch, channels, 1, horizon] -> [batch, channels, horizon] | |
| x = x.squeeze(-2) | |
| out = self.act(x) | |
| return out | |
| class SinusoidalPosEmb(nn.Module): | |
| """ | |
| Overview: | |
| class for computing sin position embeding | |
| Interfaces: | |
| ``__init__``, ``forward`` | |
| """ | |
| def __init__(self, dim: int) -> None: | |
| """ | |
| Overview: | |
| Initialization of SinusoidalPosEmb class | |
| Arguments: | |
| - dim (:obj:`int`): dimension of embeding | |
| """ | |
| super().__init__() | |
| self.dim = dim | |
| def forward(self, x) -> torch.Tensor: | |
| """ | |
| Overview: | |
| compute sin position embeding | |
| Arguments: | |
| - x (:obj:`torch.Tensor`): input tensor | |
| Return: | |
| - emb (:obj:`torch.Tensor`): output tensor | |
| """ | |
| device = x.device | |
| half_dim = self.dim // 2 | |
| emb = math.log(10000) / (half_dim - 1) | |
| emb = torch.exp(torch.arange(half_dim, device=device) * -emb) | |
| emb = x[:, None] * emb[None, :] | |
| emb = torch.cat((emb.sin(), emb.cos()), dim=1) | |
| return emb | |
| class Residual(nn.Module): | |
| """ | |
| Overview: | |
| Basic Residual block | |
| Interfaces: | |
| ``__init__``, ``forward`` | |
| """ | |
| def __init__(self, fn): | |
| """ | |
| Overview: | |
| Initialization of Residual class | |
| Arguments: | |
| - fn (:obj:`nn.Module`): function of residual block | |
| """ | |
| super().__init__() | |
| self.fn = fn | |
| def forward(self, x, *arg, **kwargs): | |
| """ | |
| Overview: | |
| compute residual block | |
| Arguments: | |
| - x (:obj:`torch.Tensor`): input tensor | |
| """ | |
| return self.fn(x, *arg, **kwargs) + x | |
| class LayerNorm(nn.Module): | |
| """ | |
| Overview: | |
| LayerNorm, compute dim = 1, because Temporal input x [batch, dim, horizon] | |
| Interfaces: | |
| ``__init__``, ``forward`` | |
| """ | |
| def __init__(self, dim, eps=1e-5) -> None: | |
| """ | |
| Overview: | |
| Initialization of LayerNorm class | |
| Arguments: | |
| - dim (:obj:`int`): dimension of input | |
| - eps (:obj:`float`): eps of LayerNorm | |
| """ | |
| super().__init__() | |
| self.eps = eps | |
| self.g = nn.Parameter(torch.ones(1, dim, 1)) | |
| self.b = nn.Parameter(torch.zeros(1, dim, 1)) | |
| def forward(self, x): | |
| """ | |
| Overview: | |
| compute LayerNorm | |
| Arguments: | |
| - x (:obj:`torch.Tensor`): input tensor | |
| """ | |
| print('x.shape:', x.shape) | |
| var = torch.var(x, dim=1, unbiased=False, keepdim=True) | |
| mean = torch.mean(x, dim=1, keepdim=True) | |
| return (x - mean) / (var + self.eps).sqrt() * self.g + self.b | |
| class PreNorm(nn.Module): | |
| """ | |
| Overview: | |
| PreNorm, compute dim = 1, because Temporal input x [batch, dim, horizon] | |
| Interfaces: | |
| ``__init__``, ``forward`` | |
| """ | |
| def __init__(self, dim, fn) -> None: | |
| """ | |
| Overview: | |
| Initialization of PreNorm class | |
| Arguments: | |
| - dim (:obj:`int`): dimension of input | |
| - fn (:obj:`nn.Module`): function of residual block | |
| """ | |
| super().__init__() | |
| self.fn = fn | |
| self.norm = LayerNorm(dim) | |
| def forward(self, x): | |
| """ | |
| Overview: | |
| compute PreNorm | |
| Arguments: | |
| - x (:obj:`torch.Tensor`): input tensor | |
| """ | |
| x = self.norm(x) | |
| return self.fn(x) | |
| class LinearAttention(nn.Module): | |
| """ | |
| Overview: | |
| Linear Attention head | |
| Interfaces: | |
| ``__init__``, ``forward`` | |
| """ | |
| def __init__(self, dim, heads=4, dim_head=32) -> None: | |
| """ | |
| Overview: | |
| Initialization of LinearAttention class | |
| Arguments: | |
| - dim (:obj:`int`): dimension of input | |
| - heads (:obj:`int`): heads of attention | |
| - dim_head (:obj:`int`): dim of head | |
| """ | |
| super().__init__() | |
| self.scale = dim_head ** -0.5 | |
| self.heads = heads | |
| hidden_dim = dim_head * heads | |
| self.to_qkv = nn.Conv1d(dim, hidden_dim * 3, 1, bias=False) | |
| self.to_out = nn.Conv1d(hidden_dim, dim, 1) | |
| def forward(self, x): | |
| """ | |
| Overview: | |
| compute LinearAttention | |
| Arguments: | |
| - x (:obj:`torch.Tensor`): input tensor | |
| """ | |
| qkv = self.to_qkv(x).chunk(3, dim=1) | |
| q, k, v = map(lambda t: t.reshape(t.shape[0], self.heads, -1, t.shape[-1]), qkv) | |
| q = q * self.scale | |
| k = k.softmax(dim=-1) | |
| context = torch.einsum('b h d n, b h e n -> b h d e', k, v) | |
| out = torch.einsum('b h d e, b h d n -> b h e n', context, q) | |
| out = out.reshape(out.shape[0], -1, out.shape[-1]) | |
| return self.to_out(out) | |
| class ResidualTemporalBlock(nn.Module): | |
| """ | |
| Overview: | |
| Residual block of temporal | |
| Interfaces: | |
| ``__init__``, ``forward`` | |
| """ | |
| def __init__( | |
| self, in_channels: int, out_channels: int, embed_dim: int, kernel_size: int = 5, mish: bool = True | |
| ) -> None: | |
| """ | |
| Overview: | |
| Initialization of ResidualTemporalBlock class | |
| Arguments: | |
| - in_channels (:obj:'int'): dim of in_channels | |
| - out_channels (:obj:'int'): dim of out_channels | |
| - embed_dim (:obj:'int'): dim of embeding layer | |
| - kernel_size (:obj:'int'): kernel_size of conv1d | |
| - mish (:obj:'bool'): whether use mish as a activate function | |
| """ | |
| super().__init__() | |
| if mish: | |
| act = nn.Mish() | |
| else: | |
| act = nn.SiLU() | |
| self.blocks = nn.ModuleList( | |
| [ | |
| DiffusionConv1d(in_channels, out_channels, kernel_size, kernel_size // 2, act), | |
| DiffusionConv1d(out_channels, out_channels, kernel_size, kernel_size // 2, act), | |
| ] | |
| ) | |
| self.time_mlp = nn.Sequential( | |
| act, | |
| nn.Linear(embed_dim, out_channels), | |
| ) | |
| self.residual_conv = nn.Conv1d(in_channels, out_channels, 1) \ | |
| if in_channels != out_channels else nn.Identity() | |
| def forward(self, x, t): | |
| """ | |
| Overview: | |
| compute residual block | |
| Arguments: | |
| - x (:obj:'tensor'): input tensor | |
| - t (:obj:'tensor'): time tensor | |
| """ | |
| out = self.blocks[0](x) + self.time_mlp(t).unsqueeze(-1) | |
| out = self.blocks[1](out) | |
| return out + self.residual_conv(x) | |
| class DiffusionUNet1d(nn.Module): | |
| """ | |
| Overview: | |
| Diffusion unet for 1d vector data | |
| Interfaces: | |
| ``__init__``, ``forward``, ``get_pred`` | |
| """ | |
| def __init__( | |
| self, | |
| transition_dim: int, | |
| dim: int = 32, | |
| dim_mults: SequenceType = [1, 2, 4, 8], | |
| returns_condition: bool = False, | |
| condition_dropout: float = 0.1, | |
| calc_energy: bool = False, | |
| kernel_size: int = 5, | |
| attention: bool = False, | |
| ) -> None: | |
| """ | |
| Overview: | |
| Initialization of DiffusionUNet1d class | |
| Arguments: | |
| - transition_dim (:obj:'int'): dim of transition, it is obs_dim + action_dim | |
| - dim (:obj:'int'): dim of layer | |
| - dim_mults (:obj:'SequenceType'): mults of dim | |
| - returns_condition (:obj:'bool'): whether use return as a condition | |
| - condition_dropout (:obj:'float'): dropout of returns condition | |
| - calc_energy (:obj:'bool'): whether use calc_energy | |
| - kernel_size (:obj:'int'): kernel_size of conv1d | |
| - attention (:obj:'bool'): whether use attention | |
| """ | |
| super().__init__() | |
| dims = [transition_dim, *map(lambda m: dim * m, dim_mults)] | |
| in_out = list(zip(dims[:-1], dims[1:])) | |
| if calc_energy: | |
| mish = False | |
| act = nn.SiLU() | |
| else: | |
| mish = True | |
| act = nn.Mish() | |
| self.time_dim = dim | |
| self.returns_dim = dim | |
| self.time_mlp = nn.Sequential( | |
| SinusoidalPosEmb(dim), | |
| nn.Linear(dim, dim * 4), | |
| act, | |
| nn.Linear(dim * 4, dim), | |
| ) | |
| self.returns_condition = returns_condition | |
| self.condition_dropout = condition_dropout | |
| self.cale_energy = calc_energy | |
| if self.returns_condition: | |
| self.returns_mlp = nn.Sequential( | |
| nn.Linear(1, dim), | |
| act, | |
| nn.Linear(dim, dim * 4), | |
| act, | |
| nn.Linear(dim * 4, dim), | |
| ) | |
| self.mask_dist = torch.distributions.Bernoulli(probs=1 - self.condition_dropout) | |
| embed_dim = 2 * dim | |
| else: | |
| embed_dim = dim | |
| self.downs = nn.ModuleList([]) | |
| self.ups = nn.ModuleList([]) | |
| num_resolution = len(in_out) | |
| for ind, (dim_in, dim_out) in enumerate(in_out): | |
| is_last = ind >= (num_resolution - 1) | |
| self.downs.append( | |
| nn.ModuleList( | |
| [ | |
| ResidualTemporalBlock(dim_in, dim_out, embed_dim, kernel_size, mish=mish), | |
| ResidualTemporalBlock(dim_out, dim_out, embed_dim, kernel_size, mish=mish), | |
| Residual(PreNorm(dim_out, LinearAttention(dim_out))) if attention else nn.Identity(), | |
| nn.Conv1d(dim_out, dim_out, 3, 2, 1) if not is_last else nn.Identity() | |
| ] | |
| ) | |
| ) | |
| mid_dim = dims[-1] | |
| self.mid_block1 = ResidualTemporalBlock(mid_dim, mid_dim, embed_dim, kernel_size, mish) | |
| self.mid_atten = Residual(PreNorm(mid_dim, LinearAttention(mid_dim))) if attention else nn.Identity() | |
| self.mid_block2 = ResidualTemporalBlock(mid_dim, mid_dim, embed_dim, kernel_size, mish) | |
| for ind, (dim_in, dim_out) in enumerate(reversed(in_out[1:])): | |
| is_last = ind >= (num_resolution - 1) | |
| self.ups.append( | |
| nn.ModuleList( | |
| [ | |
| ResidualTemporalBlock(dim_out * 2, dim_in, embed_dim, kernel_size, mish=mish), | |
| ResidualTemporalBlock(dim_in, dim_in, embed_dim, kernel_size, mish=mish), | |
| Residual(PreNorm(dim_in, LinearAttention(dim_in))) if attention else nn.Identity(), | |
| nn.ConvTranspose1d(dim_in, dim_in, 4, 2, 1) if not is_last else nn.Identity() | |
| ] | |
| ) | |
| ) | |
| self.final_conv = nn.Sequential( | |
| DiffusionConv1d(dim, dim, kernel_size=kernel_size, padding=kernel_size // 2, activation=act), | |
| nn.Conv1d(dim, transition_dim, 1), | |
| ) | |
| def forward(self, x, cond, time, returns=None, use_dropout: bool = True, force_dropout: bool = False): | |
| """ | |
| Overview: | |
| compute diffusion unet forward | |
| Arguments: | |
| - x (:obj:'tensor'): noise trajectory | |
| - cond (:obj:'tuple'): [ (time, state), ... ] state is init state of env, time = 0 | |
| - time (:obj:'int'): timestep of diffusion step | |
| - returns (:obj:'tensor'): condition returns of trajectory, returns is normal return | |
| - use_dropout (:obj:'bool'): Whether use returns condition mask | |
| - force_dropout (:obj:'bool'): Whether use returns condition | |
| """ | |
| if self.cale_energy: | |
| x_inp = x | |
| # [batch, horizon, transition ] -> [batch, transition , horizon] | |
| x = x.transpose(1, 2) | |
| t = self.time_mlp(time) | |
| if self.returns_condition: | |
| assert returns is not None | |
| returns_embed = self.returns_mlp(returns) | |
| if use_dropout: | |
| mask = self.mask_dist.sample(sample_shape=(returns_embed.size(0), 1)).to(returns_embed.device) | |
| returns_embed = mask * returns_embed | |
| if force_dropout: | |
| returns_embed = 0 * returns_embed | |
| t = torch.cat([t, returns_embed], dim=-1) | |
| h = [] | |
| for resnet, resnet2, atten, downsample in self.downs: | |
| x = resnet(x, t) | |
| x = resnet2(x, t) | |
| x = atten(x) | |
| h.append(x) | |
| x = downsample(x) | |
| x = self.mid_block1(x, t) | |
| x = self.mid_atten(x) | |
| x = self.mid_block2(x, t) | |
| for resnet, resnet2, atten, upsample in self.ups: | |
| x = torch.cat((x, h.pop()), dim=1) | |
| x = resnet(x, t) | |
| x = resnet2(x, t) | |
| x = atten(x) | |
| x = upsample(x) | |
| x = self.final_conv(x) | |
| # [batch, transition , horizon] -> [batch, horizon, transition ] | |
| x = x.transpose(1, 2) | |
| if self.cale_energy: | |
| # Energy function | |
| energy = ((x - x_inp) ** 2).mean() | |
| grad = torch.autograd.grad(outputs=energy, inputs=x_inp, create_graph=True) | |
| return grad[0] | |
| else: | |
| return x | |
| def get_pred(self, x, cond, time, returns: bool = None, use_dropout: bool = True, force_dropout: bool = False): | |
| """ | |
| Overview: | |
| compute diffusion unet forward | |
| Arguments: | |
| - x (:obj:'tensor'): noise trajectory | |
| - cond (:obj:'tuple'): [ (time, state), ... ] state is init state of env, time = 0 | |
| - time (:obj:'int'): timestep of diffusion step | |
| - returns (:obj:'tensor'): condition returns of trajectory, returns is normal return | |
| - use_dropout (:obj:'bool'): Whether use returns condition mask | |
| - force_dropout (:obj:'bool'): Whether use returns condition | |
| """ | |
| # [batch, horizon, transition ] -> [batch, transition , horizon] | |
| x = x.transpose(1, 2) | |
| t = self.time_mlp(time) | |
| if self.returns_condition: | |
| assert returns is not None | |
| returns_embed = self.returns_mlp(returns) | |
| if use_dropout: | |
| mask = self.mask_dist.sample(sample_shape=(returns_embed.size(0), 1)).to(returns_embed.device) | |
| returns_embed = mask * returns_embed | |
| if force_dropout: | |
| returns_embed = 0 * returns_embed | |
| t = torch.cat([t, returns_embed], dim=-1) | |
| h = [] | |
| for resnet, resnet2, downsample in self.downs: | |
| x = resnet(x, t) | |
| x = resnet2(x, t) | |
| h.append(x) | |
| x = downsample(x) | |
| x = self.mid_block1(x, t) | |
| x = self.mid_block2(x, t) | |
| for resnet, resnet2, upsample in self.ups: | |
| x = torch.cat((x, h.pop()), dim=1) | |
| x = resnet(x, t) | |
| x = resnet2(x, t) | |
| x = upsample(x) | |
| x = self.final_conv(x) | |
| # [batch, transition , horizon] -> [batch, horizon, transition ] | |
| x = x.transpose(1, 2) | |
| return x | |
| class TemporalValue(nn.Module): | |
| """ | |
| Overview: | |
| temporal net for value function | |
| Interfaces: | |
| ``__init__``, ``forward`` | |
| """ | |
| def __init__( | |
| self, | |
| horizon: int, | |
| transition_dim: int, | |
| dim: int = 32, | |
| time_dim: int = None, | |
| out_dim: int = 1, | |
| kernel_size: int = 5, | |
| dim_mults: SequenceType = [1, 2, 4, 8], | |
| ) -> None: | |
| """ | |
| Overview: | |
| Initialization of TemporalValue class | |
| Arguments: | |
| - horizon (:obj:'int'): horizon of trajectory | |
| - transition_dim (:obj:'int'): dim of transition, it is obs_dim + action_dim | |
| - dim (:obj:'int'): dim of layer | |
| - time_dim (:obj:'int'): dim of time | |
| - out_dim (:obj:'int'): dim of output | |
| - kernel_size (:obj:'int'): kernel_size of conv1d | |
| - dim_mults (:obj:'SequenceType'): mults of dim | |
| """ | |
| super().__init__() | |
| dims = [transition_dim, *map(lambda m: dim * m, dim_mults)] | |
| in_out = list(zip(dims[:-1], dims[1:])) | |
| time_dim = time_dim or dim | |
| self.time_mlp = nn.Sequential( | |
| SinusoidalPosEmb(dim), | |
| nn.Linear(dim, dim * 4), | |
| nn.Mish(), | |
| nn.Linear(dim * 4, dim), | |
| ) | |
| self.blocks = nn.ModuleList([]) | |
| for ind, (dim_in, dim_out) in enumerate(in_out): | |
| self.blocks.append( | |
| nn.ModuleList( | |
| [ | |
| ResidualTemporalBlock(dim_in, dim_out, kernel_size=kernel_size, embed_dim=time_dim), | |
| ResidualTemporalBlock(dim_out, dim_out, kernel_size=kernel_size, embed_dim=time_dim), | |
| nn.Conv1d(dim_out, dim_out, 3, 2, 1) | |
| ] | |
| ) | |
| ) | |
| horizon = horizon // 2 | |
| mid_dim = dims[-1] | |
| mid_dim_2 = mid_dim // 2 | |
| mid_dim_3 = mid_dim // 4 | |
| self.mid_block1 = ResidualTemporalBlock(mid_dim, mid_dim_2, kernel_size=kernel_size, embed_dim=time_dim) | |
| self.mid_down1 = nn.Conv1d(mid_dim_2, mid_dim_2, 3, 2, 1) | |
| horizon = horizon // 2 | |
| self.mid_block2 = ResidualTemporalBlock(mid_dim_2, mid_dim_3, kernel_size=kernel_size, embed_dim=time_dim) | |
| self.mid_down2 = nn.Conv1d(mid_dim_3, mid_dim_3, 3, 2, 1) | |
| horizon = horizon // 2 | |
| fc_dim = mid_dim_3 * max(horizon, 1) | |
| self.final_block = nn.Sequential( | |
| nn.Linear(fc_dim + time_dim, fc_dim // 2), | |
| nn.Mish(), | |
| nn.Linear(fc_dim // 2, out_dim), | |
| ) | |
| def forward(self, x, cond, time, *args): | |
| """ | |
| Overview: | |
| compute temporal value forward | |
| Arguments: | |
| - x (:obj:'tensor'): noise trajectory | |
| - cond (:obj:'tuple'): [ (time, state), ... ] state is init state of env, time = 0 | |
| - time (:obj:'int'): timestep of diffusion step | |
| """ | |
| # [batch, horizon, transition ] -> [batch, transition , horizon] | |
| x = x.transpose(1, 2) | |
| t = self.time_mlp(time) | |
| for resnet, resnet2, downsample in self.blocks: | |
| x = resnet(x, t) | |
| x = resnet2(x, t) | |
| x = downsample(x) | |
| x = self.mid_block1(x, t) | |
| x = self.mid_down1(x) | |
| x = self.mid_block2(x, t) | |
| x = self.mid_down2(x) | |
| x = x.view(len(x), -1) | |
| out = self.final_block(torch.cat([x, t], dim=-1)) | |
| return out | |