Spaces:
Sleeping
Sleeping
| """ | |
| Implementation of transformers, mostly based on Andrej's minGPT model. | |
| See https://github.com/karpathy/minGPT/blob/master/mingpt/model.py | |
| for more details. | |
| """ | |
| import math | |
| import numpy as np | |
| import torch | |
| import torch.nn as nn | |
| import torch.nn.functional as F | |
| from robomimic.models.base_nets import Module | |
| import robomimic.utils.tensor_utils as TensorUtils | |
| import robomimic.utils.torch_utils as TorchUtils | |
| class GEGLU(nn.Module): | |
| """ | |
| References: | |
| Shazeer et al., "GLU Variants Improve Transformer," 2020. | |
| https://arxiv.org/abs/2002.05202 | |
| Implementation: https://github.com/pfnet-research/deep-table/blob/237c8be8a405349ce6ab78075234c60d9bfe60b7/deep_table/nn/layers/activation.py | |
| """ | |
| def geglu(self, x): | |
| assert x.shape[-1] % 2 == 0 | |
| a, b = x.chunk(2, dim=-1) | |
| return a * F.gelu(b) | |
| def forward(self, x): | |
| return self.geglu(x) | |
| class PositionalEncoding(nn.Module): | |
| """ | |
| Taken from https://pytorch.org/tutorials/beginner/transformer_tutorial.html. | |
| """ | |
| def __init__(self, embed_dim): | |
| """ | |
| Standard sinusoidal positional encoding scheme in transformers. | |
| Positional encoding of the k'th position in the sequence is given by: | |
| p(k, 2i) = sin(k/n^(i/d)) | |
| p(k, 2i+1) = sin(k/n^(i/d)) | |
| n: set to 10K in original Transformer paper | |
| d: the embedding dimension | |
| i: positions along the projected embedding space (ranges from 0 to d/2) | |
| Args: | |
| embed_dim: The number of dimensions to project the timesteps into. | |
| """ | |
| super().__init__() | |
| self.embed_dim = embed_dim | |
| def forward(self, x): | |
| """ | |
| Input timestep of shape BxT | |
| """ | |
| position = x | |
| # computing 1/n^(i/d) in log space and then exponentiating and fixing the shape | |
| div_term = ( | |
| torch.exp( | |
| torch.arange(0, self.embed_dim, 2, device=x.device) | |
| * (-math.log(10000.0) / self.embed_dim) | |
| ) | |
| .unsqueeze(0) | |
| .unsqueeze(0) | |
| .repeat(x.shape[0], x.shape[1], 1) | |
| ) | |
| pe = torch.zeros((x.shape[0], x.shape[1], self.embed_dim), device=x.device) | |
| pe[:, :, 0::2] = torch.sin(position.unsqueeze(-1) * div_term) | |
| pe[:, :, 1::2] = torch.cos(position.unsqueeze(-1) * div_term) | |
| return pe.detach() | |
| class CausalSelfAttention(Module): | |
| def __init__( | |
| self, | |
| embed_dim, | |
| num_heads, | |
| context_length, | |
| attn_dropout=0.1, | |
| output_dropout=0.1, | |
| ): | |
| """ | |
| Multi-head masked self-attention layer + projection (MLP layer). | |
| For normal self-attention (@num_heads = 1), every single input in the sequence is | |
| mapped to a key, query, and value embedding of size @embed_dim. For each input, | |
| its query vector is compared (using dot-product) with all other key vectors in the | |
| sequence, and softmax normalized to compute an attention over all members of the | |
| sequence. This is used to take a linear combination of corresponding value embeddings. | |
| The @num_heads argument is for multi-head attention, where the self-attention operation above | |
| is performed in parallel over equal size partitions of the @embed_dim, allowing for different | |
| portions of the embedding dimension to model different kinds of attention. The attention | |
| output for each head is concatenated together. | |
| Finally, we use a causal mask here to ensure that each output only depends on inputs that come | |
| before it. | |
| Args: | |
| embed_dim (int): dimension of embeddings to use for keys, queries, and values | |
| used in self-attention | |
| num_heads (int): number of attention heads - must divide @embed_dim evenly. Self-attention is | |
| computed over this many partitions of the embedding dimension separately. | |
| context_length (int): expected length of input sequences | |
| attn_dropout (float): dropout probability for attention outputs | |
| output_dropout (float): dropout probability for final outputs | |
| """ | |
| super(CausalSelfAttention, self).__init__() | |
| assert ( | |
| embed_dim % num_heads == 0 | |
| ), "num_heads: {} does not divide embed_dim: {} exactly".format(num_heads, embed_dim) | |
| self.embed_dim = embed_dim | |
| self.num_heads = num_heads | |
| self.context_length = context_length | |
| self.attn_dropout = attn_dropout | |
| self.output_dropout = output_dropout | |
| self.nets = nn.ModuleDict() | |
| # projection layers for key, query, value, across all attention heads | |
| self.nets["qkv"] = nn.Linear(self.embed_dim, 3 * self.embed_dim, bias=False) | |
| # dropout layers | |
| self.nets["attn_dropout"] = nn.Dropout(self.attn_dropout) | |
| self.nets["output_dropout"] = nn.Dropout(self.output_dropout) | |
| # output layer | |
| self.nets["output"] = nn.Linear(self.embed_dim, self.embed_dim) | |
| # causal mask (ensures attention is only over previous inputs) - just a lower triangular matrix of 1s | |
| mask = torch.tril(torch.ones(context_length, context_length)).view( | |
| 1, 1, context_length, context_length | |
| ) | |
| self.register_buffer("mask", mask) | |
| def forward(self, x): | |
| """ | |
| Forward pass through Self-Attention block. | |
| Input should be shape (B, T, D) where B is batch size, T is seq length (@self.context_length), and | |
| D is input dimension (@self.embed_dim). | |
| """ | |
| # enforce shape consistency | |
| assert len(x.shape) == 3 | |
| B, T, D = x.shape | |
| assert ( | |
| T <= self.context_length | |
| ), "self-attention module can only handle sequences up to {} in length but got length {}".format( | |
| self.context_length, T | |
| ) | |
| assert D == self.embed_dim | |
| NH = self.num_heads # number of attention heads | |
| DH = D // NH # embed dimension for each attention head | |
| # compute key, query, and value vectors for each member of sequence, and split across attention heads | |
| qkv = self.nets["qkv"](x) | |
| q, k, v = torch.chunk(qkv, 3, dim=-1) | |
| k = k.view(B, T, NH, DH).transpose(1, 2) # [B, NH, T, DH] | |
| q = q.view(B, T, NH, DH).transpose(1, 2) # [B, NH, T, DH] | |
| v = v.view(B, T, NH, DH).transpose(1, 2) # [B, NH, T, DH] | |
| # causal self-attention mechanism | |
| # batched matrix multiplication between queries and keys to get all pair-wise dot-products. | |
| # We broadcast across batch and attention heads and get pair-wise dot-products between all pairs of timesteps | |
| # [B, NH, T, DH] x [B, NH, DH, T] -> [B, NH, T, T] | |
| att = (q @ k.transpose(-2, -1)) * (1.0 / math.sqrt(k.size(-1))) | |
| # use mask to replace entries in dot products with negative inf to ensure they don't contribute to softmax, | |
| # then take softmax over last dimension to end up with attention score for each member of sequence. | |
| # Note the use of [:T, :T] - this makes it so we can handle sequences less than @self.context_length in length. | |
| att = att.masked_fill(self.mask[..., :T, :T] == 0, float("-inf")) | |
| att = F.softmax( | |
| att, dim=-1 | |
| ) # shape [B, NH, T, T], last dimension has score over all T for each sequence member | |
| # dropout on attention | |
| att = self.nets["attn_dropout"](att) | |
| # take weighted sum of value vectors over whole sequence according to attention, with batched matrix multiplication | |
| # [B, NH, T, T] x [B, NH, T, DH] -> [B, NH, T, DH] | |
| y = att @ v | |
| # reshape [B, NH, T, DH] -> [B, T, NH, DH] -> [B, T, NH * DH] = [B, T, D] | |
| y = y.transpose(1, 2).contiguous().view(B, T, D) | |
| # pass through output layer + dropout | |
| y = self.nets["output"](y) | |
| y = self.nets["output_dropout"](y) | |
| return y | |
| def output_shape(self, input_shape=None): | |
| """ | |
| Function to compute output shape from inputs to this module. | |
| Args: | |
| input_shape (iterable of int): shape of input. Does not include batch dimension. | |
| Some modules may not need this argument, if their output does not depend | |
| on the size of the input, or if they assume fixed size input. | |
| Returns: | |
| out_shape ([int]): list of integers corresponding to output shape | |
| """ | |
| # this module doesn't modify the size of the input, it goes from (B, T, D) -> (B, T, D) | |
| return list(input_shape) | |
| class SelfAttentionBlock(Module): | |
| """ | |
| A single Transformer Block, that can be chained together repeatedly. | |
| It consists of a @CausalSelfAttention module and a small MLP, along with | |
| layer normalization and residual connections on each input. | |
| """ | |
| def __init__( | |
| self, | |
| embed_dim, | |
| num_heads, | |
| context_length, | |
| attn_dropout=0.1, | |
| output_dropout=0.1, | |
| activation=nn.GELU(), | |
| ): | |
| """ | |
| Args: | |
| embed_dim (int): dimension of embeddings to use for keys, queries, and values | |
| used in self-attention | |
| num_heads (int): number of attention heads - must divide @embed_dim evenly. Self-attention is | |
| computed over this many partitions of the embedding dimension separately. | |
| context_length (int): expected length of input sequences | |
| attn_dropout (float): dropout probability for attention outputs | |
| output_dropout (float): dropout probability for final outputs | |
| activation (str): string denoting the activation function to use in each transformer block | |
| """ | |
| super(SelfAttentionBlock, self).__init__() | |
| self.embed_dim = embed_dim | |
| self.num_heads = num_heads | |
| self.context_length = context_length | |
| self.attn_dropout = attn_dropout | |
| self.output_dropout = output_dropout | |
| self.nets = nn.ModuleDict() | |
| # self-attention block | |
| self.nets["attention"] = CausalSelfAttention( | |
| embed_dim=embed_dim, | |
| num_heads=num_heads, | |
| context_length=context_length, | |
| attn_dropout=attn_dropout, | |
| output_dropout=output_dropout, | |
| ) | |
| if type(activation) == GEGLU: | |
| mult = 2 | |
| else: | |
| mult = 1 | |
| # small 2-layer MLP | |
| self.nets["mlp"] = nn.Sequential( | |
| nn.Linear(embed_dim, 4 * embed_dim * mult), | |
| activation, | |
| nn.Linear(4 * embed_dim, embed_dim), | |
| nn.Dropout(output_dropout) | |
| ) | |
| # layer normalization for inputs to self-attention module and MLP | |
| self.nets["ln1"] = nn.LayerNorm(embed_dim) | |
| self.nets["ln2"] = nn.LayerNorm(embed_dim) | |
| def forward(self, x): | |
| """ | |
| Forward pass - chain self-attention + MLP blocks, with residual connections and layer norms. | |
| """ | |
| x = x + self.nets["attention"](self.nets["ln1"](x)) | |
| x = x + self.nets["mlp"](self.nets["ln2"](x)) | |
| return x | |
| def output_shape(self, input_shape=None): | |
| """ | |
| Function to compute output shape from inputs to this module. | |
| Args: | |
| input_shape (iterable of int): shape of input. Does not include batch dimension. | |
| Some modules may not need this argument, if their output does not depend | |
| on the size of the input, or if they assume fixed size input. | |
| Returns: | |
| out_shape ([int]): list of integers corresponding to output shape | |
| """ | |
| # this module doesn't modify the size of the input, it goes from (B, T, D) -> (B, T, D) | |
| return list(input_shape) | |
| class GPT_Backbone(Module): | |
| """the GPT model, with a context size of block_size""" | |
| def __init__( | |
| self, | |
| embed_dim, | |
| context_length, | |
| attn_dropout=0.1, | |
| block_output_dropout=0.1, | |
| num_layers=6, | |
| num_heads=8, | |
| activation="gelu", | |
| ): | |
| """ | |
| Args: | |
| embed_dim (int): dimension of embeddings to use for keys, queries, and values | |
| used in self-attention | |
| context_length (int): expected length of input sequences | |
| attn_dropout (float): dropout probability for attention outputs for each transformer block | |
| block_output_dropout (float): dropout probability for final outputs for each transformer block | |
| num_layers (int): number of transformer blocks to stack | |
| num_heads (int): number of attention heads - must divide @embed_dim evenly. Self-attention is | |
| computed over this many partitions of the embedding dimension separately. | |
| activation (str): string denoting the activation function to use in each transformer block | |
| """ | |
| super(GPT_Backbone, self).__init__() | |
| self.embed_dim = embed_dim | |
| self.num_layers = num_layers | |
| self.num_heads = num_heads | |
| self.context_length = context_length | |
| self.attn_dropout = attn_dropout | |
| self.block_output_dropout = block_output_dropout | |
| if activation == "gelu": | |
| self.activation = nn.GELU() | |
| elif activation == "geglu": | |
| self.activation = GEGLU() | |
| # create networks | |
| self._create_networks() | |
| # initialize weights | |
| self.apply(self._init_weights) | |
| print( | |
| "Created {} model with number of parameters: {}".format( | |
| self.__class__.__name__, sum(p.numel() for p in self.parameters()) | |
| ) | |
| ) | |
| def _create_networks(self): | |
| """ | |
| Helper function to create networks. | |
| """ | |
| self.nets = nn.ModuleDict() | |
| # transformer - cascaded transformer blocks | |
| self.nets["transformer"] = nn.Sequential( | |
| *[ | |
| SelfAttentionBlock( | |
| embed_dim=self.embed_dim, | |
| num_heads=self.num_heads, | |
| context_length=self.context_length, | |
| attn_dropout=self.attn_dropout, | |
| output_dropout=self.block_output_dropout, | |
| activation=self.activation, | |
| ) | |
| for _ in range(self.num_layers) | |
| ] | |
| ) | |
| # decoder head | |
| self.nets["output_ln"] = nn.LayerNorm(self.embed_dim) | |
| def _init_weights(self, module): | |
| """ | |
| Weight initializer. | |
| """ | |
| if isinstance(module, (nn.Linear, nn.Embedding)): | |
| module.weight.data.normal_(mean=0.0, std=0.02) | |
| if isinstance(module, nn.Linear) and module.bias is not None: | |
| module.bias.data.zero_() | |
| elif isinstance(module, nn.LayerNorm): | |
| module.bias.data.zero_() | |
| module.weight.data.fill_(1.0) | |
| def output_shape(self, input_shape=None): | |
| """ | |
| Function to compute output shape from inputs to this module. | |
| Args: | |
| input_shape (iterable of int): shape of input. Does not include batch dimension. | |
| Some modules may not need this argument, if their output does not depend | |
| on the size of the input, or if they assume fixed size input. | |
| Returns: | |
| out_shape ([int]): list of integers corresponding to output shape | |
| """ | |
| # this module takes inputs (B, T, @self.input_dim) and produces outputs (B, T, @self.output_dim) | |
| return input_shape[:-1] + [self.output_dim] | |
| def forward(self, inputs): | |
| assert inputs.shape[1:] == (self.context_length, self.embed_dim), inputs.shape | |
| x = self.nets["transformer"](inputs) | |
| transformer_output = self.nets["output_ln"](x) | |
| return transformer_output |