File size: 4,789 Bytes
e062359 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 | # Copyright 2024 The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import torch
from transformers import (
BertConfig,
BertForMaskedLM,
GPT2Config,
GPT2ForSequenceClassification,
)
from accelerate import PartialState
from accelerate.inference import prepare_pippy
from accelerate.test_utils import torch_device
from accelerate.utils import DistributedType, set_seed
model_to_config = {
"bert": (BertForMaskedLM, BertConfig, 512),
"gpt2": (GPT2ForSequenceClassification, GPT2Config, 1024),
}
def get_model_and_data_for_text(model_name, device, num_processes: int = 2):
initializer, config, seq_len = model_to_config[model_name]
config_args = {}
# Eventually needed for batch inference tests on gpt-2 when bs != 1
# if model_name == "gpt2":
# config_args["pad_token_id"] = 0
model_config = config(**config_args)
model = initializer(model_config)
kwargs = dict(low=0, high=model_config.vocab_size, device=device, dtype=torch.int64, requires_grad=False)
trace_input = torch.randint(size=(1, seq_len), **kwargs)
inference_inputs = torch.randint(size=(num_processes, seq_len), **kwargs)
return model, trace_input, inference_inputs
def test_bert(batch_size: int = 2):
set_seed(42)
state = PartialState()
model, trace_input, inference_inputs = get_model_and_data_for_text("bert", "cpu", batch_size)
model = prepare_pippy(model, example_args=(trace_input,), no_split_module_classes=model._no_split_modules)
# For inference args need to be a tuple
inputs = inference_inputs.to(torch_device)
with torch.no_grad():
output = model(inputs)
# Zach: Check that we just grab the real outputs we need at the end
if not state.is_last_process:
assert output is None, "Output was not generated on just the last process!"
else:
assert output is not None, "Output was not generated in the last process!"
def test_gpt2(batch_size: int = 2):
set_seed(42)
state = PartialState()
model, trace_input, inference_inputs = get_model_and_data_for_text("gpt2", "cpu", batch_size)
model = prepare_pippy(model, example_args=(trace_input,), no_split_module_classes=model._no_split_modules)
# For inference args need to be a tuple
inputs = inference_inputs.to(torch_device)
with torch.no_grad():
output = model(inputs)
# Zach: Check that we just grab the real outputs we need at the end
if not state.is_last_process:
assert output is None, "Output was not generated on just the last process!"
else:
assert output is not None, "Output was not generated in the last process!"
# Currently disabled, enable again once PyTorch pippy interface can trace a resnet34
# def test_resnet(batch_size: int = 2):
# set_seed(42)
# state = PartialState()
# model = resnet34()
# input_tensor = torch.rand(1, 3, 224, 224)
# model = prepare_pippy(
# model,
# example_args=(input_tensor,),
# )
# inference_inputs = torch.rand(batch_size, 3, 224, 224)
# inputs = send_to_device(inference_inputs, torch_device)
# with torch.no_grad():
# output = model(inputs)
# # Zach: Check that we just grab the real outputs we need at the end
# if not state.is_last_process:
# assert output is None, "Output was not generated on just the last process!"
# else:
# assert output is not None, "Output was not generated in the last process!"
if __name__ == "__main__":
state = PartialState()
state.print("Testing pippy integration...")
try:
if state.distributed_type in [DistributedType.MULTI_GPU, DistributedType.MULTI_XPU, DistributedType.MULTI_HPU]:
state.print("Testing GPT2...")
test_gpt2()
# Issue: When modifying the tokenizer for batch GPT2 inference, there's an issue
# due to references
# NameError: cannot access free variable 'chunk_args_list' where it is not associated with a value in enclosing scope
# test_gpt2(3)
state.print("Testing BERT...")
test_bert()
else:
print("Less than two GPUs found, not running tests!")
finally:
state.destroy_process_group()
|