Spaces:
Sleeping
Sleeping
File size: 15,788 Bytes
96da58e |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 270 271 272 273 274 275 276 277 278 279 280 281 282 283 284 285 286 287 288 289 290 291 292 293 294 295 296 297 298 299 300 301 302 303 304 305 306 307 308 309 310 311 312 313 314 315 316 317 318 319 320 321 322 323 324 325 326 327 328 329 330 331 332 333 334 335 336 337 338 339 340 341 342 343 344 345 346 347 348 349 350 351 352 353 354 355 356 357 358 359 360 361 362 363 364 365 366 367 368 369 370 371 372 373 374 375 376 377 378 379 380 381 382 383 384 385 386 387 388 389 390 391 392 393 394 395 396 397 398 399 400 401 402 403 404 405 406 407 408 409 410 411 412 413 414 415 416 417 418 419 420 421 422 423 424 425 426 |
"""
Implementation of transformers, mostly based on Andrej's minGPT model.
See https://github.com/karpathy/minGPT/blob/master/mingpt/model.py
for more details.
"""
import math
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
from robomimic.models.base_nets import Module
import robomimic.utils.tensor_utils as TensorUtils
import robomimic.utils.torch_utils as TorchUtils
class GEGLU(nn.Module):
"""
References:
Shazeer et al., "GLU Variants Improve Transformer," 2020.
https://arxiv.org/abs/2002.05202
Implementation: https://github.com/pfnet-research/deep-table/blob/237c8be8a405349ce6ab78075234c60d9bfe60b7/deep_table/nn/layers/activation.py
"""
def geglu(self, x):
assert x.shape[-1] % 2 == 0
a, b = x.chunk(2, dim=-1)
return a * F.gelu(b)
def forward(self, x):
return self.geglu(x)
class PositionalEncoding(nn.Module):
"""
Taken from https://pytorch.org/tutorials/beginner/transformer_tutorial.html.
"""
def __init__(self, embed_dim):
"""
Standard sinusoidal positional encoding scheme in transformers.
Positional encoding of the k'th position in the sequence is given by:
p(k, 2i) = sin(k/n^(i/d))
p(k, 2i+1) = sin(k/n^(i/d))
n: set to 10K in original Transformer paper
d: the embedding dimension
i: positions along the projected embedding space (ranges from 0 to d/2)
Args:
embed_dim: The number of dimensions to project the timesteps into.
"""
super().__init__()
self.embed_dim = embed_dim
def forward(self, x):
"""
Input timestep of shape BxT
"""
position = x
# computing 1/n^(i/d) in log space and then exponentiating and fixing the shape
div_term = (
torch.exp(
torch.arange(0, self.embed_dim, 2, device=x.device)
* (-math.log(10000.0) / self.embed_dim)
)
.unsqueeze(0)
.unsqueeze(0)
.repeat(x.shape[0], x.shape[1], 1)
)
pe = torch.zeros((x.shape[0], x.shape[1], self.embed_dim), device=x.device)
pe[:, :, 0::2] = torch.sin(position.unsqueeze(-1) * div_term)
pe[:, :, 1::2] = torch.cos(position.unsqueeze(-1) * div_term)
return pe.detach()
class CausalSelfAttention(Module):
def __init__(
self,
embed_dim,
num_heads,
context_length,
attn_dropout=0.1,
output_dropout=0.1,
):
"""
Multi-head masked self-attention layer + projection (MLP layer).
For normal self-attention (@num_heads = 1), every single input in the sequence is
mapped to a key, query, and value embedding of size @embed_dim. For each input,
its query vector is compared (using dot-product) with all other key vectors in the
sequence, and softmax normalized to compute an attention over all members of the
sequence. This is used to take a linear combination of corresponding value embeddings.
The @num_heads argument is for multi-head attention, where the self-attention operation above
is performed in parallel over equal size partitions of the @embed_dim, allowing for different
portions of the embedding dimension to model different kinds of attention. The attention
output for each head is concatenated together.
Finally, we use a causal mask here to ensure that each output only depends on inputs that come
before it.
Args:
embed_dim (int): dimension of embeddings to use for keys, queries, and values
used in self-attention
num_heads (int): number of attention heads - must divide @embed_dim evenly. Self-attention is
computed over this many partitions of the embedding dimension separately.
context_length (int): expected length of input sequences
attn_dropout (float): dropout probability for attention outputs
output_dropout (float): dropout probability for final outputs
"""
super(CausalSelfAttention, self).__init__()
assert (
embed_dim % num_heads == 0
), "num_heads: {} does not divide embed_dim: {} exactly".format(num_heads, embed_dim)
self.embed_dim = embed_dim
self.num_heads = num_heads
self.context_length = context_length
self.attn_dropout = attn_dropout
self.output_dropout = output_dropout
self.nets = nn.ModuleDict()
# projection layers for key, query, value, across all attention heads
self.nets["qkv"] = nn.Linear(self.embed_dim, 3 * self.embed_dim, bias=False)
# dropout layers
self.nets["attn_dropout"] = nn.Dropout(self.attn_dropout)
self.nets["output_dropout"] = nn.Dropout(self.output_dropout)
# output layer
self.nets["output"] = nn.Linear(self.embed_dim, self.embed_dim)
# causal mask (ensures attention is only over previous inputs) - just a lower triangular matrix of 1s
mask = torch.tril(torch.ones(context_length, context_length)).view(
1, 1, context_length, context_length
)
self.register_buffer("mask", mask)
def forward(self, x):
"""
Forward pass through Self-Attention block.
Input should be shape (B, T, D) where B is batch size, T is seq length (@self.context_length), and
D is input dimension (@self.embed_dim).
"""
# enforce shape consistency
assert len(x.shape) == 3
B, T, D = x.shape
assert (
T <= self.context_length
), "self-attention module can only handle sequences up to {} in length but got length {}".format(
self.context_length, T
)
assert D == self.embed_dim
NH = self.num_heads # number of attention heads
DH = D // NH # embed dimension for each attention head
# compute key, query, and value vectors for each member of sequence, and split across attention heads
qkv = self.nets["qkv"](x)
q, k, v = torch.chunk(qkv, 3, dim=-1)
k = k.view(B, T, NH, DH).transpose(1, 2) # [B, NH, T, DH]
q = q.view(B, T, NH, DH).transpose(1, 2) # [B, NH, T, DH]
v = v.view(B, T, NH, DH).transpose(1, 2) # [B, NH, T, DH]
# causal self-attention mechanism
# batched matrix multiplication between queries and keys to get all pair-wise dot-products.
# We broadcast across batch and attention heads and get pair-wise dot-products between all pairs of timesteps
# [B, NH, T, DH] x [B, NH, DH, T] -> [B, NH, T, T]
att = (q @ k.transpose(-2, -1)) * (1.0 / math.sqrt(k.size(-1)))
# use mask to replace entries in dot products with negative inf to ensure they don't contribute to softmax,
# then take softmax over last dimension to end up with attention score for each member of sequence.
# Note the use of [:T, :T] - this makes it so we can handle sequences less than @self.context_length in length.
att = att.masked_fill(self.mask[..., :T, :T] == 0, float("-inf"))
att = F.softmax(
att, dim=-1
) # shape [B, NH, T, T], last dimension has score over all T for each sequence member
# dropout on attention
att = self.nets["attn_dropout"](att)
# take weighted sum of value vectors over whole sequence according to attention, with batched matrix multiplication
# [B, NH, T, T] x [B, NH, T, DH] -> [B, NH, T, DH]
y = att @ v
# reshape [B, NH, T, DH] -> [B, T, NH, DH] -> [B, T, NH * DH] = [B, T, D]
y = y.transpose(1, 2).contiguous().view(B, T, D)
# pass through output layer + dropout
y = self.nets["output"](y)
y = self.nets["output_dropout"](y)
return y
def output_shape(self, input_shape=None):
"""
Function to compute output shape from inputs to this module.
Args:
input_shape (iterable of int): shape of input. Does not include batch dimension.
Some modules may not need this argument, if their output does not depend
on the size of the input, or if they assume fixed size input.
Returns:
out_shape ([int]): list of integers corresponding to output shape
"""
# this module doesn't modify the size of the input, it goes from (B, T, D) -> (B, T, D)
return list(input_shape)
class SelfAttentionBlock(Module):
"""
A single Transformer Block, that can be chained together repeatedly.
It consists of a @CausalSelfAttention module and a small MLP, along with
layer normalization and residual connections on each input.
"""
def __init__(
self,
embed_dim,
num_heads,
context_length,
attn_dropout=0.1,
output_dropout=0.1,
activation=nn.GELU(),
):
"""
Args:
embed_dim (int): dimension of embeddings to use for keys, queries, and values
used in self-attention
num_heads (int): number of attention heads - must divide @embed_dim evenly. Self-attention is
computed over this many partitions of the embedding dimension separately.
context_length (int): expected length of input sequences
attn_dropout (float): dropout probability for attention outputs
output_dropout (float): dropout probability for final outputs
activation (str): string denoting the activation function to use in each transformer block
"""
super(SelfAttentionBlock, self).__init__()
self.embed_dim = embed_dim
self.num_heads = num_heads
self.context_length = context_length
self.attn_dropout = attn_dropout
self.output_dropout = output_dropout
self.nets = nn.ModuleDict()
# self-attention block
self.nets["attention"] = CausalSelfAttention(
embed_dim=embed_dim,
num_heads=num_heads,
context_length=context_length,
attn_dropout=attn_dropout,
output_dropout=output_dropout,
)
if type(activation) == GEGLU:
mult = 2
else:
mult = 1
# small 2-layer MLP
self.nets["mlp"] = nn.Sequential(
nn.Linear(embed_dim, 4 * embed_dim * mult),
activation,
nn.Linear(4 * embed_dim, embed_dim),
nn.Dropout(output_dropout)
)
# layer normalization for inputs to self-attention module and MLP
self.nets["ln1"] = nn.LayerNorm(embed_dim)
self.nets["ln2"] = nn.LayerNorm(embed_dim)
def forward(self, x):
"""
Forward pass - chain self-attention + MLP blocks, with residual connections and layer norms.
"""
x = x + self.nets["attention"](self.nets["ln1"](x))
x = x + self.nets["mlp"](self.nets["ln2"](x))
return x
def output_shape(self, input_shape=None):
"""
Function to compute output shape from inputs to this module.
Args:
input_shape (iterable of int): shape of input. Does not include batch dimension.
Some modules may not need this argument, if their output does not depend
on the size of the input, or if they assume fixed size input.
Returns:
out_shape ([int]): list of integers corresponding to output shape
"""
# this module doesn't modify the size of the input, it goes from (B, T, D) -> (B, T, D)
return list(input_shape)
class GPT_Backbone(Module):
"""the GPT model, with a context size of block_size"""
def __init__(
self,
embed_dim,
context_length,
attn_dropout=0.1,
block_output_dropout=0.1,
num_layers=6,
num_heads=8,
activation="gelu",
):
"""
Args:
embed_dim (int): dimension of embeddings to use for keys, queries, and values
used in self-attention
context_length (int): expected length of input sequences
attn_dropout (float): dropout probability for attention outputs for each transformer block
block_output_dropout (float): dropout probability for final outputs for each transformer block
num_layers (int): number of transformer blocks to stack
num_heads (int): number of attention heads - must divide @embed_dim evenly. Self-attention is
computed over this many partitions of the embedding dimension separately.
activation (str): string denoting the activation function to use in each transformer block
"""
super(GPT_Backbone, self).__init__()
self.embed_dim = embed_dim
self.num_layers = num_layers
self.num_heads = num_heads
self.context_length = context_length
self.attn_dropout = attn_dropout
self.block_output_dropout = block_output_dropout
if activation == "gelu":
self.activation = nn.GELU()
elif activation == "geglu":
self.activation = GEGLU()
# create networks
self._create_networks()
# initialize weights
self.apply(self._init_weights)
print(
"Created {} model with number of parameters: {}".format(
self.__class__.__name__, sum(p.numel() for p in self.parameters())
)
)
def _create_networks(self):
"""
Helper function to create networks.
"""
self.nets = nn.ModuleDict()
# transformer - cascaded transformer blocks
self.nets["transformer"] = nn.Sequential(
*[
SelfAttentionBlock(
embed_dim=self.embed_dim,
num_heads=self.num_heads,
context_length=self.context_length,
attn_dropout=self.attn_dropout,
output_dropout=self.block_output_dropout,
activation=self.activation,
)
for _ in range(self.num_layers)
]
)
# decoder head
self.nets["output_ln"] = nn.LayerNorm(self.embed_dim)
def _init_weights(self, module):
"""
Weight initializer.
"""
if isinstance(module, (nn.Linear, nn.Embedding)):
module.weight.data.normal_(mean=0.0, std=0.02)
if isinstance(module, nn.Linear) and module.bias is not None:
module.bias.data.zero_()
elif isinstance(module, nn.LayerNorm):
module.bias.data.zero_()
module.weight.data.fill_(1.0)
def output_shape(self, input_shape=None):
"""
Function to compute output shape from inputs to this module.
Args:
input_shape (iterable of int): shape of input. Does not include batch dimension.
Some modules may not need this argument, if their output does not depend
on the size of the input, or if they assume fixed size input.
Returns:
out_shape ([int]): list of integers corresponding to output shape
"""
# this module takes inputs (B, T, @self.input_dim) and produces outputs (B, T, @self.output_dim)
return input_shape[:-1] + [self.output_dim]
def forward(self, inputs):
assert inputs.shape[1:] == (self.context_length, self.embed_dim), inputs.shape
x = self.nets["transformer"](inputs)
transformer_output = self.nets["output_ln"](x)
return transformer_output |