| <!--Copyright 2023 The HuggingFace Team. All rights reserved. | |
| Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with | |
| the License. You may obtain a copy of the License at | |
| http://www.apache.org/licenses/LICENSE-2.0 | |
| Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on | |
| an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the | |
| specific language governing permissions and limitations under the License. | |
| โ ๏ธ Note that this file is in Markdown but contain specific syntax for our doc-builder (similar to MDX) that may not be | |
| rendered properly in your Markdown viewer. | |
| --> | |
| # Trainer [[trainer]] | |
| [`Trainer`]๋ Transformers ๋ผ์ด๋ธ๋ฌ๋ฆฌ์ ๊ตฌํ๋ PyTorch ๋ชจ๋ธ์ ๋ฐ๋ณตํ์ฌ ํ๋ จ ๋ฐ ํ๊ฐ ๊ณผ์ ์ ๋๋ค. ํ๋ จ์ ํ์ํ ์์(๋ชจ๋ธ, ํ ํฌ๋์ด์ , ๋ฐ์ดํฐ์ , ํ๊ฐ ํจ์, ํ๋ จ ํ์ดํผํ๋ผ๋ฏธํฐ ๋ฑ)๋ง ์ ๊ณตํ๋ฉด [`Trainer`]๊ฐ ํ์ํ ๋๋จธ์ง ์์ ์ ์ฒ๋ฆฌํฉ๋๋ค. ์ด๋ฅผ ํตํด ์ง์ ํ๋ จ ๋ฃจํ๋ฅผ ์์ฑํ์ง ์๊ณ ๋ ๋น ๋ฅด๊ฒ ํ๋ จ์ ์์ํ ์ ์์ต๋๋ค. ๋ํ [`Trainer`]๋ ๊ฐ๋ ฅํ ๋ง์ถค ์ค์ ๊ณผ ๋ค์ํ ํ๋ จ ์ต์ ์ ์ ๊ณตํ์ฌ ์ฌ์ฉ์ ๋ง์ถค ํ๋ จ์ด ๊ฐ๋ฅํฉ๋๋ค. | |
| <Tip> | |
| Transformers๋ [`Trainer`] ํด๋์ค ์ธ์๋ ๋ฒ์ญ์ด๋ ์์ฝ๊ณผ ๊ฐ์ ์ํ์ค-ํฌ-์ํ์ค ์์ ์ ์ํ [`Seq2SeqTrainer`] ํด๋์ค๋ ์ ๊ณตํฉ๋๋ค. ๋ํ [TRL](https://hf.co/docs/trl) ๋ผ์ด๋ธ๋ฌ๋ฆฌ์๋ [`Trainer`] ํด๋์ค๋ฅผ ๊ฐ์ธ๊ณ Llama-2 ๋ฐ Mistral๊ณผ ๊ฐ์ ์ธ์ด ๋ชจ๋ธ์ ์๋ ํ๊ท ๊ธฐ๋ฒ์ผ๋ก ํ๋ จํ๋ ๋ฐ ์ต์ ํ๋ [`~trl.SFTTrainer`] ํด๋์ค ์ ๋๋ค. [`~trl.SFTTrainer`]๋ ์ํ์ค ํจํน, LoRA, ์์ํ ๋ฐ DeepSpeed์ ๊ฐ์ ๊ธฐ๋ฅ์ ์ง์ํ์ฌ ํฌ๊ธฐ ์๊ด์์ด ๋ชจ๋ธ ํจ์จ์ ์ผ๋ก ํ์ฅํ ์ ์์ต๋๋ค. | |
| <br> | |
| ์ด๋ค ๋ค๋ฅธ [`Trainer`] ์ ํ ํด๋์ค์ ๋ํด ๋ ์๊ณ ์ถ๋ค๋ฉด [API ์ฐธ์กฐ](./main_classes/trainer)๋ฅผ ํ์ธํ์ฌ ์ธ์ ์ด๋ค ํด๋์ค๊ฐ ์ ํฉํ ์ง ์ผ๋ง๋ ์ง ํ์ธํ์ธ์. ์ผ๋ฐ์ ์ผ๋ก [`Trainer`]๋ ๊ฐ์ฅ ๋ค์ฌ๋ค๋ฅํ ์ต์ ์ผ๋ก, ๋ค์ํ ์์ ์ ์ ํฉํฉ๋๋ค. [`Seq2SeqTrainer`]๋ ์ํ์ค-ํฌ-์ํ์ค ์์ ์ ์ํด ์ค๊ณ๋์๊ณ , [`~trl.SFTTrainer`]๋ ์ธ์ด ๋ชจ๋ธ ํ๋ จ์ ์ํด ์ค๊ณ๋์์ต๋๋ค. | |
| </Tip> | |
| ์์ํ๊ธฐ ์ ์, ๋ถ์ฐ ํ๊ฒฝ์์ PyTorch ํ๋ จ๊ณผ ์คํ์ ํ ์ ์๊ฒ [Accelerate](https://hf.co/docs/accelerate) ๋ผ์ด๋ธ๋ฌ๋ฆฌ๊ฐ ์ค์น๋์๋์ง ํ์ธํ์ธ์. | |
| ```bash | |
| pip install accelerate | |
| # ์ ๊ทธ๋ ์ด๋ | |
| pip install accelerate --upgrade | |
| ``` | |
| ์ด ๊ฐ์ด๋๋ [`Trainer`] ํด๋์ค์ ๋ํ ๊ฐ์๋ฅผ ์ ๊ณตํฉ๋๋ค. | |
| ## ๊ธฐ๋ณธ ์ฌ์ฉ๋ฒ [[basic-usage]] | |
| [`Trainer`]๋ ๊ธฐ๋ณธ์ ์ธ ํ๋ จ ๋ฃจํ์ ํ์ํ ๋ชจ๋ ์ฝ๋๋ฅผ ํฌํจํ๊ณ ์์ต๋๋ค. | |
| 1. ์์ค์ ๊ณ์ฐํ๋ ํ๋ จ ๋จ๊ณ๋ฅผ ์ํํฉ๋๋ค. | |
| 2. [`~accelerate.Accelerator.backward`] ๋ฉ์๋๋ก ๊ทธ๋ ์ด๋์ธํธ๋ฅผ ๊ณ์ฐํฉ๋๋ค. | |
| 3. ๊ทธ๋ ์ด๋์ธํธ๋ฅผ ๊ธฐ๋ฐ์ผ๋ก ๊ฐ์ค์น๋ฅผ ์ ๋ฐ์ดํธํฉ๋๋ค. | |
| 4. ์ ํด์ง ์ํญ ์์ ๋๋ฌํ ๋๊น์ง ์ด ๊ณผ์ ์ ๋ฐ๋ณตํฉ๋๋ค. | |
| [`Trainer`] ํด๋์ค๋ PyTorch์ ํ๋ จ ๊ณผ์ ์ ์ต์ํ์ง ์๊ฑฐ๋ ๋ง ์์ํ ๊ฒฝ์ฐ์๋ ํ๋ จ์ด ๊ฐ๋ฅํ๋๋ก ํ์ํ ๋ชจ๋ ์ฝ๋๋ฅผ ์ถ์ํํ์์ต๋๋ค. ๋ํ ๋งค๋ฒ ํ๋ จ ๋ฃจํ๋ฅผ ์์ ์์ฑํ์ง ์์๋ ๋๋ฉฐ, ํ๋ จ์ ํ์ํ ๋ชจ๋ธ๊ณผ ๋ฐ์ดํฐ์ ๊ฐ์ ํ์ ๊ตฌ์ฑ ์์๋ง ์ ๊ณตํ๋ฉด, [Trainer] ํด๋์ค๊ฐ ๋๋จธ์ง๋ฅผ ์ฒ๋ฆฌํฉ๋๋ค. | |
| ํ๋ จ ์ต์ ์ด๋ ํ์ดํผํ๋ผ๋ฏธํฐ๋ฅผ ์ง์ ํ๋ ค๋ฉด, [`TrainingArguments`] ํด๋์ค์์ ํ์ธ ํ ์ ์์ต๋๋ค. ์๋ฅผ ๋ค์ด, ๋ชจ๋ธ์ ์ ์ฅํ ๋๋ ํ ๋ฆฌ๋ฅผ `output_dir`์ ์ ์ํ๊ณ , ํ๋ จ ํ์ Hub๋ก ๋ชจ๋ธ์ ํธ์ํ๋ ค๋ฉด `push_to_hub=True`๋ก ์ค์ ํฉ๋๋ค. | |
| ```py | |
| from transformers import TrainingArguments | |
| training_args = TrainingArguments( | |
| output_dir="your-model", | |
| learning_rate=2e-5, | |
| per_device_train_batch_size=16, | |
| per_device_eval_batch_size=16, | |
| num_train_epochs=2, | |
| weight_decay=0.01, | |
| eval_strategy="epoch", | |
| save_strategy="epoch", | |
| load_best_model_at_end=True, | |
| push_to_hub=True, | |
| ) | |
| ``` | |
| `training_args`๋ฅผ [`Trainer`]์ ๋ชจ๋ธ, ๋ฐ์ดํฐ์ , ๋ฐ์ดํฐ์ ์ ์ฒ๋ฆฌ ๋๊ตฌ(๋ฐ์ดํฐ ์ ํ์ ๋ฐ๋ผ ํ ํฌ๋์ด์ , ํน์ง ์ถ์ถ๊ธฐ ๋๋ ์ด๋ฏธ์ง ํ๋ก์ธ์์ผ ์ ์์), ๋ฐ์ดํฐ ์์ง๊ธฐ ๋ฐ ํ๋ จ ์ค ํ์ธํ ์งํ๋ฅผ ๊ณ์ฐํ ํจ์๋ฅผ ํจ๊ป ์ ๋ฌํ์ธ์. | |
| ๋ง์ง๋ง์ผ๋ก, [`~Trainer.train`]๋ฅผ ํธ์ถํ์ฌ ํ๋ จ์ ์์ํ์ธ์! | |
| ```py | |
| from transformers import Trainer | |
| trainer = Trainer( | |
| model=model, | |
| args=training_args, | |
| train_dataset=dataset["train"], | |
| eval_dataset=dataset["test"], | |
| tokenizer=tokenizer, | |
| data_collator=data_collator, | |
| compute_metrics=compute_metrics, | |
| ) | |
| trainer.train() | |
| ``` | |
| ### ์ฒดํฌํฌ์ธํธ [[checkpoints]] | |
| [`Trainer`] ํด๋์ค๋ [`TrainingArguments`]์ `output_dir` ๋งค๊ฐ๋ณ์์ ์ง์ ๋ ๋๋ ํ ๋ฆฌ์ ๋ชจ๋ธ ์ฒดํฌํฌ์ธํธ๋ฅผ ์ ์ฅํฉ๋๋ค. ์ฒดํฌํฌ์ธํธ๋ `checkpoint-000` ํ์ ํด๋์ ์ ์ฅ๋๋ฉฐ, ์ฌ๊ธฐ์ ๋์ ์ซ์๋ ํ๋ จ ๋จ๊ณ์ ํด๋นํฉ๋๋ค. ์ฒดํฌํฌ์ธํธ๋ฅผ ์ ์ฅํ๋ฉด ๋์ค์ ํ๋ จ์ ์ฌ๊ฐํ ๋ ์ ์ฉํฉ๋๋ค. | |
| ```py | |
| # ์ต์ ์ฒดํฌํฌ์ธํธ์์ ์ฌ๊ฐ | |
| trainer.train(resume_from_checkpoint=True) | |
| # ์ถ๋ ฅ ๋๋ ํ ๋ฆฌ์ ์ ์ฅ๋ ํน์ ์ฒดํฌํฌ์ธํธ์์ ์ฌ๊ฐ | |
| trainer.train(resume_from_checkpoint="your-model/checkpoint-1000") | |
| ``` | |
| ์ฒดํฌํฌ์ธํธ๋ฅผ Hub์ ํธ์ํ๋ ค๋ฉด [`TrainingArguments`]์์ `push_to_hub=True`๋ก ์ค์ ํ์ฌ ์ปค๋ฐํ๊ณ ํธ์ํ ์ ์์ต๋๋ค. ์ฒดํฌํฌ์ธํธ ์ ์ฅ ๋ฐฉ๋ฒ์ ๊ฒฐ์ ํ๋ ๋ค๋ฅธ ์ต์ ์ [`hub_strategy`](https://huggingface.co/docs/transformers/main_classes/trainer#transformers.TrainingArguments.hub_strategy) ๋งค๊ฐ๋ณ์์์ ์ค์ ํฉ๋๋ค: | |
| * `hub_strategy="checkpoint"`๋ ์ต์ ์ฒดํฌํฌ์ธํธ๋ฅผ "last-checkpoint"๋ผ๋ ํ์ ํด๋์ ํธ์ํ์ฌ ํ๋ จ์ ์ฌ๊ฐํ ์ ์์ต๋๋ค. | |
| * `hub_strategy="all_checkpoints"`๋ ๋ชจ๋ ์ฒดํฌํฌ์ธํธ๋ฅผ `output_dir`์ ์ ์๋ ๋๋ ํ ๋ฆฌ์ ํธ์ํฉ๋๋ค(๋ชจ๋ธ ๋ฆฌํฌ์งํ ๋ฆฌ์์ ํด๋๋น ํ๋์ ์ฒดํฌํฌ์ธํธ๋ฅผ ๋ณผ ์ ์์ต๋๋ค). | |
| ์ฒดํฌํฌ์ธํธ์์ ํ๋ จ์ ์ฌ๊ฐํ ๋, [`Trainer`]๋ ์ฒดํฌํฌ์ธํธ๊ฐ ์ ์ฅ๋ ๋์ ๋์ผํ Python, NumPy ๋ฐ PyTorch RNG ์ํ๋ฅผ ์ ์งํ๋ ค๊ณ ํฉ๋๋ค. ํ์ง๋ง PyTorch๋ ๊ธฐ๋ณธ ์ค์ ์ผ๋ก '์ผ๊ด๋ ๊ฒฐ๊ณผ๋ฅผ ๋ณด์ฅํ์ง ์์'์ผ๋ก ๋ง์ด ๋์ด์๊ธฐ ๋๋ฌธ์, RNG ์ํ๊ฐ ๋์ผํ ๊ฒ์ด๋ผ๊ณ ๋ณด์ฅํ ์ ์์ต๋๋ค. ๋ฐ๋ผ์, ์ผ๊ด๋ ๊ฒฐ๊ณผ๊ฐ ๋ณด์ฅ๋๋๋ก ํ์ฑํ ํ๋ ค๋ฉด, [๋๋ค์ฑ ์ ์ด](https://pytorch.org/docs/stable/notes/randomness#controlling-sources-of-randomness) ๊ฐ์ด๋๋ฅผ ์ฐธ๊ณ ํ์ฌ ํ๋ จ์ ์์ ํ ์ผ๊ด๋ ๊ฒฐ๊ณผ๋ฅผ ๋ณด์ฅ ๋ฐ๋๋ก ๋ง๋ค๊ธฐ ์ํด ํ์ฑํํ ์ ์๋ ํญ๋ชฉ์ ํ์ธํ์ธ์. ๋ค๋ง, ํน์ ์ค์ ์ ๊ฒฐ์ ์ ์ผ๋ก ๋ง๋ค๋ฉด ํ๋ จ์ด ๋๋ ค์ง ์ ์์ต๋๋ค. | |
| ## Trainer ๋ง์ถค ์ค์ [[customize-the-trainer]] | |
| [`Trainer`] ํด๋์ค๋ ์ ๊ทผ์ฑ๊ณผ ์ฉ์ด์ฑ์ ์ผ๋์ ๋๊ณ ์ค๊ณ๋์์ง๋ง, ๋ ๋ค์ํ ๊ธฐ๋ฅ์ ์ํ๋ ์ฌ์ฉ์๋ค์ ์ํด ๋ค์ํ ๋ง์ถค ์ค์ ์ต์ ์ ์ ๊ณตํฉ๋๋ค. [`Trainer`]์ ๋ง์ ๋ฉ์๋๋ ์๋ธํด๋์คํ ๋ฐ ์ค๋ฒ๋ผ์ด๋ํ์ฌ ์ํ๋ ๊ธฐ๋ฅ์ ์ ๊ณตํ ์ ์์ผ๋ฉฐ, ์ด๋ฅผ ํตํด ์ ์ฒด ํ๋ จ ๋ฃจํ๋ฅผ ๋ค์ ์์ฑํ ํ์ ์์ด ์ํ๋ ๊ธฐ๋ฅ์ ์ถ๊ฐํ ์ ์์ต๋๋ค. ์ด๋ฌํ ๋ฉ์๋์๋ ๋ค์์ด ํฌํจ๋ฉ๋๋ค: | |
| * [`~Trainer.get_train_dataloader`]๋ ํ๋ จ ๋ฐ์ดํฐ๋ก๋๋ฅผ ์์ฑํฉ๋๋ค. | |
| * [`~Trainer.get_eval_dataloader`]๋ ํ๊ฐ ๋ฐ์ดํฐ๋ก๋๋ฅผ ์์ฑํฉ๋๋ค. | |
| * [`~Trainer.get_test_dataloader`]๋ ํ ์คํธ ๋ฐ์ดํฐ๋ก๋๋ฅผ ์์ฑํฉ๋๋ค. | |
| * [`~Trainer.log`]๋ ํ๋ จ์ ๋ชจ๋ํฐ๋งํ๋ ๋ค์ํ ๊ฐ์ฒด์ ๋ํ ์ ๋ณด๋ฅผ ๋ก๊ทธ๋ก ๋จ๊น๋๋ค. | |
| * [`~Trainer.create_optimizer_and_scheduler`]๋ `__init__`์์ ์ ๋ฌ๋์ง ์์ ๊ฒฝ์ฐ ์ตํฐ๋ง์ด์ ์ ํ์ต๋ฅ ์ค์ผ์ค๋ฌ๋ฅผ ์์ฑํฉ๋๋ค. ์ด๋ค์ ๊ฐ๊ฐ [`~Trainer.create_optimizer`] ๋ฐ [`~Trainer.create_scheduler`]๋ก ๋ณ๋๋ก ๋ง์ถค ์ค์ ํ ์ ์์ต๋๋ค. | |
| * [`~Trainer.compute_loss`]๋ ํ๋ จ ์ ๋ ฅ ๋ฐฐ์น์ ๋ํ ์์ค์ ๊ณ์ฐํฉ๋๋ค. | |
| * [`~Trainer.training_step`]๋ ํ๋ จ ๋จ๊ณ๋ฅผ ์ํํฉ๋๋ค. | |
| * [`~Trainer.prediction_step`]๋ ์์ธก ๋ฐ ํ ์คํธ ๋จ๊ณ๋ฅผ ์ํํฉ๋๋ค. | |
| * [`~Trainer.evaluate`]๋ ๋ชจ๋ธ์ ํ๊ฐํ๊ณ ํ๊ฐ ์งํ์ ๋ฐํํฉ๋๋ค. | |
| * [`~Trainer.predict`]๋ ํ ์คํธ ์ธํธ์ ๋ํ ์์ธก(๋ ์ด๋ธ์ด ์๋ ๊ฒฝ์ฐ ์งํ ํฌํจ)์ ์ํํฉ๋๋ค. | |
| ์๋ฅผ ๋ค์ด, [`~Trainer.compute_loss`] ๋ฉ์๋๋ฅผ ๋ง์ถค ์ค์ ํ์ฌ ๊ฐ์ค ์์ค์ ์ฌ์ฉํ๋ ค๋ ๊ฒฝ์ฐ: | |
| ```py | |
| from torch import nn | |
| from transformers import Trainer | |
| class CustomTrainer(Trainer): | |
| def compute_loss(self, | |
| model, inputs, return_outputs=False): | |
| labels = inputs.pop("labels") | |
| # ์๋ฐฉํฅ ์ ํ | |
| outputs = model(**inputs) | |
| logits = outputs.get("logits") | |
| # ์๋ก ๋ค๋ฅธ ๊ฐ์ค์น๋ก 3๊ฐ์ ๋ ์ด๋ธ์ ๋ํ ์ฌ์ฉ์ ์ ์ ์์ค์ ๊ณ์ฐ | |
| loss_fct = nn.CrossEntropyLoss(weight=torch.tensor([1.0, 2.0, 3.0], device=model.device)) | |
| loss = loss_fct(logits.view(-1, self.model.config.num_labels), labels.view(-1)) | |
| return (loss, outputs) if return_outputs else loss | |
| ``` | |
| ### ์ฝ๋ฐฑ [[callbacks]] | |
| [`Trainer`]๋ฅผ ๋ง์ถค ์ค์ ํ๋ ๋ ๋ค๋ฅธ ๋ฐฉ๋ฒ์ [์ฝ๋ฐฑ](callbacks)์ ์ฌ์ฉํ๋ ๊ฒ์ ๋๋ค. ์ฝ๋ฐฑ์ ํ๋ จ ๋ฃจํ์์ *๋ณํ๋ฅผ ์ฃผ์ง ์์ต๋๋ค*. ํ๋ จ ๋ฃจํ์ ์ํ๋ฅผ ๊ฒ์ฌํ ํ ์ํ์ ๋ฐ๋ผ ์ผ๋ถ ์์ (์กฐ๊ธฐ ์ข ๋ฃ, ๊ฒฐ๊ณผ ๋ก๊ทธ ๋ฑ)์ ์คํํฉ๋๋ค. ์ฆ, ์ฝ๋ฐฑ์ ์ฌ์ฉ์ ์ ์ ์์ค ํจ์์ ๊ฐ์ ๊ฒ์ ๊ตฌํํ๋ ๋ฐ ์ฌ์ฉํ ์ ์์ผ๋ฉฐ, ์ด๋ฅผ ์ํด์๋ [`~Trainer.compute_loss`] ๋ฉ์๋๋ฅผ ์๋ธํด๋์คํํ๊ณ ์ค๋ฒ๋ผ์ด๋ํด์ผ ํฉ๋๋ค. | |
| ์๋ฅผ ๋ค์ด, ํ๋ จ ๋ฃจํ์ 10๋จ๊ณ ํ ์กฐ๊ธฐ ์ข ๋ฃ ์ฝ๋ฐฑ์ ์ถ๊ฐํ๋ ค๋ฉด ๋ค์๊ณผ ๊ฐ์ด ํฉ๋๋ค. | |
| ```py | |
| from transformers import TrainerCallback | |
| class EarlyStoppingCallback(TrainerCallback): | |
| def __init__(self, num_steps=10): | |
| self.num_steps = num_steps | |
| def on_step_end(self, args, state, control, **kwargs): | |
| if state.global_step >= self.num_steps: | |
| return {"should_training_stop": True} | |
| else: | |
| return {} | |
| ``` | |
| ๊ทธ๋ฐ ๋ค์, ์ด๋ฅผ [`Trainer`]์ `callback` ๋งค๊ฐ๋ณ์์ ์ ๋ฌํฉ๋๋ค. | |
| ```py | |
| from transformers import Trainer | |
| trainer = Trainer( | |
| model=model, | |
| args=training_args, | |
| train_dataset=dataset["train"], | |
| eval_dataset=dataset["test"], | |
| tokenizer=tokenizer, | |
| data_collator=data_collator, | |
| compute_metrics=compute_metrics, | |
| callbacks=[EarlyStoppingCallback()], | |
| ) | |
| ``` | |
| ## ๋ก๊น [[logging]] | |
| <Tip> | |
| ๋ก๊น API์ ๋ํ ์์ธํ ๋ด์ฉ์ [๋ก๊น ](./main_classes/logging) API ๋ ํผ๋ฐ์ค๋ฅผ ํ์ธํ์ธ์. | |
| </Tip> | |
| [`Trainer`]๋ ๊ธฐ๋ณธ์ ์ผ๋ก `logging.INFO`๋ก ์ค์ ๋์ด ์์ด ์ค๋ฅ, ๊ฒฝ๊ณ ๋ฐ ๊ธฐํ ๊ธฐ๋ณธ ์ ๋ณด๋ฅผ ๋ณด๊ณ ํฉ๋๋ค. ๋ถ์ฐ ํ๊ฒฝ์์๋ [`Trainer`] ๋ณต์ ๋ณธ์ด `logging.WARNING`์ผ๋ก ์ค์ ๋์ด ์ค๋ฅ์ ๊ฒฝ๊ณ ๋ง ๋ณด๊ณ ํฉ๋๋ค. [`TrainingArguments`]์ [`log_level`](https://huggingface.co/docs/transformers/main_classes/trainer#transformers.TrainingArguments.log_level) ๋ฐ [`log_level_replica`](https://huggingface.co/docs/transformers/main_classes/trainer#transformers.TrainingArguments.log_level_replica) ๋งค๊ฐ๋ณ์๋ก ๋ก๊ทธ ๋ ๋ฒจ์ ๋ณ๊ฒฝํ ์ ์์ต๋๋ค. | |
| ๊ฐ ๋ ธ๋์ ๋ก๊ทธ ๋ ๋ฒจ ์ค์ ์ ๊ตฌ์ฑํ๋ ค๋ฉด [`log_on_each_node`](https://huggingface.co/docs/transformers/main/en/main_classes/trainer#transformers.TrainingArguments.log_on_each_node) ๋งค๊ฐ๋ณ์๋ฅผ ์ฌ์ฉํ์ฌ ๊ฐ ๋ ธ๋์์ ๋ก๊ทธ ๋ ๋ฒจ์ ์ฌ์ฉํ ์ง ์๋๋ฉด ์ฃผ ๋ ธ๋์์๋ง ์ฌ์ฉํ ์ง ๊ฒฐ์ ํ์ธ์. | |
| <Tip> | |
| [`Trainer`]๋ [`Trainer.__init__`] ๋ฉ์๋์์ ๊ฐ ๋ ธ๋์ ๋ํด ๋ก๊ทธ ๋ ๋ฒจ์ ๋ณ๋๋ก ์ค์ ํ๋ฏ๋ก, ๋ค๋ฅธ Transformers ๊ธฐ๋ฅ์ ์ฌ์ฉํ ๊ฒฝ์ฐ [`Trainer`] ๊ฐ์ฒด๋ฅผ ์์ฑํ๊ธฐ ์ ์ ์ด๋ฅผ ๋ฏธ๋ฆฌ ์ค์ ํ๋ ๊ฒ์ด ์ข์ต๋๋ค. | |
| </Tip> | |
| ์๋ฅผ ๋ค์ด, ๋ฉ์ธ ์ฝ๋์ ๋ชจ๋์ ๊ฐ ๋ ธ๋์ ๋ฐ๋ผ ๋์ผํ ๋ก๊ทธ ๋ ๋ฒจ์ ์ฌ์ฉํ๋๋ก ์ค์ ํ๋ ค๋ฉด ๋ค์๊ณผ ๊ฐ์ด ํฉ๋๋ค. | |
| ```py | |
| logger = logging.getLogger(__name__) | |
| logging.basicConfig( | |
| format="%(asctime)s - %(levelname)s - %(name)s - %(message)s", | |
| datefmt="%m/%d/%Y %H:%M:%S", | |
| handlers=[logging.StreamHandler(sys.stdout)], | |
| ) | |
| log_level = training_args.get_process_log_level() | |
| logger.setLevel(log_level) | |
| datasets.utils.logging.set_verbosity(log_level) | |
| transformers.utils.logging.set_verbosity(log_level) | |
| trainer = Trainer(...) | |
| ``` | |
| ๊ฐ ๋ ธ๋์์ ๊ธฐ๋ก๋ ๋ด์ฉ์ ๊ตฌ์ฑํ๊ธฐ ์ํด `log_level`๊ณผ `log_level_replica`๋ฅผ ๋ค์ํ ์กฐํฉ์ผ๋ก ์ฌ์ฉํด๋ณด์ธ์. | |
| <hfoptions id="logging"> | |
| <hfoption id="single node"> | |
| ```bash | |
| my_app.py ... --log_level warning --log_level_replica error | |
| ``` | |
| </hfoption> | |
| <hfoption id="multi-node"> | |
| ๋ฉํฐ ๋ ธ๋ ํ๊ฒฝ์์๋ `log_on_each_node 0` ๋งค๊ฐ๋ณ์๋ฅผ ์ถ๊ฐํฉ๋๋ค. | |
| ```bash | |
| my_app.py ... --log_level warning --log_level_replica error --log_on_each_node 0 | |
| # ์ค๋ฅ๋ง ๋ณด๊ณ ํ๋๋ก ์ค์ | |
| my_app.py ... --log_level error --log_level_replica error --log_on_each_node 0 | |
| ``` | |
| </hfoption> | |
| </hfoptions> | |
| ## NEFTune [[neftune]] | |
| [NEFTune](https://hf.co/papers/2310.05914)์ ํ๋ จ ์ค ์๋ฒ ๋ฉ ๋ฒกํฐ์ ๋ ธ์ด์ฆ๋ฅผ ์ถ๊ฐํ์ฌ ์ฑ๋ฅ์ ํฅ์์ํฌ ์ ์๋ ๊ธฐ์ ์ ๋๋ค. [`Trainer`]์์ ์ด๋ฅผ ํ์ฑํํ๋ ค๋ฉด [`TrainingArguments`]์ `neftune_noise_alpha` ๋งค๊ฐ๋ณ์๋ฅผ ์ค์ ํ์ฌ ๋ ธ์ด์ฆ์ ์์ ์กฐ์ ํฉ๋๋ค. | |
| ```py | |
| from transformers import TrainingArguments, Trainer | |
| training_args = TrainingArguments(..., neftune_noise_alpha=0.1) | |
| trainer = Trainer(..., args=training_args) | |
| ``` | |
| NEFTune์ ์์์น ๋ชปํ ๋์์ ํผํ ๋ชฉ์ ์ผ๋ก ์ฒ์ ์๋ฒ ๋ฉ ๋ ์ด์ด๋ก ๋ณต์ํ๊ธฐ ์ํด ํ๋ จ ํ ๋นํ์ฑํ ๋ฉ๋๋ค. | |
| ## GaLore [[galore]] | |
| Gradient Low-Rank Projection (GaLore)์ ์ ์ฒด ๋งค๊ฐ๋ณ์๋ฅผ ํ์ตํ๋ฉด์๋ LoRA์ ๊ฐ์ ์ผ๋ฐ์ ์ธ ์ ๊ณ์ ์ ์ ๋ฐฉ๋ฒ๋ณด๋ค ๋ ๋ฉ๋ชจ๋ฆฌ ํจ์จ์ ์ธ ์ ๊ณ์ ํ์ต ์ ๋ต์ ๋๋ค. | |
| ๋จผ์ GaLore ๊ณต์ ๋ฆฌํฌ์งํ ๋ฆฌ๋ฅผ ์ค์นํฉ๋๋ค: | |
| ```bash | |
| pip install galore-torch | |
| ``` | |
| ๊ทธ๋ฐ ๋ค์ `optim`์ `["galore_adamw", "galore_adafactor", "galore_adamw_8bit"]` ์ค ํ๋์ ํจ๊ป `optim_target_modules`๋ฅผ ์ถ๊ฐํฉ๋๋ค. ์ด๋ ์ ์ฉํ๋ ค๋ ๋์ ๋ชจ๋ ์ด๋ฆ์ ํด๋นํ๋ ๋ฌธ์์ด, ์ ๊ท ํํ์ ๋๋ ์ ์ฒด ๊ฒฝ๋ก์ ๋ชฉ๋ก์ผ ์ ์์ต๋๋ค. ์๋๋ end-to-end ์์ ์คํฌ๋ฆฝํธ์ ๋๋ค(ํ์ํ ๊ฒฝ์ฐ `pip install trl datasets`๋ฅผ ์คํ): | |
| ```python | |
| import torch | |
| import datasets | |
| import trl | |
| from transformers import TrainingArguments, AutoConfig, AutoTokenizer, AutoModelForCausalLM | |
| train_dataset = datasets.load_dataset('imdb', split='train') | |
| args = TrainingArguments( | |
| output_dir="./test-galore", | |
| max_steps=100, | |
| per_device_train_batch_size=2, | |
| optim="galore_adamw", | |
| optim_target_modules=["attn", "mlp"] | |
| ) | |
| model_id = "google/gemma-2b" | |
| config = AutoConfig.from_pretrained(model_id) | |
| tokenizer = AutoTokenizer.from_pretrained(model_id) | |
| model = AutoModelForCausalLM.from_config(config).to(0) | |
| trainer = trl.SFTTrainer( | |
| model=model, | |
| args=args, | |
| train_dataset=train_dataset, | |
| dataset_text_field='text', | |
| max_seq_length=512, | |
| ) | |
| trainer.train() | |
| ``` | |
| GaLore๊ฐ ์ง์ํ๋ ์ถ๊ฐ ๋งค๊ฐ๋ณ์๋ฅผ ์ ๋ฌํ๋ ค๋ฉด `optim_args`๋ฅผ ์ค์ ํฉ๋๋ค. ์๋ฅผ ๋ค์ด: | |
| ```python | |
| import torch | |
| import datasets | |
| import trl | |
| from transformers import TrainingArguments, AutoConfig, AutoTokenizer, AutoModelForCausalLM | |
| train_dataset = datasets.load_dataset('imdb', split='train') | |
| args = TrainingArguments( | |
| output_dir="./test-galore", | |
| max_steps=100, | |
| per_device_train_batch_size=2, | |
| optim="galore_adamw", | |
| optim_target_modules=["attn", "mlp"], | |
| optim_args="rank=64, update_proj_gap=100, scale=0.10", | |
| ) | |
| model_id = "google/gemma-2b" | |
| config = AutoConfig.from_pretrained(model_id) | |
| tokenizer = AutoTokenizer.from_pretrained(model_id) | |
| model = AutoModelForCausalLM.from_config(config).to(0) | |
| trainer = trl.SFTTrainer( | |
| model=model, | |
| args=args, | |
| train_dataset=train_dataset, | |
| dataset_text_field='text', | |
| max_seq_length=512, | |
| ) | |
| trainer.train() | |
| ``` | |
| ํด๋น ๋ฐฉ๋ฒ์ ๋ํ ์์ธํ ๋ด์ฉ์ [์๋ณธ ๋ฆฌํฌ์งํ ๋ฆฌ](https://github.com/jiaweizzhao/GaLore) ๋๋ [๋ ผ๋ฌธ](https://arxiv.org/abs/2403.03507)์ ์ฐธ๊ณ ํ์ธ์. | |
| ํ์ฌ GaLore ๋ ์ด์ด๋ก ๊ฐ์ฃผ๋๋ Linear ๋ ์ด์ด๋ง ํ๋ จ ํ ์ ์์ผ๋ฉฐ, ์ ๊ณ์ ๋ถํด๋ฅผ ์ฌ์ฉํ์ฌ ํ๋ จ๋๊ณ ๋๋จธ์ง ๋ ์ด์ด๋ ๊ธฐ์กด ๋ฐฉ์์ผ๋ก ์ต์ ํ๋ฉ๋๋ค. | |
| ํ๋ จ ์์ ์ ์ ์๊ฐ์ด ์ฝ๊ฐ ๊ฑธ๋ฆด ์ ์์ต๋๋ค(NVIDIA A100์์ 2B ๋ชจ๋ธ์ ๊ฒฝ์ฐ ์ฝ 3๋ถ), ํ์ง๋ง ์ดํ ํ๋ จ์ ์ํํ๊ฒ ์งํ๋ฉ๋๋ค. | |
| ๋ค์๊ณผ ๊ฐ์ด ์ตํฐ๋ง์ด์ ์ด๋ฆ์ `layerwise`๋ฅผ ์ถ๊ฐํ์ฌ ๋ ์ด์ด๋ณ ์ต์ ํ๋ฅผ ์ํํ ์๋ ์์ต๋๋ค: | |
| ```python | |
| import torch | |
| import datasets | |
| import trl | |
| from transformers import TrainingArguments, AutoConfig, AutoTokenizer, AutoModelForCausalLM | |
| train_dataset = datasets.load_dataset('imdb', split='train') | |
| args = TrainingArguments( | |
| output_dir="./test-galore", | |
| max_steps=100, | |
| per_device_train_batch_size=2, | |
| optim="galore_adamw_layerwise", | |
| optim_target_modules=["attn", "mlp"] | |
| ) | |
| model_id = "google/gemma-2b" | |
| config = AutoConfig.from_pretrained(model_id) | |
| tokenizer = AutoTokenizer.from_pretrained(model_id) | |
| model = AutoModelForCausalLM.from_config(config).to(0) | |
| trainer = trl.SFTTrainer( | |
| model=model, | |
| args=args, | |
| train_dataset=train_dataset, | |
| dataset_text_field='text', | |
| max_seq_length=512, | |
| ) | |
| trainer.train() | |
| ``` | |
| ๋ ์ด์ด๋ณ ์ต์ ํ๋ ๋ค์ ์คํ์ ์ด๋ฉฐ DDP(๋ถ์ฐ ๋ฐ์ดํฐ ๋ณ๋ ฌ)๋ฅผ ์ง์ํ์ง ์์ผ๋ฏ๋ก, ๋จ์ผ GPU์์๋ง ํ๋ จ ์คํฌ๋ฆฝํธ๋ฅผ ์คํํ ์ ์์ต๋๋ค. ์์ธํ ๋ด์ฉ์ [์ด ๋ฌธ์๋ฅผ](https://github.com/jiaweizzhao/GaLore?tab=readme-ov-file#train-7b-model-with-a-single-gpu-with-24gb-memory)์ ์ฐธ์กฐํ์ธ์. gradient clipping, DeepSpeed ๋ฑ ๋ค๋ฅธ ๊ธฐ๋ฅ์ ๊ธฐ๋ณธ์ ์ผ๋ก ์ง์๋์ง ์์ ์ ์์ต๋๋ค. ์ด๋ฌํ ๋ฌธ์ ๊ฐ ๋ฐ์ํ๋ฉด [GitHub์ ์ด์๋ฅผ ์ฌ๋ ค์ฃผ์ธ์](https://github.com/huggingface/transformers/issues). | |
| ## LOMO ์ตํฐ๋ง์ด์ [[lomo-optimizer]] | |
| LOMO ์ตํฐ๋ง์ด์ ๋ [์ ํ๋ ์์์ผ๋ก ๋ํ ์ธ์ด ๋ชจ๋ธ์ ์ ์ฒด ๋งค๊ฐ๋ณ์ ๋ฏธ์ธ ์กฐ์ ](https://hf.co/papers/2306.09782)๊ณผ [์ ์ํ ํ์ต๋ฅ ์ ํตํ ์ ๋ฉ๋ชจ๋ฆฌ ์ต์ ํ(AdaLomo)](https://hf.co/papers/2310.10195)์์ ๋์ ๋์์ต๋๋ค. | |
| ์ด๋ค์ ๋ชจ๋ ํจ์จ์ ์ธ ์ ์ฒด ๋งค๊ฐ๋ณ์ ๋ฏธ์ธ ์กฐ์ ๋ฐฉ๋ฒ์ผ๋ก ๊ตฌ์ฑ๋์ด ์์ต๋๋ค. ์ด๋ฌํ ์ตํฐ๋ง์ด์ ๋ค์ ๋ฉ๋ชจ๋ฆฌ ์ฌ์ฉ๋์ ์ค์ด๊ธฐ ์ํด ๊ทธ๋ ์ด๋์ธํธ ๊ณ์ฐ๊ณผ ๋งค๊ฐ๋ณ์ ์ ๋ฐ์ดํธ๋ฅผ ํ๋์ ๋จ๊ณ๋ก ์ตํฉํฉ๋๋ค. LOMO์์ ์ง์๋๋ ์ตํฐ๋ง์ด์ ๋ `"lomo"`์ `"adalomo"`์ ๋๋ค. ๋จผ์ pypi์์ `pip install lomo-optim`๋ฅผ ํตํด `lomo`๋ฅผ ์ค์นํ๊ฑฐ๋, GitHub ์์ค์์ `pip install git+https://github.com/OpenLMLab/LOMO.git`๋ก ์ค์นํ์ธ์. | |
| <Tip> | |
| ์ ์์ ๋ฐ๋ฅด๋ฉด, `grad_norm` ์์ด `AdaLomo`๋ฅผ ์ฌ์ฉํ๋ ๊ฒ์ด ๋ ๋์ ์ฑ๋ฅ๊ณผ ๋์ ์ฒ๋ฆฌ๋์ ์ ๊ณตํ๋ค๊ณ ํฉ๋๋ค. | |
| </Tip> | |
| ๋ค์์ IMDB ๋ฐ์ดํฐ์ ์์ [google/gemma-2b](https://huggingface.co/google/gemma-2b)๋ฅผ ์ต๋ ์ ๋ฐ๋๋ก ๋ฏธ์ธ ์กฐ์ ํ๋ ๊ฐ๋จํ ์คํฌ๋ฆฝํธ์ ๋๋ค: | |
| ```python | |
| import torch | |
| import datasets | |
| from transformers import TrainingArguments, AutoTokenizer, AutoModelForCausalLM | |
| import trl | |
| train_dataset = datasets.load_dataset('imdb', split='train') | |
| args = TrainingArguments( | |
| output_dir="./test-lomo", | |
| max_steps=1000, | |
| per_device_train_batch_size=4, | |
| optim="adalomo", | |
| gradient_checkpointing=True, | |
| logging_strategy="steps", | |
| logging_steps=1, | |
| learning_rate=2e-6, | |
| save_strategy="no", | |
| run_name="lomo-imdb", | |
| ) | |
| model_id = "google/gemma-2b" | |
| tokenizer = AutoTokenizer.from_pretrained(model_id) | |
| model = AutoModelForCausalLM.from_pretrained(model_id, low_cpu_mem_usage=True).to(0) | |
| trainer = trl.SFTTrainer( | |
| model=model, | |
| args=args, | |
| train_dataset=train_dataset, | |
| dataset_text_field='text', | |
| max_seq_length=1024, | |
| ) | |
| trainer.train() | |
| ``` | |
| ## Accelerate์ Trainer [[accelerate-and-trainer]] | |
| [`Trainer`] ํด๋์ค๋ [Accelerate](https://hf.co/docs/accelerate)๋ก ๊ตฌ๋๋๋ฉฐ, ์ด๋ [FullyShardedDataParallel (FSDP)](https://pytorch.org/blog/introducing-pytorch-fully-sharded-data-parallel-api/) ๋ฐ [DeepSpeed](https://www.deepspeed.ai/)์ ๊ฐ์ ํตํฉ์ ์ง์ํ๋ ๋ถ์ฐ ํ๊ฒฝ์์ PyTorch ๋ชจ๋ธ์ ์ฝ๊ฒ ํ๋ จํ ์ ์๋ ๋ผ์ด๋ธ๋ฌ๋ฆฌ์ ๋๋ค. | |
| <Tip> | |
| FSDP ์ค๋ฉ ์ ๋ต, CPU ์คํ๋ก๋ ๋ฐ [`Trainer`]์ ํจ๊ป ์ฌ์ฉํ ์ ์๋ ๋ ๋ง์ ๊ธฐ๋ฅ์ ์์๋ณด๋ ค๋ฉด [Fully Sharded Data Parallel](fsdp) ๊ฐ์ด๋๋ฅผ ํ์ธํ์ธ์. | |
| </Tip> | |
| [`Trainer`]์ Accelerate๋ฅผ ์ฌ์ฉํ๋ ค๋ฉด [`accelerate.config`](https://huggingface.co/docs/accelerate/package_reference/cli#accelerate-config) ๋ช ๋ น์ ์คํํ์ฌ ํ๋ จ ํ๊ฒฝ์ ์ค์ ํ์ธ์. ์ด ๋ช ๋ น์ ํ๋ จ ์คํฌ๋ฆฝํธ๋ฅผ ์คํํ ๋ ์ฌ์ฉํ `config_file.yaml`์ ์์ฑํฉ๋๋ค. ์๋ฅผ ๋ค์ด, ๋ค์ ์์๋ ์ค์ ํ ์ ์๋ ์ผ๋ถ ๊ตฌ์ฑ ์์ ๋๋ค. | |
| <hfoptions id="config"> | |
| <hfoption id="DistributedDataParallel"> | |
| ```yml | |
| compute_environment: LOCAL_MACHINE | |
| distributed_type: MULTI_GPU | |
| downcast_bf16: 'no' | |
| gpu_ids: all | |
| machine_rank: 0 # ๋ ธ๋์ ๋ฐ๋ผ ์์๋ฅผ ๋ณ๊ฒฝํ์ธ์ | |
| main_process_ip: 192.168.20.1 | |
| main_process_port: 9898 | |
| main_training_function: main | |
| mixed_precision: fp16 | |
| num_machines: 2 | |
| num_processes: 8 | |
| rdzv_backend: static | |
| same_network: true | |
| tpu_env: [] | |
| tpu_use_cluster: false | |
| tpu_use_sudo: false | |
| use_cpu: false | |
| ``` | |
| </hfoption> | |
| <hfoption id="FSDP"> | |
| ```yml | |
| compute_environment: LOCAL_MACHINE | |
| distributed_type: FSDP | |
| downcast_bf16: 'no' | |
| fsdp_config: | |
| fsdp_auto_wrap_policy: TRANSFORMER_BASED_WRAP | |
| fsdp_backward_prefetch_policy: BACKWARD_PRE | |
| fsdp_forward_prefetch: true | |
| fsdp_offload_params: false | |
| fsdp_sharding_strategy: 1 | |
| fsdp_state_dict_type: FULL_STATE_DICT | |
| fsdp_sync_module_states: true | |
| fsdp_transformer_layer_cls_to_wrap: BertLayer | |
| fsdp_use_orig_params: true | |
| machine_rank: 0 | |
| main_training_function: main | |
| mixed_precision: bf16 | |
| num_machines: 1 | |
| num_processes: 2 | |
| rdzv_backend: static | |
| same_network: true | |
| tpu_env: [] | |
| tpu_use_cluster: false | |
| tpu_use_sudo: false | |
| use_cpu: false | |
| ``` | |
| </hfoption> | |
| <hfoption id="DeepSpeed"> | |
| ```yml | |
| compute_environment: LOCAL_MACHINE | |
| deepspeed_config: | |
| deepspeed_config_file: /home/user/configs/ds_zero3_config.json | |
| zero3_init_flag: true | |
| distributed_type: DEEPSPEED | |
| downcast_bf16: 'no' | |
| machine_rank: 0 | |
| main_training_function: main | |
| num_machines: 1 | |
| num_processes: 4 | |
| rdzv_backend: static | |
| same_network: true | |
| tpu_env: [] | |
| tpu_use_cluster: false | |
| tpu_use_sudo: false | |
| use_cpu: false | |
| ``` | |
| </hfoption> | |
| <hfoption id="DeepSpeed with Accelerate plugin"> | |
| ```yml | |
| compute_environment: LOCAL_MACHINE | |
| deepspeed_config: | |
| gradient_accumulation_steps: 1 | |
| gradient_clipping: 0.7 | |
| offload_optimizer_device: cpu | |
| offload_param_device: cpu | |
| zero3_init_flag: true | |
| zero_stage: 2 | |
| distributed_type: DEEPSPEED | |
| downcast_bf16: 'no' | |
| machine_rank: 0 | |
| main_training_function: main | |
| mixed_precision: bf16 | |
| num_machines: 1 | |
| num_processes: 4 | |
| rdzv_backend: static | |
| same_network: true | |
| tpu_env: [] | |
| tpu_use_cluster: false | |
| tpu_use_sudo: false | |
| use_cpu: false | |
| ``` | |
| </hfoption> | |
| </hfoptions> | |
| [`accelerate_launch`](https://huggingface.co/docs/accelerate/package_reference/cli#accelerate-launch) ๋ช ๋ น์ Accelerate์ [`Trainer`]๋ฅผ ์ฌ์ฉํ์ฌ ๋ถ์ฐ ์์คํ ์์ ํ๋ จ ์คํฌ๋ฆฝํธ๋ฅผ ์คํํ๋ ๊ถ์ฅ ๋ฐฉ๋ฒ์ด๋ฉฐ, `config_file.yaml`์ ์ง์ ๋ ๋งค๊ฐ๋ณ์๋ฅผ ์ฌ์ฉํฉ๋๋ค. ์ด ํ์ผ์ Accelerate ์บ์ ํด๋์ ์ ์ฅ๋๋ฉฐ `accelerate_launch`๋ฅผ ์คํํ ๋ ์๋์ผ๋ก ๋ก๋๋ฉ๋๋ค. | |
| ์๋ฅผ ๋ค์ด, FSDP ๊ตฌ์ฑ์ ์ฌ์ฉํ์ฌ [run_glue.py](https://github.com/huggingface/transformers/blob/f4db565b695582891e43a5e042e5d318e28f20b8/examples/pytorch/text-classification/run_glue.py#L4) ํ๋ จ ์คํฌ๋ฆฝํธ๋ฅผ ์คํํ๋ ค๋ฉด ๋ค์๊ณผ ๊ฐ์ด ํฉ๋๋ค: | |
| ```bash | |
| accelerate launch \ | |
| ./examples/pytorch/text-classification/run_glue.py \ | |
| --model_name_or_path google-bert/bert-base-cased \ | |
| --task_name $TASK_NAME \ | |
| --do_train \ | |
| --do_eval \ | |
| --max_seq_length 128 \ | |
| --per_device_train_batch_size 16 \ | |
| --learning_rate 5e-5 \ | |
| --num_train_epochs 3 \ | |
| --output_dir /tmp/$TASK_NAME/ \ | |
| --overwrite_output_dir | |
| ``` | |
| `config_file.yaml` ํ์ผ์ ๋งค๊ฐ๋ณ์๋ฅผ ์ง์ ์ง์ ํ ์๋ ์์ต๋๋ค: | |
| ```bash | |
| accelerate launch --num_processes=2 \ | |
| --use_fsdp \ | |
| --mixed_precision=bf16 \ | |
| --fsdp_auto_wrap_policy=TRANSFORMER_BASED_WRAP \ | |
| --fsdp_transformer_layer_cls_to_wrap="BertLayer" \ | |
| --fsdp_sharding_strategy=1 \ | |
| --fsdp_state_dict_type=FULL_STATE_DICT \ | |
| ./examples/pytorch/text-classification/run_glue.py \ | |
| --model_name_or_path google-bert/bert-base-cased \ | |
| --task_name $TASK_NAME \ | |
| --do_train \ | |
| --do_eval \ | |
| --max_seq_length 128 \ | |
| --per_device_train_batch_size 16 \ | |
| --learning_rate 5e-5 \ | |
| --num_train_epochs 3 \ | |
| --output_dir /tmp/$TASK_NAME/ \ | |
| --overwrite_output_dir | |
| ``` | |
| `accelerate_launch`์ ์ฌ์ฉ์ ์ ์ ๊ตฌ์ฑ์ ๋ํด ๋ ์์๋ณด๋ ค๋ฉด [Accelerate ์คํฌ๋ฆฝํธ ์คํ](https://huggingface.co/docs/accelerate/basic_tutorials/launch) ํํ ๋ฆฌ์ผ์ ํ์ธํ์ธ์. |