| """Computes the flops needed for training/running transformer networks.""" |
|
|
| import collections |
|
|
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
|
|
| |
| |
| |
| |
|
|
| |
| |
| DROPOUT_FLOPS = 4 |
|
|
| |
| |
| LAYER_NORM_FLOPS = 5 |
|
|
| |
| ACTIVATION_FLOPS = 8 |
|
|
| |
| SOFTMAX_FLOPS = 5 |
|
|
|
|
| class TransformerHparams(object): |
| """Computes the train/inference FLOPs for transformers.""" |
|
|
| def __init__(self, h, l, s=512, v=30522, e=None, i=None, heads=None, |
| head_size=None, output_frac=0.15625, sparse_embed_lookup=False, |
| decoder=False): |
| self.h = h |
| self.l = l |
| self.s = s |
| self.v = v |
| self.e = h if e is None else e |
| self.i = h * 4 if i is None else i |
| self.kqv = h if head_size is None else head_size * heads |
| self.heads = max(h // 64, 1) if heads is None else heads |
| self.output_frac = output_frac |
| self.sparse_embed_lookup = sparse_embed_lookup |
| self.decoder = decoder |
|
|
| def get_block_flops(self): |
| """Get the forward-pass FLOPs for a single transformer block.""" |
| attn_mul = 2 if self.decoder else 1 |
| block_flops = dict( |
| kqv=3 * 2 * self.h * self.kqv * attn_mul, |
| kqv_bias=3 * self.kqv * attn_mul, |
| attention_scores=2 * self.kqv * self.s * attn_mul, |
| attn_softmax=SOFTMAX_FLOPS * self.s * self.heads * attn_mul, |
| attention_dropout=DROPOUT_FLOPS * self.s * self.heads * attn_mul, |
| attention_scale=self.s * self.heads * attn_mul, |
| attention_weighted_avg_values=2 * self.h * self.s * attn_mul, |
| attn_output=2 * self.h * self.h * attn_mul, |
| attn_output_bias=self.h * attn_mul, |
| attn_output_dropout=DROPOUT_FLOPS * self.h * attn_mul, |
| attn_output_residual=self.h * attn_mul, |
| attn_output_layer_norm=LAYER_NORM_FLOPS * attn_mul, |
| intermediate=2 * self.h * self.i, |
| intermediate_act=ACTIVATION_FLOPS * self.i, |
| intermediate_bias=self.i, |
| output=2 * self.h * self.i, |
| output_bias=self.h, |
| output_dropout=DROPOUT_FLOPS * self.h, |
| output_residual=self.h, |
| output_layer_norm=LAYER_NORM_FLOPS * self.h, |
| ) |
| return sum(block_flops.values()) * self.s |
|
|
| def get_embedding_flops(self, output=False): |
| """Get the forward-pass FLOPs the transformer inputs or output softmax.""" |
| embedding_flops = {} |
| if output or (not self.sparse_embed_lookup): |
| embedding_flops["main_multiply"] = 2 * self.e * self.v |
| |
| if not output: |
| embedding_flops.update(dict( |
| tok_type_and_position=2 * self.e * (self.s + 2), |
| add_tok_type_and_position=2 * self.e, |
| emb_layer_norm=LAYER_NORM_FLOPS * self.e, |
| emb_dropout=DROPOUT_FLOPS * self.e |
| )) |
| |
| if self.e != self.h or output: |
| embedding_flops.update(dict( |
| hidden_kernel=2 * self.h * self.e, |
| hidden_bias=self.e if output else self.h |
| )) |
| |
| if output: |
| embedding_flops.update(dict( |
| hidden_activation=ACTIVATION_FLOPS * self.e, |
| hidden_layernorm=LAYER_NORM_FLOPS * self.e, |
| output_softmax=SOFTMAX_FLOPS * self.v, |
| output_target_word=2 * self.v |
| )) |
| return self.output_frac * sum(embedding_flops.values()) * self.s |
| return sum(embedding_flops.values()) * self.s |
|
|
| def get_binary_classification_flops(self): |
| classification_flops = dict( |
| hidden=2 * self.h * self.h, |
| hidden_bias=self.h, |
| hidden_act=ACTIVATION_FLOPS * self.h, |
| logits=2 * self.h |
| ) |
| return sum(classification_flops.values()) * self.s |
|
|
| def get_train_flops(self, batch_size, train_steps, discriminator=False): |
| """Get the FLOPs for pre-training the transformer.""" |
| |
| return 2 * batch_size * train_steps * ( |
| (self.l * self.get_block_flops()) + |
| self.get_embedding_flops(output=False) + |
| (self.get_binary_classification_flops() if discriminator else |
| self.get_embedding_flops(output=True)) |
| ) |
|
|
| def get_infer_flops(self): |
| """Get the FLOPs for running inference with the transformer on a |
| classification task.""" |
| return ((self.l * self.get_block_flops()) + |
| self.get_embedding_flops(output=False) + |
| self.get_binary_classification_flops()) |
|
|
|
|
| def get_electra_train_flops( |
| h_d, l_d, h_g, l_g, batch_size, train_steps, tied_embeddings, |
| e=None, s=512, output_frac=0.15625): |
| """Get the FLOPs needed for pre-training ELECTRA.""" |
| if e is None: |
| e = h_d |
| disc = TransformerHparams( |
| h_d, l_d, s=s, e=e, |
| output_frac=output_frac).get_train_flops(batch_size, train_steps, True) |
| gen = TransformerHparams( |
| h_g, l_g, s=s, e=e if tied_embeddings else None, |
| output_frac=output_frac).get_train_flops(batch_size, train_steps) |
| return disc + gen |
|
|
|
|
| MODEL_FLOPS = collections.OrderedDict([ |
| |
| |
| |
| |
| ("elmo", 2 * 10 * 768648884 * 568093262680 / (20.0 * 128)), |
| |
| |
| ("xlnet", 2 * 500000 * 8192 * 15064773691518 / 32.0), |
|
|
| |
| ("gpt", TransformerHparams(768, 12, v=40000, output_frac=1.0).get_train_flops( |
| 128, 960800)), |
| ("bert_small", TransformerHparams(256, 12, e=128, s=128).get_train_flops(128, 1.45e6)), |
| ("bert_base", TransformerHparams(768, 12).get_train_flops(256, 1e6)), |
| ("bert_large", TransformerHparams(1024, 24).get_train_flops(256, 1e6)), |
| ("electra_small", get_electra_train_flops(256, 12, 64, 12, 128, 1e6, True, s=128, e=128)), |
| ("electra_base", get_electra_train_flops(768, 12, 256, 12, 256, 766000, True)), |
| ("electra_400k", get_electra_train_flops(1024, 24, 256, 24, 2048, 400000, True)), |
| ("electra_1.75M", get_electra_train_flops(1024, 24, 256, 24, 2048, 1750000, True)), |
|
|
| |
| |
| |
| ("roberta", TransformerHparams(1024, 24, v=50265).get_train_flops(8000, 500000)), |
| ("albert", TransformerHparams(4096, 12, v=30000, e=128).get_train_flops( |
| 4096, 1.5e6)), |
| ("t5_11b", TransformerHparams( |
| 1024, |
| 24, |
| v=32000, |
| i=65536, |
| heads=128, head_size=128, |
| output_frac=0.0 |
| ).get_train_flops(2048, 1e6) + |
| TransformerHparams( |
| 1024, |
| 24, |
| v=32000, |
| i=65536, |
| heads=128, head_size=128, |
| output_frac=1.0, |
| decoder=True |
| ).get_train_flops(2048, 1e6)) |
| ]) |
|
|
|
|
| def main(): |
| for k, v in MODEL_FLOPS.items(): |
| print(k, v) |
|
|
|
|
| if __name__ == "__main__": |
| main() |
|
|
|
|