ArabGlossBERT / arabert /araelectra /flops_computation.py
TymaaHammouda's picture
Upload 106 files
ceed500 verified
"""Computes the flops needed for training/running transformer networks."""
import collections
# We checked this code with TensorFlow"s FLOPs counting, although we had to
# correct for this issue: https://github.com/tensorflow/tensorflow/issues/22071
# Assumptions going into the FLOPs counting
# - An "operation" is a mathematical operation, not a machine instruction. So
# an "exp" takes one opp like and add, even though in practice an exp
# might be slower. This is not too bad an assumption because
# matrix-multiplies dominate the compute for most models, so minor details
# about activation functions don"t matter too much. Similarly, we count
# matrix-multiplies as 2*m*n flops instead of m*n, as one might if
# if considering fused multiply-add ops.
# - Backward pass takes the same number of FLOPs as forward pass. No exactly
# right (e.g., for softmax cross entropy loss the backward pass is faster).
# Importantly, it really is the same for matrix-multiplies, which is most of
# the compute anyway.
# - We assume "dense" embedding lookups (i.e., multiplication by a one-hot
# vector). On some hardware accelerators, these dense operations are
# actually faster than sparse lookups.
# Please open a github issue if you spot a problem with this code!
# I am not sure if the below constants are 100% right, but they are only applied
# to O(hidden_size) activations, which is generally a lot less compute than the
# matrix-multiplies, which are O(hidden_size^2), so they don't affect the total
# number of FLOPs much.
# random number, >=, multiply activations by dropout mask, multiply activations
# by correction (1 / (1 - dropout_rate))
DROPOUT_FLOPS = 4
# compute mean activation (sum), computate variance of activation
# (square and sum), bias (add), scale (multiply)
LAYER_NORM_FLOPS = 5
# GELU: 0.5 * x * (1 + tanh(sqrt(2 / np.pi) * (x + 0.044715 * pow(x, 3))))
ACTIVATION_FLOPS = 8
# max/substract (for stability), exp, sum, divide
SOFTMAX_FLOPS = 5
class TransformerHparams(object):
"""Computes the train/inference FLOPs for transformers."""
def __init__(self, h, l, s=512, v=30522, e=None, i=None, heads=None,
head_size=None, output_frac=0.15625, sparse_embed_lookup=False,
decoder=False):
self.h = h # hidden size
self.l = l # number of layers
self.s = s # sequence length
self.v = v # vocab size
self.e = h if e is None else e # embedding size
self.i = h * 4 if i is None else i # intermediate size
self.kqv = h if head_size is None else head_size * heads # attn proj sizes
self.heads = max(h // 64, 1) if heads is None else heads # attention heads
self.output_frac = output_frac # percent of tokens using an output softmax
self.sparse_embed_lookup = sparse_embed_lookup # sparse embedding lookups
self.decoder = decoder # decoder has extra attn to encoder states
def get_block_flops(self):
"""Get the forward-pass FLOPs for a single transformer block."""
attn_mul = 2 if self.decoder else 1
block_flops = dict(
kqv=3 * 2 * self.h * self.kqv * attn_mul,
kqv_bias=3 * self.kqv * attn_mul,
attention_scores=2 * self.kqv * self.s * attn_mul,
attn_softmax=SOFTMAX_FLOPS * self.s * self.heads * attn_mul,
attention_dropout=DROPOUT_FLOPS * self.s * self.heads * attn_mul,
attention_scale=self.s * self.heads * attn_mul,
attention_weighted_avg_values=2 * self.h * self.s * attn_mul,
attn_output=2 * self.h * self.h * attn_mul,
attn_output_bias=self.h * attn_mul,
attn_output_dropout=DROPOUT_FLOPS * self.h * attn_mul,
attn_output_residual=self.h * attn_mul,
attn_output_layer_norm=LAYER_NORM_FLOPS * attn_mul,
intermediate=2 * self.h * self.i,
intermediate_act=ACTIVATION_FLOPS * self.i,
intermediate_bias=self.i,
output=2 * self.h * self.i,
output_bias=self.h,
output_dropout=DROPOUT_FLOPS * self.h,
output_residual=self.h,
output_layer_norm=LAYER_NORM_FLOPS * self.h,
)
return sum(block_flops.values()) * self.s
def get_embedding_flops(self, output=False):
"""Get the forward-pass FLOPs the transformer inputs or output softmax."""
embedding_flops = {}
if output or (not self.sparse_embed_lookup):
embedding_flops["main_multiply"] = 2 * self.e * self.v
# input embedding post-processing
if not output:
embedding_flops.update(dict(
tok_type_and_position=2 * self.e * (self.s + 2),
add_tok_type_and_position=2 * self.e,
emb_layer_norm=LAYER_NORM_FLOPS * self.e,
emb_dropout=DROPOUT_FLOPS * self.e
))
# projection layer if e != h
if self.e != self.h or output:
embedding_flops.update(dict(
hidden_kernel=2 * self.h * self.e,
hidden_bias=self.e if output else self.h
))
# extra hidden layer and output softmax
if output:
embedding_flops.update(dict(
hidden_activation=ACTIVATION_FLOPS * self.e,
hidden_layernorm=LAYER_NORM_FLOPS * self.e,
output_softmax=SOFTMAX_FLOPS * self.v,
output_target_word=2 * self.v
))
return self.output_frac * sum(embedding_flops.values()) * self.s
return sum(embedding_flops.values()) * self.s
def get_binary_classification_flops(self):
classification_flops = dict(
hidden=2 * self.h * self.h,
hidden_bias=self.h,
hidden_act=ACTIVATION_FLOPS * self.h,
logits=2 * self.h
)
return sum(classification_flops.values()) * self.s
def get_train_flops(self, batch_size, train_steps, discriminator=False):
"""Get the FLOPs for pre-training the transformer."""
# 2* for forward/backward pass
return 2 * batch_size * train_steps * (
(self.l * self.get_block_flops()) +
self.get_embedding_flops(output=False) +
(self.get_binary_classification_flops() if discriminator else
self.get_embedding_flops(output=True))
)
def get_infer_flops(self):
"""Get the FLOPs for running inference with the transformer on a
classification task."""
return ((self.l * self.get_block_flops()) +
self.get_embedding_flops(output=False) +
self.get_binary_classification_flops())
def get_electra_train_flops(
h_d, l_d, h_g, l_g, batch_size, train_steps, tied_embeddings,
e=None, s=512, output_frac=0.15625):
"""Get the FLOPs needed for pre-training ELECTRA."""
if e is None:
e = h_d
disc = TransformerHparams(
h_d, l_d, s=s, e=e,
output_frac=output_frac).get_train_flops(batch_size, train_steps, True)
gen = TransformerHparams(
h_g, l_g, s=s, e=e if tied_embeddings else None,
output_frac=output_frac).get_train_flops(batch_size, train_steps)
return disc + gen
MODEL_FLOPS = collections.OrderedDict([
# These runtimes were computed with tensorflow FLOPs counting instead of the
# script, as the neural architectures are quite different.
# 768648884 words in LM1b benchmark, 10 epochs with batch size 20,
# seq length 128, 568093262680 FLOPs per example.
("elmo", 2 * 10 * 768648884 * 568093262680 / (20.0 * 128)),
# 15064773691518 is FLOPs for forward pass on 32 examples.
# Therefore 2 * steps * batch_size * 15064773691518 / 32 is XLNet compute
("xlnet", 2 * 500000 * 8192 * 15064773691518 / 32.0),
# Runtimes computed with the script
("gpt", TransformerHparams(768, 12, v=40000, output_frac=1.0).get_train_flops(
128, 960800)),
("bert_small", TransformerHparams(256, 12, e=128, s=128).get_train_flops(128, 1.45e6)),
("bert_base", TransformerHparams(768, 12).get_train_flops(256, 1e6)),
("bert_large", TransformerHparams(1024, 24).get_train_flops(256, 1e6)),
("electra_small", get_electra_train_flops(256, 12, 64, 12, 128, 1e6, True, s=128, e=128)),
("electra_base", get_electra_train_flops(768, 12, 256, 12, 256, 766000, True)),
("electra_400k", get_electra_train_flops(1024, 24, 256, 24, 2048, 400000, True)),
("electra_1.75M", get_electra_train_flops(1024, 24, 256, 24, 2048, 1750000, True)),
# RoBERTa, ALBERT, and T5 have minor architectural differences from
# BERT/ELECTRA, but I believe they don't significantly effect the runtime,
# so we use this script for those models as well.
("roberta", TransformerHparams(1024, 24, v=50265).get_train_flops(8000, 500000)),
("albert", TransformerHparams(4096, 12, v=30000, e=128).get_train_flops(
4096, 1.5e6)),
("t5_11b", TransformerHparams(
1024, # hidden size
24, # layers
v=32000, # vocab size
i=65536, # ff intermediate hidden size
heads=128, head_size=128, # heads/head size
output_frac=0.0 # encoder has no output softmax
).get_train_flops(2048, 1e6) + # 1M steps with batch size 2048
TransformerHparams(
1024,
24,
v=32000,
i=65536,
heads=128, head_size=128,
output_frac=1.0, # decoder has output softmax for all positions
decoder=True
).get_train_flops(2048, 1e6))
])
def main():
for k, v in MODEL_FLOPS.items():
print(k, v)
if __name__ == "__main__":
main()