| |
| import time |
| from argparse import ArgumentParser |
|
|
| import jax |
| import numpy as np |
|
|
| from transformers import BertConfig, FlaxBertModel |
|
|
|
|
| parser = ArgumentParser() |
| parser.add_argument("--precision", type=str, choices=["float32", "bfloat16"], default="float32") |
| args = parser.parse_args() |
|
|
| dtype = jax.numpy.float32 |
| if args.precision == "bfloat16": |
| dtype = jax.numpy.bfloat16 |
|
|
| VOCAB_SIZE = 30522 |
| BS = 32 |
| SEQ_LEN = 128 |
|
|
|
|
| def get_input_data(batch_size=1, seq_length=384): |
| shape = (batch_size, seq_length) |
| input_ids = np.random.randint(1, VOCAB_SIZE, size=shape).astype(np.int32) |
| token_type_ids = np.ones(shape).astype(np.int32) |
| attention_mask = np.ones(shape).astype(np.int32) |
| return {"input_ids": input_ids, "token_type_ids": token_type_ids, "attention_mask": attention_mask} |
|
|
|
|
| inputs = get_input_data(BS, SEQ_LEN) |
| config = BertConfig.from_pretrained("bert-base-uncased", hidden_act="gelu_new") |
| model = FlaxBertModel.from_pretrained("bert-base-uncased", config=config, dtype=dtype) |
|
|
|
|
| @jax.jit |
| def func(): |
| outputs = model(**inputs) |
| return outputs |
|
|
|
|
| (nwarmup, nbenchmark) = (5, 100) |
|
|
| |
| for _ in range(nwarmup): |
| func() |
|
|
| |
|
|
| start = time.time() |
| for _ in range(nbenchmark): |
| func() |
| end = time.time() |
| print(end - start) |
| print(f"Throughput: {((nbenchmark * BS) / (end - start)):.3f} examples/sec") |
|
|