xfu314's picture
Add phantom project with submodules and dependencies
96da58e
"""
Implementation of transformers, mostly based on Andrej's minGPT model.
See https://github.com/karpathy/minGPT/blob/master/mingpt/model.py
for more details.
"""
import math
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
from robomimic.models.base_nets import Module
import robomimic.utils.tensor_utils as TensorUtils
import robomimic.utils.torch_utils as TorchUtils
class GEGLU(nn.Module):
"""
References:
Shazeer et al., "GLU Variants Improve Transformer," 2020.
https://arxiv.org/abs/2002.05202
Implementation: https://github.com/pfnet-research/deep-table/blob/237c8be8a405349ce6ab78075234c60d9bfe60b7/deep_table/nn/layers/activation.py
"""
def geglu(self, x):
assert x.shape[-1] % 2 == 0
a, b = x.chunk(2, dim=-1)
return a * F.gelu(b)
def forward(self, x):
return self.geglu(x)
class PositionalEncoding(nn.Module):
"""
Taken from https://pytorch.org/tutorials/beginner/transformer_tutorial.html.
"""
def __init__(self, embed_dim):
"""
Standard sinusoidal positional encoding scheme in transformers.
Positional encoding of the k'th position in the sequence is given by:
p(k, 2i) = sin(k/n^(i/d))
p(k, 2i+1) = sin(k/n^(i/d))
n: set to 10K in original Transformer paper
d: the embedding dimension
i: positions along the projected embedding space (ranges from 0 to d/2)
Args:
embed_dim: The number of dimensions to project the timesteps into.
"""
super().__init__()
self.embed_dim = embed_dim
def forward(self, x):
"""
Input timestep of shape BxT
"""
position = x
# computing 1/n^(i/d) in log space and then exponentiating and fixing the shape
div_term = (
torch.exp(
torch.arange(0, self.embed_dim, 2, device=x.device)
* (-math.log(10000.0) / self.embed_dim)
)
.unsqueeze(0)
.unsqueeze(0)
.repeat(x.shape[0], x.shape[1], 1)
)
pe = torch.zeros((x.shape[0], x.shape[1], self.embed_dim), device=x.device)
pe[:, :, 0::2] = torch.sin(position.unsqueeze(-1) * div_term)
pe[:, :, 1::2] = torch.cos(position.unsqueeze(-1) * div_term)
return pe.detach()
class CausalSelfAttention(Module):
def __init__(
self,
embed_dim,
num_heads,
context_length,
attn_dropout=0.1,
output_dropout=0.1,
):
"""
Multi-head masked self-attention layer + projection (MLP layer).
For normal self-attention (@num_heads = 1), every single input in the sequence is
mapped to a key, query, and value embedding of size @embed_dim. For each input,
its query vector is compared (using dot-product) with all other key vectors in the
sequence, and softmax normalized to compute an attention over all members of the
sequence. This is used to take a linear combination of corresponding value embeddings.
The @num_heads argument is for multi-head attention, where the self-attention operation above
is performed in parallel over equal size partitions of the @embed_dim, allowing for different
portions of the embedding dimension to model different kinds of attention. The attention
output for each head is concatenated together.
Finally, we use a causal mask here to ensure that each output only depends on inputs that come
before it.
Args:
embed_dim (int): dimension of embeddings to use for keys, queries, and values
used in self-attention
num_heads (int): number of attention heads - must divide @embed_dim evenly. Self-attention is
computed over this many partitions of the embedding dimension separately.
context_length (int): expected length of input sequences
attn_dropout (float): dropout probability for attention outputs
output_dropout (float): dropout probability for final outputs
"""
super(CausalSelfAttention, self).__init__()
assert (
embed_dim % num_heads == 0
), "num_heads: {} does not divide embed_dim: {} exactly".format(num_heads, embed_dim)
self.embed_dim = embed_dim
self.num_heads = num_heads
self.context_length = context_length
self.attn_dropout = attn_dropout
self.output_dropout = output_dropout
self.nets = nn.ModuleDict()
# projection layers for key, query, value, across all attention heads
self.nets["qkv"] = nn.Linear(self.embed_dim, 3 * self.embed_dim, bias=False)
# dropout layers
self.nets["attn_dropout"] = nn.Dropout(self.attn_dropout)
self.nets["output_dropout"] = nn.Dropout(self.output_dropout)
# output layer
self.nets["output"] = nn.Linear(self.embed_dim, self.embed_dim)
# causal mask (ensures attention is only over previous inputs) - just a lower triangular matrix of 1s
mask = torch.tril(torch.ones(context_length, context_length)).view(
1, 1, context_length, context_length
)
self.register_buffer("mask", mask)
def forward(self, x):
"""
Forward pass through Self-Attention block.
Input should be shape (B, T, D) where B is batch size, T is seq length (@self.context_length), and
D is input dimension (@self.embed_dim).
"""
# enforce shape consistency
assert len(x.shape) == 3
B, T, D = x.shape
assert (
T <= self.context_length
), "self-attention module can only handle sequences up to {} in length but got length {}".format(
self.context_length, T
)
assert D == self.embed_dim
NH = self.num_heads # number of attention heads
DH = D // NH # embed dimension for each attention head
# compute key, query, and value vectors for each member of sequence, and split across attention heads
qkv = self.nets["qkv"](x)
q, k, v = torch.chunk(qkv, 3, dim=-1)
k = k.view(B, T, NH, DH).transpose(1, 2) # [B, NH, T, DH]
q = q.view(B, T, NH, DH).transpose(1, 2) # [B, NH, T, DH]
v = v.view(B, T, NH, DH).transpose(1, 2) # [B, NH, T, DH]
# causal self-attention mechanism
# batched matrix multiplication between queries and keys to get all pair-wise dot-products.
# We broadcast across batch and attention heads and get pair-wise dot-products between all pairs of timesteps
# [B, NH, T, DH] x [B, NH, DH, T] -> [B, NH, T, T]
att = (q @ k.transpose(-2, -1)) * (1.0 / math.sqrt(k.size(-1)))
# use mask to replace entries in dot products with negative inf to ensure they don't contribute to softmax,
# then take softmax over last dimension to end up with attention score for each member of sequence.
# Note the use of [:T, :T] - this makes it so we can handle sequences less than @self.context_length in length.
att = att.masked_fill(self.mask[..., :T, :T] == 0, float("-inf"))
att = F.softmax(
att, dim=-1
) # shape [B, NH, T, T], last dimension has score over all T for each sequence member
# dropout on attention
att = self.nets["attn_dropout"](att)
# take weighted sum of value vectors over whole sequence according to attention, with batched matrix multiplication
# [B, NH, T, T] x [B, NH, T, DH] -> [B, NH, T, DH]
y = att @ v
# reshape [B, NH, T, DH] -> [B, T, NH, DH] -> [B, T, NH * DH] = [B, T, D]
y = y.transpose(1, 2).contiguous().view(B, T, D)
# pass through output layer + dropout
y = self.nets["output"](y)
y = self.nets["output_dropout"](y)
return y
def output_shape(self, input_shape=None):
"""
Function to compute output shape from inputs to this module.
Args:
input_shape (iterable of int): shape of input. Does not include batch dimension.
Some modules may not need this argument, if their output does not depend
on the size of the input, or if they assume fixed size input.
Returns:
out_shape ([int]): list of integers corresponding to output shape
"""
# this module doesn't modify the size of the input, it goes from (B, T, D) -> (B, T, D)
return list(input_shape)
class SelfAttentionBlock(Module):
"""
A single Transformer Block, that can be chained together repeatedly.
It consists of a @CausalSelfAttention module and a small MLP, along with
layer normalization and residual connections on each input.
"""
def __init__(
self,
embed_dim,
num_heads,
context_length,
attn_dropout=0.1,
output_dropout=0.1,
activation=nn.GELU(),
):
"""
Args:
embed_dim (int): dimension of embeddings to use for keys, queries, and values
used in self-attention
num_heads (int): number of attention heads - must divide @embed_dim evenly. Self-attention is
computed over this many partitions of the embedding dimension separately.
context_length (int): expected length of input sequences
attn_dropout (float): dropout probability for attention outputs
output_dropout (float): dropout probability for final outputs
activation (str): string denoting the activation function to use in each transformer block
"""
super(SelfAttentionBlock, self).__init__()
self.embed_dim = embed_dim
self.num_heads = num_heads
self.context_length = context_length
self.attn_dropout = attn_dropout
self.output_dropout = output_dropout
self.nets = nn.ModuleDict()
# self-attention block
self.nets["attention"] = CausalSelfAttention(
embed_dim=embed_dim,
num_heads=num_heads,
context_length=context_length,
attn_dropout=attn_dropout,
output_dropout=output_dropout,
)
if type(activation) == GEGLU:
mult = 2
else:
mult = 1
# small 2-layer MLP
self.nets["mlp"] = nn.Sequential(
nn.Linear(embed_dim, 4 * embed_dim * mult),
activation,
nn.Linear(4 * embed_dim, embed_dim),
nn.Dropout(output_dropout)
)
# layer normalization for inputs to self-attention module and MLP
self.nets["ln1"] = nn.LayerNorm(embed_dim)
self.nets["ln2"] = nn.LayerNorm(embed_dim)
def forward(self, x):
"""
Forward pass - chain self-attention + MLP blocks, with residual connections and layer norms.
"""
x = x + self.nets["attention"](self.nets["ln1"](x))
x = x + self.nets["mlp"](self.nets["ln2"](x))
return x
def output_shape(self, input_shape=None):
"""
Function to compute output shape from inputs to this module.
Args:
input_shape (iterable of int): shape of input. Does not include batch dimension.
Some modules may not need this argument, if their output does not depend
on the size of the input, or if they assume fixed size input.
Returns:
out_shape ([int]): list of integers corresponding to output shape
"""
# this module doesn't modify the size of the input, it goes from (B, T, D) -> (B, T, D)
return list(input_shape)
class GPT_Backbone(Module):
"""the GPT model, with a context size of block_size"""
def __init__(
self,
embed_dim,
context_length,
attn_dropout=0.1,
block_output_dropout=0.1,
num_layers=6,
num_heads=8,
activation="gelu",
):
"""
Args:
embed_dim (int): dimension of embeddings to use for keys, queries, and values
used in self-attention
context_length (int): expected length of input sequences
attn_dropout (float): dropout probability for attention outputs for each transformer block
block_output_dropout (float): dropout probability for final outputs for each transformer block
num_layers (int): number of transformer blocks to stack
num_heads (int): number of attention heads - must divide @embed_dim evenly. Self-attention is
computed over this many partitions of the embedding dimension separately.
activation (str): string denoting the activation function to use in each transformer block
"""
super(GPT_Backbone, self).__init__()
self.embed_dim = embed_dim
self.num_layers = num_layers
self.num_heads = num_heads
self.context_length = context_length
self.attn_dropout = attn_dropout
self.block_output_dropout = block_output_dropout
if activation == "gelu":
self.activation = nn.GELU()
elif activation == "geglu":
self.activation = GEGLU()
# create networks
self._create_networks()
# initialize weights
self.apply(self._init_weights)
print(
"Created {} model with number of parameters: {}".format(
self.__class__.__name__, sum(p.numel() for p in self.parameters())
)
)
def _create_networks(self):
"""
Helper function to create networks.
"""
self.nets = nn.ModuleDict()
# transformer - cascaded transformer blocks
self.nets["transformer"] = nn.Sequential(
*[
SelfAttentionBlock(
embed_dim=self.embed_dim,
num_heads=self.num_heads,
context_length=self.context_length,
attn_dropout=self.attn_dropout,
output_dropout=self.block_output_dropout,
activation=self.activation,
)
for _ in range(self.num_layers)
]
)
# decoder head
self.nets["output_ln"] = nn.LayerNorm(self.embed_dim)
def _init_weights(self, module):
"""
Weight initializer.
"""
if isinstance(module, (nn.Linear, nn.Embedding)):
module.weight.data.normal_(mean=0.0, std=0.02)
if isinstance(module, nn.Linear) and module.bias is not None:
module.bias.data.zero_()
elif isinstance(module, nn.LayerNorm):
module.bias.data.zero_()
module.weight.data.fill_(1.0)
def output_shape(self, input_shape=None):
"""
Function to compute output shape from inputs to this module.
Args:
input_shape (iterable of int): shape of input. Does not include batch dimension.
Some modules may not need this argument, if their output does not depend
on the size of the input, or if they assume fixed size input.
Returns:
out_shape ([int]): list of integers corresponding to output shape
"""
# this module takes inputs (B, T, @self.input_dim) and produces outputs (B, T, @self.output_dim)
return input_shape[:-1] + [self.output_dim]
def forward(self, inputs):
assert inputs.shape[1:] == (self.context_length, self.embed_dim), inputs.shape
x = self.nets["transformer"](inputs)
transformer_output = self.nets["output_ln"](x)
return transformer_output