| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| import torch |
| from transformers import ( |
| BertConfig, |
| BertForMaskedLM, |
| GPT2Config, |
| GPT2ForSequenceClassification, |
| ) |
|
|
| from accelerate import PartialState |
| from accelerate.inference import prepare_pippy |
| from accelerate.test_utils import torch_device |
| from accelerate.utils import DistributedType, set_seed |
|
|
|
|
| model_to_config = { |
| "bert": (BertForMaskedLM, BertConfig, 512), |
| "gpt2": (GPT2ForSequenceClassification, GPT2Config, 1024), |
| } |
|
|
|
|
| def get_model_and_data_for_text(model_name, device, num_processes: int = 2): |
| initializer, config, seq_len = model_to_config[model_name] |
| config_args = {} |
| |
| |
| |
| model_config = config(**config_args) |
| model = initializer(model_config) |
| kwargs = dict(low=0, high=model_config.vocab_size, device=device, dtype=torch.int64, requires_grad=False) |
| trace_input = torch.randint(size=(1, seq_len), **kwargs) |
| inference_inputs = torch.randint(size=(num_processes, seq_len), **kwargs) |
| return model, trace_input, inference_inputs |
|
|
|
|
| def test_bert(batch_size: int = 2): |
| set_seed(42) |
| state = PartialState() |
| model, trace_input, inference_inputs = get_model_and_data_for_text("bert", "cpu", batch_size) |
| model = prepare_pippy(model, example_args=(trace_input,), no_split_module_classes=model._no_split_modules) |
| |
| inputs = inference_inputs.to(torch_device) |
| with torch.no_grad(): |
| output = model(inputs) |
| |
| if not state.is_last_process: |
| assert output is None, "Output was not generated on just the last process!" |
| else: |
| assert output is not None, "Output was not generated in the last process!" |
|
|
|
|
| def test_gpt2(batch_size: int = 2): |
| set_seed(42) |
| state = PartialState() |
| model, trace_input, inference_inputs = get_model_and_data_for_text("gpt2", "cpu", batch_size) |
| model = prepare_pippy(model, example_args=(trace_input,), no_split_module_classes=model._no_split_modules) |
| |
| inputs = inference_inputs.to(torch_device) |
| with torch.no_grad(): |
| output = model(inputs) |
| |
| if not state.is_last_process: |
| assert output is None, "Output was not generated on just the last process!" |
| else: |
| assert output is not None, "Output was not generated in the last process!" |
|
|
|
|
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
|
|
|
|
| if __name__ == "__main__": |
| state = PartialState() |
| state.print("Testing pippy integration...") |
| try: |
| if state.distributed_type in [DistributedType.MULTI_GPU, DistributedType.MULTI_XPU, DistributedType.MULTI_HPU]: |
| state.print("Testing GPT2...") |
| test_gpt2() |
| |
| |
| |
| |
| state.print("Testing BERT...") |
| test_bert() |
| else: |
| print("Less than two GPUs found, not running tests!") |
| finally: |
| state.destroy_process_group() |
|
|