| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
|
|
| import argparse |
| from typing import Any, Callable |
|
|
| from transformers import is_torch_available, is_torch_xpu_available |
| from transformers.testing_utils import ( |
| TestCasePlus, |
| backend_device_count, |
| backend_torch_accelerator_module, |
| execute_subprocess_async, |
| get_torch_dist_unique_port, |
| require_torch_multi_accelerator, |
| torch_device, |
| ) |
| from transformers.utils import is_ccl_available, is_ipex_available |
|
|
|
|
| if is_torch_available(): |
| import functools |
|
|
| import torch |
|
|
| if is_torch_xpu_available(): |
| if is_ipex_available(): |
| import intel_extension_for_pytorch |
| if is_ccl_available(): |
| import oneccl_bindings_for_pytorch |
| import torch.distributed |
| from torch.distributed._composable.fsdp import fully_shard, register_fsdp_forward_method |
| from torch.distributed.device_mesh import init_device_mesh |
| from torch.distributed.fsdp import FullyShardedDataParallel |
| from torch.distributed.fsdp.wrap import transformer_auto_wrap_policy |
|
|
| from transformers import AutoModelForCausalLM, AutoTokenizer |
| from transformers.models.gpt2.modeling_gpt2 import GPT2Block |
|
|
| data = 4 * [ |
| "Hello world!", |
| "The quick brown fox jumps over the lazy dog.", |
| ] |
|
|
| def manage_process_group(func: Callable[..., Any]) -> Callable[..., Any]: |
| """Manage the creation and destruction of the distributed process group for the wrapped function.""" |
|
|
| def wrapped(*args: Any, **kwargs: Any) -> Any: |
| device_count = backend_device_count(torch_device) |
| torch.distributed.init_process_group(world_size=device_count) |
| try: |
| return func(*args, **kwargs) |
| finally: |
| torch.distributed.destroy_process_group() |
|
|
| return wrapped |
|
|
| @manage_process_group |
| def fsdp_generate(): |
| torch_accelerator_module = backend_torch_accelerator_module(torch_device) |
| torch_accelerator_module.set_device(device := torch.device(rank := torch.distributed.get_rank())) |
|
|
| model = AutoModelForCausalLM.from_pretrained("hf-internal-testing/tiny-random-gpt2").to(device) |
|
|
| fsdp_model = FullyShardedDataParallel( |
| model, |
| auto_wrap_policy=functools.partial(transformer_auto_wrap_policy, transformer_layer_cls={GPT2Block}), |
| limit_all_gathers=True, |
| use_orig_params=True, |
| ) |
|
|
| tokenizer = AutoTokenizer.from_pretrained("hf-internal-testing/tiny-random-gpt2") |
| batch = tokenizer(data[rank], return_tensors="pt", return_attention_mask=True).to(device) |
|
|
| with FullyShardedDataParallel.summon_full_params(fsdp_model): |
| _ = fsdp_model.module.generate( |
| input_ids=batch["input_ids"], |
| attention_mask=batch["attention_mask"], |
| max_length=30, |
| ) |
|
|
| @manage_process_group |
| def fsdp2_generate(): |
| torch_accelerator_module = backend_torch_accelerator_module(torch_device) |
| torch_accelerator_module.set_device(device := torch.device(rank := torch.distributed.get_rank())) |
|
|
| model = AutoModelForCausalLM.from_pretrained("hf-internal-testing/tiny-random-gpt2").to(device) |
|
|
| mesh = init_device_mesh(device.type, (torch.distributed.get_world_size(),)) |
| for submodule in model.modules(): |
| if isinstance(submodule, GPT2Block): |
| fully_shard(submodule, mesh=mesh) |
| fully_shard(model, mesh=mesh) |
|
|
| register_fsdp_forward_method(model, "generate") |
|
|
| tokenizer = AutoTokenizer.from_pretrained("hf-internal-testing/tiny-random-gpt2") |
| batch = tokenizer(data[rank], return_tensors="pt", return_attention_mask=True).to(device) |
|
|
| _ = model.generate( |
| input_ids=batch["input_ids"], |
| attention_mask=batch["attention_mask"], |
| max_length=30, |
| ) |
|
|
|
|
| class TestFSDPGeneration(TestCasePlus): |
| @require_torch_multi_accelerator |
| def test_fsdp_generate(self): |
| device_count = backend_device_count(torch_device) |
| distributed_args = f"""--nproc_per_node={device_count} |
| --master_port={get_torch_dist_unique_port()} |
| {self.test_file_dir}/test_fsdp.py |
| """.split() |
| args = "--fsdp".split() |
| cmd = ["torchrun"] + distributed_args + args |
| execute_subprocess_async(cmd, env=self.get_env()) |
| |
|
|
| @require_torch_multi_accelerator |
| def test_fsdp2_generate(self): |
| device_count = backend_device_count(torch_device) |
|
|
| distributed_args = f"""--nproc_per_node={device_count} |
| --master_port={get_torch_dist_unique_port()} |
| {self.test_file_dir}/test_fsdp.py |
| """.split() |
| args = "--fsdp2".split() |
| cmd = ["torchrun"] + distributed_args + args |
| execute_subprocess_async(cmd, env=self.get_env()) |
| |
|
|
|
|
| if __name__ == "__main__": |
| |
| |
| |
|
|
| class CLIArgs(argparse.Namespace): |
| fsdp: bool |
| fsdp2: bool |
|
|
| parser = argparse.ArgumentParser() |
| group = parser.add_mutually_exclusive_group() |
| group.add_argument("--fsdp", action="store_true") |
| group.add_argument("--fsdp2", action="store_true") |
| args = parser.parse_args(namespace=CLIArgs()) |
|
|
| if args.fsdp: |
| fsdp_generate() |
| elif args.fsdp2: |
| fsdp2_generate() |
| else: |
| raise ValueError("Missing test selection") |
|
|