| <!--Copyright 2023 The HuggingFace Team. All rights reserved. | |
| Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with | |
| the License. You may obtain a copy of the License at | |
| http://www.apache.org/licenses/LICENSE-2.0 | |
| Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on | |
| an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the | |
| specific language governing permissions and limitations under the License. | |
| β οΈ Note that this file is in Markdown but contain specific syntax for our doc-builder (similar to MDX) that may not be | |
| rendered properly in your Markdown viewer. | |
| --> | |
| # Trainer [[trainer]] | |
| [`Trainer`]λ Transformers λΌμ΄λΈλ¬λ¦¬μ ꡬνλ PyTorch λͺ¨λΈμ λ°λ³΅νμ¬ νλ ¨ λ° νκ° κ³Όμ μ λλ€. νλ ¨μ νμν μμ(λͺ¨λΈ, ν ν¬λμ΄μ , λ°μ΄ν°μ , νκ° ν¨μ, νλ ¨ νμ΄νΌνλΌλ―Έν° λ±)λ§ μ 곡νλ©΄ [`Trainer`]κ° νμν λλ¨Έμ§ μμ μ μ²λ¦¬ν©λλ€. μ΄λ₯Ό ν΅ν΄ μ§μ νλ ¨ 루νλ₯Ό μμ±νμ§ μκ³ λ λΉ λ₯΄κ² νλ ¨μ μμν μ μμ΅λλ€. λν [`Trainer`]λ κ°λ ₯ν λ§μΆ€ μ€μ κ³Ό λ€μν νλ ¨ μ΅μ μ μ 곡νμ¬ μ¬μ©μ λ§μΆ€ νλ ¨μ΄ κ°λ₯ν©λλ€. | |
| <Tip> | |
| Transformersλ [`Trainer`] ν΄λμ€ μΈμλ λ²μμ΄λ μμ½κ³Ό κ°μ μνμ€-ν¬-μνμ€ μμ μ μν [`Seq2SeqTrainer`] ν΄λμ€λ μ 곡ν©λλ€. λν [TRL](https://hf.co/docs/trl) λΌμ΄λΈλ¬λ¦¬μλ [`Trainer`] ν΄λμ€λ₯Ό κ°μΈκ³ Llama-2 λ° Mistralκ³Ό κ°μ μΈμ΄ λͺ¨λΈμ μλ νκ· κΈ°λ²μΌλ‘ νλ ¨νλ λ° μ΅μ νλ [`~trl.SFTTrainer`] ν΄λμ€ μ λλ€. [`~trl.SFTTrainer`]λ μνμ€ ν¨νΉ, LoRA, μμν λ° DeepSpeedμ κ°μ κΈ°λ₯μ μ§μνμ¬ ν¬κΈ° μκ΄μμ΄ λͺ¨λΈ ν¨μ¨μ μΌλ‘ νμ₯ν μ μμ΅λλ€. | |
| <br> | |
| μ΄λ€ λ€λ₯Έ [`Trainer`] μ ν ν΄λμ€μ λν΄ λ μκ³ μΆλ€λ©΄ [API μ°Έμ‘°](./main_classes/trainer)λ₯Ό νμΈνμ¬ μΈμ μ΄λ€ ν΄λμ€κ° μ ν©ν μ§ μΌλ§λ μ§ νμΈνμΈμ. μΌλ°μ μΌλ‘ [`Trainer`]λ κ°μ₯ λ€μ¬λ€λ₯ν μ΅μ μΌλ‘, λ€μν μμ μ μ ν©ν©λλ€. [`Seq2SeqTrainer`]λ μνμ€-ν¬-μνμ€ μμ μ μν΄ μ€κ³λμκ³ , [`~trl.SFTTrainer`]λ μΈμ΄ λͺ¨λΈ νλ ¨μ μν΄ μ€κ³λμμ΅λλ€. | |
| </Tip> | |
| μμνκΈ° μ μ, λΆμ° νκ²½μμ PyTorch νλ ¨κ³Ό μ€νμ ν μ μκ² [Accelerate](https://hf.co/docs/accelerate) λΌμ΄λΈλ¬λ¦¬κ° μ€μΉλμλμ§ νμΈνμΈμ. | |
| ```bash | |
| pip install accelerate | |
| # μ κ·Έλ μ΄λ | |
| pip install accelerate --upgrade | |
| ``` | |
| μ΄ κ°μ΄λλ [`Trainer`] ν΄λμ€μ λν κ°μλ₯Ό μ 곡ν©λλ€. | |
| ## κΈ°λ³Έ μ¬μ©λ² [[basic-usage]] | |
| [`Trainer`]λ κΈ°λ³Έμ μΈ νλ ¨ 루νμ νμν λͺ¨λ μ½λλ₯Ό ν¬ν¨νκ³ μμ΅λλ€. | |
| 1. μμ€μ κ³μ°νλ νλ ¨ λ¨κ³λ₯Ό μνν©λλ€. | |
| 2. [`~accelerate.Accelerator.backward`] λ©μλλ‘ κ·Έλ μ΄λμΈνΈλ₯Ό κ³μ°ν©λλ€. | |
| 3. κ·Έλ μ΄λμΈνΈλ₯Ό κΈ°λ°μΌλ‘ κ°μ€μΉλ₯Ό μ λ°μ΄νΈν©λλ€. | |
| 4. μ ν΄μ§ μν μμ λλ¬ν λκΉμ§ μ΄ κ³Όμ μ λ°λ³΅ν©λλ€. | |
| [`Trainer`] ν΄λμ€λ PyTorchμ νλ ¨ κ³Όμ μ μ΅μνμ§ μκ±°λ λ§ μμν κ²½μ°μλ νλ ¨μ΄ κ°λ₯νλλ‘ νμν λͺ¨λ μ½λλ₯Ό μΆμννμμ΅λλ€. λν λ§€λ² νλ ¨ 루νλ₯Ό μμ μμ±νμ§ μμλ λλ©°, νλ ¨μ νμν λͺ¨λΈκ³Ό λ°μ΄ν°μ κ°μ νμ κ΅¬μ± μμλ§ μ 곡νλ©΄, [Trainer] ν΄λμ€κ° λλ¨Έμ§λ₯Ό μ²λ¦¬ν©λλ€. | |
| νλ ¨ μ΅μ μ΄λ νμ΄νΌνλΌλ―Έν°λ₯Ό μ§μ νλ €λ©΄, [`TrainingArguments`] ν΄λμ€μμ νμΈ ν μ μμ΅λλ€. μλ₯Ό λ€μ΄, λͺ¨λΈμ μ μ₯ν λλ ν 리λ₯Ό `output_dir`μ μ μνκ³ , νλ ¨ νμ Hubλ‘ λͺ¨λΈμ νΈμνλ €λ©΄ `push_to_hub=True`λ‘ μ€μ ν©λλ€. | |
| ```py | |
| from transformers import TrainingArguments | |
| training_args = TrainingArguments( | |
| output_dir="your-model", | |
| learning_rate=2e-5, | |
| per_device_train_batch_size=16, | |
| per_device_eval_batch_size=16, | |
| num_train_epochs=2, | |
| weight_decay=0.01, | |
| eval_strategy="epoch", | |
| save_strategy="epoch", | |
| load_best_model_at_end=True, | |
| push_to_hub=True, | |
| ) | |
| ``` | |
| `training_args`λ₯Ό [`Trainer`]μ λͺ¨λΈ, λ°μ΄ν°μ , λ°μ΄ν°μ μ μ²λ¦¬ λꡬ(λ°μ΄ν° μ νμ λ°λΌ ν ν¬λμ΄μ , νΉμ§ μΆμΆκΈ° λλ μ΄λ―Έμ§ νλ‘μΈμμΌ μ μμ), λ°μ΄ν° μμ§κΈ° λ° νλ ¨ μ€ νμΈν μ§νλ₯Ό κ³μ°ν ν¨μλ₯Ό ν¨κ» μ λ¬νμΈμ. | |
| λ§μ§λ§μΌλ‘, [`~Trainer.train`]λ₯Ό νΈμΆνμ¬ νλ ¨μ μμνμΈμ! | |
| ```py | |
| from transformers import Trainer | |
| trainer = Trainer( | |
| model=model, | |
| args=training_args, | |
| train_dataset=dataset["train"], | |
| eval_dataset=dataset["test"], | |
| tokenizer=tokenizer, | |
| data_collator=data_collator, | |
| compute_metrics=compute_metrics, | |
| ) | |
| trainer.train() | |
| ``` | |
| ### 체ν¬ν¬μΈνΈ [[checkpoints]] | |
| [`Trainer`] ν΄λμ€λ [`TrainingArguments`]μ `output_dir` λ§€κ°λ³μμ μ§μ λ λλ ν 리μ λͺ¨λΈ 체ν¬ν¬μΈνΈλ₯Ό μ μ₯ν©λλ€. 체ν¬ν¬μΈνΈλ `checkpoint-000` νμ ν΄λμ μ μ₯λλ©°, μ¬κΈ°μ λμ μ«μλ νλ ¨ λ¨κ³μ ν΄λΉν©λλ€. 체ν¬ν¬μΈνΈλ₯Ό μ μ₯νλ©΄ λμ€μ νλ ¨μ μ¬κ°ν λ μ μ©ν©λλ€. | |
| ```py | |
| # μ΅μ 체ν¬ν¬μΈνΈμμ μ¬κ° | |
| trainer.train(resume_from_checkpoint=True) | |
| # μΆλ ₯ λλ ν 리μ μ μ₯λ νΉμ 체ν¬ν¬μΈνΈμμ μ¬κ° | |
| trainer.train(resume_from_checkpoint="your-model/checkpoint-1000") | |
| ``` | |
| 체ν¬ν¬μΈνΈλ₯Ό Hubμ νΈμνλ €λ©΄ [`TrainingArguments`]μμ `push_to_hub=True`λ‘ μ€μ νμ¬ μ»€λ°νκ³ νΈμν μ μμ΅λλ€. 체ν¬ν¬μΈνΈ μ μ₯ λ°©λ²μ κ²°μ νλ λ€λ₯Έ μ΅μ μ [`hub_strategy`](https://huggingface.co/docs/transformers/main_classes/trainer#transformers.TrainingArguments.hub_strategy) λ§€κ°λ³μμμ μ€μ ν©λλ€: | |
| * `hub_strategy="checkpoint"`λ μ΅μ 체ν¬ν¬μΈνΈλ₯Ό "last-checkpoint"λΌλ νμ ν΄λμ νΈμνμ¬ νλ ¨μ μ¬κ°ν μ μμ΅λλ€. | |
| * `hub_strategy="all_checkpoints"`λ λͺ¨λ 체ν¬ν¬μΈνΈλ₯Ό `output_dir`μ μ μλ λλ ν 리μ νΈμν©λλ€(λͺ¨λΈ 리ν¬μ§ν 리μμ ν΄λλΉ νλμ 체ν¬ν¬μΈνΈλ₯Ό λ³Ό μ μμ΅λλ€). | |
| 체ν¬ν¬μΈνΈμμ νλ ¨μ μ¬κ°ν λ, [`Trainer`]λ 체ν¬ν¬μΈνΈκ° μ μ₯λ λμ λμΌν Python, NumPy λ° PyTorch RNG μνλ₯Ό μ μ§νλ €κ³ ν©λλ€. νμ§λ§ PyTorchλ κΈ°λ³Έ μ€μ μΌλ‘ 'μΌκ΄λ κ²°κ³Όλ₯Ό 보μ₯νμ§ μμ'μΌλ‘ λ§μ΄ λμ΄μκΈ° λλ¬Έμ, RNG μνκ° λμΌν κ²μ΄λΌκ³ 보μ₯ν μ μμ΅λλ€. λ°λΌμ, μΌκ΄λ κ²°κ³Όκ° λ³΄μ₯λλλ‘ νμ±ν νλ €λ©΄, [λλ€μ± μ μ΄](https://pytorch.org/docs/stable/notes/randomness#controlling-sources-of-randomness) κ°μ΄λλ₯Ό μ°Έκ³ νμ¬ νλ ¨μ μμ ν μΌκ΄λ κ²°κ³Όλ₯Ό 보μ₯ λ°λλ‘ λ§λ€κΈ° μν΄ νμ±νν μ μλ νλͺ©μ νμΈνμΈμ. λ€λ§, νΉμ μ€μ μ κ²°μ μ μΌλ‘ λ§λ€λ©΄ νλ ¨μ΄ λλ €μ§ μ μμ΅λλ€. | |
| ## Trainer λ§μΆ€ μ€μ [[customize-the-trainer]] | |
| [`Trainer`] ν΄λμ€λ μ κ·Όμ±κ³Ό μ©μ΄μ±μ μΌλμ λκ³ μ€κ³λμμ§λ§, λ λ€μν κΈ°λ₯μ μνλ μ¬μ©μλ€μ μν΄ λ€μν λ§μΆ€ μ€μ μ΅μ μ μ 곡ν©λλ€. [`Trainer`]μ λ§μ λ©μλλ μλΈν΄λμ€ν λ° μ€λ²λΌμ΄λνμ¬ μνλ κΈ°λ₯μ μ 곡ν μ μμΌλ©°, μ΄λ₯Ό ν΅ν΄ μ 체 νλ ¨ 루νλ₯Ό λ€μ μμ±ν νμ μμ΄ μνλ κΈ°λ₯μ μΆκ°ν μ μμ΅λλ€. μ΄λ¬ν λ©μλμλ λ€μμ΄ ν¬ν¨λ©λλ€: | |
| * [`~Trainer.get_train_dataloader`]λ νλ ¨ λ°μ΄ν°λ‘λλ₯Ό μμ±ν©λλ€. | |
| * [`~Trainer.get_eval_dataloader`]λ νκ° λ°μ΄ν°λ‘λλ₯Ό μμ±ν©λλ€. | |
| * [`~Trainer.get_test_dataloader`]λ ν μ€νΈ λ°μ΄ν°λ‘λλ₯Ό μμ±ν©λλ€. | |
| * [`~Trainer.log`]λ νλ ¨μ λͺ¨λν°λ§νλ λ€μν κ°μ²΄μ λν μ 보λ₯Ό λ‘κ·Έλ‘ λ¨κΉλλ€. | |
| * [`~Trainer.create_optimizer_and_scheduler`]λ `__init__`μμ μ λ¬λμ§ μμ κ²½μ° μ΅ν°λ§μ΄μ μ νμ΅λ₯ μ€μΌμ€λ¬λ₯Ό μμ±ν©λλ€. μ΄λ€μ κ°κ° [`~Trainer.create_optimizer`] λ° [`~Trainer.create_scheduler`]λ‘ λ³λλ‘ λ§μΆ€ μ€μ ν μ μμ΅λλ€. | |
| * [`~Trainer.compute_loss`]λ νλ ¨ μ λ ₯ λ°°μΉμ λν μμ€μ κ³μ°ν©λλ€. | |
| * [`~Trainer.training_step`]λ νλ ¨ λ¨κ³λ₯Ό μνν©λλ€. | |
| * [`~Trainer.prediction_step`]λ μμΈ‘ λ° ν μ€νΈ λ¨κ³λ₯Ό μνν©λλ€. | |
| * [`~Trainer.evaluate`]λ λͺ¨λΈμ νκ°νκ³ νκ° μ§νμ λ°νν©λλ€. | |
| * [`~Trainer.predict`]λ ν μ€νΈ μΈνΈμ λν μμΈ‘(λ μ΄λΈμ΄ μλ κ²½μ° μ§ν ν¬ν¨)μ μνν©λλ€. | |
| μλ₯Ό λ€μ΄, [`~Trainer.compute_loss`] λ©μλλ₯Ό λ§μΆ€ μ€μ νμ¬ κ°μ€ μμ€μ μ¬μ©νλ €λ κ²½μ°: | |
| ```py | |
| from torch import nn | |
| from transformers import Trainer | |
| class CustomTrainer(Trainer): | |
| def compute_loss(self, | |
| model, inputs, return_outputs=False): | |
| labels = inputs.pop("labels") | |
| # μλ°©ν₯ μ ν | |
| outputs = model(**inputs) | |
| logits = outputs.get("logits") | |
| # μλ‘ λ€λ₯Έ κ°μ€μΉλ‘ 3κ°μ λ μ΄λΈμ λν μ¬μ©μ μ μ μμ€μ κ³μ° | |
| loss_fct = nn.CrossEntropyLoss(weight=torch.tensor([1.0, 2.0, 3.0], device=model.device)) | |
| loss = loss_fct(logits.view(-1, self.model.config.num_labels), labels.view(-1)) | |
| return (loss, outputs) if return_outputs else loss | |
| ``` | |
| ### μ½λ°± [[callbacks]] | |
| [`Trainer`]λ₯Ό λ§μΆ€ μ€μ νλ λ λ€λ₯Έ λ°©λ²μ [μ½λ°±](callbacks)μ μ¬μ©νλ κ²μ λλ€. μ½λ°±μ νλ ¨ 루νμμ *λ³νλ₯Ό μ£Όμ§ μμ΅λλ€*. νλ ¨ 루νμ μνλ₯Ό κ²μ¬ν ν μνμ λ°λΌ μΌλΆ μμ (μ‘°κΈ° μ’ λ£, κ²°κ³Ό λ‘κ·Έ λ±)μ μ€νν©λλ€. μ¦, μ½λ°±μ μ¬μ©μ μ μ μμ€ ν¨μμ κ°μ κ²μ ꡬννλ λ° μ¬μ©ν μ μμΌλ©°, μ΄λ₯Ό μν΄μλ [`~Trainer.compute_loss`] λ©μλλ₯Ό μλΈν΄λμ€ννκ³ μ€λ²λΌμ΄λν΄μΌ ν©λλ€. | |
| μλ₯Ό λ€μ΄, νλ ¨ 루νμ 10λ¨κ³ ν μ‘°κΈ° μ’ λ£ μ½λ°±μ μΆκ°νλ €λ©΄ λ€μκ³Ό κ°μ΄ ν©λλ€. | |
| ```py | |
| from transformers import TrainerCallback | |
| class EarlyStoppingCallback(TrainerCallback): | |
| def __init__(self, num_steps=10): | |
| self.num_steps = num_steps | |
| def on_step_end(self, args, state, control, **kwargs): | |
| if state.global_step >= self.num_steps: | |
| return {"should_training_stop": True} | |
| else: | |
| return {} | |
| ``` | |
| κ·Έλ° λ€μ, μ΄λ₯Ό [`Trainer`]μ `callback` λ§€κ°λ³μμ μ λ¬ν©λλ€. | |
| ```py | |
| from transformers import Trainer | |
| trainer = Trainer( | |
| model=model, | |
| args=training_args, | |
| train_dataset=dataset["train"], | |
| eval_dataset=dataset["test"], | |
| tokenizer=tokenizer, | |
| data_collator=data_collator, | |
| compute_metrics=compute_metrics, | |
| callbacks=[EarlyStoppingCallback()], | |
| ) | |
| ``` | |
| ## λ‘κΉ [[logging]] | |
| <Tip> | |
| λ‘κΉ APIμ λν μμΈν λ΄μ©μ [λ‘κΉ ](./main_classes/logging) API λ νΌλ°μ€λ₯Ό νμΈνμΈμ. | |
| </Tip> | |
| [`Trainer`]λ κΈ°λ³Έμ μΌλ‘ `logging.INFO`λ‘ μ€μ λμ΄ μμ΄ μ€λ₯, κ²½κ³ λ° κΈ°ν κΈ°λ³Έ μ 보λ₯Ό λ³΄κ³ ν©λλ€. λΆμ° νκ²½μμλ [`Trainer`] 볡μ λ³Έμ΄ `logging.WARNING`μΌλ‘ μ€μ λμ΄ μ€λ₯μ κ²½κ³ λ§ λ³΄κ³ ν©λλ€. [`TrainingArguments`]μ [`log_level`](https://huggingface.co/docs/transformers/main_classes/trainer#transformers.TrainingArguments.log_level) λ° [`log_level_replica`](https://huggingface.co/docs/transformers/main_classes/trainer#transformers.TrainingArguments.log_level_replica) λ§€κ°λ³μλ‘ λ‘κ·Έ λ 벨μ λ³κ²½ν μ μμ΅λλ€. | |
| κ° λ Έλμ λ‘κ·Έ λ 벨 μ€μ μ ꡬμ±νλ €λ©΄ [`log_on_each_node`](https://huggingface.co/docs/transformers/main/en/main_classes/trainer#transformers.TrainingArguments.log_on_each_node) λ§€κ°λ³μλ₯Ό μ¬μ©νμ¬ κ° λ Έλμμ λ‘κ·Έ λ 벨μ μ¬μ©ν μ§ μλλ©΄ μ£Ό λ Έλμμλ§ μ¬μ©ν μ§ κ²°μ νμΈμ. | |
| <Tip> | |
| [`Trainer`]λ [`Trainer.__init__`] λ©μλμμ κ° λ Έλμ λν΄ λ‘κ·Έ λ 벨μ λ³λλ‘ μ€μ νλ―λ‘, λ€λ₯Έ Transformers κΈ°λ₯μ μ¬μ©ν κ²½μ° [`Trainer`] κ°μ²΄λ₯Ό μμ±νκΈ° μ μ μ΄λ₯Ό 미리 μ€μ νλ κ²μ΄ μ’μ΅λλ€. | |
| </Tip> | |
| μλ₯Ό λ€μ΄, λ©μΈ μ½λμ λͺ¨λμ κ° λ Έλμ λ°λΌ λμΌν λ‘κ·Έ λ 벨μ μ¬μ©νλλ‘ μ€μ νλ €λ©΄ λ€μκ³Ό κ°μ΄ ν©λλ€. | |
| ```py | |
| logger = logging.getLogger(__name__) | |
| logging.basicConfig( | |
| format="%(asctime)s - %(levelname)s - %(name)s - %(message)s", | |
| datefmt="%m/%d/%Y %H:%M:%S", | |
| handlers=[logging.StreamHandler(sys.stdout)], | |
| ) | |
| log_level = training_args.get_process_log_level() | |
| logger.setLevel(log_level) | |
| datasets.utils.logging.set_verbosity(log_level) | |
| transformers.utils.logging.set_verbosity(log_level) | |
| trainer = Trainer(...) | |
| ``` | |
| κ° λ Έλμμ κΈ°λ‘λ λ΄μ©μ ꡬμ±νκΈ° μν΄ `log_level`κ³Ό `log_level_replica`λ₯Ό λ€μν μ‘°ν©μΌλ‘ μ¬μ©ν΄λ³΄μΈμ. | |
| <hfoptions id="logging"> | |
| <hfoption id="single node"> | |
| ```bash | |
| my_app.py ... --log_level warning --log_level_replica error | |
| ``` | |
| </hfoption> | |
| <hfoption id="multi-node"> | |
| λ©ν° λ Έλ νκ²½μμλ `log_on_each_node 0` λ§€κ°λ³μλ₯Ό μΆκ°ν©λλ€. | |
| ```bash | |
| my_app.py ... --log_level warning --log_level_replica error --log_on_each_node 0 | |
| # μ€λ₯λ§ λ³΄κ³ νλλ‘ μ€μ | |
| my_app.py ... --log_level error --log_level_replica error --log_on_each_node 0 | |
| ``` | |
| </hfoption> | |
| </hfoptions> | |
| ## NEFTune [[neftune]] | |
| [NEFTune](https://hf.co/papers/2310.05914)μ νλ ¨ μ€ μλ² λ© λ²‘ν°μ λ Έμ΄μ¦λ₯Ό μΆκ°νμ¬ μ±λ₯μ ν₯μμν¬ μ μλ κΈ°μ μ λλ€. [`Trainer`]μμ μ΄λ₯Ό νμ±ννλ €λ©΄ [`TrainingArguments`]μ `neftune_noise_alpha` λ§€κ°λ³μλ₯Ό μ€μ νμ¬ λ Έμ΄μ¦μ μμ μ‘°μ ν©λλ€. | |
| ```py | |
| from transformers import TrainingArguments, Trainer | |
| training_args = TrainingArguments(..., neftune_noise_alpha=0.1) | |
| trainer = Trainer(..., args=training_args) | |
| ``` | |
| NEFTuneμ μμμΉ λͺ»ν λμμ νΌν λͺ©μ μΌλ‘ μ²μ μλ² λ© λ μ΄μ΄λ‘ 볡μνκΈ° μν΄ νλ ¨ ν λΉνμ±ν λ©λλ€. | |
| ## GaLore [[galore]] | |
| Gradient Low-Rank Projection (GaLore)μ μ 체 λ§€κ°λ³μλ₯Ό νμ΅νλ©΄μλ LoRAμ κ°μ μΌλ°μ μΈ μ κ³μ μ μ λ°©λ²λ³΄λ€ λ λ©λͺ¨λ¦¬ ν¨μ¨μ μΈ μ κ³μ νμ΅ μ λ΅μ λλ€. | |
| λ¨Όμ GaLore 곡μ 리ν¬μ§ν 리λ₯Ό μ€μΉν©λλ€: | |
| ```bash | |
| pip install galore-torch | |
| ``` | |
| κ·Έλ° λ€μ `optim`μ `["galore_adamw", "galore_adafactor", "galore_adamw_8bit"]` μ€ νλμ ν¨κ» `optim_target_modules`λ₯Ό μΆκ°ν©λλ€. μ΄λ μ μ©νλ €λ λμ λͺ¨λ μ΄λ¦μ ν΄λΉνλ λ¬Έμμ΄, μ κ· ννμ λλ μ 체 κ²½λ‘μ λͺ©λ‘μΌ μ μμ΅λλ€. μλλ end-to-end μμ μ€ν¬λ¦½νΈμ λλ€(νμν κ²½μ° `pip install trl datasets`λ₯Ό μ€ν): | |
| ```python | |
| import torch | |
| import datasets | |
| import trl | |
| from transformers import TrainingArguments, AutoConfig, AutoTokenizer, AutoModelForCausalLM | |
| train_dataset = datasets.load_dataset('imdb', split='train') | |
| args = TrainingArguments( | |
| output_dir="./test-galore", | |
| max_steps=100, | |
| per_device_train_batch_size=2, | |
| optim="galore_adamw", | |
| optim_target_modules=["attn", "mlp"] | |
| ) | |
| model_id = "google/gemma-2b" | |
| config = AutoConfig.from_pretrained(model_id) | |
| tokenizer = AutoTokenizer.from_pretrained(model_id) | |
| model = AutoModelForCausalLM.from_config(config).to(0) | |
| trainer = trl.SFTTrainer( | |
| model=model, | |
| args=args, | |
| train_dataset=train_dataset, | |
| dataset_text_field='text', | |
| max_seq_length=512, | |
| ) | |
| trainer.train() | |
| ``` | |
| GaLoreκ° μ§μνλ μΆκ° λ§€κ°λ³μλ₯Ό μ λ¬νλ €λ©΄ `optim_args`λ₯Ό μ€μ ν©λλ€. μλ₯Ό λ€μ΄: | |
| ```python | |
| import torch | |
| import datasets | |
| import trl | |
| from transformers import TrainingArguments, AutoConfig, AutoTokenizer, AutoModelForCausalLM | |
| train_dataset = datasets.load_dataset('imdb', split='train') | |
| args = TrainingArguments( | |
| output_dir="./test-galore", | |
| max_steps=100, | |
| per_device_train_batch_size=2, | |
| optim="galore_adamw", | |
| optim_target_modules=["attn", "mlp"], | |
| optim_args="rank=64, update_proj_gap=100, scale=0.10", | |
| ) | |
| model_id = "google/gemma-2b" | |
| config = AutoConfig.from_pretrained(model_id) | |
| tokenizer = AutoTokenizer.from_pretrained(model_id) | |
| model = AutoModelForCausalLM.from_config(config).to(0) | |
| trainer = trl.SFTTrainer( | |
| model=model, | |
| args=args, | |
| train_dataset=train_dataset, | |
| dataset_text_field='text', | |
| max_seq_length=512, | |
| ) | |
| trainer.train() | |
| ``` | |
| ν΄λΉ λ°©λ²μ λν μμΈν λ΄μ©μ [μλ³Έ 리ν¬μ§ν 리](https://github.com/jiaweizzhao/GaLore) λλ [λ Όλ¬Έ](https://arxiv.org/abs/2403.03507)μ μ°Έκ³ νμΈμ. | |
| νμ¬ GaLore λ μ΄μ΄λ‘ κ°μ£Όλλ Linear λ μ΄μ΄λ§ νλ ¨ ν μ μμΌλ©°, μ κ³μ λΆν΄λ₯Ό μ¬μ©νμ¬ νλ ¨λκ³ λλ¨Έμ§ λ μ΄μ΄λ κΈ°μ‘΄ λ°©μμΌλ‘ μ΅μ νλ©λλ€. | |
| νλ ¨ μμ μ μ μκ°μ΄ μ½κ° 걸릴 μ μμ΅λλ€(NVIDIA A100μμ 2B λͺ¨λΈμ κ²½μ° μ½ 3λΆ), νμ§λ§ μ΄ν νλ ¨μ μννκ² μ§νλ©λλ€. | |
| λ€μκ³Ό κ°μ΄ μ΅ν°λ§μ΄μ μ΄λ¦μ `layerwise`λ₯Ό μΆκ°νμ¬ λ μ΄μ΄λ³ μ΅μ νλ₯Ό μνν μλ μμ΅λλ€: | |
| ```python | |
| import torch | |
| import datasets | |
| import trl | |
| from transformers import TrainingArguments, AutoConfig, AutoTokenizer, AutoModelForCausalLM | |
| train_dataset = datasets.load_dataset('imdb', split='train') | |
| args = TrainingArguments( | |
| output_dir="./test-galore", | |
| max_steps=100, | |
| per_device_train_batch_size=2, | |
| optim="galore_adamw_layerwise", | |
| optim_target_modules=["attn", "mlp"] | |
| ) | |
| model_id = "google/gemma-2b" | |
| config = AutoConfig.from_pretrained(model_id) | |
| tokenizer = AutoTokenizer.from_pretrained(model_id) | |
| model = AutoModelForCausalLM.from_config(config).to(0) | |
| trainer = trl.SFTTrainer( | |
| model=model, | |
| args=args, | |
| train_dataset=train_dataset, | |
| dataset_text_field='text', | |
| max_seq_length=512, | |
| ) | |
| trainer.train() | |
| ``` | |
| λ μ΄μ΄λ³ μ΅μ νλ λ€μ μ€νμ μ΄λ©° DDP(λΆμ° λ°μ΄ν° λ³λ ¬)λ₯Ό μ§μνμ§ μμΌλ―λ‘, λ¨μΌ GPUμμλ§ νλ ¨ μ€ν¬λ¦½νΈλ₯Ό μ€νν μ μμ΅λλ€. μμΈν λ΄μ©μ [μ΄ λ¬Έμλ₯Ό](https://github.com/jiaweizzhao/GaLore?tab=readme-ov-file#train-7b-model-with-a-single-gpu-with-24gb-memory)μ μ°Έμ‘°νμΈμ. gradient clipping, DeepSpeed λ± λ€λ₯Έ κΈ°λ₯μ κΈ°λ³Έμ μΌλ‘ μ§μλμ§ μμ μ μμ΅λλ€. μ΄λ¬ν λ¬Έμ κ° λ°μνλ©΄ [GitHubμ μ΄μλ₯Ό μ¬λ €μ£ΌμΈμ](https://github.com/huggingface/transformers/issues). | |
| ## LOMO μ΅ν°λ§μ΄μ [[lomo-optimizer]] | |
| LOMO μ΅ν°λ§μ΄μ λ [μ νλ μμμΌλ‘ λν μΈμ΄ λͺ¨λΈμ μ 체 λ§€κ°λ³μ λ―ΈμΈ μ‘°μ ](https://hf.co/papers/2306.09782)κ³Ό [μ μν νμ΅λ₯ μ ν΅ν μ λ©λͺ¨λ¦¬ μ΅μ ν(AdaLomo)](https://hf.co/papers/2310.10195)μμ λμ λμμ΅λλ€. | |
| μ΄λ€μ λͺ¨λ ν¨μ¨μ μΈ μ 체 λ§€κ°λ³μ λ―ΈμΈ μ‘°μ λ°©λ²μΌλ‘ ꡬμ±λμ΄ μμ΅λλ€. μ΄λ¬ν μ΅ν°λ§μ΄μ λ€μ λ©λͺ¨λ¦¬ μ¬μ©λμ μ€μ΄κΈ° μν΄ κ·Έλ μ΄λμΈνΈ κ³μ°κ³Ό λ§€κ°λ³μ μ λ°μ΄νΈλ₯Ό νλμ λ¨κ³λ‘ μ΅ν©ν©λλ€. LOMOμμ μ§μλλ μ΅ν°λ§μ΄μ λ `"lomo"`μ `"adalomo"`μ λλ€. λ¨Όμ pypiμμ `pip install lomo-optim`λ₯Ό ν΅ν΄ `lomo`λ₯Ό μ€μΉνκ±°λ, GitHub μμ€μμ `pip install git+https://github.com/OpenLMLab/LOMO.git`λ‘ μ€μΉνμΈμ. | |
| <Tip> | |
| μ μμ λ°λ₯΄λ©΄, `grad_norm` μμ΄ `AdaLomo`λ₯Ό μ¬μ©νλ κ²μ΄ λ λμ μ±λ₯κ³Ό λμ μ²λ¦¬λμ μ 곡νλ€κ³ ν©λλ€. | |
| </Tip> | |
| λ€μμ IMDB λ°μ΄ν°μ μμ [google/gemma-2b](https://huggingface.co/google/gemma-2b)λ₯Ό μ΅λ μ λ°λλ‘ λ―ΈμΈ μ‘°μ νλ κ°λ¨ν μ€ν¬λ¦½νΈμ λλ€: | |
| ```python | |
| import torch | |
| import datasets | |
| from transformers import TrainingArguments, AutoTokenizer, AutoModelForCausalLM | |
| import trl | |
| train_dataset = datasets.load_dataset('imdb', split='train') | |
| args = TrainingArguments( | |
| output_dir="./test-lomo", | |
| max_steps=1000, | |
| per_device_train_batch_size=4, | |
| optim="adalomo", | |
| gradient_checkpointing=True, | |
| logging_strategy="steps", | |
| logging_steps=1, | |
| learning_rate=2e-6, | |
| save_strategy="no", | |
| run_name="lomo-imdb", | |
| ) | |
| model_id = "google/gemma-2b" | |
| tokenizer = AutoTokenizer.from_pretrained(model_id) | |
| model = AutoModelForCausalLM.from_pretrained(model_id, low_cpu_mem_usage=True).to(0) | |
| trainer = trl.SFTTrainer( | |
| model=model, | |
| args=args, | |
| train_dataset=train_dataset, | |
| dataset_text_field='text', | |
| max_seq_length=1024, | |
| ) | |
| trainer.train() | |
| ``` | |
| ## Accelerateμ Trainer [[accelerate-and-trainer]] | |
| [`Trainer`] ν΄λμ€λ [Accelerate](https://hf.co/docs/accelerate)λ‘ κ΅¬λλλ©°, μ΄λ [FullyShardedDataParallel (FSDP)](https://pytorch.org/blog/introducing-pytorch-fully-sharded-data-parallel-api/) λ° [DeepSpeed](https://www.deepspeed.ai/)μ κ°μ ν΅ν©μ μ§μνλ λΆμ° νκ²½μμ PyTorch λͺ¨λΈμ μ½κ² νλ ¨ν μ μλ λΌμ΄λΈλ¬λ¦¬μ λλ€. | |
| <Tip> | |
| FSDP μ€λ© μ λ΅, CPU μ€νλ‘λ λ° [`Trainer`]μ ν¨κ» μ¬μ©ν μ μλ λ λ§μ κΈ°λ₯μ μμλ³΄λ €λ©΄ [Fully Sharded Data Parallel](fsdp) κ°μ΄λλ₯Ό νμΈνμΈμ. | |
| </Tip> | |
| [`Trainer`]μ Accelerateλ₯Ό μ¬μ©νλ €λ©΄ [`accelerate.config`](https://huggingface.co/docs/accelerate/package_reference/cli#accelerate-config) λͺ λ Ήμ μ€ννμ¬ νλ ¨ νκ²½μ μ€μ νμΈμ. μ΄ λͺ λ Ήμ νλ ¨ μ€ν¬λ¦½νΈλ₯Ό μ€νν λ μ¬μ©ν `config_file.yaml`μ μμ±ν©λλ€. μλ₯Ό λ€μ΄, λ€μ μμλ μ€μ ν μ μλ μΌλΆ κ΅¬μ± μμ λλ€. | |
| <hfoptions id="config"> | |
| <hfoption id="DistributedDataParallel"> | |
| ```yml | |
| compute_environment: LOCAL_MACHINE | |
| distributed_type: MULTI_GPU | |
| downcast_bf16: 'no' | |
| gpu_ids: all | |
| machine_rank: 0 # λ Έλμ λ°λΌ μμλ₯Ό λ³κ²½νμΈμ | |
| main_process_ip: 192.168.20.1 | |
| main_process_port: 9898 | |
| main_training_function: main | |
| mixed_precision: fp16 | |
| num_machines: 2 | |
| num_processes: 8 | |
| rdzv_backend: static | |
| same_network: true | |
| tpu_env: [] | |
| tpu_use_cluster: false | |
| tpu_use_sudo: false | |
| use_cpu: false | |
| ``` | |
| </hfoption> | |
| <hfoption id="FSDP"> | |
| ```yml | |
| compute_environment: LOCAL_MACHINE | |
| distributed_type: FSDP | |
| downcast_bf16: 'no' | |
| fsdp_config: | |
| fsdp_auto_wrap_policy: TRANSFORMER_BASED_WRAP | |
| fsdp_backward_prefetch_policy: BACKWARD_PRE | |
| fsdp_forward_prefetch: true | |
| fsdp_offload_params: false | |
| fsdp_sharding_strategy: 1 | |
| fsdp_state_dict_type: FULL_STATE_DICT | |
| fsdp_sync_module_states: true | |
| fsdp_transformer_layer_cls_to_wrap: BertLayer | |
| fsdp_use_orig_params: true | |
| machine_rank: 0 | |
| main_training_function: main | |
| mixed_precision: bf16 | |
| num_machines: 1 | |
| num_processes: 2 | |
| rdzv_backend: static | |
| same_network: true | |
| tpu_env: [] | |
| tpu_use_cluster: false | |
| tpu_use_sudo: false | |
| use_cpu: false | |
| ``` | |
| </hfoption> | |
| <hfoption id="DeepSpeed"> | |
| ```yml | |
| compute_environment: LOCAL_MACHINE | |
| deepspeed_config: | |
| deepspeed_config_file: /home/user/configs/ds_zero3_config.json | |
| zero3_init_flag: true | |
| distributed_type: DEEPSPEED | |
| downcast_bf16: 'no' | |
| machine_rank: 0 | |
| main_training_function: main | |
| num_machines: 1 | |
| num_processes: 4 | |
| rdzv_backend: static | |
| same_network: true | |
| tpu_env: [] | |
| tpu_use_cluster: false | |
| tpu_use_sudo: false | |
| use_cpu: false | |
| ``` | |
| </hfoption> | |
| <hfoption id="DeepSpeed with Accelerate plugin"> | |
| ```yml | |
| compute_environment: LOCAL_MACHINE | |
| deepspeed_config: | |
| gradient_accumulation_steps: 1 | |
| gradient_clipping: 0.7 | |
| offload_optimizer_device: cpu | |
| offload_param_device: cpu | |
| zero3_init_flag: true | |
| zero_stage: 2 | |
| distributed_type: DEEPSPEED | |
| downcast_bf16: 'no' | |
| machine_rank: 0 | |
| main_training_function: main | |
| mixed_precision: bf16 | |
| num_machines: 1 | |
| num_processes: 4 | |
| rdzv_backend: static | |
| same_network: true | |
| tpu_env: [] | |
| tpu_use_cluster: false | |
| tpu_use_sudo: false | |
| use_cpu: false | |
| ``` | |
| </hfoption> | |
| </hfoptions> | |
| [`accelerate_launch`](https://huggingface.co/docs/accelerate/package_reference/cli#accelerate-launch) λͺ λ Ήμ Accelerateμ [`Trainer`]λ₯Ό μ¬μ©νμ¬ λΆμ° μμ€ν μμ νλ ¨ μ€ν¬λ¦½νΈλ₯Ό μ€ννλ κΆμ₯ λ°©λ²μ΄λ©°, `config_file.yaml`μ μ§μ λ λ§€κ°λ³μλ₯Ό μ¬μ©ν©λλ€. μ΄ νμΌμ Accelerate μΊμ ν΄λμ μ μ₯λλ©° `accelerate_launch`λ₯Ό μ€νν λ μλμΌλ‘ λ‘λλ©λλ€. | |
| μλ₯Ό λ€μ΄, FSDP ꡬμ±μ μ¬μ©νμ¬ [run_glue.py](https://github.com/huggingface/transformers/blob/f4db565b695582891e43a5e042e5d318e28f20b8/examples/pytorch/text-classification/run_glue.py#L4) νλ ¨ μ€ν¬λ¦½νΈλ₯Ό μ€ννλ €λ©΄ λ€μκ³Ό κ°μ΄ ν©λλ€: | |
| ```bash | |
| accelerate launch \ | |
| ./examples/pytorch/text-classification/run_glue.py \ | |
| --model_name_or_path google-bert/bert-base-cased \ | |
| --task_name $TASK_NAME \ | |
| --do_train \ | |
| --do_eval \ | |
| --max_seq_length 128 \ | |
| --per_device_train_batch_size 16 \ | |
| --learning_rate 5e-5 \ | |
| --num_train_epochs 3 \ | |
| --output_dir /tmp/$TASK_NAME/ \ | |
| --overwrite_output_dir | |
| ``` | |
| `config_file.yaml` νμΌμ λ§€κ°λ³μλ₯Ό μ§μ μ§μ ν μλ μμ΅λλ€: | |
| ```bash | |
| accelerate launch --num_processes=2 \ | |
| --use_fsdp \ | |
| --mixed_precision=bf16 \ | |
| --fsdp_auto_wrap_policy=TRANSFORMER_BASED_WRAP \ | |
| --fsdp_transformer_layer_cls_to_wrap="BertLayer" \ | |
| --fsdp_sharding_strategy=1 \ | |
| --fsdp_state_dict_type=FULL_STATE_DICT \ | |
| ./examples/pytorch/text-classification/run_glue.py \ | |
| --model_name_or_path google-bert/bert-base-cased \ | |
| --task_name $TASK_NAME \ | |
| --do_train \ | |
| --do_eval \ | |
| --max_seq_length 128 \ | |
| --per_device_train_batch_size 16 \ | |
| --learning_rate 5e-5 \ | |
| --num_train_epochs 3 \ | |
| --output_dir /tmp/$TASK_NAME/ \ | |
| --overwrite_output_dir | |
| ``` | |
| `accelerate_launch`μ μ¬μ©μ μ μ ꡬμ±μ λν΄ λ μμλ³΄λ €λ©΄ [Accelerate μ€ν¬λ¦½νΈ μ€ν](https://huggingface.co/docs/accelerate/basic_tutorials/launch) νν 리μΌμ νμΈνμΈμ. |