| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| |
|
| | import json |
| | import sys |
| | from pathlib import Path |
| |
|
| | from transformers import is_torch_available |
| | from transformers.testing_utils import ( |
| | TestCasePlus, |
| | execute_subprocess_async, |
| | require_accelerate, |
| | require_torch_multi_accelerator, |
| | run_first, |
| | slow, |
| | ) |
| |
|
| |
|
| | if is_torch_available(): |
| | import torch |
| |
|
| | from transformers import ( |
| | AutoModelForCausalLM, |
| | AutoTokenizer, |
| | DataCollatorForLanguageModeling, |
| | HfArgumentParser, |
| | Trainer, |
| | TrainingArguments, |
| | ) |
| |
|
| |
|
| | class TestContextParallel(TestCasePlus): |
| | """Test Trainer with Torch context parallelism enabled via accelerate's ParallelismConfig.""" |
| |
|
| | @require_torch_multi_accelerator |
| | @require_accelerate |
| | @slow |
| | @run_first |
| | def test_cp_equivalence(self): |
| | """Test that CP produces the same losses as without CP.""" |
| |
|
| | |
| | world_size = 2 |
| | script_path = __file__ |
| |
|
| | |
| | cp_yes_output_dir = Path(self.get_auto_remove_tmp_dir()).resolve() |
| | cp_yes_config_path = cp_yes_output_dir / "context_parallel_config.yaml" |
| | cp_yes_losses_path = cp_yes_output_dir / "cp_yes_losses.json" |
| |
|
| | |
| | with open(cp_yes_config_path, "w") as f: |
| | f.write( |
| | f"""distributed_type: FSDP |
| | fsdp_config: |
| | fsdp_auto_wrap_policy: TRANSFORMER_BASED_WRAP |
| | fsdp_state_dict_type: SHARDED_STATE_DICT |
| | fsdp_version: 2 |
| | mixed_precision: bf16 |
| | num_processes: {world_size} |
| | parallelism_config: |
| | parallelism_config_dp_replicate_size: 1 |
| | parallelism_config_dp_shard_size: 1 |
| | parallelism_config_tp_size: 1 |
| | parallelism_config_cp_size: {world_size} |
| | parallelism_config_cp_comm_strategy: alltoall |
| | """ |
| | ) |
| |
|
| | cmd_cp_yes = f""" |
| | accelerate launch |
| | --config_file {cp_yes_config_path} |
| | {script_path} |
| | --output_dir {cp_yes_output_dir} |
| | --report_to none |
| | --max_steps 10 |
| | --per_device_train_batch_size 1 |
| | --gradient_accumulation_steps 1 |
| | --logging_steps 1 |
| | --remove_unused_columns False |
| | --seed 42 |
| | --loss_output_file {cp_yes_losses_path} |
| | """.split() |
| |
|
| | execute_subprocess_async(cmd_cp_yes, env=self.get_env()) |
| |
|
| | |
| | cp_no_output_dir = Path(self.get_auto_remove_tmp_dir()).resolve() |
| | cp_no_config_path = cp_no_output_dir / "context_parallel_config.yaml" |
| | cp_no_losses_path = cp_no_output_dir / "cp_no_losses.json" |
| |
|
| | |
| | with open(cp_no_config_path, "w") as f: |
| | f.write( |
| | """distributed_type: FSDP |
| | fsdp_config: |
| | fsdp_auto_wrap_policy: TRANSFORMER_BASED_WRAP |
| | fsdp_state_dict_type: SHARDED_STATE_DICT |
| | fsdp_transformer_layer_cls_to_wrap: LlamaDecoderLayer |
| | fsdp_version: 2 |
| | mixed_precision: bf16 |
| | num_processes: 1 |
| | """ |
| | ) |
| |
|
| | cmd_cp_no = f""" |
| | accelerate launch |
| | --config_file {cp_no_config_path} |
| | {script_path} |
| | --output_dir {cp_no_output_dir} |
| | --report_to none |
| | --max_steps 10 |
| | --per_device_train_batch_size 1 |
| | --gradient_accumulation_steps 1 |
| | --logging_steps 1 |
| | --remove_unused_columns False |
| | --seed 42 |
| | --loss_output_file {cp_no_losses_path} |
| | """.split() |
| |
|
| | execute_subprocess_async(cmd_cp_no, env=self.get_env()) |
| |
|
| | |
| | with open(cp_yes_losses_path) as f: |
| | cp_yes_losses = json.load(f) |
| | with open(cp_no_losses_path) as f: |
| | cp_no_losses = json.load(f) |
| |
|
| | assert len(cp_yes_losses) == len(cp_no_losses), ( |
| | f"Different number of losses: CP has {len(cp_yes_losses)}, no-CP has {len(cp_no_losses)}" |
| | ) |
| |
|
| | |
| | |
| | |
| | |
| | |
| | cp_yes_losses_tensor = torch.tensor(cp_yes_losses) |
| | cp_no_losses_tensor = torch.tensor(cp_no_losses) |
| |
|
| | |
| | |
| | torch.testing.assert_close( |
| | cp_yes_losses_tensor, |
| | cp_no_losses_tensor, |
| | rtol=2e-2, |
| | atol=2e-2, |
| | msg=f"CP losses {cp_yes_losses} do not match non-CP losses {cp_no_losses}", |
| | ) |
| |
|
| |
|
| | if __name__ == "__main__": |
| | |
| | loss_output_file = None |
| |
|
| | if "--loss_output_file" in sys.argv: |
| | idx = sys.argv.index("--loss_output_file") |
| | loss_output_file = sys.argv[idx + 1] |
| | sys.argv.pop(idx) |
| | sys.argv.pop(idx) |
| |
|
| | parser = HfArgumentParser((TrainingArguments,)) |
| | training_args = parser.parse_args_into_dataclasses()[0] |
| |
|
| | |
| | model_name = "HuggingFaceTB/SmolLM-135M" |
| | tokenizer = AutoTokenizer.from_pretrained(model_name) |
| | if tokenizer.pad_token is None: |
| | tokenizer.pad_token = tokenizer.eos_token |
| |
|
| | model = AutoModelForCausalLM.from_pretrained( |
| | model_name, |
| | attn_implementation="sdpa", |
| | dtype=torch.float32, |
| | ) |
| |
|
| | |
| | texts = [ |
| | "The quick brown fox jumps over the lazy dog. " * 10, |
| | "Hello world, this is a test sentence for training. " * 10, |
| | ] * 4 |
| |
|
| | def tokenize_function(examples): |
| | return tokenizer(examples, max_length=128, truncation=True, padding="max_length") |
| |
|
| | train_dataset = [tokenize_function(text) for text in texts] |
| |
|
| | |
| | |
| | |
| | data_collator = DataCollatorForLanguageModeling( |
| | tokenizer=tokenizer, |
| | mlm=False, |
| | pad_to_multiple_of=4, |
| | ) |
| |
|
| | trainer = Trainer( |
| | model=model, |
| | args=training_args, |
| | train_dataset=train_dataset, |
| | data_collator=data_collator, |
| | ) |
| |
|
| | |
| | trainer.train() |
| |
|
| | |
| | assert trainer.state.global_step > 0, "Training should have completed at least one step" |
| |
|
| | |
| | if loss_output_file and training_args.process_index == 0: |
| | losses = [log["loss"] for log in trainer.state.log_history if "loss" in log] |
| | with open(loss_output_file, "w") as f: |
| | json.dump(losses, f) |
| |
|