PythonProject1 / .venv /transformers /docs /source /ko /model_memory_anatomy.md
DrDavis's picture
Upload folder using huggingface_hub
17c6d62 verified

๋ชจ๋ธ ํ•™์Šต ํ•ด๋ถ€ํ•˜๊ธฐ [[model-training-anatomy]]

๋ชจ๋ธ ํ›ˆ๋ จ ์†๋„์™€ ๋ฉ”๋ชจ๋ฆฌ ํ™œ์šฉ์˜ ํšจ์œจ์„ฑ์„ ํ–ฅ์ƒ์‹œํ‚ค๊ธฐ ์œ„ํ•ด ์ ์šฉํ•  ์ˆ˜ ์žˆ๋Š” ์„ฑ๋Šฅ ์ตœ์ ํ™” ๊ธฐ์ˆ ์„ ์ดํ•ดํ•˜๋ ค๋ฉด GPU๊ฐ€ ํ›ˆ๋ จ ์ค‘์— ์–ด๋–ป๊ฒŒ ํ™œ์šฉ๋˜๋Š”์ง€, ๊ทธ๋ฆฌ๊ณ  ์ˆ˜ํ–‰๋˜๋Š” ์—ฐ์‚ฐ์— ๋”ฐ๋ผ ์—ฐ์‚ฐ ๊ฐ•๋„๊ฐ€ ์–ด๋–ป๊ฒŒ ๋ณ€ํ•˜๋Š”์ง€์— ์ต์ˆ™ํ•ด์ ธ์•ผ ํ•ฉ๋‹ˆ๋‹ค.

๋จผ์ € GPU ํ™œ์šฉ๊ณผ ๋ชจ๋ธ ํ›ˆ๋ จ ์‹คํ–‰์— ๋Œ€ํ•œ ์˜ˆ์‹œ๋ฅผ ์‚ดํŽด๋ณด๊ฒ ์Šต๋‹ˆ๋‹ค. ๋ฐ๋ชจ๋ฅผ ์œ„ํ•ด ๋ช‡๋ช‡ ๋ผ์ด๋ธŒ๋Ÿฌ๋ฆฌ๋ฅผ ์„ค์น˜ํ•ด์•ผ ํ•ฉ๋‹ˆ๋‹ค:

pip install transformers datasets accelerate nvidia-ml-py3

nvidia-ml-py3 ๋ผ์ด๋ธŒ๋Ÿฌ๋ฆฌ๋Š” Python ๋‚ด๋ถ€์—์„œ ๋ชจ๋ธ์˜ ๋ฉ”๋ชจ๋ฆฌ ์‚ฌ์šฉ๋Ÿ‰์„ ๋ชจ๋‹ˆํ„ฐ๋งํ•  ์ˆ˜ ์žˆ๊ฒŒ ํ•ด์ค๋‹ˆ๋‹ค. ํ„ฐ๋ฏธ๋„์˜ nvidia-smi ๋ช…๋ น์–ด์— ์ต์ˆ™ํ•  ์ˆ˜ ์žˆ๋Š”๋ฐ, ์ด ๋ผ์ด๋ธŒ๋Ÿฌ๋ฆฌ๋Š” Python์—์„œ ์ง์ ‘ ๋™์ผํ•œ ์ •๋ณด์— ์ ‘๊ทผํ•  ์ˆ˜ ์žˆ๊ฒŒ ํ•ด์ค๋‹ˆ๋‹ค.

๊ทธ ๋‹ค์Œ, 100๊ณผ 30000 ์‚ฌ์ด์˜ ๋ฌด์ž‘์œ„ ํ† ํฐ ID์™€ ๋ถ„๋ฅ˜๊ธฐ๋ฅผ ์œ„ํ•œ ์ด์ง„ ๋ ˆ์ด๋ธ”์ธ ๋”๋ฏธ ๋ฐ์ดํ„ฐ๋ฅผ ์ƒ์„ฑํ•ฉ๋‹ˆ๋‹ค. ๊ธธ์ด๊ฐ€ ๊ฐ๊ฐ 512์ธ ์ด 512๊ฐœ์˜ ์‹œํ€€์Šค๋ฅผ ๊ฐ€์ ธ์™€ PyTorch ํ˜•์‹์˜ [~datasets.Dataset]์— ์ €์žฅํ•ฉ๋‹ˆ๋‹ค.

>>> import numpy as np
>>> from datasets import Dataset


>>> seq_len, dataset_size = 512, 512
>>> dummy_data = {
...     "input_ids": np.random.randint(100, 30000, (dataset_size, seq_len)),
...     "labels": np.random.randint(0, 1, (dataset_size)),
... }
>>> ds = Dataset.from_dict(dummy_data)
>>> ds.set_format("pt")

GPU ํ™œ์šฉ ๋ฐ [Trainer]๋กœ ์‹คํ–‰ํ•œ ํ›ˆ๋ จ ๊ณผ์ •์— ๋Œ€ํ•œ ์š”์•ฝ ํ†ต๊ณ„๋ฅผ ์ถœ๋ ฅํ•˜๊ธฐ ์œ„ํ•ด ๋‘ ๊ฐœ์˜ ๋„์šฐ๋ฏธ ํ•จ์ˆ˜๋ฅผ ์ •์˜ํ•˜๊ฒ ์Šต๋‹ˆ๋‹ค:

>>> from pynvml import *


>>> def print_gpu_utilization():
...     nvmlInit()
...     handle = nvmlDeviceGetHandleByIndex(0)
...     info = nvmlDeviceGetMemoryInfo(handle)
...     print(f"GPU memory occupied: {info.used//1024**2} MB.")


>>> def print_summary(result):
...     print(f"Time: {result.metrics['train_runtime']:.2f}")
...     print(f"Samples/second: {result.metrics['train_samples_per_second']:.2f}")
...     print_gpu_utilization()

์‹œ์ž‘ํ•  ๋•Œ GPU ๋ฉ”๋ชจ๋ฆฌ๊ฐ€ ๋น„์–ด ์žˆ๋Š”์ง€ ํ™•์ธํ•ด ๋ด…์‹œ๋‹ค:

>>> print_gpu_utilization()
GPU memory occupied: 0 MB.

์ข‹์Šต๋‹ˆ๋‹ค. ๋ชจ๋ธ์„ ๋กœ๋“œํ•˜๊ธฐ ์ „์—๋Š” ์˜ˆ์ƒ๋Œ€๋กœ GPU ๋ฉ”๋ชจ๋ฆฌ๊ฐ€ ์ ์œ ๋˜์ง€ ์•Š์•˜์Šต๋‹ˆ๋‹ค. ๊ทธ๋ ‡์ง€ ์•Š๋‹ค๋ฉด ์‚ฌ์šฉ์ž์˜ ๊ธฐ๊ธฐ์—์„œ GPU ๋ฉ”๋ชจ๋ฆฌ๋ฅผ ์‚ฌ์šฉํ•˜๋Š” ๋ชจ๋“  ํ”„๋กœ์„ธ์Šค๋ฅผ ์ค‘๋‹จํ•ด์•ผ ํ•ฉ๋‹ˆ๋‹ค. ๊ทธ๋Ÿฌ๋‚˜ ์‚ฌ์šฉ์ž๋Š” ๋ชจ๋“  ์—ฌ์œ  GPU ๋ฉ”๋ชจ๋ฆฌ๋ฅผ ์‚ฌ์šฉํ•  ์ˆ˜๋Š” ์—†์Šต๋‹ˆ๋‹ค. ๋ชจ๋ธ์ด GPU์— ๋กœ๋“œ๋  ๋•Œ ์ปค๋„๋„ ๋กœ๋“œ๋˜๋ฏ€๋กœ 1-2GB์˜ ๋ฉ”๋ชจ๋ฆฌ๋ฅผ ์ฐจ์ง€ํ•  ์ˆ˜ ์žˆ์Šต๋‹ˆ๋‹ค. ์–ผ๋งˆ๋‚˜ ๋˜๋Š”์ง€ ํ™•์ธํ•˜๊ธฐ ์œ„ํ•ด GPU์— ์ž‘์€ ํ…์„œ๋ฅผ ๋กœ๋“œํ•˜์—ฌ ์ปค๋„์ด ๋กœ๋“œ๋˜๋„๋ก ํŠธ๋ฆฌ๊ฑฐํ•ฉ๋‹ˆ๋‹ค.

>>> import torch


>>> torch.ones((1, 1)).to("cuda")
>>> print_gpu_utilization()
GPU memory occupied: 1343 MB.

์ปค๋„๋งŒ์œผ๋กœ๋„ GPU ๋ฉ”๋ชจ๋ฆฌ์˜ 1.3GB๋ฅผ ์ฐจ์ง€ํ•ฉ๋‹ˆ๋‹ค. ์ด์ œ ๋ชจ๋ธ์ด ์–ผ๋งˆ๋‚˜ ๋งŽ์€ ๊ณต๊ฐ„์„ ์‚ฌ์šฉํ•˜๋Š”์ง€ ํ™•์ธํ•ด ๋ณด๊ฒ ์Šต๋‹ˆ๋‹ค.

๋ชจ๋ธ ๋กœ๋“œ [[load-model]]

์šฐ์„ , google-bert/bert-large-uncased ๋ชจ๋ธ์„ ๋กœ๋“œํ•ฉ๋‹ˆ๋‹ค. ๋ชจ๋ธ์˜ ๊ฐ€์ค‘์น˜๋ฅผ ์ง์ ‘ GPU์— ๋กœ๋“œํ•ด์„œ ๊ฐ€์ค‘์น˜๋งŒ์ด ์–ผ๋งˆ๋‚˜ ๋งŽ์€ ๊ณต๊ฐ„์„ ์ฐจ์ง€ํ•˜๋Š”์ง€ ํ™•์ธํ•  ์ˆ˜ ์žˆ์Šต๋‹ˆ๋‹ค.

>>> from transformers import AutoModelForSequenceClassification


>>> model = AutoModelForSequenceClassification.from_pretrained("google-bert/bert-large-uncased").to("cuda")
>>> print_gpu_utilization()
GPU memory occupied: 2631 MB.

๋ชจ๋ธ์˜ ๊ฐ€์ค‘์น˜๋งŒ์œผ๋กœ๋„ GPU ๋ฉ”๋ชจ๋ฆฌ๋ฅผ 1.3 GB ์ฐจ์ง€ํ•˜๋Š” ๊ฒƒ์„ ๋ณผ ์ˆ˜ ์žˆ์Šต๋‹ˆ๋‹ค. ์ •ํ™•ํ•œ ์ˆซ์ž๋Š” ์‚ฌ์šฉํ•˜๋Š” GPU์— ๋”ฐ๋ผ ๋‹ค๋ฆ…๋‹ˆ๋‹ค. ์ตœ์‹  GPU์—์„œ๋Š” ๋ชจ๋ธ ์‚ฌ์šฉ ์†๋„๋ฅผ ๋†’์ด๋Š” ์ตœ์ ํ™”๋œ ๋ฐฉ์‹์œผ๋กœ ๊ฐ€์ค‘์น˜๊ฐ€ ๋กœ๋“œ๋˜๋ฏ€๋กœ, ๋ชจ๋ธ์ด ๋” ๋งŽ์€ ๊ณต๊ฐ„์„ ์ฐจ์ง€ํ•  ์ˆ˜ ์žˆ์Šต๋‹ˆ๋‹ค. ์ด์ œ nvidia-smi CLI์™€ ๋™์ผํ•œ ๊ฒฐ๊ณผ๋ฅผ ์–ป๋Š”์ง€ ๋น ๋ฅด๊ฒŒ ํ™•์ธํ•  ์ˆ˜ ์žˆ์Šต๋‹ˆ๋‹ค:

nvidia-smi
Tue Jan 11 08:58:05 2022
+-----------------------------------------------------------------------------+
| NVIDIA-SMI 460.91.03    Driver Version: 460.91.03    CUDA Version: 11.2     |
|-------------------------------+----------------------+----------------------+
| GPU  Name        Persistence-M| Bus-Id        Disp.A | Volatile Uncorr. ECC |
| Fan  Temp  Perf  Pwr:Usage/Cap|         Memory-Usage | GPU-Util  Compute M. |
|                               |                      |               MIG M. |
|===============================+======================+======================|
|   0  Tesla V100-SXM2...  On   | 00000000:00:04.0 Off |                    0 |
| N/A   37C    P0    39W / 300W |   2631MiB / 16160MiB |      0%      Default |
|                               |                      |                  N/A |
+-------------------------------+----------------------+----------------------+

+-----------------------------------------------------------------------------+
| Processes:                                                                  |
|  GPU   GI   CI        PID   Type   Process name                  GPU Memory |
|        ID   ID                                                   Usage      |
|=============================================================================|
|    0   N/A  N/A      3721      C   ...nvs/codeparrot/bin/python     2629MiB |
+-----------------------------------------------------------------------------+

์ด์ „๊ณผ ๋™์ผํ•œ ์ˆซ์ž๊ฐ€ ์ถœ๋ ฅ๋˜๊ณ  16GB ๋ฉ”๋ชจ๋ฆฌ๋ฅผ ๊ฐ€์ง„ V100 GPU๋ฅผ ์‚ฌ์šฉํ•˜๊ณ  ์žˆ๋‹ค๋Š” ๊ฒƒ๋„ ๋ณผ ์ˆ˜ ์žˆ์Šต๋‹ˆ๋‹ค. ๊ทธ๋Ÿฌ๋ฏ€๋กœ ์ด์ œ ๋ชจ๋ธ ํ›ˆ๋ จ์„ ์‹œ์ž‘ํ•˜์—ฌ GPU ๋ฉ”๋ชจ๋ฆฌ ์‚ฌ์šฉ๋Ÿ‰์ด ์–ด๋–ป๊ฒŒ ๋‹ฌ๋ผ์ง€๋Š”์ง€ ๋ณผ ์ˆ˜ ์žˆ์Šต๋‹ˆ๋‹ค. ์šฐ์„  ๋ช‡๋ช‡ ํ‘œ์ค€ ํ›ˆ๋ จ ์ธ์ˆ˜๋ฅผ ์„ค์ •ํ•ฉ๋‹ˆ๋‹ค:

default_args = {
    "output_dir": "tmp",
    "eval_strategy": "steps",
    "num_train_epochs": 1,
    "log_level": "error",
    "report_to": "none",
}

์—ฌ๋Ÿฌ ์‹คํ—˜์„ ์‹คํ–‰ํ•  ๊ณ„ํš์ด๋ผ๋ฉด, ์‹คํ—˜ ๊ฐ„์— ๋ฉ”๋ชจ๋ฆฌ๋ฅผ ์ œ๋Œ€๋กœ ๋น„์šฐ๊ธฐ ์œ„ํ•ด์„œ Python ์ปค๋„์„ ์‹คํ—˜ ์‚ฌ์ด๋งˆ๋‹ค ์žฌ์‹œ์ž‘ํ•ด์•ผ ํ•ฉ๋‹ˆ๋‹ค.

๊ธฐ๋ณธ ํ›ˆ๋ จ์—์„œ์˜ ๋ฉ”๋ชจ๋ฆฌ ํ™œ์šฉ [[memory-utilization-at-vanilla-training]]

[Trainer]๋ฅผ ์‚ฌ์šฉํ•˜์—ฌ, GPU ์„ฑ๋Šฅ ์ตœ์ ํ™” ๊ธฐ์ˆ ์„ ์‚ฌ์šฉํ•˜์ง€ ์•Š๊ณ  ๋ฐฐ์น˜ ํฌ๊ธฐ๊ฐ€ 4์ธ ๋ชจ๋ธ์„ ํ›ˆ๋ จ์‹œํ‚ค๊ฒ ์Šต๋‹ˆ๋‹ค:

>>> from transformers import TrainingArguments, Trainer, logging

>>> logging.set_verbosity_error()


>>> training_args = TrainingArguments(per_device_train_batch_size=4, **default_args)
>>> trainer = Trainer(model=model, args=training_args, train_dataset=ds)
>>> result = trainer.train()
>>> print_summary(result)
Time: 57.82
Samples/second: 8.86
GPU memory occupied: 14949 MB.

์šฐ๋ฆฌ๋Š” ๋น„๊ต์  ์ž‘์€ ๋ฐฐ์น˜ ํฌ๊ธฐ๋กœ๋„ ์ „์ฒด GPU ๋ฉ”๋ชจ๋ฆฌ๋ฅผ ๊ฑฐ์˜ ๋‹ค ์ฐจ์ง€ํ•˜๋Š” ๊ฒƒ์„ ๋ณผ ์ˆ˜ ์žˆ์Šต๋‹ˆ๋‹ค. ๊ทธ๋Ÿฌ๋‚˜ ๋ฐฐ์น˜ ํฌ๊ธฐ๊ฐ€ ํด์ˆ˜๋ก ๋ชจ๋ธ ์ˆ˜๋ ด ์†๋„๊ฐ€ ๋นจ๋ผ์ง€๊ณ  ์ตœ์ข… ์„ฑ๋Šฅ์ด ํ–ฅ์ƒ๋˜๋Š” ๊ฒฝ์šฐ๊ฐ€ ๋งŽ์Šต๋‹ˆ๋‹ค. ๊ทธ๋ž˜์„œ ์ด์ƒ์ ์œผ๋กœ๋Š” GPU ์ œํ•œ์ด ์•„๋‹Œ ์šฐ๋ฆฌ ๋ชจ๋ธ์˜ ์š”๊ตฌ์‚ฌํ•ญ์— ๋งž๊ฒŒ ๋ฐฐ์น˜ ํฌ๊ธฐ๋ฅผ ์กฐ์ •ํ•˜๋ ค๊ณ  ํ•ฉ๋‹ˆ๋‹ค. ํฅ๋ฏธ๋กญ๊ฒŒ๋„ ์šฐ๋ฆฌ๋Š” ๋ชจ๋ธ์˜ ํฌ๊ธฐ๋ณด๋‹ค ํ›จ์”ฌ ๋” ๋งŽ์€ ๋ฉ”๋ชจ๋ฆฌ๋ฅผ ์‚ฌ์šฉํ•ฉ๋‹ˆ๋‹ค. ์™œ ์ด๋Ÿฐ ํ˜„์ƒ์ด ๋ฐœ์ƒํ•˜๋Š”์ง€ ์กฐ๊ธˆ ๋” ์ž˜ ์ดํ•ดํ•˜๊ธฐ ์œ„ํ•ด ๋ชจ๋ธ์˜ ์—ฐ์‚ฐ๊ณผ ๋ฉ”๋ชจ๋ฆฌ ์š”๊ตฌ ์‚ฌํ•ญ์„ ์‚ดํŽด๋ณด๊ฒ ์Šต๋‹ˆ๋‹ค.

๋ชจ๋ธ์˜ ์—ฐ์‚ฐ ํ•ด๋ถ€ํ•˜๊ธฐ [[anatomy-of-models-operations]]

ํŠธ๋žœ์Šคํฌ๋จธ ์•„ํ‚คํ…์ฒ˜์—๋Š” ์—ฐ์‚ฐ ๊ฐ•๋„(compute-intensity)์— ๋”ฐ๋ผ ๊ทธ๋ฃนํ™”๋œ 3๊ฐ€์ง€ ์ฃผ์š” ์—ฐ์‚ฐ ๊ทธ๋ฃน์ด ์žˆ์Šต๋‹ˆ๋‹ค.

  1. ํ…์„œ ์ถ•์•ฝ(Tensor Contractions)

    ์„ ํ˜• ๋ ˆ์ด์–ด์™€ ๋ฉ€ํ‹ฐํ—ค๋“œ ์–ดํ…์…˜์˜ ๊ตฌ์„ฑ ์š”์†Œ๋Š” ๋ชจ๋‘ **ํ–‰๋ ฌ-ํ–‰๋ ฌ ๊ณฑ์…ˆ(matrix-matrix multiplications)**์„ ์ผ๊ด„์ ์œผ๋กœ ์ฒ˜๋ฆฌํ•ฉ๋‹ˆ๋‹ค. ์ด ์—ฐ์‚ฐ์€ ํŠธ๋žœ์Šคํฌ๋จธ ํ›ˆ๋ จ์—์„œ ๊ฐ€์žฅ ์—ฐ์‚ฐ ๊ฐ•๋„๊ฐ€ ๋†’์€ ๋ถ€๋ถ„์ž…๋‹ˆ๋‹ค.

  2. ํ†ต๊ณ„ ์ •๊ทœํ™”(Statistical Normalizations)

    ์†Œํ”„ํŠธ๋งฅ์Šค์™€ ๋ ˆ์ด์–ด ์ •๊ทœํ™”๋Š” ํ…์„œ ์ถ•์•ฝ๋ณด๋‹ค ์—ฐ์‚ฐ ๊ฐ•๋„๊ฐ€ ๋‚ฎ์Šต๋‹ˆ๋‹ค. ํ•˜๋‚˜ ์ด์ƒ์˜ **๊ฐ์†Œ ์—ฐ์‚ฐ(reduction operations)**์„ ํฌํ•จํ•˜๋ฉฐ, ๊ทธ ๊ฒฐ๊ณผ๋Š” map์„ ํ†ตํ•ด ์ ์šฉ๋ฉ๋‹ˆ๋‹ค.

  3. ์›์†Œ๋ณ„ ์—ฐ์‚ฐ์ž(Element-wise Operators)

    ๊ทธ ์™ธ ์—ฐ์‚ฐ์ž๋“ค, **ํŽธํ–ฅ(biases), ๋“œ๋กญ์•„์›ƒ(dropout), ํ™œ์„ฑํ™” ํ•จ์ˆ˜(activations), ์ž”์ฐจ ์—ฐ๊ฒฐ(residual connections)**์ด ์—ฌ๊ธฐ์— ํ•ด๋‹นํ•ฉ๋‹ˆ๋‹ค. ์ด ์—ฐ์‚ฐ๋“ค์€ ์—ฐ์‚ฐ ๊ฐ•๋„๊ฐ€ ๊ฐ€์žฅ ๋‚ฎ์Šต๋‹ˆ๋‹ค.

์ด๋Ÿฌํ•œ ์ง€์‹์€ ์„ฑ๋Šฅ ๋ณ‘๋ชฉ ํ˜„์ƒ์„ ๋ถ„์„ํ•  ๋•Œ ๋„์›€์ด ๋  ์ˆ˜ ์žˆ์Šต๋‹ˆ๋‹ค.

์ด ๋‚ด์šฉ์€ Data Movement Is All You Need: A Case Study on Optimizing Transformers 2020์„ ์ฐธ๊ณ ํ•˜์˜€์Šต๋‹ˆ๋‹ค.

๋ชจ๋ธ์˜ ๋ฉ”๋ชจ๋ฆฌ ๊ตฌ์กฐ [[anatomy-of-models-memory]]

๋ชจ๋ธ์„ ํ›ˆ๋ จ์‹œํ‚ค๋Š” ๋ฐ๋Š” ๋‹จ์ˆœํžˆ GPU์— ๋ชจ๋ธ์„ ์˜ฌ๋ฆฌ๋Š” ๊ฒƒ๋ณด๋‹ค ํ›จ์”ฌ ๋” ๋งŽ์€ ๋ฉ”๋ชจ๋ฆฌ๋ฅผ ์‚ฌ์šฉํ•œ๋‹ค๋Š” ๊ฒƒ์„ ๋ณด์•˜์Šต๋‹ˆ๋‹ค. ์ด๋Š” ํ›ˆ๋ จ ์ค‘ GPU ๋ฉ”๋ชจ๋ฆฌ๋ฅผ ์‚ฌ์šฉํ•˜๋Š” ๋งŽ์€ ๊ตฌ์„ฑ ์š”์†Œ๊ฐ€ ์žˆ๊ธฐ ๋•Œ๋ฌธ์ž…๋‹ˆ๋‹ค. GPU ๋ฉ”๋ชจ๋ฆฌ์˜ ๊ตฌ์„ฑ ์š”์†Œ๋Š” ๋‹ค์Œ๊ณผ ๊ฐ™์Šต๋‹ˆ๋‹ค:

  1. ๋ชจ๋ธ ๊ฐ€์ค‘์น˜
  2. ์˜ตํ‹ฐ๋งˆ์ด์ € ์ƒํƒœ
  3. ๊ทธ๋ผ๋””์–ธํŠธ
  4. ๊ทธ๋ผ๋””์–ธํŠธ ๊ณ„์‚ฐ์„ ์œ„ํ•ด ์ €์žฅ๋œ ์ˆœ๋ฐฉํ–ฅ ํ™œ์„ฑํ™”
  5. ์ž„์‹œ ๋ฒ„ํผ
  6. ๊ธฐ๋Šฅ๋ณ„ ๋ฉ”๋ชจ๋ฆฌ

AdamW๋ฅผ ์‚ฌ์šฉํ•˜์—ฌ ํ˜ผํ•ฉ ์ •๋ฐ€๋„๋กœ ํ›ˆ๋ จ๋œ ์ผ๋ฐ˜์ ์ธ ๋ชจ๋ธ์€ ๋ชจ๋ธ ํŒŒ๋ผ๋ฏธํ„ฐ๋‹น 18 ๋ฐ”์ดํŠธ์™€ ํ™œ์„ฑํ™” ๋ฉ”๋ชจ๋ฆฌ๊ฐ€ ํ•„์š”ํ•ฉ๋‹ˆ๋‹ค. ์ถ”๋ก  ๋‹จ๊ณ„์—์„œ๋Š” ์˜ตํ‹ฐ๋งˆ์ด์ €์™€ ๊ทธ๋ผ๋””์–ธํŠธ๊ฐ€ ํ•„์š”ํ•˜์ง€ ์•Š์œผ๋ฏ€๋กœ ์ด๋“ค์€ ์ œ์™ธํ•ฉ๋‹ˆ๋‹ค. ๋”ฐ๋ผ์„œ ํ˜ผํ•ฉ ์ •๋ฐ€๋„ ์ถ”๋ก ์˜ ๊ฒฝ์šฐ ๋ชจ๋ธ ๋งค๊ฐœ๋ณ€์ˆ˜๋‹น 6 ๋ฐ”์ดํŠธ์™€ ํ™œ์„ฑํ™” ๋ฉ”๋ชจ๋ฆฌ๊ฐ€ ํ•„์š”ํ•ฉ๋‹ˆ๋‹ค.

์ž์„ธํžˆ ์‚ดํŽด๋ณด๊ฒ ์Šต๋‹ˆ๋‹ค.

๋ชจ๋ธ ๊ฐ€์ค‘์น˜:

  • fp32 ํ›ˆ๋ จ์˜ ๊ฒฝ์šฐ ๋งค๊ฐœ ๋ณ€์ˆ˜ ์ˆ˜ * 4 ๋ฐ”์ดํŠธ
  • ํ˜ผํ•ฉ ์ •๋ฐ€๋„ ํ›ˆ๋ จ์˜ ๊ฒฝ์šฐ ๋งค๊ฐœ ๋ณ€์ˆ˜ ์ˆ˜ * 6 ๋ฐ”์ดํŠธ (๋ฉ”๋ชจ๋ฆฌ์— fp32์™€ fp16 ๋‘ ๊ฐ€์ง€ ๋ชจ๋ธ์„ ์œ ์ง€)

์˜ตํ‹ฐ๋งˆ์ด์ € ์ƒํƒœ:

  • ์ผ๋ฐ˜ AdamW์˜ ๊ฒฝ์šฐ ๋งค๊ฐœ ๋ณ€์ˆ˜ ์ˆ˜ * 8 ๋ฐ”์ดํŠธ (2๊ฐ€์ง€ ์ƒํƒœ ์œ ์ง€)
  • bitsandbytes์™€ ๊ฐ™์€ 8๋น„ํŠธ AdamW ์˜ตํ‹ฐ๋งˆ์ด์ €์˜ ๊ฒฝ์šฐ ๋งค๊ฐœ ๋ณ€์ˆ˜ ์ˆ˜ * 2 ๋ฐ”์ดํŠธ
  • Momentum์„ ๊ฐ€์ง„ SGD์™€ ๊ฐ™์€ ์˜ตํ‹ฐ๋งˆ์ด์ €์˜ ๊ฒฝ์šฐ ๋งค๊ฐœ ๋ณ€์ˆ˜ ์ˆ˜ * 4 ๋ฐ”์ดํŠธ (ํ•˜๋‚˜์˜ ์ƒํƒœ๋งŒ ์œ ์ง€)

๊ทธ๋ผ๋””์–ธํŠธ

  • fp32 ๋˜๋Š” ํ˜ผํ•ฉ ์ •๋ฐ€๋„ ํ›ˆ๋ จ์˜ ๊ฒฝ์šฐ ๋งค๊ฐœ ๋ณ€์ˆ˜ ์ˆ˜ * 4 ๋ฐ”์ดํŠธ (๊ทธ๋ผ๋””์–ธํŠธ๋Š” ํ•ญ์ƒ fp32์œผ๋กœ ์œ ์ง€๋ฉ๋‹ˆ๋‹ค.)

์ˆœ๋ฐฉํ–ฅ ํ™œ์„ฑํ™”

  • ํฌ๊ธฐ๋Š” ์—ฌ๋Ÿฌ ์š”์ธ์— ๋”ฐ๋ผ ๋‹ฌ๋ผ์ง€๋ฉฐ, ์ฃผ์š” ์š”์ธ์€ ์‹œํ€€์Šค ๊ธธ์ด, ์€๋‹‰ ์ƒํƒœ์˜ ํฌ๊ธฐ ๋ฐ ๋ฐฐ์น˜ ํฌ๊ธฐ์ž…๋‹ˆ๋‹ค.

์ˆœ๋ฐฉํ–ฅ ๋ฐ ์—ญ๋ฐฉํ–ฅ ํ•จ์ˆ˜์—์„œ ์ „๋‹ฌ ๋ฐ ๋ฐ˜ํ™˜๋˜๋Š” ์ž…๋ ฅ๊ณผ ์ถœ๋ ฅ์ด ์žˆ์œผ๋ฉฐ, ๊ทธ๋ผ๋””์–ธํŠธ ๊ณ„์‚ฐ์„ ์œ„ํ•ด ์ €์žฅ๋œ ์ˆœ๋ฐฉํ–ฅ ํ™œ์„ฑํ™”๊ฐ€ ์žˆ์Šต๋‹ˆ๋‹ค.

์ž„์‹œ ๋ฉ”๋ชจ๋ฆฌ

๋”๋ถˆ์–ด ๋ชจ๋“  ์ข…๋ฅ˜์˜ ์ž„์‹œ ๋ณ€์ˆ˜๋Š” ์—ฐ์‚ฐ์ด ์™„๋ฃŒ๋˜๋ฉด ๊ณง๋ฐ”๋กœ ํ•ด์ œ๋˜์ง€๋งŒ, ๊ทธ ์ˆœ๊ฐ„์—๋Š” ์ถ”๊ฐ€ ๋ฉ”๋ชจ๋ฆฌ๊ฐ€ ํ•„์š”ํ•  ์ˆ˜ ์žˆ๊ณ  OOM์„ ์œ ๋ฐœํ•  ์ˆ˜ ์žˆ์Šต๋‹ˆ๋‹ค. ๋”ฐ๋ผ์„œ ์ฝ”๋”ฉํ•  ๋•Œ ์ด๋Ÿฌํ•œ ์ž„์‹œ ๋ณ€์ˆ˜์— ๋Œ€ํ•ด ์ „๋žต์ ์œผ๋กœ ์ƒ๊ฐํ•˜๊ณ  ๋•Œ๋กœ๋Š” ๋” ์ด์ƒ ํ•„์š” ์—†๋Š” ์ž„์‹œ ๋ณ€์ˆ˜๋ฅผ ์ฆ‰์‹œ ๋ช…์‹œ์ ์œผ๋กœ ๋ฉ”๋ชจ๋ฆฌ์—์„œ ์ œ๊ฑฐํ•˜๋Š” ๊ฒƒ์ด ์ค‘์š”ํ•ฉ๋‹ˆ๋‹ค.

๊ธฐ๋Šฅ๋ณ„ ๋ฉ”๋ชจ๋ฆฌ

๊ทธ๋Ÿฐ ๋‹ค์Œ, ์†Œํ”„ํŠธ์›จ์–ด์—๋Š” ํŠน๋ณ„ํ•œ ๋ฉ”๋ชจ๋ฆฌ ์š”๊ตฌ ์‚ฌํ•ญ์ด ์žˆ์„ ์ˆ˜ ์žˆ์Šต๋‹ˆ๋‹ค. ์˜ˆ๋ฅผ ๋“ค์–ด, ๋น” ๊ฒ€์ƒ‰์„ ์‚ฌ์šฉํ•˜์—ฌ ํ…์ŠคํŠธ๋ฅผ ์ƒ์„ฑํ•  ๋•Œ ์†Œํ”„ํŠธ์›จ์–ด๋Š” ์ž…๋ ฅ๊ณผ ์ถœ๋ ฅ ์‚ฌ๋ณธ์„ ์—ฌ๋Ÿฌ ๊ฐœ ์œ ์ง€ํ•ด์•ผ ํ•ฉ๋‹ˆ๋‹ค.

forward vs backward ์‹คํ–‰ ์†๋„

ํ•ฉ์„ฑ๊ณฑ๊ณผ ์„ ํ˜• ๋ ˆ์ด์–ด์˜ ๊ฒฝ์šฐ ์ˆœ๋ฐฉํ–ฅ์— ๋น„ํ•ด ์—ญ๋ฐฉํ–ฅ์—์„œ๋Š” 2๋ฐฐ์˜ ํ”Œ๋กญ์Šค๊ฐ€ ํ•„์š”ํ•˜๋ฏ€๋กœ ์ผ๋ฐ˜์ ์œผ๋กœ 2๋ฐฐ ์ •๋„ ๋А๋ฆฌ๊ฒŒ ๋ณ€ํ™˜๋ฉ๋‹ˆ๋‹ค(์—ญ๋ฐฉํ–ฅ์˜ ๊ฒฝ์šฐ ์‚ฌ์ด์ฆˆ๊ฐ€ ๋ถ€์ž์—ฐ์Šค๋Ÿฝ๊ธฐ ๋•Œ๋ฌธ์—, ๋•Œ๋กœ๋Š” ๋”์šฑ ๋А๋ฆด ์ˆ˜๋„ ์žˆ์Šต๋‹ˆ๋‹ค). ํ™œ์„ฑํ™”๋Š” ์ผ๋ฐ˜์ ์œผ๋กœ ๋Œ€์—ญํญ์ด ์ œํ•œ๋˜์–ด ์žˆ์œผ๋ฉฐ, ์ผ๋ฐ˜์ ์œผ๋กœ ์ˆœ๋ฐฉํ–ฅ๋ณด๋‹ค ์—ญ๋ฐฉํ–ฅ์—์„œ ๋” ๋งŽ์€ ๋ฐ์ดํ„ฐ๋ฅผ ์ฝ์–ด์•ผ ํ•ฉ๋‹ˆ๋‹ค. (์˜ˆ๋ฅผ ๋“ค์–ด, ์ˆœ๋ฐฉํ–ฅ ํ™œ์„ฑํ™” ์‹œ ํ•œ ๋ฒˆ ์”ฉ ์ฝ๊ณ  ์“ฐ์ง€๋งŒ, ์—ญ๋ฐฉํ–ฅ ํ™œ์„ฑํ™”์—์„œ๋Š” ์ˆœ๋ฐฉํ–ฅ gradOutput๊ณผ ์ถœ๋ ฅ์— ๋Œ€ํ•ด ์ด ๋‘ ๋ฒˆ ์ฝ๊ณ  gradInput์— ๋Œ€ํ•ด ํ•œ ๋ฒˆ ์”๋‹ˆ๋‹ค.)

๋ณด๋‹ค์‹œํ”ผ, GPU ๋ฉ”๋ชจ๋ฆฌ๋ฅผ ์ ˆ์•ฝํ•˜๊ฑฐ๋‚˜ ์ž‘์—… ์†๋„๋ฅผ ๋†’์ผ ์ˆ˜ ์žˆ๋Š” ๋ช‡ ๊ฐ€์ง€ ๋ฐฉ๋ฒ•์ด ์žˆ์Šต๋‹ˆ๋‹ค. ์ด์ œ GPU ํ™œ์šฉ๊ณผ ๊ณ„์‚ฐ ์†๋„์— ์˜ํ–ฅ์„ ์ฃผ๋Š” ๊ฒƒ์ด ๋ฌด์—‡์ธ์ง€๋ฅผ ์ดํ•ดํ–ˆ์œผ๋ฏ€๋กœ, Methods and tools for efficient training on a single GPU ๋ฌธ์„œ ํŽ˜์ด์ง€๋ฅผ ์ฐธ์กฐํ•˜์—ฌ ์„ฑ๋Šฅ ์ตœ์ ํ™” ๊ธฐ๋ฒ•์— ๋Œ€ํ•ด ์•Œ์•„๋ณด์„ธ์š”.