ghrua's picture
Initial commit with Dockerfile
8b821fa
<!--Copyright 2023 The HuggingFace Team. All rights reserved.
Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with
the License. You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on
an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the
specific language governing permissions and limitations under the License.
โš ๏ธ Note that this file is in Markdown but contain specific syntax for our doc-builder (similar to MDX) that may not be
rendered properly in your Markdown viewer.
-->
# Trainer [[trainer]]
[`Trainer`]๋Š” Transformers ๋ผ์ด๋ธŒ๋Ÿฌ๋ฆฌ์— ๊ตฌํ˜„๋œ PyTorch ๋ชจ๋ธ์„ ๋ฐ˜๋ณตํ•˜์—ฌ ํ›ˆ๋ จ ๋ฐ ํ‰๊ฐ€ ๊ณผ์ •์ž…๋‹ˆ๋‹ค. ํ›ˆ๋ จ์— ํ•„์š”ํ•œ ์š”์†Œ(๋ชจ๋ธ, ํ† ํฌ๋‚˜์ด์ €, ๋ฐ์ดํ„ฐ์…‹, ํ‰๊ฐ€ ํ•จ์ˆ˜, ํ›ˆ๋ จ ํ•˜์ดํผํŒŒ๋ผ๋ฏธํ„ฐ ๋“ฑ)๋งŒ ์ œ๊ณตํ•˜๋ฉด [`Trainer`]๊ฐ€ ํ•„์š”ํ•œ ๋‚˜๋จธ์ง€ ์ž‘์—…์„ ์ฒ˜๋ฆฌํ•ฉ๋‹ˆ๋‹ค. ์ด๋ฅผ ํ†ตํ•ด ์ง์ ‘ ํ›ˆ๋ จ ๋ฃจํ”„๋ฅผ ์ž‘์„ฑํ•˜์ง€ ์•Š๊ณ ๋„ ๋น ๋ฅด๊ฒŒ ํ›ˆ๋ จ์„ ์‹œ์ž‘ํ•  ์ˆ˜ ์žˆ์Šต๋‹ˆ๋‹ค. ๋˜ํ•œ [`Trainer`]๋Š” ๊ฐ•๋ ฅํ•œ ๋งž์ถค ์„ค์ •๊ณผ ๋‹ค์–‘ํ•œ ํ›ˆ๋ จ ์˜ต์…˜์„ ์ œ๊ณตํ•˜์—ฌ ์‚ฌ์šฉ์ž ๋งž์ถค ํ›ˆ๋ จ์ด ๊ฐ€๋Šฅํ•ฉ๋‹ˆ๋‹ค.
<Tip>
Transformers๋Š” [`Trainer`] ํด๋ž˜์Šค ์™ธ์—๋„ ๋ฒˆ์—ญ์ด๋‚˜ ์š”์•ฝ๊ณผ ๊ฐ™์€ ์‹œํ€€์Šค-ํˆฌ-์‹œํ€€์Šค ์ž‘์—…์„ ์œ„ํ•œ [`Seq2SeqTrainer`] ํด๋ž˜์Šค๋„ ์ œ๊ณตํ•ฉ๋‹ˆ๋‹ค. ๋˜ํ•œ [TRL](https://hf.co/docs/trl) ๋ผ์ด๋ธŒ๋Ÿฌ๋ฆฌ์—๋Š” [`Trainer`] ํด๋ž˜์Šค๋ฅผ ๊ฐ์‹ธ๊ณ  Llama-2 ๋ฐ Mistral๊ณผ ๊ฐ™์€ ์–ธ์–ด ๋ชจ๋ธ์„ ์ž๋™ ํšŒ๊ท€ ๊ธฐ๋ฒ•์œผ๋กœ ํ›ˆ๋ จํ•˜๋Š” ๋ฐ ์ตœ์ ํ™”๋œ [`~trl.SFTTrainer`] ํด๋ž˜์Šค ์ž…๋‹ˆ๋‹ค. [`~trl.SFTTrainer`]๋Š” ์‹œํ€€์Šค ํŒจํ‚น, LoRA, ์–‘์žํ™” ๋ฐ DeepSpeed์™€ ๊ฐ™์€ ๊ธฐ๋Šฅ์„ ์ง€์›ํ•˜์—ฌ ํฌ๊ธฐ ์ƒ๊ด€์—†์ด ๋ชจ๋ธ ํšจ์œจ์ ์œผ๋กœ ํ™•์žฅํ•  ์ˆ˜ ์žˆ์Šต๋‹ˆ๋‹ค.
<br>
์ด๋“ค ๋‹ค๋ฅธ [`Trainer`] ์œ ํ˜• ํด๋ž˜์Šค์— ๋Œ€ํ•ด ๋” ์•Œ๊ณ  ์‹ถ๋‹ค๋ฉด [API ์ฐธ์กฐ](./main_classes/trainer)๋ฅผ ํ™•์ธํ•˜์—ฌ ์–ธ์ œ ์–ด๋–ค ํด๋ž˜์Šค๊ฐ€ ์ ํ•ฉํ• ์ง€ ์–ผ๋งˆ๋“ ์ง€ ํ™•์ธํ•˜์„ธ์š”. ์ผ๋ฐ˜์ ์œผ๋กœ [`Trainer`]๋Š” ๊ฐ€์žฅ ๋‹ค์žฌ๋‹ค๋Šฅํ•œ ์˜ต์…˜์œผ๋กœ, ๋‹ค์–‘ํ•œ ์ž‘์—…์— ์ ํ•ฉํ•ฉ๋‹ˆ๋‹ค. [`Seq2SeqTrainer`]๋Š” ์‹œํ€€์Šค-ํˆฌ-์‹œํ€€์Šค ์ž‘์—…์„ ์œ„ํ•ด ์„ค๊ณ„๋˜์—ˆ๊ณ , [`~trl.SFTTrainer`]๋Š” ์–ธ์–ด ๋ชจ๋ธ ํ›ˆ๋ จ์„ ์œ„ํ•ด ์„ค๊ณ„๋˜์—ˆ์Šต๋‹ˆ๋‹ค.
</Tip>
์‹œ์ž‘ํ•˜๊ธฐ ์ „์—, ๋ถ„์‚ฐ ํ™˜๊ฒฝ์—์„œ PyTorch ํ›ˆ๋ จ๊ณผ ์‹คํ–‰์„ ํ•  ์ˆ˜ ์žˆ๊ฒŒ [Accelerate](https://hf.co/docs/accelerate) ๋ผ์ด๋ธŒ๋Ÿฌ๋ฆฌ๊ฐ€ ์„ค์น˜๋˜์—ˆ๋Š”์ง€ ํ™•์ธํ•˜์„ธ์š”.
```bash
pip install accelerate
# ์—…๊ทธ๋ ˆ์ด๋“œ
pip install accelerate --upgrade
```
์ด ๊ฐ€์ด๋“œ๋Š” [`Trainer`] ํด๋ž˜์Šค์— ๋Œ€ํ•œ ๊ฐœ์š”๋ฅผ ์ œ๊ณตํ•ฉ๋‹ˆ๋‹ค.
## ๊ธฐ๋ณธ ์‚ฌ์šฉ๋ฒ• [[basic-usage]]
[`Trainer`]๋Š” ๊ธฐ๋ณธ์ ์ธ ํ›ˆ๋ จ ๋ฃจํ”„์— ํ•„์š”ํ•œ ๋ชจ๋“  ์ฝ”๋“œ๋ฅผ ํฌํ•จํ•˜๊ณ  ์žˆ์Šต๋‹ˆ๋‹ค.
1. ์†์‹ค์„ ๊ณ„์‚ฐํ•˜๋Š” ํ›ˆ๋ จ ๋‹จ๊ณ„๋ฅผ ์ˆ˜ํ–‰ํ•ฉ๋‹ˆ๋‹ค.
2. [`~accelerate.Accelerator.backward`] ๋ฉ”์†Œ๋“œ๋กœ ๊ทธ๋ ˆ์ด๋””์–ธํŠธ๋ฅผ ๊ณ„์‚ฐํ•ฉ๋‹ˆ๋‹ค.
3. ๊ทธ๋ ˆ์ด๋””์–ธํŠธ๋ฅผ ๊ธฐ๋ฐ˜์œผ๋กœ ๊ฐ€์ค‘์น˜๋ฅผ ์—…๋ฐ์ดํŠธํ•ฉ๋‹ˆ๋‹ค.
4. ์ •ํ•ด์ง„ ์—ํญ ์ˆ˜์— ๋„๋‹ฌํ•  ๋•Œ๊นŒ์ง€ ์ด ๊ณผ์ •์„ ๋ฐ˜๋ณตํ•ฉ๋‹ˆ๋‹ค.
[`Trainer`] ํด๋ž˜์Šค๋Š” PyTorch์™€ ํ›ˆ๋ จ ๊ณผ์ •์— ์ต์ˆ™ํ•˜์ง€ ์•Š๊ฑฐ๋‚˜ ๋ง‰ ์‹œ์ž‘ํ•œ ๊ฒฝ์šฐ์—๋„ ํ›ˆ๋ จ์ด ๊ฐ€๋Šฅํ•˜๋„๋ก ํ•„์š”ํ•œ ๋ชจ๋“  ์ฝ”๋“œ๋ฅผ ์ถ”์ƒํ™”ํ•˜์˜€์Šต๋‹ˆ๋‹ค. ๋˜ํ•œ ๋งค๋ฒˆ ํ›ˆ๋ จ ๋ฃจํ”„๋ฅผ ์†์ˆ˜ ์ž‘์„ฑํ•˜์ง€ ์•Š์•„๋„ ๋˜๋ฉฐ, ํ›ˆ๋ จ์— ํ•„์š”ํ•œ ๋ชจ๋ธ๊ณผ ๋ฐ์ดํ„ฐ์…‹ ๊ฐ™์€ ํ•„์ˆ˜ ๊ตฌ์„ฑ ์š”์†Œ๋งŒ ์ œ๊ณตํ•˜๋ฉด, [Trainer] ํด๋ž˜์Šค๊ฐ€ ๋‚˜๋จธ์ง€๋ฅผ ์ฒ˜๋ฆฌํ•ฉ๋‹ˆ๋‹ค.
ํ›ˆ๋ จ ์˜ต์…˜์ด๋‚˜ ํ•˜์ดํผํŒŒ๋ผ๋ฏธํ„ฐ๋ฅผ ์ง€์ •ํ•˜๋ ค๋ฉด, [`TrainingArguments`] ํด๋ž˜์Šค์—์„œ ํ™•์ธ ํ•  ์ˆ˜ ์žˆ์Šต๋‹ˆ๋‹ค. ์˜ˆ๋ฅผ ๋“ค์–ด, ๋ชจ๋ธ์„ ์ €์žฅํ•  ๋””๋ ‰ํ† ๋ฆฌ๋ฅผ `output_dir`์— ์ •์˜ํ•˜๊ณ , ํ›ˆ๋ จ ํ›„์— Hub๋กœ ๋ชจ๋ธ์„ ํ‘ธ์‹œํ•˜๋ ค๋ฉด `push_to_hub=True`๋กœ ์„ค์ •ํ•ฉ๋‹ˆ๋‹ค.
```py
from transformers import TrainingArguments
training_args = TrainingArguments(
output_dir="your-model",
learning_rate=2e-5,
per_device_train_batch_size=16,
per_device_eval_batch_size=16,
num_train_epochs=2,
weight_decay=0.01,
eval_strategy="epoch",
save_strategy="epoch",
load_best_model_at_end=True,
push_to_hub=True,
)
```
`training_args`๋ฅผ [`Trainer`]์— ๋ชจ๋ธ, ๋ฐ์ดํ„ฐ์…‹, ๋ฐ์ดํ„ฐ์…‹ ์ „์ฒ˜๋ฆฌ ๋„๊ตฌ(๋ฐ์ดํ„ฐ ์œ ํ˜•์— ๋”ฐ๋ผ ํ† ํฌ๋‚˜์ด์ €, ํŠน์ง• ์ถ”์ถœ๊ธฐ ๋˜๋Š” ์ด๋ฏธ์ง€ ํ”„๋กœ์„ธ์„œ์ผ ์ˆ˜ ์žˆ์Œ), ๋ฐ์ดํ„ฐ ์ˆ˜์ง‘๊ธฐ ๋ฐ ํ›ˆ๋ จ ์ค‘ ํ™•์ธํ•  ์ง€ํ‘œ๋ฅผ ๊ณ„์‚ฐํ•  ํ•จ์ˆ˜๋ฅผ ํ•จ๊ป˜ ์ „๋‹ฌํ•˜์„ธ์š”.
๋งˆ์ง€๋ง‰์œผ๋กœ, [`~Trainer.train`]๋ฅผ ํ˜ธ์ถœํ•˜์—ฌ ํ›ˆ๋ จ์„ ์‹œ์ž‘ํ•˜์„ธ์š”!
```py
from transformers import Trainer
trainer = Trainer(
model=model,
args=training_args,
train_dataset=dataset["train"],
eval_dataset=dataset["test"],
tokenizer=tokenizer,
data_collator=data_collator,
compute_metrics=compute_metrics,
)
trainer.train()
```
### ์ฒดํฌํฌ์ธํŠธ [[checkpoints]]
[`Trainer`] ํด๋ž˜์Šค๋Š” [`TrainingArguments`]์˜ `output_dir` ๋งค๊ฐœ๋ณ€์ˆ˜์— ์ง€์ •๋œ ๋””๋ ‰ํ† ๋ฆฌ์— ๋ชจ๋ธ ์ฒดํฌํฌ์ธํŠธ๋ฅผ ์ €์žฅํ•ฉ๋‹ˆ๋‹ค. ์ฒดํฌํฌ์ธํŠธ๋Š” `checkpoint-000` ํ•˜์œ„ ํด๋”์— ์ €์žฅ๋˜๋ฉฐ, ์—ฌ๊ธฐ์„œ ๋์˜ ์ˆซ์ž๋Š” ํ›ˆ๋ จ ๋‹จ๊ณ„์— ํ•ด๋‹นํ•ฉ๋‹ˆ๋‹ค. ์ฒดํฌํฌ์ธํŠธ๋ฅผ ์ €์žฅํ•˜๋ฉด ๋‚˜์ค‘์— ํ›ˆ๋ จ์„ ์žฌ๊ฐœํ•  ๋•Œ ์œ ์šฉํ•ฉ๋‹ˆ๋‹ค.
```py
# ์ตœ์‹  ์ฒดํฌํฌ์ธํŠธ์—์„œ ์žฌ๊ฐœ
trainer.train(resume_from_checkpoint=True)
# ์ถœ๋ ฅ ๋””๋ ‰ํ† ๋ฆฌ์— ์ €์žฅ๋œ ํŠน์ • ์ฒดํฌํฌ์ธํŠธ์—์„œ ์žฌ๊ฐœ
trainer.train(resume_from_checkpoint="your-model/checkpoint-1000")
```
์ฒดํฌํฌ์ธํŠธ๋ฅผ Hub์— ํ‘ธ์‹œํ•˜๋ ค๋ฉด [`TrainingArguments`]์—์„œ `push_to_hub=True`๋กœ ์„ค์ •ํ•˜์—ฌ ์ปค๋ฐ‹ํ•˜๊ณ  ํ‘ธ์‹œํ•  ์ˆ˜ ์žˆ์Šต๋‹ˆ๋‹ค. ์ฒดํฌํฌ์ธํŠธ ์ €์žฅ ๋ฐฉ๋ฒ•์„ ๊ฒฐ์ •ํ•˜๋Š” ๋‹ค๋ฅธ ์˜ต์…˜์€ [`hub_strategy`](https://huggingface.co/docs/transformers/main_classes/trainer#transformers.TrainingArguments.hub_strategy) ๋งค๊ฐœ๋ณ€์ˆ˜์—์„œ ์„ค์ •ํ•ฉ๋‹ˆ๋‹ค:
* `hub_strategy="checkpoint"`๋Š” ์ตœ์‹  ์ฒดํฌํฌ์ธํŠธ๋ฅผ "last-checkpoint"๋ผ๋Š” ํ•˜์œ„ ํด๋”์— ํ‘ธ์‹œํ•˜์—ฌ ํ›ˆ๋ จ์„ ์žฌ๊ฐœํ•  ์ˆ˜ ์žˆ์Šต๋‹ˆ๋‹ค.
* `hub_strategy="all_checkpoints"`๋Š” ๋ชจ๋“  ์ฒดํฌํฌ์ธํŠธ๋ฅผ `output_dir`์— ์ •์˜๋œ ๋””๋ ‰ํ† ๋ฆฌ์— ํ‘ธ์‹œํ•ฉ๋‹ˆ๋‹ค(๋ชจ๋ธ ๋ฆฌํฌ์ง€ํ† ๋ฆฌ์—์„œ ํด๋”๋‹น ํ•˜๋‚˜์˜ ์ฒดํฌํฌ์ธํŠธ๋ฅผ ๋ณผ ์ˆ˜ ์žˆ์Šต๋‹ˆ๋‹ค).
์ฒดํฌํฌ์ธํŠธ์—์„œ ํ›ˆ๋ จ์„ ์žฌ๊ฐœํ•  ๋•Œ, [`Trainer`]๋Š” ์ฒดํฌํฌ์ธํŠธ๊ฐ€ ์ €์žฅ๋  ๋•Œ์™€ ๋™์ผํ•œ Python, NumPy ๋ฐ PyTorch RNG ์ƒํƒœ๋ฅผ ์œ ์ง€ํ•˜๋ ค๊ณ  ํ•ฉ๋‹ˆ๋‹ค. ํ•˜์ง€๋งŒ PyTorch๋Š” ๊ธฐ๋ณธ ์„ค์ •์œผ๋กœ '์ผ๊ด€๋œ ๊ฒฐ๊ณผ๋ฅผ ๋ณด์žฅํ•˜์ง€ ์•Š์Œ'์œผ๋กœ ๋งŽ์ด ๋˜์–ด์žˆ๊ธฐ ๋•Œ๋ฌธ์—, RNG ์ƒํƒœ๊ฐ€ ๋™์ผํ•  ๊ฒƒ์ด๋ผ๊ณ  ๋ณด์žฅํ•  ์ˆ˜ ์—†์Šต๋‹ˆ๋‹ค. ๋”ฐ๋ผ์„œ, ์ผ๊ด€๋œ ๊ฒฐ๊ณผ๊ฐ€ ๋ณด์žฅ๋˜๋„๋ก ํ™œ์„ฑํ™” ํ•˜๋ ค๋ฉด, [๋žœ๋ค์„ฑ ์ œ์–ด](https://pytorch.org/docs/stable/notes/randomness#controlling-sources-of-randomness) ๊ฐ€์ด๋“œ๋ฅผ ์ฐธ๊ณ ํ•˜์—ฌ ํ›ˆ๋ จ์„ ์™„์ „ํžˆ ์ผ๊ด€๋œ ๊ฒฐ๊ณผ๋ฅผ ๋ณด์žฅ ๋ฐ›๋„๋ก ๋งŒ๋“ค๊ธฐ ์œ„ํ•ด ํ™œ์„ฑํ™”ํ•  ์ˆ˜ ์žˆ๋Š” ํ•ญ๋ชฉ์„ ํ™•์ธํ•˜์„ธ์š”. ๋‹ค๋งŒ, ํŠน์ • ์„ค์ •์„ ๊ฒฐ์ •์ ์œผ๋กœ ๋งŒ๋“ค๋ฉด ํ›ˆ๋ จ์ด ๋А๋ ค์งˆ ์ˆ˜ ์žˆ์Šต๋‹ˆ๋‹ค.
## Trainer ๋งž์ถค ์„ค์ • [[customize-the-trainer]]
[`Trainer`] ํด๋ž˜์Šค๋Š” ์ ‘๊ทผ์„ฑ๊ณผ ์šฉ์ด์„ฑ์„ ์—ผ๋‘์— ๋‘๊ณ  ์„ค๊ณ„๋˜์—ˆ์ง€๋งŒ, ๋” ๋‹ค์–‘ํ•œ ๊ธฐ๋Šฅ์„ ์›ํ•˜๋Š” ์‚ฌ์šฉ์ž๋“ค์„ ์œ„ํ•ด ๋‹ค์–‘ํ•œ ๋งž์ถค ์„ค์ • ์˜ต์…˜์„ ์ œ๊ณตํ•ฉ๋‹ˆ๋‹ค. [`Trainer`]์˜ ๋งŽ์€ ๋ฉ”์†Œ๋“œ๋Š” ์„œ๋ธŒํด๋ž˜์Šคํ™” ๋ฐ ์˜ค๋ฒ„๋ผ์ด๋“œํ•˜์—ฌ ์›ํ•˜๋Š” ๊ธฐ๋Šฅ์„ ์ œ๊ณตํ•  ์ˆ˜ ์žˆ์œผ๋ฉฐ, ์ด๋ฅผ ํ†ตํ•ด ์ „์ฒด ํ›ˆ๋ จ ๋ฃจํ”„๋ฅผ ๋‹ค์‹œ ์ž‘์„ฑํ•  ํ•„์š” ์—†์ด ์›ํ•˜๋Š” ๊ธฐ๋Šฅ์„ ์ถ”๊ฐ€ํ•  ์ˆ˜ ์žˆ์Šต๋‹ˆ๋‹ค. ์ด๋Ÿฌํ•œ ๋ฉ”์†Œ๋“œ์—๋Š” ๋‹ค์Œ์ด ํฌํ•จ๋ฉ๋‹ˆ๋‹ค:
* [`~Trainer.get_train_dataloader`]๋Š” ํ›ˆ๋ จ ๋ฐ์ดํ„ฐ๋กœ๋”๋ฅผ ์ƒ์„ฑํ•ฉ๋‹ˆ๋‹ค.
* [`~Trainer.get_eval_dataloader`]๋Š” ํ‰๊ฐ€ ๋ฐ์ดํ„ฐ๋กœ๋”๋ฅผ ์ƒ์„ฑํ•ฉ๋‹ˆ๋‹ค.
* [`~Trainer.get_test_dataloader`]๋Š” ํ…Œ์ŠคํŠธ ๋ฐ์ดํ„ฐ๋กœ๋”๋ฅผ ์ƒ์„ฑํ•ฉ๋‹ˆ๋‹ค.
* [`~Trainer.log`]๋Š” ํ›ˆ๋ จ์„ ๋ชจ๋‹ˆํ„ฐ๋งํ•˜๋Š” ๋‹ค์–‘ํ•œ ๊ฐ์ฒด์— ๋Œ€ํ•œ ์ •๋ณด๋ฅผ ๋กœ๊ทธ๋กœ ๋‚จ๊น๋‹ˆ๋‹ค.
* [`~Trainer.create_optimizer_and_scheduler`]๋Š” `__init__`์—์„œ ์ „๋‹ฌ๋˜์ง€ ์•Š์€ ๊ฒฝ์šฐ ์˜ตํ‹ฐ๋งˆ์ด์ €์™€ ํ•™์Šต๋ฅ  ์Šค์ผ€์ค„๋Ÿฌ๋ฅผ ์ƒ์„ฑํ•ฉ๋‹ˆ๋‹ค. ์ด๋“ค์€ ๊ฐ๊ฐ [`~Trainer.create_optimizer`] ๋ฐ [`~Trainer.create_scheduler`]๋กœ ๋ณ„๋„๋กœ ๋งž์ถค ์„ค์ • ํ•  ์ˆ˜ ์žˆ์Šต๋‹ˆ๋‹ค.
* [`~Trainer.compute_loss`]๋Š” ํ›ˆ๋ จ ์ž…๋ ฅ ๋ฐฐ์น˜์— ๋Œ€ํ•œ ์†์‹ค์„ ๊ณ„์‚ฐํ•ฉ๋‹ˆ๋‹ค.
* [`~Trainer.training_step`]๋Š” ํ›ˆ๋ จ ๋‹จ๊ณ„๋ฅผ ์ˆ˜ํ–‰ํ•ฉ๋‹ˆ๋‹ค.
* [`~Trainer.prediction_step`]๋Š” ์˜ˆ์ธก ๋ฐ ํ…Œ์ŠคํŠธ ๋‹จ๊ณ„๋ฅผ ์ˆ˜ํ–‰ํ•ฉ๋‹ˆ๋‹ค.
* [`~Trainer.evaluate`]๋Š” ๋ชจ๋ธ์„ ํ‰๊ฐ€ํ•˜๊ณ  ํ‰๊ฐ€ ์ง€ํ‘œ์„ ๋ฐ˜ํ™˜ํ•ฉ๋‹ˆ๋‹ค.
* [`~Trainer.predict`]๋Š” ํ…Œ์ŠคํŠธ ์„ธํŠธ์— ๋Œ€ํ•œ ์˜ˆ์ธก(๋ ˆ์ด๋ธ”์ด ์žˆ๋Š” ๊ฒฝ์šฐ ์ง€ํ‘œ ํฌํ•จ)์„ ์ˆ˜ํ–‰ํ•ฉ๋‹ˆ๋‹ค.
์˜ˆ๋ฅผ ๋“ค์–ด, [`~Trainer.compute_loss`] ๋ฉ”์†Œ๋“œ๋ฅผ ๋งž์ถค ์„ค์ •ํ•˜์—ฌ ๊ฐ€์ค‘ ์†์‹ค์„ ์‚ฌ์šฉํ•˜๋ ค๋Š” ๊ฒฝ์šฐ:
```py
from torch import nn
from transformers import Trainer
class CustomTrainer(Trainer):
def compute_loss(self,
model, inputs, return_outputs=False):
labels = inputs.pop("labels")
# ์ˆœ๋ฐฉํ–ฅ ์ „ํŒŒ
outputs = model(**inputs)
logits = outputs.get("logits")
# ์„œ๋กœ ๋‹ค๋ฅธ ๊ฐ€์ค‘์น˜๋กœ 3๊ฐœ์˜ ๋ ˆ์ด๋ธ”์— ๋Œ€ํ•œ ์‚ฌ์šฉ์ž ์ •์˜ ์†์‹ค์„ ๊ณ„์‚ฐ
loss_fct = nn.CrossEntropyLoss(weight=torch.tensor([1.0, 2.0, 3.0], device=model.device))
loss = loss_fct(logits.view(-1, self.model.config.num_labels), labels.view(-1))
return (loss, outputs) if return_outputs else loss
```
### ์ฝœ๋ฐฑ [[callbacks]]
[`Trainer`]๋ฅผ ๋งž์ถค ์„ค์ •ํ•˜๋Š” ๋˜ ๋‹ค๋ฅธ ๋ฐฉ๋ฒ•์€ [์ฝœ๋ฐฑ](callbacks)์„ ์‚ฌ์šฉํ•˜๋Š” ๊ฒƒ์ž…๋‹ˆ๋‹ค. ์ฝœ๋ฐฑ์€ ํ›ˆ๋ จ ๋ฃจํ”„์—์„œ *๋ณ€ํ™”๋ฅผ ์ฃผ์ง€ ์•Š์Šต๋‹ˆ๋‹ค*. ํ›ˆ๋ จ ๋ฃจํ”„์˜ ์ƒํƒœ๋ฅผ ๊ฒ€์‚ฌํ•œ ํ›„ ์ƒํƒœ์— ๋”ฐ๋ผ ์ผ๋ถ€ ์ž‘์—…(์กฐ๊ธฐ ์ข…๋ฃŒ, ๊ฒฐ๊ณผ ๋กœ๊ทธ ๋“ฑ)์„ ์‹คํ–‰ํ•ฉ๋‹ˆ๋‹ค. ์ฆ‰, ์ฝœ๋ฐฑ์€ ์‚ฌ์šฉ์ž ์ •์˜ ์†์‹ค ํ•จ์ˆ˜์™€ ๊ฐ™์€ ๊ฒƒ์„ ๊ตฌํ˜„ํ•˜๋Š” ๋ฐ ์‚ฌ์šฉํ•  ์ˆ˜ ์—†์œผ๋ฉฐ, ์ด๋ฅผ ์œ„ํ•ด์„œ๋Š” [`~Trainer.compute_loss`] ๋ฉ”์†Œ๋“œ๋ฅผ ์„œ๋ธŒํด๋ž˜์Šคํ™”ํ•˜๊ณ  ์˜ค๋ฒ„๋ผ์ด๋“œํ•ด์•ผ ํ•ฉ๋‹ˆ๋‹ค.
์˜ˆ๋ฅผ ๋“ค์–ด, ํ›ˆ๋ จ ๋ฃจํ”„์— 10๋‹จ๊ณ„ ํ›„ ์กฐ๊ธฐ ์ข…๋ฃŒ ์ฝœ๋ฐฑ์„ ์ถ”๊ฐ€ํ•˜๋ ค๋ฉด ๋‹ค์Œ๊ณผ ๊ฐ™์ด ํ•ฉ๋‹ˆ๋‹ค.
```py
from transformers import TrainerCallback
class EarlyStoppingCallback(TrainerCallback):
def __init__(self, num_steps=10):
self.num_steps = num_steps
def on_step_end(self, args, state, control, **kwargs):
if state.global_step >= self.num_steps:
return {"should_training_stop": True}
else:
return {}
```
๊ทธ๋Ÿฐ ๋‹ค์Œ, ์ด๋ฅผ [`Trainer`]์˜ `callback` ๋งค๊ฐœ๋ณ€์ˆ˜์— ์ „๋‹ฌํ•ฉ๋‹ˆ๋‹ค.
```py
from transformers import Trainer
trainer = Trainer(
model=model,
args=training_args,
train_dataset=dataset["train"],
eval_dataset=dataset["test"],
tokenizer=tokenizer,
data_collator=data_collator,
compute_metrics=compute_metrics,
callbacks=[EarlyStoppingCallback()],
)
```
## ๋กœ๊น… [[logging]]
<Tip>
๋กœ๊น… API์— ๋Œ€ํ•œ ์ž์„ธํ•œ ๋‚ด์šฉ์€ [๋กœ๊น…](./main_classes/logging) API ๋ ˆํผ๋Ÿฐ์Šค๋ฅผ ํ™•์ธํ•˜์„ธ์š”.
</Tip>
[`Trainer`]๋Š” ๊ธฐ๋ณธ์ ์œผ๋กœ `logging.INFO`๋กœ ์„ค์ •๋˜์–ด ์žˆ์–ด ์˜ค๋ฅ˜, ๊ฒฝ๊ณ  ๋ฐ ๊ธฐํƒ€ ๊ธฐ๋ณธ ์ •๋ณด๋ฅผ ๋ณด๊ณ ํ•ฉ๋‹ˆ๋‹ค. ๋ถ„์‚ฐ ํ™˜๊ฒฝ์—์„œ๋Š” [`Trainer`] ๋ณต์ œ๋ณธ์ด `logging.WARNING`์œผ๋กœ ์„ค์ •๋˜์–ด ์˜ค๋ฅ˜์™€ ๊ฒฝ๊ณ ๋งŒ ๋ณด๊ณ ํ•ฉ๋‹ˆ๋‹ค. [`TrainingArguments`]์˜ [`log_level`](https://huggingface.co/docs/transformers/main_classes/trainer#transformers.TrainingArguments.log_level) ๋ฐ [`log_level_replica`](https://huggingface.co/docs/transformers/main_classes/trainer#transformers.TrainingArguments.log_level_replica) ๋งค๊ฐœ๋ณ€์ˆ˜๋กœ ๋กœ๊ทธ ๋ ˆ๋ฒจ์„ ๋ณ€๊ฒฝํ•  ์ˆ˜ ์žˆ์Šต๋‹ˆ๋‹ค.
๊ฐ ๋…ธ๋“œ์˜ ๋กœ๊ทธ ๋ ˆ๋ฒจ ์„ค์ •์„ ๊ตฌ์„ฑํ•˜๋ ค๋ฉด [`log_on_each_node`](https://huggingface.co/docs/transformers/main/en/main_classes/trainer#transformers.TrainingArguments.log_on_each_node) ๋งค๊ฐœ๋ณ€์ˆ˜๋ฅผ ์‚ฌ์šฉํ•˜์—ฌ ๊ฐ ๋…ธ๋“œ์—์„œ ๋กœ๊ทธ ๋ ˆ๋ฒจ์„ ์‚ฌ์šฉํ• ์ง€ ์•„๋‹ˆ๋ฉด ์ฃผ ๋…ธ๋“œ์—์„œ๋งŒ ์‚ฌ์šฉํ• ์ง€ ๊ฒฐ์ •ํ•˜์„ธ์š”.
<Tip>
[`Trainer`]๋Š” [`Trainer.__init__`] ๋ฉ”์†Œ๋“œ์—์„œ ๊ฐ ๋…ธ๋“œ์— ๋Œ€ํ•ด ๋กœ๊ทธ ๋ ˆ๋ฒจ์„ ๋ณ„๋„๋กœ ์„ค์ •ํ•˜๋ฏ€๋กœ, ๋‹ค๋ฅธ Transformers ๊ธฐ๋Šฅ์„ ์‚ฌ์šฉํ•  ๊ฒฝ์šฐ [`Trainer`] ๊ฐ์ฒด๋ฅผ ์ƒ์„ฑํ•˜๊ธฐ ์ „์— ์ด๋ฅผ ๋ฏธ๋ฆฌ ์„ค์ •ํ•˜๋Š” ๊ฒƒ์ด ์ข‹์Šต๋‹ˆ๋‹ค.
</Tip>
์˜ˆ๋ฅผ ๋“ค์–ด, ๋ฉ”์ธ ์ฝ”๋“œ์™€ ๋ชจ๋“ˆ์„ ๊ฐ ๋…ธ๋“œ์— ๋”ฐ๋ผ ๋™์ผํ•œ ๋กœ๊ทธ ๋ ˆ๋ฒจ์„ ์‚ฌ์šฉํ•˜๋„๋ก ์„ค์ •ํ•˜๋ ค๋ฉด ๋‹ค์Œ๊ณผ ๊ฐ™์ด ํ•ฉ๋‹ˆ๋‹ค.
```py
logger = logging.getLogger(__name__)
logging.basicConfig(
format="%(asctime)s - %(levelname)s - %(name)s - %(message)s",
datefmt="%m/%d/%Y %H:%M:%S",
handlers=[logging.StreamHandler(sys.stdout)],
)
log_level = training_args.get_process_log_level()
logger.setLevel(log_level)
datasets.utils.logging.set_verbosity(log_level)
transformers.utils.logging.set_verbosity(log_level)
trainer = Trainer(...)
```
๊ฐ ๋…ธ๋“œ์—์„œ ๊ธฐ๋ก๋  ๋‚ด์šฉ์„ ๊ตฌ์„ฑํ•˜๊ธฐ ์œ„ํ•ด `log_level`๊ณผ `log_level_replica`๋ฅผ ๋‹ค์–‘ํ•œ ์กฐํ•ฉ์œผ๋กœ ์‚ฌ์šฉํ•ด๋ณด์„ธ์š”.
<hfoptions id="logging">
<hfoption id="single node">
```bash
my_app.py ... --log_level warning --log_level_replica error
```
</hfoption>
<hfoption id="multi-node">
๋ฉ€ํ‹ฐ ๋…ธ๋“œ ํ™˜๊ฒฝ์—์„œ๋Š” `log_on_each_node 0` ๋งค๊ฐœ๋ณ€์ˆ˜๋ฅผ ์ถ”๊ฐ€ํ•ฉ๋‹ˆ๋‹ค.
```bash
my_app.py ... --log_level warning --log_level_replica error --log_on_each_node 0
# ์˜ค๋ฅ˜๋งŒ ๋ณด๊ณ ํ•˜๋„๋ก ์„ค์ •
my_app.py ... --log_level error --log_level_replica error --log_on_each_node 0
```
</hfoption>
</hfoptions>
## NEFTune [[neftune]]
[NEFTune](https://hf.co/papers/2310.05914)์€ ํ›ˆ๋ จ ์ค‘ ์ž„๋ฒ ๋”ฉ ๋ฒกํ„ฐ์— ๋…ธ์ด์ฆˆ๋ฅผ ์ถ”๊ฐ€ํ•˜์—ฌ ์„ฑ๋Šฅ์„ ํ–ฅ์ƒ์‹œํ‚ฌ ์ˆ˜ ์žˆ๋Š” ๊ธฐ์ˆ ์ž…๋‹ˆ๋‹ค. [`Trainer`]์—์„œ ์ด๋ฅผ ํ™œ์„ฑํ™”ํ•˜๋ ค๋ฉด [`TrainingArguments`]์˜ `neftune_noise_alpha` ๋งค๊ฐœ๋ณ€์ˆ˜๋ฅผ ์„ค์ •ํ•˜์—ฌ ๋…ธ์ด์ฆˆ์˜ ์–‘์„ ์กฐ์ ˆํ•ฉ๋‹ˆ๋‹ค.
```py
from transformers import TrainingArguments, Trainer
training_args = TrainingArguments(..., neftune_noise_alpha=0.1)
trainer = Trainer(..., args=training_args)
```
NEFTune์€ ์˜ˆ์ƒ์น˜ ๋ชปํ•œ ๋™์ž‘์„ ํ”ผํ•  ๋ชฉ์ ์œผ๋กœ ์ฒ˜์Œ ์ž„๋ฒ ๋”ฉ ๋ ˆ์ด์–ด๋กœ ๋ณต์›ํ•˜๊ธฐ ์œ„ํ•ด ํ›ˆ๋ จ ํ›„ ๋น„ํ™œ์„ฑํ™” ๋ฉ๋‹ˆ๋‹ค.
## GaLore [[galore]]
Gradient Low-Rank Projection (GaLore)์€ ์ „์ฒด ๋งค๊ฐœ๋ณ€์ˆ˜๋ฅผ ํ•™์Šตํ•˜๋ฉด์„œ๋„ LoRA์™€ ๊ฐ™์€ ์ผ๋ฐ˜์ ์ธ ์ €๊ณ„์ˆ˜ ์ ์‘ ๋ฐฉ๋ฒ•๋ณด๋‹ค ๋” ๋ฉ”๋ชจ๋ฆฌ ํšจ์œจ์ ์ธ ์ €๊ณ„์ˆ˜ ํ•™์Šต ์ „๋žต์ž…๋‹ˆ๋‹ค.
๋จผ์ € GaLore ๊ณต์‹ ๋ฆฌํฌ์ง€ํ† ๋ฆฌ๋ฅผ ์„ค์น˜ํ•ฉ๋‹ˆ๋‹ค:
```bash
pip install galore-torch
```
๊ทธ๋Ÿฐ ๋‹ค์Œ `optim`์— `["galore_adamw", "galore_adafactor", "galore_adamw_8bit"]` ์ค‘ ํ•˜๋‚˜์™€ ํ•จ๊ป˜ `optim_target_modules`๋ฅผ ์ถ”๊ฐ€ํ•ฉ๋‹ˆ๋‹ค. ์ด๋Š” ์ ์šฉํ•˜๋ ค๋Š” ๋Œ€์ƒ ๋ชจ๋“ˆ ์ด๋ฆ„์— ํ•ด๋‹นํ•˜๋Š” ๋ฌธ์ž์—ด, ์ •๊ทœ ํ‘œํ˜„์‹ ๋˜๋Š” ์ „์ฒด ๊ฒฝ๋กœ์˜ ๋ชฉ๋ก์ผ ์ˆ˜ ์žˆ์Šต๋‹ˆ๋‹ค. ์•„๋ž˜๋Š” end-to-end ์˜ˆ์ œ ์Šคํฌ๋ฆฝํŠธ์ž…๋‹ˆ๋‹ค(ํ•„์š”ํ•œ ๊ฒฝ์šฐ `pip install trl datasets`๋ฅผ ์‹คํ–‰):
```python
import torch
import datasets
import trl
from transformers import TrainingArguments, AutoConfig, AutoTokenizer, AutoModelForCausalLM
train_dataset = datasets.load_dataset('imdb', split='train')
args = TrainingArguments(
output_dir="./test-galore",
max_steps=100,
per_device_train_batch_size=2,
optim="galore_adamw",
optim_target_modules=["attn", "mlp"]
)
model_id = "google/gemma-2b"
config = AutoConfig.from_pretrained(model_id)
tokenizer = AutoTokenizer.from_pretrained(model_id)
model = AutoModelForCausalLM.from_config(config).to(0)
trainer = trl.SFTTrainer(
model=model,
args=args,
train_dataset=train_dataset,
dataset_text_field='text',
max_seq_length=512,
)
trainer.train()
```
GaLore๊ฐ€ ์ง€์›ํ•˜๋Š” ์ถ”๊ฐ€ ๋งค๊ฐœ๋ณ€์ˆ˜๋ฅผ ์ „๋‹ฌํ•˜๋ ค๋ฉด `optim_args`๋ฅผ ์„ค์ •ํ•ฉ๋‹ˆ๋‹ค. ์˜ˆ๋ฅผ ๋“ค์–ด:
```python
import torch
import datasets
import trl
from transformers import TrainingArguments, AutoConfig, AutoTokenizer, AutoModelForCausalLM
train_dataset = datasets.load_dataset('imdb', split='train')
args = TrainingArguments(
output_dir="./test-galore",
max_steps=100,
per_device_train_batch_size=2,
optim="galore_adamw",
optim_target_modules=["attn", "mlp"],
optim_args="rank=64, update_proj_gap=100, scale=0.10",
)
model_id = "google/gemma-2b"
config = AutoConfig.from_pretrained(model_id)
tokenizer = AutoTokenizer.from_pretrained(model_id)
model = AutoModelForCausalLM.from_config(config).to(0)
trainer = trl.SFTTrainer(
model=model,
args=args,
train_dataset=train_dataset,
dataset_text_field='text',
max_seq_length=512,
)
trainer.train()
```
ํ•ด๋‹น ๋ฐฉ๋ฒ•์— ๋Œ€ํ•œ ์ž์„ธํ•œ ๋‚ด์šฉ์€ [์›๋ณธ ๋ฆฌํฌ์ง€ํ† ๋ฆฌ](https://github.com/jiaweizzhao/GaLore) ๋˜๋Š” [๋…ผ๋ฌธ](https://arxiv.org/abs/2403.03507)์„ ์ฐธ๊ณ ํ•˜์„ธ์š”.
ํ˜„์žฌ GaLore ๋ ˆ์ด์–ด๋กœ ๊ฐ„์ฃผ๋˜๋Š” Linear ๋ ˆ์ด์–ด๋งŒ ํ›ˆ๋ จ ํ• ์ˆ˜ ์žˆ์œผ๋ฉฐ, ์ €๊ณ„์ˆ˜ ๋ถ„ํ•ด๋ฅผ ์‚ฌ์šฉํ•˜์—ฌ ํ›ˆ๋ จ๋˜๊ณ  ๋‚˜๋จธ์ง€ ๋ ˆ์ด์–ด๋Š” ๊ธฐ์กด ๋ฐฉ์‹์œผ๋กœ ์ตœ์ ํ™”๋ฉ๋‹ˆ๋‹ค.
ํ›ˆ๋ จ ์‹œ์ž‘ ์ „์— ์‹œ๊ฐ„์ด ์•ฝ๊ฐ„ ๊ฑธ๋ฆด ์ˆ˜ ์žˆ์Šต๋‹ˆ๋‹ค(NVIDIA A100์—์„œ 2B ๋ชจ๋ธ์˜ ๊ฒฝ์šฐ ์•ฝ 3๋ถ„), ํ•˜์ง€๋งŒ ์ดํ›„ ํ›ˆ๋ จ์€ ์›ํ™œํ•˜๊ฒŒ ์ง„ํ–‰๋ฉ๋‹ˆ๋‹ค.
๋‹ค์Œ๊ณผ ๊ฐ™์ด ์˜ตํ‹ฐ๋งˆ์ด์ € ์ด๋ฆ„์— `layerwise`๋ฅผ ์ถ”๊ฐ€ํ•˜์—ฌ ๋ ˆ์ด์–ด๋ณ„ ์ตœ์ ํ™”๋ฅผ ์ˆ˜ํ–‰ํ•  ์ˆ˜๋„ ์žˆ์Šต๋‹ˆ๋‹ค:
```python
import torch
import datasets
import trl
from transformers import TrainingArguments, AutoConfig, AutoTokenizer, AutoModelForCausalLM
train_dataset = datasets.load_dataset('imdb', split='train')
args = TrainingArguments(
output_dir="./test-galore",
max_steps=100,
per_device_train_batch_size=2,
optim="galore_adamw_layerwise",
optim_target_modules=["attn", "mlp"]
)
model_id = "google/gemma-2b"
config = AutoConfig.from_pretrained(model_id)
tokenizer = AutoTokenizer.from_pretrained(model_id)
model = AutoModelForCausalLM.from_config(config).to(0)
trainer = trl.SFTTrainer(
model=model,
args=args,
train_dataset=train_dataset,
dataset_text_field='text',
max_seq_length=512,
)
trainer.train()
```
๋ ˆ์ด์–ด๋ณ„ ์ตœ์ ํ™”๋Š” ๋‹ค์†Œ ์‹คํ—˜์ ์ด๋ฉฐ DDP(๋ถ„์‚ฐ ๋ฐ์ดํ„ฐ ๋ณ‘๋ ฌ)๋ฅผ ์ง€์›ํ•˜์ง€ ์•Š์œผ๋ฏ€๋กœ, ๋‹จ์ผ GPU์—์„œ๋งŒ ํ›ˆ๋ จ ์Šคํฌ๋ฆฝํŠธ๋ฅผ ์‹คํ–‰ํ•  ์ˆ˜ ์žˆ์Šต๋‹ˆ๋‹ค. ์ž์„ธํ•œ ๋‚ด์šฉ์€ [์ด ๋ฌธ์„œ๋ฅผ](https://github.com/jiaweizzhao/GaLore?tab=readme-ov-file#train-7b-model-with-a-single-gpu-with-24gb-memory)์„ ์ฐธ์กฐํ•˜์„ธ์š”. gradient clipping, DeepSpeed ๋“ฑ ๋‹ค๋ฅธ ๊ธฐ๋Šฅ์€ ๊ธฐ๋ณธ์ ์œผ๋กœ ์ง€์›๋˜์ง€ ์•Š์„ ์ˆ˜ ์žˆ์Šต๋‹ˆ๋‹ค. ์ด๋Ÿฌํ•œ ๋ฌธ์ œ๊ฐ€ ๋ฐœ์ƒํ•˜๋ฉด [GitHub์— ์ด์Šˆ๋ฅผ ์˜ฌ๋ ค์ฃผ์„ธ์š”](https://github.com/huggingface/transformers/issues).
## LOMO ์˜ตํ‹ฐ๋งˆ์ด์ € [[lomo-optimizer]]
LOMO ์˜ตํ‹ฐ๋งˆ์ด์ €๋Š” [์ œํ•œ๋œ ์ž์›์œผ๋กœ ๋Œ€ํ˜• ์–ธ์–ด ๋ชจ๋ธ์˜ ์ „์ฒด ๋งค๊ฐœ๋ณ€์ˆ˜ ๋ฏธ์„ธ ์กฐ์ •](https://hf.co/papers/2306.09782)๊ณผ [์ ์‘ํ˜• ํ•™์Šต๋ฅ ์„ ํ†ตํ•œ ์ €๋ฉ”๋ชจ๋ฆฌ ์ตœ์ ํ™”(AdaLomo)](https://hf.co/papers/2310.10195)์—์„œ ๋„์ž…๋˜์—ˆ์Šต๋‹ˆ๋‹ค.
์ด๋“ค์€ ๋ชจ๋‘ ํšจ์œจ์ ์ธ ์ „์ฒด ๋งค๊ฐœ๋ณ€์ˆ˜ ๋ฏธ์„ธ ์กฐ์ • ๋ฐฉ๋ฒ•์œผ๋กœ ๊ตฌ์„ฑ๋˜์–ด ์žˆ์Šต๋‹ˆ๋‹ค. ์ด๋Ÿฌํ•œ ์˜ตํ‹ฐ๋งˆ์ด์ €๋“ค์€ ๋ฉ”๋ชจ๋ฆฌ ์‚ฌ์šฉ๋Ÿ‰์„ ์ค„์ด๊ธฐ ์œ„ํ•ด ๊ทธ๋ ˆ์ด๋””์–ธํŠธ ๊ณ„์‚ฐ๊ณผ ๋งค๊ฐœ๋ณ€์ˆ˜ ์—…๋ฐ์ดํŠธ๋ฅผ ํ•˜๋‚˜์˜ ๋‹จ๊ณ„๋กœ ์œตํ•ฉํ•ฉ๋‹ˆ๋‹ค. LOMO์—์„œ ์ง€์›๋˜๋Š” ์˜ตํ‹ฐ๋งˆ์ด์ €๋Š” `"lomo"`์™€ `"adalomo"`์ž…๋‹ˆ๋‹ค. ๋จผ์ € pypi์—์„œ `pip install lomo-optim`๋ฅผ ํ†ตํ•ด `lomo`๋ฅผ ์„ค์น˜ํ•˜๊ฑฐ๋‚˜, GitHub ์†Œ์Šค์—์„œ `pip install git+https://github.com/OpenLMLab/LOMO.git`๋กœ ์„ค์น˜ํ•˜์„ธ์š”.
<Tip>
์ €์ž์— ๋”ฐ๋ฅด๋ฉด, `grad_norm` ์—†์ด `AdaLomo`๋ฅผ ์‚ฌ์šฉํ•˜๋Š” ๊ฒƒ์ด ๋” ๋‚˜์€ ์„ฑ๋Šฅ๊ณผ ๋†’์€ ์ฒ˜๋ฆฌ๋Ÿ‰์„ ์ œ๊ณตํ•œ๋‹ค๊ณ  ํ•ฉ๋‹ˆ๋‹ค.
</Tip>
๋‹ค์Œ์€ IMDB ๋ฐ์ดํ„ฐ์…‹์—์„œ [google/gemma-2b](https://huggingface.co/google/gemma-2b)๋ฅผ ์ตœ๋Œ€ ์ •๋ฐ€๋„๋กœ ๋ฏธ์„ธ ์กฐ์ •ํ•˜๋Š” ๊ฐ„๋‹จํ•œ ์Šคํฌ๋ฆฝํŠธ์ž…๋‹ˆ๋‹ค:
```python
import torch
import datasets
from transformers import TrainingArguments, AutoTokenizer, AutoModelForCausalLM
import trl
train_dataset = datasets.load_dataset('imdb', split='train')
args = TrainingArguments(
output_dir="./test-lomo",
max_steps=1000,
per_device_train_batch_size=4,
optim="adalomo",
gradient_checkpointing=True,
logging_strategy="steps",
logging_steps=1,
learning_rate=2e-6,
save_strategy="no",
run_name="lomo-imdb",
)
model_id = "google/gemma-2b"
tokenizer = AutoTokenizer.from_pretrained(model_id)
model = AutoModelForCausalLM.from_pretrained(model_id, low_cpu_mem_usage=True).to(0)
trainer = trl.SFTTrainer(
model=model,
args=args,
train_dataset=train_dataset,
dataset_text_field='text',
max_seq_length=1024,
)
trainer.train()
```
## Accelerate์™€ Trainer [[accelerate-and-trainer]]
[`Trainer`] ํด๋ž˜์Šค๋Š” [Accelerate](https://hf.co/docs/accelerate)๋กœ ๊ตฌ๋™๋˜๋ฉฐ, ์ด๋Š” [FullyShardedDataParallel (FSDP)](https://pytorch.org/blog/introducing-pytorch-fully-sharded-data-parallel-api/) ๋ฐ [DeepSpeed](https://www.deepspeed.ai/)์™€ ๊ฐ™์€ ํ†ตํ•ฉ์„ ์ง€์›ํ•˜๋Š” ๋ถ„์‚ฐ ํ™˜๊ฒฝ์—์„œ PyTorch ๋ชจ๋ธ์„ ์‰ฝ๊ฒŒ ํ›ˆ๋ จํ•  ์ˆ˜ ์žˆ๋Š” ๋ผ์ด๋ธŒ๋Ÿฌ๋ฆฌ์ž…๋‹ˆ๋‹ค.
<Tip>
FSDP ์ƒค๋”ฉ ์ „๋žต, CPU ์˜คํ”„๋กœ๋“œ ๋ฐ [`Trainer`]์™€ ํ•จ๊ป˜ ์‚ฌ์šฉํ•  ์ˆ˜ ์žˆ๋Š” ๋” ๋งŽ์€ ๊ธฐ๋Šฅ์„ ์•Œ์•„๋ณด๋ ค๋ฉด [Fully Sharded Data Parallel](fsdp) ๊ฐ€์ด๋“œ๋ฅผ ํ™•์ธํ•˜์„ธ์š”.
</Tip>
[`Trainer`]์™€ Accelerate๋ฅผ ์‚ฌ์šฉํ•˜๋ ค๋ฉด [`accelerate.config`](https://huggingface.co/docs/accelerate/package_reference/cli#accelerate-config) ๋ช…๋ น์„ ์‹คํ–‰ํ•˜์—ฌ ํ›ˆ๋ จ ํ™˜๊ฒฝ์„ ์„ค์ •ํ•˜์„ธ์š”. ์ด ๋ช…๋ น์€ ํ›ˆ๋ จ ์Šคํฌ๋ฆฝํŠธ๋ฅผ ์‹คํ–‰ํ•  ๋•Œ ์‚ฌ์šฉํ•  `config_file.yaml`์„ ์ƒ์„ฑํ•ฉ๋‹ˆ๋‹ค. ์˜ˆ๋ฅผ ๋“ค์–ด, ๋‹ค์Œ ์˜ˆ์‹œ๋Š” ์„ค์ •ํ•  ์ˆ˜ ์žˆ๋Š” ์ผ๋ถ€ ๊ตฌ์„ฑ ์˜ˆ์ž…๋‹ˆ๋‹ค.
<hfoptions id="config">
<hfoption id="DistributedDataParallel">
```yml
compute_environment: LOCAL_MACHINE
distributed_type: MULTI_GPU
downcast_bf16: 'no'
gpu_ids: all
machine_rank: 0 # ๋…ธ๋“œ์— ๋”ฐ๋ผ ์ˆœ์œ„๋ฅผ ๋ณ€๊ฒฝํ•˜์„ธ์š”
main_process_ip: 192.168.20.1
main_process_port: 9898
main_training_function: main
mixed_precision: fp16
num_machines: 2
num_processes: 8
rdzv_backend: static
same_network: true
tpu_env: []
tpu_use_cluster: false
tpu_use_sudo: false
use_cpu: false
```
</hfoption>
<hfoption id="FSDP">
```yml
compute_environment: LOCAL_MACHINE
distributed_type: FSDP
downcast_bf16: 'no'
fsdp_config:
fsdp_auto_wrap_policy: TRANSFORMER_BASED_WRAP
fsdp_backward_prefetch_policy: BACKWARD_PRE
fsdp_forward_prefetch: true
fsdp_offload_params: false
fsdp_sharding_strategy: 1
fsdp_state_dict_type: FULL_STATE_DICT
fsdp_sync_module_states: true
fsdp_transformer_layer_cls_to_wrap: BertLayer
fsdp_use_orig_params: true
machine_rank: 0
main_training_function: main
mixed_precision: bf16
num_machines: 1
num_processes: 2
rdzv_backend: static
same_network: true
tpu_env: []
tpu_use_cluster: false
tpu_use_sudo: false
use_cpu: false
```
</hfoption>
<hfoption id="DeepSpeed">
```yml
compute_environment: LOCAL_MACHINE
deepspeed_config:
deepspeed_config_file: /home/user/configs/ds_zero3_config.json
zero3_init_flag: true
distributed_type: DEEPSPEED
downcast_bf16: 'no'
machine_rank: 0
main_training_function: main
num_machines: 1
num_processes: 4
rdzv_backend: static
same_network: true
tpu_env: []
tpu_use_cluster: false
tpu_use_sudo: false
use_cpu: false
```
</hfoption>
<hfoption id="DeepSpeed with Accelerate plugin">
```yml
compute_environment: LOCAL_MACHINE
deepspeed_config:
gradient_accumulation_steps: 1
gradient_clipping: 0.7
offload_optimizer_device: cpu
offload_param_device: cpu
zero3_init_flag: true
zero_stage: 2
distributed_type: DEEPSPEED
downcast_bf16: 'no'
machine_rank: 0
main_training_function: main
mixed_precision: bf16
num_machines: 1
num_processes: 4
rdzv_backend: static
same_network: true
tpu_env: []
tpu_use_cluster: false
tpu_use_sudo: false
use_cpu: false
```
</hfoption>
</hfoptions>
[`accelerate_launch`](https://huggingface.co/docs/accelerate/package_reference/cli#accelerate-launch) ๋ช…๋ น์€ Accelerate์™€ [`Trainer`]๋ฅผ ์‚ฌ์šฉํ•˜์—ฌ ๋ถ„์‚ฐ ์‹œ์Šคํ…œ์—์„œ ํ›ˆ๋ จ ์Šคํฌ๋ฆฝํŠธ๋ฅผ ์‹คํ–‰ํ•˜๋Š” ๊ถŒ์žฅ ๋ฐฉ๋ฒ•์ด๋ฉฐ, `config_file.yaml`์— ์ง€์ •๋œ ๋งค๊ฐœ๋ณ€์ˆ˜๋ฅผ ์‚ฌ์šฉํ•ฉ๋‹ˆ๋‹ค. ์ด ํŒŒ์ผ์€ Accelerate ์บ์‹œ ํด๋”์— ์ €์žฅ๋˜๋ฉฐ `accelerate_launch`๋ฅผ ์‹คํ–‰ํ•  ๋•Œ ์ž๋™์œผ๋กœ ๋กœ๋“œ๋ฉ๋‹ˆ๋‹ค.
์˜ˆ๋ฅผ ๋“ค์–ด, FSDP ๊ตฌ์„ฑ์„ ์‚ฌ์šฉํ•˜์—ฌ [run_glue.py](https://github.com/huggingface/transformers/blob/f4db565b695582891e43a5e042e5d318e28f20b8/examples/pytorch/text-classification/run_glue.py#L4) ํ›ˆ๋ จ ์Šคํฌ๋ฆฝํŠธ๋ฅผ ์‹คํ–‰ํ•˜๋ ค๋ฉด ๋‹ค์Œ๊ณผ ๊ฐ™์ด ํ•ฉ๋‹ˆ๋‹ค:
```bash
accelerate launch \
./examples/pytorch/text-classification/run_glue.py \
--model_name_or_path google-bert/bert-base-cased \
--task_name $TASK_NAME \
--do_train \
--do_eval \
--max_seq_length 128 \
--per_device_train_batch_size 16 \
--learning_rate 5e-5 \
--num_train_epochs 3 \
--output_dir /tmp/$TASK_NAME/ \
--overwrite_output_dir
```
`config_file.yaml` ํŒŒ์ผ์˜ ๋งค๊ฐœ๋ณ€์ˆ˜๋ฅผ ์ง์ ‘ ์ง€์ •ํ•  ์ˆ˜๋„ ์žˆ์Šต๋‹ˆ๋‹ค:
```bash
accelerate launch --num_processes=2 \
--use_fsdp \
--mixed_precision=bf16 \
--fsdp_auto_wrap_policy=TRANSFORMER_BASED_WRAP \
--fsdp_transformer_layer_cls_to_wrap="BertLayer" \
--fsdp_sharding_strategy=1 \
--fsdp_state_dict_type=FULL_STATE_DICT \
./examples/pytorch/text-classification/run_glue.py \
--model_name_or_path google-bert/bert-base-cased \
--task_name $TASK_NAME \
--do_train \
--do_eval \
--max_seq_length 128 \
--per_device_train_batch_size 16 \
--learning_rate 5e-5 \
--num_train_epochs 3 \
--output_dir /tmp/$TASK_NAME/ \
--overwrite_output_dir
```
`accelerate_launch`์™€ ์‚ฌ์šฉ์ž ์ •์˜ ๊ตฌ์„ฑ์— ๋Œ€ํ•ด ๋” ์•Œ์•„๋ณด๋ ค๋ฉด [Accelerate ์Šคํฌ๋ฆฝํŠธ ์‹คํ–‰](https://huggingface.co/docs/accelerate/basic_tutorials/launch) ํŠœํ† ๋ฆฌ์–ผ์„ ํ™•์ธํ•˜์„ธ์š”.