๋ชจ๋ธ ํ์ต ํด๋ถํ๊ธฐ [[model-training-anatomy]]
๋ชจ๋ธ ํ๋ จ ์๋์ ๋ฉ๋ชจ๋ฆฌ ํ์ฉ์ ํจ์จ์ฑ์ ํฅ์์ํค๊ธฐ ์ํด ์ ์ฉํ ์ ์๋ ์ฑ๋ฅ ์ต์ ํ ๊ธฐ์ ์ ์ดํดํ๋ ค๋ฉด GPU๊ฐ ํ๋ จ ์ค์ ์ด๋ป๊ฒ ํ์ฉ๋๋์ง, ๊ทธ๋ฆฌ๊ณ ์ํ๋๋ ์ฐ์ฐ์ ๋ฐ๋ผ ์ฐ์ฐ ๊ฐ๋๊ฐ ์ด๋ป๊ฒ ๋ณํ๋์ง์ ์ต์ํด์ ธ์ผ ํฉ๋๋ค.
๋จผ์ GPU ํ์ฉ๊ณผ ๋ชจ๋ธ ํ๋ จ ์คํ์ ๋ํ ์์๋ฅผ ์ดํด๋ณด๊ฒ ์ต๋๋ค. ๋ฐ๋ชจ๋ฅผ ์ํด ๋ช๋ช ๋ผ์ด๋ธ๋ฌ๋ฆฌ๋ฅผ ์ค์นํด์ผ ํฉ๋๋ค:
pip install transformers datasets accelerate nvidia-ml-py3
nvidia-ml-py3 ๋ผ์ด๋ธ๋ฌ๋ฆฌ๋ Python ๋ด๋ถ์์ ๋ชจ๋ธ์ ๋ฉ๋ชจ๋ฆฌ ์ฌ์ฉ๋์ ๋ชจ๋ํฐ๋งํ ์ ์๊ฒ ํด์ค๋๋ค. ํฐ๋ฏธ๋์ nvidia-smi ๋ช
๋ น์ด์ ์ต์ํ ์ ์๋๋ฐ, ์ด ๋ผ์ด๋ธ๋ฌ๋ฆฌ๋ Python์์ ์ง์ ๋์ผํ ์ ๋ณด์ ์ ๊ทผํ ์ ์๊ฒ ํด์ค๋๋ค.
๊ทธ ๋ค์, 100๊ณผ 30000 ์ฌ์ด์ ๋ฌด์์ ํ ํฐ ID์ ๋ถ๋ฅ๊ธฐ๋ฅผ ์ํ ์ด์ง ๋ ์ด๋ธ์ธ ๋๋ฏธ ๋ฐ์ดํฐ๋ฅผ ์์ฑํฉ๋๋ค.
๊ธธ์ด๊ฐ ๊ฐ๊ฐ 512์ธ ์ด 512๊ฐ์ ์ํ์ค๋ฅผ ๊ฐ์ ธ์ PyTorch ํ์์ [~datasets.Dataset]์ ์ ์ฅํฉ๋๋ค.
>>> import numpy as np
>>> from datasets import Dataset
>>> seq_len, dataset_size = 512, 512
>>> dummy_data = {
... "input_ids": np.random.randint(100, 30000, (dataset_size, seq_len)),
... "labels": np.random.randint(0, 1, (dataset_size)),
... }
>>> ds = Dataset.from_dict(dummy_data)
>>> ds.set_format("pt")
GPU ํ์ฉ ๋ฐ [Trainer]๋ก ์คํํ ํ๋ จ ๊ณผ์ ์ ๋ํ ์์ฝ ํต๊ณ๋ฅผ ์ถ๋ ฅํ๊ธฐ ์ํด ๋ ๊ฐ์ ๋์ฐ๋ฏธ ํจ์๋ฅผ ์ ์ํ๊ฒ ์ต๋๋ค:
>>> from pynvml import *
>>> def print_gpu_utilization():
... nvmlInit()
... handle = nvmlDeviceGetHandleByIndex(0)
... info = nvmlDeviceGetMemoryInfo(handle)
... print(f"GPU memory occupied: {info.used//1024**2} MB.")
>>> def print_summary(result):
... print(f"Time: {result.metrics['train_runtime']:.2f}")
... print(f"Samples/second: {result.metrics['train_samples_per_second']:.2f}")
... print_gpu_utilization()
์์ํ ๋ GPU ๋ฉ๋ชจ๋ฆฌ๊ฐ ๋น์ด ์๋์ง ํ์ธํด ๋ด ์๋ค:
>>> print_gpu_utilization()
GPU memory occupied: 0 MB.
์ข์ต๋๋ค. ๋ชจ๋ธ์ ๋ก๋ํ๊ธฐ ์ ์๋ ์์๋๋ก GPU ๋ฉ๋ชจ๋ฆฌ๊ฐ ์ ์ ๋์ง ์์์ต๋๋ค. ๊ทธ๋ ์ง ์๋ค๋ฉด ์ฌ์ฉ์์ ๊ธฐ๊ธฐ์์ GPU ๋ฉ๋ชจ๋ฆฌ๋ฅผ ์ฌ์ฉํ๋ ๋ชจ๋ ํ๋ก์ธ์ค๋ฅผ ์ค๋จํด์ผ ํฉ๋๋ค. ๊ทธ๋ฌ๋ ์ฌ์ฉ์๋ ๋ชจ๋ ์ฌ์ GPU ๋ฉ๋ชจ๋ฆฌ๋ฅผ ์ฌ์ฉํ ์๋ ์์ต๋๋ค. ๋ชจ๋ธ์ด GPU์ ๋ก๋๋ ๋ ์ปค๋๋ ๋ก๋๋๋ฏ๋ก 1-2GB์ ๋ฉ๋ชจ๋ฆฌ๋ฅผ ์ฐจ์งํ ์ ์์ต๋๋ค. ์ผ๋ง๋ ๋๋์ง ํ์ธํ๊ธฐ ์ํด GPU์ ์์ ํ ์๋ฅผ ๋ก๋ํ์ฌ ์ปค๋์ด ๋ก๋๋๋๋ก ํธ๋ฆฌ๊ฑฐํฉ๋๋ค.
>>> import torch
>>> torch.ones((1, 1)).to("cuda")
>>> print_gpu_utilization()
GPU memory occupied: 1343 MB.
์ปค๋๋ง์ผ๋ก๋ GPU ๋ฉ๋ชจ๋ฆฌ์ 1.3GB๋ฅผ ์ฐจ์งํฉ๋๋ค. ์ด์ ๋ชจ๋ธ์ด ์ผ๋ง๋ ๋ง์ ๊ณต๊ฐ์ ์ฌ์ฉํ๋์ง ํ์ธํด ๋ณด๊ฒ ์ต๋๋ค.
๋ชจ๋ธ ๋ก๋ [[load-model]]
์ฐ์ , google-bert/bert-large-uncased ๋ชจ๋ธ์ ๋ก๋ํฉ๋๋ค. ๋ชจ๋ธ์ ๊ฐ์ค์น๋ฅผ ์ง์ GPU์ ๋ก๋ํด์ ๊ฐ์ค์น๋ง์ด ์ผ๋ง๋ ๋ง์ ๊ณต๊ฐ์ ์ฐจ์งํ๋์ง ํ์ธํ ์ ์์ต๋๋ค.
>>> from transformers import AutoModelForSequenceClassification
>>> model = AutoModelForSequenceClassification.from_pretrained("google-bert/bert-large-uncased").to("cuda")
>>> print_gpu_utilization()
GPU memory occupied: 2631 MB.
๋ชจ๋ธ์ ๊ฐ์ค์น๋ง์ผ๋ก๋ GPU ๋ฉ๋ชจ๋ฆฌ๋ฅผ 1.3 GB ์ฐจ์งํ๋ ๊ฒ์ ๋ณผ ์ ์์ต๋๋ค. ์ ํํ ์ซ์๋ ์ฌ์ฉํ๋ GPU์ ๋ฐ๋ผ ๋ค๋ฆ
๋๋ค. ์ต์ GPU์์๋ ๋ชจ๋ธ ์ฌ์ฉ ์๋๋ฅผ ๋์ด๋ ์ต์ ํ๋ ๋ฐฉ์์ผ๋ก ๊ฐ์ค์น๊ฐ ๋ก๋๋๋ฏ๋ก, ๋ชจ๋ธ์ด ๋ ๋ง์ ๊ณต๊ฐ์ ์ฐจ์งํ ์ ์์ต๋๋ค. ์ด์ nvidia-smi CLI์ ๋์ผํ ๊ฒฐ๊ณผ๋ฅผ ์ป๋์ง ๋น ๋ฅด๊ฒ ํ์ธํ ์ ์์ต๋๋ค:
nvidia-smi
Tue Jan 11 08:58:05 2022
+-----------------------------------------------------------------------------+
| NVIDIA-SMI 460.91.03 Driver Version: 460.91.03 CUDA Version: 11.2 |
|-------------------------------+----------------------+----------------------+
| GPU Name Persistence-M| Bus-Id Disp.A | Volatile Uncorr. ECC |
| Fan Temp Perf Pwr:Usage/Cap| Memory-Usage | GPU-Util Compute M. |
| | | MIG M. |
|===============================+======================+======================|
| 0 Tesla V100-SXM2... On | 00000000:00:04.0 Off | 0 |
| N/A 37C P0 39W / 300W | 2631MiB / 16160MiB | 0% Default |
| | | N/A |
+-------------------------------+----------------------+----------------------+
+-----------------------------------------------------------------------------+
| Processes: |
| GPU GI CI PID Type Process name GPU Memory |
| ID ID Usage |
|=============================================================================|
| 0 N/A N/A 3721 C ...nvs/codeparrot/bin/python 2629MiB |
+-----------------------------------------------------------------------------+
์ด์ ๊ณผ ๋์ผํ ์ซ์๊ฐ ์ถ๋ ฅ๋๊ณ 16GB ๋ฉ๋ชจ๋ฆฌ๋ฅผ ๊ฐ์ง V100 GPU๋ฅผ ์ฌ์ฉํ๊ณ ์๋ค๋ ๊ฒ๋ ๋ณผ ์ ์์ต๋๋ค. ๊ทธ๋ฌ๋ฏ๋ก ์ด์ ๋ชจ๋ธ ํ๋ จ์ ์์ํ์ฌ GPU ๋ฉ๋ชจ๋ฆฌ ์ฌ์ฉ๋์ด ์ด๋ป๊ฒ ๋ฌ๋ผ์ง๋์ง ๋ณผ ์ ์์ต๋๋ค. ์ฐ์ ๋ช๋ช ํ์ค ํ๋ จ ์ธ์๋ฅผ ์ค์ ํฉ๋๋ค:
default_args = {
"output_dir": "tmp",
"eval_strategy": "steps",
"num_train_epochs": 1,
"log_level": "error",
"report_to": "none",
}
์ฌ๋ฌ ์คํ์ ์คํํ ๊ณํ์ด๋ผ๋ฉด, ์คํ ๊ฐ์ ๋ฉ๋ชจ๋ฆฌ๋ฅผ ์ ๋๋ก ๋น์ฐ๊ธฐ ์ํด์ Python ์ปค๋์ ์คํ ์ฌ์ด๋ง๋ค ์ฌ์์ํด์ผ ํฉ๋๋ค.
๊ธฐ๋ณธ ํ๋ จ์์์ ๋ฉ๋ชจ๋ฆฌ ํ์ฉ [[memory-utilization-at-vanilla-training]]
[Trainer]๋ฅผ ์ฌ์ฉํ์ฌ, GPU ์ฑ๋ฅ ์ต์ ํ ๊ธฐ์ ์ ์ฌ์ฉํ์ง ์๊ณ ๋ฐฐ์น ํฌ๊ธฐ๊ฐ 4์ธ ๋ชจ๋ธ์ ํ๋ จ์ํค๊ฒ ์ต๋๋ค:
>>> from transformers import TrainingArguments, Trainer, logging
>>> logging.set_verbosity_error()
>>> training_args = TrainingArguments(per_device_train_batch_size=4, **default_args)
>>> trainer = Trainer(model=model, args=training_args, train_dataset=ds)
>>> result = trainer.train()
>>> print_summary(result)
Time: 57.82
Samples/second: 8.86
GPU memory occupied: 14949 MB.
์ฐ๋ฆฌ๋ ๋น๊ต์ ์์ ๋ฐฐ์น ํฌ๊ธฐ๋ก๋ ์ ์ฒด GPU ๋ฉ๋ชจ๋ฆฌ๋ฅผ ๊ฑฐ์ ๋ค ์ฐจ์งํ๋ ๊ฒ์ ๋ณผ ์ ์์ต๋๋ค. ๊ทธ๋ฌ๋ ๋ฐฐ์น ํฌ๊ธฐ๊ฐ ํด์๋ก ๋ชจ๋ธ ์๋ ด ์๋๊ฐ ๋นจ๋ผ์ง๊ณ ์ต์ข ์ฑ๋ฅ์ด ํฅ์๋๋ ๊ฒฝ์ฐ๊ฐ ๋ง์ต๋๋ค. ๊ทธ๋์ ์ด์์ ์ผ๋ก๋ GPU ์ ํ์ด ์๋ ์ฐ๋ฆฌ ๋ชจ๋ธ์ ์๊ตฌ์ฌํญ์ ๋ง๊ฒ ๋ฐฐ์น ํฌ๊ธฐ๋ฅผ ์กฐ์ ํ๋ ค๊ณ ํฉ๋๋ค. ํฅ๋ฏธ๋กญ๊ฒ๋ ์ฐ๋ฆฌ๋ ๋ชจ๋ธ์ ํฌ๊ธฐ๋ณด๋ค ํจ์ฌ ๋ ๋ง์ ๋ฉ๋ชจ๋ฆฌ๋ฅผ ์ฌ์ฉํฉ๋๋ค. ์ ์ด๋ฐ ํ์์ด ๋ฐ์ํ๋์ง ์กฐ๊ธ ๋ ์ ์ดํดํ๊ธฐ ์ํด ๋ชจ๋ธ์ ์ฐ์ฐ๊ณผ ๋ฉ๋ชจ๋ฆฌ ์๊ตฌ ์ฌํญ์ ์ดํด๋ณด๊ฒ ์ต๋๋ค.
๋ชจ๋ธ์ ์ฐ์ฐ ํด๋ถํ๊ธฐ [[anatomy-of-models-operations]]
ํธ๋์คํฌ๋จธ ์ํคํ ์ฒ์๋ ์ฐ์ฐ ๊ฐ๋(compute-intensity)์ ๋ฐ๋ผ ๊ทธ๋ฃนํ๋ 3๊ฐ์ง ์ฃผ์ ์ฐ์ฐ ๊ทธ๋ฃน์ด ์์ต๋๋ค.
ํ ์ ์ถ์ฝ(Tensor Contractions)
์ ํ ๋ ์ด์ด์ ๋ฉํฐํค๋ ์ดํ ์ ์ ๊ตฌ์ฑ ์์๋ ๋ชจ๋ **ํ๋ ฌ-ํ๋ ฌ ๊ณฑ์ (matrix-matrix multiplications)**์ ์ผ๊ด์ ์ผ๋ก ์ฒ๋ฆฌํฉ๋๋ค. ์ด ์ฐ์ฐ์ ํธ๋์คํฌ๋จธ ํ๋ จ์์ ๊ฐ์ฅ ์ฐ์ฐ ๊ฐ๋๊ฐ ๋์ ๋ถ๋ถ์ ๋๋ค.
ํต๊ณ ์ ๊ทํ(Statistical Normalizations)
์ํํธ๋งฅ์ค์ ๋ ์ด์ด ์ ๊ทํ๋ ํ ์ ์ถ์ฝ๋ณด๋ค ์ฐ์ฐ ๊ฐ๋๊ฐ ๋ฎ์ต๋๋ค. ํ๋ ์ด์์ **๊ฐ์ ์ฐ์ฐ(reduction operations)**์ ํฌํจํ๋ฉฐ, ๊ทธ ๊ฒฐ๊ณผ๋ map์ ํตํด ์ ์ฉ๋ฉ๋๋ค.
์์๋ณ ์ฐ์ฐ์(Element-wise Operators)
๊ทธ ์ธ ์ฐ์ฐ์๋ค, **ํธํฅ(biases), ๋๋กญ์์(dropout), ํ์ฑํ ํจ์(activations), ์์ฐจ ์ฐ๊ฒฐ(residual connections)**์ด ์ฌ๊ธฐ์ ํด๋นํฉ๋๋ค. ์ด ์ฐ์ฐ๋ค์ ์ฐ์ฐ ๊ฐ๋๊ฐ ๊ฐ์ฅ ๋ฎ์ต๋๋ค.
์ด๋ฌํ ์ง์์ ์ฑ๋ฅ ๋ณ๋ชฉ ํ์์ ๋ถ์ํ ๋ ๋์์ด ๋ ์ ์์ต๋๋ค.
์ด ๋ด์ฉ์ Data Movement Is All You Need: A Case Study on Optimizing Transformers 2020์ ์ฐธ๊ณ ํ์์ต๋๋ค.
๋ชจ๋ธ์ ๋ฉ๋ชจ๋ฆฌ ๊ตฌ์กฐ [[anatomy-of-models-memory]]
๋ชจ๋ธ์ ํ๋ จ์ํค๋ ๋ฐ๋ ๋จ์ํ GPU์ ๋ชจ๋ธ์ ์ฌ๋ฆฌ๋ ๊ฒ๋ณด๋ค ํจ์ฌ ๋ ๋ง์ ๋ฉ๋ชจ๋ฆฌ๋ฅผ ์ฌ์ฉํ๋ค๋ ๊ฒ์ ๋ณด์์ต๋๋ค. ์ด๋ ํ๋ จ ์ค GPU ๋ฉ๋ชจ๋ฆฌ๋ฅผ ์ฌ์ฉํ๋ ๋ง์ ๊ตฌ์ฑ ์์๊ฐ ์๊ธฐ ๋๋ฌธ์ ๋๋ค. GPU ๋ฉ๋ชจ๋ฆฌ์ ๊ตฌ์ฑ ์์๋ ๋ค์๊ณผ ๊ฐ์ต๋๋ค:
- ๋ชจ๋ธ ๊ฐ์ค์น
- ์ตํฐ๋ง์ด์ ์ํ
- ๊ทธ๋ผ๋์ธํธ
- ๊ทธ๋ผ๋์ธํธ ๊ณ์ฐ์ ์ํด ์ ์ฅ๋ ์๋ฐฉํฅ ํ์ฑํ
- ์์ ๋ฒํผ
- ๊ธฐ๋ฅ๋ณ ๋ฉ๋ชจ๋ฆฌ
AdamW๋ฅผ ์ฌ์ฉํ์ฌ ํผํฉ ์ ๋ฐ๋๋ก ํ๋ จ๋ ์ผ๋ฐ์ ์ธ ๋ชจ๋ธ์ ๋ชจ๋ธ ํ๋ผ๋ฏธํฐ๋น 18 ๋ฐ์ดํธ์ ํ์ฑํ ๋ฉ๋ชจ๋ฆฌ๊ฐ ํ์ํฉ๋๋ค. ์ถ๋ก ๋จ๊ณ์์๋ ์ตํฐ๋ง์ด์ ์ ๊ทธ๋ผ๋์ธํธ๊ฐ ํ์ํ์ง ์์ผ๋ฏ๋ก ์ด๋ค์ ์ ์ธํฉ๋๋ค. ๋ฐ๋ผ์ ํผํฉ ์ ๋ฐ๋ ์ถ๋ก ์ ๊ฒฝ์ฐ ๋ชจ๋ธ ๋งค๊ฐ๋ณ์๋น 6 ๋ฐ์ดํธ์ ํ์ฑํ ๋ฉ๋ชจ๋ฆฌ๊ฐ ํ์ํฉ๋๋ค.
์์ธํ ์ดํด๋ณด๊ฒ ์ต๋๋ค.
๋ชจ๋ธ ๊ฐ์ค์น:
- fp32 ํ๋ จ์ ๊ฒฝ์ฐ ๋งค๊ฐ ๋ณ์ ์ * 4 ๋ฐ์ดํธ
- ํผํฉ ์ ๋ฐ๋ ํ๋ จ์ ๊ฒฝ์ฐ ๋งค๊ฐ ๋ณ์ ์ * 6 ๋ฐ์ดํธ (๋ฉ๋ชจ๋ฆฌ์ fp32์ fp16 ๋ ๊ฐ์ง ๋ชจ๋ธ์ ์ ์ง)
์ตํฐ๋ง์ด์ ์ํ:
- ์ผ๋ฐ AdamW์ ๊ฒฝ์ฐ ๋งค๊ฐ ๋ณ์ ์ * 8 ๋ฐ์ดํธ (2๊ฐ์ง ์ํ ์ ์ง)
- bitsandbytes์ ๊ฐ์ 8๋นํธ AdamW ์ตํฐ๋ง์ด์ ์ ๊ฒฝ์ฐ ๋งค๊ฐ ๋ณ์ ์ * 2 ๋ฐ์ดํธ
- Momentum์ ๊ฐ์ง SGD์ ๊ฐ์ ์ตํฐ๋ง์ด์ ์ ๊ฒฝ์ฐ ๋งค๊ฐ ๋ณ์ ์ * 4 ๋ฐ์ดํธ (ํ๋์ ์ํ๋ง ์ ์ง)
๊ทธ๋ผ๋์ธํธ
- fp32 ๋๋ ํผํฉ ์ ๋ฐ๋ ํ๋ จ์ ๊ฒฝ์ฐ ๋งค๊ฐ ๋ณ์ ์ * 4 ๋ฐ์ดํธ (๊ทธ๋ผ๋์ธํธ๋ ํญ์ fp32์ผ๋ก ์ ์ง๋ฉ๋๋ค.)
์๋ฐฉํฅ ํ์ฑํ
- ํฌ๊ธฐ๋ ์ฌ๋ฌ ์์ธ์ ๋ฐ๋ผ ๋ฌ๋ผ์ง๋ฉฐ, ์ฃผ์ ์์ธ์ ์ํ์ค ๊ธธ์ด, ์๋ ์ํ์ ํฌ๊ธฐ ๋ฐ ๋ฐฐ์น ํฌ๊ธฐ์ ๋๋ค.
์๋ฐฉํฅ ๋ฐ ์ญ๋ฐฉํฅ ํจ์์์ ์ ๋ฌ ๋ฐ ๋ฐํ๋๋ ์ ๋ ฅ๊ณผ ์ถ๋ ฅ์ด ์์ผ๋ฉฐ, ๊ทธ๋ผ๋์ธํธ ๊ณ์ฐ์ ์ํด ์ ์ฅ๋ ์๋ฐฉํฅ ํ์ฑํ๊ฐ ์์ต๋๋ค.
์์ ๋ฉ๋ชจ๋ฆฌ
๋๋ถ์ด ๋ชจ๋ ์ข ๋ฅ์ ์์ ๋ณ์๋ ์ฐ์ฐ์ด ์๋ฃ๋๋ฉด ๊ณง๋ฐ๋ก ํด์ ๋์ง๋ง, ๊ทธ ์๊ฐ์๋ ์ถ๊ฐ ๋ฉ๋ชจ๋ฆฌ๊ฐ ํ์ํ ์ ์๊ณ OOM์ ์ ๋ฐํ ์ ์์ต๋๋ค. ๋ฐ๋ผ์ ์ฝ๋ฉํ ๋ ์ด๋ฌํ ์์ ๋ณ์์ ๋ํด ์ ๋ต์ ์ผ๋ก ์๊ฐํ๊ณ ๋๋ก๋ ๋ ์ด์ ํ์ ์๋ ์์ ๋ณ์๋ฅผ ์ฆ์ ๋ช ์์ ์ผ๋ก ๋ฉ๋ชจ๋ฆฌ์์ ์ ๊ฑฐํ๋ ๊ฒ์ด ์ค์ํฉ๋๋ค.
๊ธฐ๋ฅ๋ณ ๋ฉ๋ชจ๋ฆฌ
๊ทธ๋ฐ ๋ค์, ์ํํธ์จ์ด์๋ ํน๋ณํ ๋ฉ๋ชจ๋ฆฌ ์๊ตฌ ์ฌํญ์ด ์์ ์ ์์ต๋๋ค. ์๋ฅผ ๋ค์ด, ๋น ๊ฒ์์ ์ฌ์ฉํ์ฌ ํ ์คํธ๋ฅผ ์์ฑํ ๋ ์ํํธ์จ์ด๋ ์ ๋ ฅ๊ณผ ์ถ๋ ฅ ์ฌ๋ณธ์ ์ฌ๋ฌ ๊ฐ ์ ์งํด์ผ ํฉ๋๋ค.
forward vs backward ์คํ ์๋
ํฉ์ฑ๊ณฑ๊ณผ ์ ํ ๋ ์ด์ด์ ๊ฒฝ์ฐ ์๋ฐฉํฅ์ ๋นํด ์ญ๋ฐฉํฅ์์๋ 2๋ฐฐ์ ํ๋กญ์ค๊ฐ ํ์ํ๋ฏ๋ก ์ผ๋ฐ์ ์ผ๋ก 2๋ฐฐ ์ ๋ ๋๋ฆฌ๊ฒ ๋ณํ๋ฉ๋๋ค(์ญ๋ฐฉํฅ์ ๊ฒฝ์ฐ ์ฌ์ด์ฆ๊ฐ ๋ถ์์ฐ์ค๋ฝ๊ธฐ ๋๋ฌธ์, ๋๋ก๋ ๋์ฑ ๋๋ฆด ์๋ ์์ต๋๋ค). ํ์ฑํ๋ ์ผ๋ฐ์ ์ผ๋ก ๋์ญํญ์ด ์ ํ๋์ด ์์ผ๋ฉฐ, ์ผ๋ฐ์ ์ผ๋ก ์๋ฐฉํฅ๋ณด๋ค ์ญ๋ฐฉํฅ์์ ๋ ๋ง์ ๋ฐ์ดํฐ๋ฅผ ์ฝ์ด์ผ ํฉ๋๋ค. (์๋ฅผ ๋ค์ด, ์๋ฐฉํฅ ํ์ฑํ ์ ํ ๋ฒ ์ฉ ์ฝ๊ณ ์ฐ์ง๋ง, ์ญ๋ฐฉํฅ ํ์ฑํ์์๋ ์๋ฐฉํฅ gradOutput๊ณผ ์ถ๋ ฅ์ ๋ํด ์ด ๋ ๋ฒ ์ฝ๊ณ gradInput์ ๋ํด ํ ๋ฒ ์๋๋ค.)
๋ณด๋ค์ํผ, GPU ๋ฉ๋ชจ๋ฆฌ๋ฅผ ์ ์ฝํ๊ฑฐ๋ ์์ ์๋๋ฅผ ๋์ผ ์ ์๋ ๋ช ๊ฐ์ง ๋ฐฉ๋ฒ์ด ์์ต๋๋ค. ์ด์ GPU ํ์ฉ๊ณผ ๊ณ์ฐ ์๋์ ์ํฅ์ ์ฃผ๋ ๊ฒ์ด ๋ฌด์์ธ์ง๋ฅผ ์ดํดํ์ผ๋ฏ๋ก, Methods and tools for efficient training on a single GPU ๋ฌธ์ ํ์ด์ง๋ฅผ ์ฐธ์กฐํ์ฌ ์ฑ๋ฅ ์ต์ ํ ๊ธฐ๋ฒ์ ๋ํด ์์๋ณด์ธ์.