๋ค์ค GPU์์ ํจ์จ์ ์ธ ํ๋ จ [[efficient-training-on-multiple-gpus]]
๋จ์ผ GPU์์์ ํ๋ จ์ด ๋๋ฌด ๋๋ฆฌ๊ฑฐ๋ ๋ชจ๋ธ ๊ฐ์ค์น๊ฐ ๋จ์ผ GPU์ ๋ฉ๋ชจ๋ฆฌ์ ๋ง์ง ์๋ ๊ฒฝ์ฐ, ๋ค์ค-GPU ์ค์ ์ ์ฌ์ฉํฉ๋๋ค. ๋จ์ผ GPU์์ ๋ค์ค GPU๋ก ์ ํํ๊ธฐ ์ํด์๋ ์์ ์ ๋ถ์ฐํด์ผ ํฉ๋๋ค. ๋ฐ์ดํฐ, ํ ์ ๋๋ ํ์ดํ๋ผ์ธ๊ณผ ๊ฐ์ ๋ณ๋ ฌํ ๊ธฐ๋ฒ์ ์ฌ์ฉํ์ฌ ์์ ์ ๋ณ๋ ฌ๋ก ์ฒ๋ฆฌํ ์ ์์ต๋๋ค. ๊ทธ๋ฌ๋ ์ด๋ฌํ ์ค์ ์ ๋ชจ๋์๊ฒ ์ ์ฉํ ์ ์๋ ์๋ฒฝํ ํด๊ฒฐ์ฑ ์ ์์ผ๋ฉฐ, ์ด๋ค ์ค์ ์ด ๊ฐ์ฅ ์ ํฉํ์ง๋ ์ฌ์ฉํ๋ ํ๋์จ์ด์ ๋ฐ๋ผ ๋ฌ๋ผ์ง๋๋ค. ์ด ๋ฌธ์๋ ์ฃผ๋ก PyTorch ๊ธฐ๋ฐ์ ๊ตฌํ์ ์ค์ฌ์ผ๋ก ์ค๋ช ํ๋ฉฐ, ๋๋ถ๋ถ์ ๊ฐ๋ ์ ๋ค๋ฅธ ํ๋ ์์ํฌ์๋ ์ ์ฉ๋ ์ ์์ ๊ฒ์ผ๋ก ์์๋ฉ๋๋ค.
์ฐธ๊ณ : ๋จ์ผ GPU ์น์ ์์ ์๊ฐ๋ ์ ๋ต(ํผํฉ ์ ๋ฐ๋ ํ๋ จ ๋๋ ๊ทธ๋๋์ธํธ ๋์ ๋ฑ)์ ์ผ๋ฐ์ ์ผ๋ก ๋ชจ๋ธ ํ๋ จ์ ์ ์ฉ๋๋ฉฐ, ๋ค์ค-GPU ๋๋ CPU ํ๋ จ๊ณผ ๊ฐ์ ๋ค์ ์น์ ์ผ๋ก ์ง์ ํ๊ธฐ ์ ์ ํด๋น ์น์ ์ ์ฐธ๊ณ ํ๋ ๊ฒ์ด ์ข์ต๋๋ค.
๋จผ์ 1D ๋ณ๋ ฌํ ๊ธฐ์ ์ ๋ํด ์์ธํ ๋ ผ์ํ ํ, ์ด๋ฌํ ๊ธฐ์ ์ ๊ฒฐํฉํ์ฌ 2D ๋ฐ 3D ๋ณ๋ ฌํ๋ฅผ ๊ตฌํํ์ฌ ๋ ๋น ๋ฅธ ํ๋ จ๊ณผ ๋ ํฐ ๋ชจ๋ธ์ ์ง์ํ๋ ๋ฐฉ๋ฒ์ ์ดํด๋ณผ ๊ฒ์ ๋๋ค. ๋ํ ๋ค๋ฅธ ํจ๊ณผ์ ์ธ ๋์ ๋ฐฉ์๋ ์๊ฐ๋ ์์ ์ ๋๋ค.
๊ฐ๋ [[concepts]]
๋ค์์ ์ด ๋ฌธ์์์ ์์ธํ ์ค๋ช ๋ ์ฃผ์ ๊ฐ๋ ์ ๋ํ ๊ฐ๋จํ ์ค๋ช ์ ๋๋ค.
- DataParallel (DP) - ๋์ผํ ์ค์ ์ด ์ฌ๋ฌ ๋ฒ ๋ณต์ ๋๊ณ , ๊ฐ ์ค์ ์ ๋ฐ์ดํฐ ์ผ๋ถ๋ฅผ ๋ฐ์ต๋๋ค. ์ฒ๋ฆฌ๋ ๋ณ๋ ฌ๋ก ์ํ๋๋ฉฐ ๋ชจ๋ ์ค์ ์ ๊ฐ ํ๋ จ ๋จ๊ณ์ ๋๋ ๋ ๋๊ธฐํ๋ฉ๋๋ค.
- TensorParallel (TP) - ๊ฐ ํ ์๋ ์ฌ๋ฌ ๊ฐ์ ๋ฌถ์์ผ๋ก ๋ถํ ๋๊ธฐ์, ์ ์ฒด ํ ์๊ฐ ๋จ์ผ GPU์ ์์ฃผํ๋ ๋์ ํ ์์ ๊ฐ ์ค๋๊ฐ ์ง์ ๋ GPU์ ์์ฃผํฉ๋๋ค. ์ฒ๋ฆฌํ๋ ๋์ ๊ฐ ์ค๋๋ ์๋ก ๋ค๋ฅธ GPU์์ ๊ฐ๋ณ์ ์ผ๋ก ๋ณ๋ ฌ ์ฒ๋ฆฌ๋๋ฉฐ ๊ฒฐ๊ณผ๋ ๋จ๊ณ๊ฐ ๋๋ ๋ ๋๊ธฐํ๋ฉ๋๋ค. ๋ถํ ์ด ์ํ ์์ค์์ ์ด๋ฃจ์ด์ง๊ธฐ ๋๋ฌธ์ ์ด๋ฅผ ์ํ ๋ณ๋ ฌ ์ฒ๋ฆฌ๋ผ๊ณ ๋ถ๋ฅผ ์ ์์ต๋๋ค.
- PipelineParallel (PP) - ๋ชจ๋ธ์ด ์์ง์ผ๋ก (๋ ์ด์ด ์์ค) ์ฌ๋ฌ GPU์ ๋ถํ ๋์ด ๋ชจ๋ธ์ ๋จ์ผ GPU์๋ ํ๋ ๋๋ ์ฌ๋ฌ ๋ ์ด์ด๊ฐ ๋ฐฐ์น๋ฉ๋๋ค. ๊ฐ GPU๋ ํ์ดํ๋ผ์ธ์ ์๋ก ๋ค๋ฅธ ๋จ๊ณ๋ฅผ ๋ณ๋ ฌ๋ก ์ฒ๋ฆฌํ๋ฉฐ ์์ ๋ฐฐ์น ๋ฌถ์์์ ์๋ํฉ๋๋ค.
- Zero Redundancy Optimizer (ZeRO) - TP์ ์ ์ฌํ๊ฒ ํ ์๋ฅผ ์ค๋ฉํ์ง๋ง, ์ ์ฒด ํ ์๋ ์๋ฐฉํฅ ๋๋ ์ญ๋ฐฉํฅ ๊ณ์ฐ์ ์ํด ์ฌ๊ตฌ์ฑ๋๋ฏ๋ก ๋ชจ๋ธ์ ์์ ํ ํ์๊ฐ ์์ต๋๋ค. ๋ํ ์ ํ๋ GPU ๋ฉ๋ชจ๋ฆฌ๋ฅผ ๋ณด์ํ๊ธฐ ์ํด ๋ค์ํ ์คํ๋ก๋ ๊ธฐ์ ์ ์ง์ํฉ๋๋ค.
- Sharded DDP - ZeRO์ ๊ธฐ๋ณธ ๊ฐ๋ ์ผ๋ก ๋ค๋ฅธ ZeRO ๊ตฌํ์์๋ ์ฌ์ฉ๋๋ ์ฉ์ด์ ๋๋ค.
๊ฐ ๊ฐ๋ ์ ๊ตฌ์ฒด์ ์ธ ๋ด์ฉ์ ๋ํด ์์ธํ ๋ค์ด๊ฐ๊ธฐ ์ ์ ๋๊ท๋ชจ ์ธํ๋ผ์์ ๋๊ท๋ชจ ๋ชจ๋ธ์ ํ๋ จํ๋ ๊ฒฝ์ฐ์ ๋๋ต์ ์ธ ๊ฒฐ์ ๊ณผ์ ์ ์ดํด๋ณด๊ฒ ์ต๋๋ค.
ํ์ฅ์ฑ ์ ๋ต [[scalability-strategy]]
โจ ๋จ์ผ ๋ ธ๋ / ๋ค์ค-GPU
๋ชจ๋ธ์ด ๋จ์ผ GPU์ ๋ง๋ ๊ฒฝ์ฐ:
- DDP - ๋ถ์ฐ DP
- ZeRO - ์ํฉ๊ณผ ๊ตฌ์ฑ์ ๋ฐ๋ผ ๋ ๋น ๋ฅผ ์๋ ์๊ณ ๊ทธ๋ ์ง ์์ ์๋ ์์
๋ชจ๋ธ์ด ๋จ์ผ GPU์ ๋ง์ง ์๋ ๊ฒฝ์ฐ:
- PP
- ZeRO
- TP
๋ ธ๋ ๋ด ์ฐ๊ฒฐ ์๋๊ฐ ๋งค์ฐ ๋น ๋ฅธ NVLINK ๋๋ NVSwitch์ ๊ฒฝ์ฐ ์ธ ๊ฐ์ง ๋ฐฉ๋ฒ์ ๋๋ถ๋ถ ๋น์ทํ ์ฑ๋ฅ์ ๋ณด์ฌ์ผ ํ๋ฉฐ, PP๊ฐ ์๋ ๊ฒฝ์ฐ TP ๋๋ ZeRO๋ณด๋ค ๋น ๋ฅผ ๊ฒ์ ๋๋ค. TP์ ์ ๋๋ ์ฐจ์ด๋ฅผ ๋ง๋ค ์ ์์ต๋๋ค. ํน์ ์ค์ ์์ ์น์๋ฅผ ์ฐพ๊ธฐ ์ํด ์คํํ๋ ๊ฒ์ด ๊ฐ์ฅ ์ข์ต๋๋ค.
TP๋ ๊ฑฐ์ ํญ์ ๋จ์ผ ๋ ธ๋ ๋ด์์ ์ฌ์ฉ๋ฉ๋๋ค. ์ฆ, TP ํฌ๊ธฐ <= ๋ ธ๋๋น GPU ์์ ๋๋ค.
๊ฐ์ฅ ํฐ ๋ ์ด์ด๊ฐ ๋จ์ผ GPU์ ๋ง์ง ์๋ ๊ฒฝ์ฐ:
- ZeRO๋ฅผ ์ฌ์ฉํ์ง ์๋ ๊ฒฝ์ฐ - PP๋ง์ผ๋ก๋ ๋ง์ง ์์ผ๋ฏ๋ก TP๋ฅผ ๋ฐ๋์ ์ฌ์ฉํด์ผ ํจ
- ZeRO๋ฅผ ์ฌ์ฉํ๋ ๊ฒฝ์ฐ์๋ ์์ "๋จ์ผ GPU" ํญ๋ชฉ๊ณผ ๋์ผ
โจ ๋ค์ค ๋ ธ๋ / ๋ค์ค GPU
๋ ธ๋ ๊ฐ ์ฐ๊ฒฐ ์๋๊ฐ ๋น ๋ฅธ ๊ฒฝ์ฐ:
- ZeRO - ๋ชจ๋ธ์ ๋๋ถ๋ถ์ ์์ ์ ํ์๋ก ํ์ง ์์
- PP+TP+DP - ํต์ ์ด ์ ์ง๋ง ๋ชจ๋ธ์ ๋๋์ ์ธ ๋ณ๊ฒฝ์ด ํ์ํจ
๋ ธ๋ ๊ฐ ์ฐ๊ฒฐ ์๋๊ฐ ๋๋ฆฌ๋ฉฐ, GPU ๋ฉ๋ชจ๋ฆฌ๊ฐ ์ฌ์ ํ ๋ถ์กฑํ ๊ฒฝ์ฐ:
- DP+PP+TP+ZeRO-1
๋ฐ์ดํฐ ๋ณ๋ ฌํ [[data-parallelism]]
2๊ฐ์ GPU๋ง์ผ๋ก๋ ๋๋ถ๋ถ์ ์ฌ์ฉ์๋ค์ DataParallel (DP)๊ณผ DistributedDataParallel (DDP)์ ํตํด ํฅ์๋ ํ๋ จ ์๋๋ฅผ ๋๋ฆด ์ ์์ต๋๋ค. ์ด๋ PyTorch์ ๋ด์ฅ ๊ธฐ๋ฅ์
๋๋ค. ์ผ๋ฐ์ ์ผ๋ก DDP๋ฅผ ์ฌ์ฉํ๋ ๊ฒ์ด ์ข์ผ๋ฉฐ, DP๋ ์ผ๋ถ ๋ชจ๋ธ์์ ์๋ํ์ง ์์ ์ ์์ผ๋ฏ๋ก ์ฃผ์ํด์ผ ํฉ๋๋ค. PyTorch ๋ฌธ์์์๋ DDP์ ์ฌ์ฉ์ ๊ถ์ฅํฉ๋๋ค.
DP vs DDP [[dp-vs-ddp]]
DistributedDataParallel (DDP)์ ์ผ๋ฐ์ ์ผ๋ก DataParallel (DP)๋ณด๋ค ๋น ๋ฅด์ง๋ง, ํญ์ ๊ทธ๋ ์ง๋ ์์ต๋๋ค:
- DP๋ ํ์ด์ฌ ์ค๋ ๋ ๊ธฐ๋ฐ์ธ ๋ฐ๋ฉด, DDP๋ ๋ค์ค ํ๋ก์ธ์ค ๊ธฐ๋ฐ์ด๊ธฐ ๋๋ฌธ์ GIL๊ณผ ๊ฐ์ ํ์ด์ฌ ์ค๋ ๋ ์ ํ์ด ์์ต๋๋ค.
- ๊ทธ๋ฌ๋ GPU ์นด๋ ๊ฐ์ ๋๋ฆฐ ์ํธ ์ฐ๊ฒฐ์ฑ์ DDP๋ก ์ธํด ์ค์ ๋ก ๋๋ฆฐ ๊ฒฐ๊ณผ๋ฅผ ๋ผ ์ ์์ต๋๋ค.
์ด ๋ ๋ชจ๋ ๊ฐ์ GPU ๊ฐ ํต์ ์ค๋ฒํค๋์ ์ฃผ์ ์ฐจ์ด์ ์ ๋ค์๊ณผ ๊ฐ์ต๋๋ค:
DDP:
- ์์ํ ๋, ์ฃผ ํ๋ก์ธ์ค๊ฐ ๋ชจ๋ธ์ gpu 0์์ ๋ค๋ฅธ ๋ชจ๋ gpu๋ก ๋ณต์ ํฉ๋๋ค.
- ๊ทธ๋ฐ ๋ค์ ๊ฐ ๋ฐฐ์น์ ๋ํด:
- ๊ฐ gpu๋ ์์ฒด ๋ฏธ๋ ๋ฐฐ์น ๋ฐ์ดํฐ๋ฅผ ์ง์ ์ฌ์ฉํฉ๋๋ค.
backward๋์ ๋ก์ปฌ ๊ทธ๋๋์ธํธ๊ฐ ์ค๋น๋๋ฉด, ๋ชจ๋ ํ๋ก์ธ์ค์ ํ๊ท ํ๋ฉ๋๋ค.
DP:
๊ฐ ๋ฐฐ์น์ ๋ํด:
- gpu 0์ ๋ฐ์ดํฐ ๋ฐฐ์น๋ฅผ ์ฝ๊ณ ๊ฐ gpu์ ๋ฏธ๋ ๋ฐฐ์น๋ฅผ ๋ณด๋ ๋๋ค.
- ์ ๋ฐ์ดํธ๋ ๋ชจ๋ธ์ gpu 0์์ ๊ฐ gpu๋ก ๋ณต์ ํฉ๋๋ค.
forward๋ฅผ ์คํํ๊ณ ๊ฐ gpu์ ์ถ๋ ฅ์ gpu 0์ผ๋ก ๋ณด๋ด๊ณ ์์ค์ ๊ณ์ฐํฉ๋๋ค.- gpu 0์์ ๋ชจ๋ gpu๋ก ์์ค์ ๋ถ์ฐํ๊ณ
backward๋ฅผ ์คํํฉ๋๋ค. - ๊ฐ gpu์์ ๊ทธ๋๋์ธํธ๋ฅผ gpu 0์ผ๋ก ๋ณด๋ด๊ณ ์ด๋ฅผ ํ๊ท ํํฉ๋๋ค.
DDP๋ ๊ฐ ๋ฐฐ์น๋ง๋ค ๊ทธ๋๋์ธํธ๋ฅผ ๋ณด๋ด๋ ํต์ ๋ง์ ์ํํ๋ฉฐ, DP๋ ๋ฐฐ์น๋ง๋ค 5๊ฐ์ ๋ค๋ฅธ ๋ฐ์ดํฐ ๊ตํ์ ์ํํฉ๋๋ค.
DP๋ ํ์ด์ฌ ์ค๋ ๋๋ฅผ ํตํด ํ๋ก์ธ์ค ๋ด์์ ๋ฐ์ดํฐ๋ฅผ ๋ณต์ ํ๋ฉฐ, DDP๋ torch.distributed๋ฅผ ํตํด ๋ฐ์ดํฐ๋ฅผ ๋ณต์ ํฉ๋๋ค.
DP์์๋ gpu 0์ด ๋ค๋ฅธ gpu๋ณด๋ค ํจ์ฌ ๋ ๋ง์ ์์ ์ ์ํํ๋ฏ๋ก, gpu์ ํ์ฉ๋๊ฐ ๋ฎ์์ง๋๋ค.
DDP๋ ์ฌ๋ฌ ๋์ ์ปดํจํฐ์์ ์ฌ์ฉํ ์ ์์ง๋ง, DP์ ๊ฒฝ์ฐ๋ ๊ทธ๋ ์ง ์์ต๋๋ค.
DP์ DDP ์ฌ์ด์๋ ๋ค๋ฅธ ์ฐจ์ด์ ์ด ์์ง๋ง, ์ด ํ ๋ก ๊ณผ๋ ๊ด๋ จ์ด ์์ต๋๋ค.
์ด 2๊ฐ์ง ๋ชจ๋๋ฅผ ๊น๊ฒ ์ดํดํ๊ณ ์ถ๋ค๋ฉด, ์ด ๋ฌธ์๋ฅผ ๊ฐ๋ ฅํ ์ถ์ฒํฉ๋๋ค. ์ด ๋ฌธ์๋ ๋ฉ์ง ๋ค์ด์ด๊ทธ๋จ์ ํฌํจํ๊ณ ์์ผ๋ฉฐ, ๋ค์ํ ํ๋์จ์ด์์ ์ฌ๋ฌ ๋ฒค์น๋งํฌ์ ํ๋กํ์ผ๋ฌ ์ถ๋ ฅ์ ์ค๋ช ํ์ฌ ํ์ํ ์ธ๋ถ ์ฌํญ์ ๋ชจ๋ ์ค๋ช ํฉ๋๋ค.
์ค์ ๋ฒค์น๋งํฌ๋ฅผ ์ดํด๋ณด๊ฒ ์ต๋๋ค:
| Type | NVlink | Time |
|---|---|---|
| 2:DP | Y | 110s |
| 2:DDP | Y | 101s |
| 2:DDP | N | 131s |
๋ถ์:
์ฌ๊ธฐ์ DP๋ NVlink๊ฐ ์๋ DDP๋ณด๋ค ์ฝ 10% ๋๋ฆฝ๋๋ค. ๊ทธ๋ฌ๋ NVlink๊ฐ ์๋ DDP๋ณด๋ค ์ฝ 15% ๋น ๋ฆ ๋๋ค.
์ค์ ์ฐจ์ด๋ ๊ฐ GPU๊ฐ ๋ค๋ฅธ GPU์ ๋๊ธฐํํด์ผ ํ๋ ๋ฐ์ดํฐ ์์ ๋ฐ๋ผ ๋ฌ๋ผ์ง ๊ฒ์ ๋๋ค. ๋๊ธฐํํ ๋ฐ์ดํฐ๊ฐ ๋ง์์๋ก ๋๋ฆฐ ๋งํฌ๊ฐ ์ด ์คํ ์๊ฐ์ ๋ฆ์ถ ์ ์์ต๋๋ค.
๋ค์์ ์ ์ฒด ๋ฒค์น๋งํฌ ์ฝ๋์ ์ถ๋ ฅ์ ๋๋ค:
ํด๋น ๋ฒค์น๋งํฌ์์ NCCL_P2P_DISABLE=1์ ์ฌ์ฉํ์ฌ NVLink ๊ธฐ๋ฅ์ ๋นํ์ฑํํ์ต๋๋ค.
# DP
rm -r /tmp/test-clm; CUDA_VISIBLE_DEVICES=0,1 \
python examples/pytorch/language-modeling/run_clm.py \
--model_name_or_path openai-community/gpt2 --dataset_name wikitext --dataset_config_name wikitext-2-raw-v1 \
--do_train --output_dir /tmp/test-clm --per_device_train_batch_size 4 --max_steps 200
{'train_runtime': 110.5948, 'train_samples_per_second': 1.808, 'epoch': 0.69}
# DDP w/ NVlink
rm -r /tmp/test-clm; CUDA_VISIBLE_DEVICES=0,1 \
torchrun --nproc_per_node 2 examples/pytorch/language-modeling/run_clm.py \
--model_name_or_path openai-community/gpt2 --dataset_name wikitext --dataset_config_name wikitext-2-raw-v1 \
--do_train --output_dir /tmp/test-clm --per_device_train_batch_size 4 --max_steps 200
{'train_runtime': 101.9003, 'train_samples_per_second': 1.963, 'epoch': 0.69}
# DDP w/o NVlink
rm -r /tmp/test-clm; NCCL_P2P_DISABLE=1 CUDA_VISIBLE_DEVICES=0,1 \
torchrun --nproc_per_node 2 examples/pytorch/language-modeling/run_clm.py \
--model_name_or_path openai-community/gpt2 --dataset_name wikitext --dataset_config_name wikitext-2-raw-v1 \
--do_train --output_dir /tmp/test-clm --per_device_train_batch_size 4 --max_steps 200
{'train_runtime': 131.4367, 'train_samples_per_second': 1.522, 'epoch': 0.69}
ํ๋์จ์ด: ๊ฐ๊ฐ 24GB์ TITAN RTX 2๊ฐ + NVlink๊ณผ 2๊ฐ์ NVLink (nvidia-smi topo -m์์ NV2์
๋๋ค.)
์ํํธ์จ์ด: pytorch-1.8-to-be + cuda-11.0 / transformers==4.3.0.dev0
ZeRO ๋ฐ์ดํฐ ๋ณ๋ ฌํ [[zero-data-parallelism]]
ZeRO๋ฅผ ๊ธฐ๋ฐ์ผ๋ก ํ ๋ฐ์ดํฐ ๋ณ๋ ฌํ (ZeRO-DP)๋ ๋ค์ ๋ธ๋ก๊ทธ ๊ธ์ ๋ค์ ๋ค์ด์ด๊ทธ๋จ์์ ์ค๋ช
๋๊ณ ์์ต๋๋ค.

์ด ๊ฐ๋
์ ์ดํดํ๊ธฐ ์ด๋ ค์ธ ์ ์์ง๋ง, ์ค์ ๋ก๋ ๋งค์ฐ ๊ฐ๋จํ ๊ฐ๋
์
๋๋ค. ์ด๋ ์ผ๋ฐ์ ์ธ DataParallel (DP)๊ณผ ๋์ผํ์ง๋ง, ์ ์ฒด ๋ชจ๋ธ ๋งค๊ฐ๋ณ์, ๊ทธ๋๋์ธํธ ๋ฐ ์ตํฐ๋ง์ด์ ์ํ๋ฅผ ๋ณต์ ํ๋ ๋์ ๊ฐ GPU๋ ๊ทธ ์ค ์ผ๋ถ๋ง ์ ์ฅํฉ๋๋ค. ๊ทธ๋ฆฌ๊ณ ์คํ ์๊ฐ์๋ ์ฃผ์ด์ง ๋ ์ด์ด์ ๋ํด ์ ์ฒด ๋ ์ด์ด ๋งค๊ฐ๋ณ์๊ฐ ํ์ํ ๋ ๊ฐ GPU๊ฐ ์๋ก์๊ฒ ํ์ํ ๋ถ๋ถ์ ์ ๊ณตํ๊ธฐ ์ํด ๋๊ธฐํ๋ฉ๋๋ค - ๊ทธ๊ฒ ์ ๋ถ์
๋๋ค.
๊ฐ๊ฐ 3๊ฐ์ ๋ ์ด์ด์ 3๊ฐ์ ๋งค๊ฐ๋ณ์๊ฐ ์๋ ๊ฐ๋จํ ๋ชจ๋ธ์ ์๊ฐํด ๋ด ์๋ค:
La | Lb | Lc
---|----|---
a0 | b0 | c0
a1 | b1 | c1
a2 | b2 | c2
๋ ์ด์ด La์๋ ๊ฐ์ค์น a0, a1 ๋ฐ a2๊ฐ ์์ต๋๋ค.
3๊ฐ์ GPU๊ฐ ์๋ ๊ฒฝ์ฐ, Sharded DDP (= Zero-DP)๋ ๋ค์๊ณผ ๊ฐ์ด ๋ชจ๋ธ์ 3๊ฐ์ GPU์ ๋ถํ ํฉ๋๋ค:
GPU0:
La | Lb | Lc
---|----|---
a0 | b0 | c0
GPU1:
La | Lb | Lc
---|----|---
a1 | b1 | c1
GPU2:
La | Lb | Lc
---|----|---
a2 | b2 | c2
์ผ๋ฐ์ ์ธ DNN ๋ค์ด์ด๊ทธ๋จ์ ์์ํด๋ณด๋ฉด ์ด๋ ํ ์ ๋ณ๋ ฌ ์ฒ๋ฆฌ์ ๊ฐ์ ์ํ ์ฌ๋ผ์ด์ฑ์ ๋๋ค. ์์ง ์ฌ๋ผ์ด์ฑ์ ์ ์ฒด ๋ ์ด์ด ๊ทธ๋ฃน์ ๋ค๋ฅธ GPU์ ๋ฐฐ์นํ๋ ๊ฒ์ ๋๋ค. ์ด๋ ์์์ ๋ถ๊ณผํฉ๋๋ค.
์ด์ ์ด๋ฌํ ๊ฐ๊ฐ์ GPU๋ DP์์ ์๋ํ๋ ๊ฒ๊ณผ ๋ง์ฐฌ๊ฐ์ง๋ก ์ผ๋ฐ์ ์ธ ๋ฏธ๋ ๋ฐฐ์น๋ฅผ ๋ฐ์ต๋๋ค:
x0 => GPU0
x1 => GPU1
x2 => GPU2
์ ๋ ฅ์ ์์ ๋์ง ์์ ์ํ๋ก ์ผ๋ฐ ๋ชจ๋ธ์ ์ํด ์ฒ๋ฆฌ๋ ๊ฒ์ผ๋ก ๊ฐ์ฃผํฉ๋๋ค.
๋จผ์ , ์ ๋ ฅ์ ๋ ์ด์ด La์ ๋๋ฌํฉ๋๋ค.
GPU0์๋ง ์ง์คํด ๋ณด๊ฒ ์ต๋๋ค. x0์ ์๋ฐฉํฅ ๊ฒฝ๋ก๋ฅผ ์ํํ๊ธฐ ์ํด a0, a1, a2 ํ๋ผ๋ฏธํฐ๊ฐ ํ์ํ์ง๋ง GPU0์๋ a0๋ง ์์ต๋๋ค. GPU1์์ a1์, GPU2์์ a2๋ฅผ ์ ์ก๋ฐ์ ๋ชจ๋ธ์ ๋ชจ๋ ์กฐ๊ฐ์ ํ๋๋ก ๋ชจ์๋๋ค.
๋ณ๋ ฌ์ ์ผ๋ก, GPU1์ ๋ฏธ๋ ๋ฐฐ์น x1์ ๋ฐ๊ณ a1๋ง ๊ฐ์ง๊ณ ์์ง๋ง, a0 ๋ฐ a2 ๋งค๊ฐ๋ณ์๊ฐ ํ์ํฉ๋๋ค. ๋ฐ๋ผ์ GPU0 ๋ฐ GPU2์์ ์ด๋ฅผ ๊ฐ์ ธ์ต๋๋ค.
GPU2๋ ๋์ผํ ์์ ์ ์ํํฉ๋๋ค. ์ ๋ ฅ x2๋ฅผ ๋ฐ๊ณ GPU0 ๋ฐ GPU1์์ ๊ฐ๊ฐ a0๊ณผ a1์, ๊ทธ๋ฆฌ๊ณ ์์ ์ a2์ ํจ๊ป ์ ์ฒด ํ ์๋ฅผ ๋ณต์ํฉ๋๋ค.
3๊ฐ์ GPU๋ ๋ณต์๋ ์ ์ฒด ํ ์๋ฅผ ๋ฐ๊ณ forward๊ฐ ์ํ๋ฉ๋๋ค.
๊ณ์ฐ์ด ์๋ฃ๋๋ฉด ๋ ์ด์ ํ์ํ์ง ์์ ๋ฐ์ดํฐ๋ ์ญ์ ๋๊ณ , ํด๋น ๋ฐ์ดํฐ๋ ๊ณ์ฐ ์ค์๋ง ์ฌ์ฉ๋ฉ๋๋ค. ๋ณต์์ ์ฌ์ ํจ์น๋ฅผ ํตํด ํจ์จ์ ์ผ๋ก ์ํ๋ฉ๋๋ค.
๊ทธ๋ฆฌ๊ณ ์ ์ฒด ํ๋ก์ธ์ค๋ ๋ ์ด์ด Lb์ ๋ํด ๋ฐ๋ณต๋๊ณ , ๊ทธ ๋ค์ Lc๋ก ์๋ฐฉํฅ์ผ๋ก, ๊ทธ๋ค์์ ์ญ๋ฐฉํฅ์ผ๋ก Lc -> Lb -> La๋ก ๋ฐ๋ณต๋ฉ๋๋ค.
๊ฐ์ธ์ ์ผ๋ก ์ด๊ฒ์ ํจ์จ์ ์ธ ๊ทธ๋ฃน ๋ฐฐ๋ญ ์ฌํ์์ ์ค๋ ๋ถ๋ฐฐ ์ ๋ต์ฒ๋ผ ๋ค๋ฆฝ๋๋ค:
- ์ฌ๋ A๊ฐ ํ ํธ๋ฅผ ์ด๋ฐํฉ๋๋ค.
- ์ฌ๋ B๊ฐ ๋๋ก๋ฅผ ์ด๋ฐํฉ๋๋ค.
- ์ฌ๋ C๊ฐ ๋๋ผ๋ฅผ ์ด๋ฐํฉ๋๋ค.
์ด์ ๋งค์ผ ๋ฐค ๊ฐ์ ๊ฐ์ง ๊ฒ์ ๋ค๋ฅธ ์ฌ๋๋ค๊ณผ ๊ณต์ ํ๊ณ , ๊ฐ์ง์ง ์์ ๊ฒ์ ๋ค๋ฅธ ์ฌ๋๋ค๋ก๋ถํฐ ๋ฐ๊ณ , ์์นจ์๋ ํ ๋น๋ ์ ํ์ ์ฅ๋น๋ฅผ ์ธ๊ณ ๊ณ์ํด์ ์ฌํ์ ์งํํฉ๋๋ค. ์ด๊ฒ์ด Sharded DDP / Zero DP์ ๋๋ค.
์ด ์ ๋ต์ ๊ฐ๊ฐ ์์ ์ ํ ํธ, ๋๋ก ๋ฐ ๋๋ผ๋ฅผ ๊ฐ๋ณ์ ์ผ๋ก ์ด๋ฐํด์ผ ํ๋ ๋จ์ํ ์ ๋ต๊ณผ ๋น๊ตํด๋ณด๋ฉด ํจ์ฌ ๋นํจ์จ์ ์ผ ๊ฒ์ ๋๋ค. ์ด๊ฒ์ด Pytorch์ DataParallel (DP ๋ฐ DDP)์ ๋๋ค.
์ด ์ฃผ์ ์ ๋ํด ๋ ผ๋ฌธ์ ์ฝ์ ๋ ๋ค์ ๋์์ด๋ฅผ ๋ง๋ ์ ์์ต๋๋ค: Sharded, Partitioned.
ZeRO๊ฐ ๋ชจ๋ธ ๊ฐ์ค์น๋ฅผ ๋ถํ ํ๋ ๋ฐฉ์์ ์์ธํ ์ดํด๋ณด๋ฉด, ํ ์ ๋ณ๋ ฌํ์ ๋งค์ฐ ์ ์ฌํ ๊ฒ์ ์ ์ ์์ต๋๋ค. ์ด๋ ์ดํ์ ์ค๋ช ๋ ์์ง ๋ชจ๋ธ ๋ณ๋ ฌํ์๋ ๋ฌ๋ฆฌ ๊ฐ ๋ ์ด์ด์ ๊ฐ์ค์น๋ฅผ ๋ถํ /๋ถํ ํ๊ธฐ ๋๋ฌธ์ ๋๋ค.
๊ตฌํ:
- DeepSpeed๋ 1๋จ๊ณ + 2๋จ๊ณ + 3๋จ๊ณ์ ZeRO-DP๋ฅผ ์ ๊ณตํฉ๋๋ค.
- Fairscale์ 1๋จ๊ณ + 2๋จ๊ณ + 3๋จ๊ณ์ ZeRO-DP๋ฅผ ์ ๊ณตํฉ๋๋ค.
transformersํตํฉ
๋ค์ดํฐ๋ธ ๋ชจ๋ธ ๋ณ๋ ฌ ์ฒ๋ฆฌ(์์ง์ ) ๋ฐ ํ์ดํ๋ผ์ธ ๋ณ๋ ฌ ์ฒ๋ฆฌ[[naive-model-parallelism-vertical-and-pipeline-parallelism]]
Naive Model Parallelism (MP)์ ๋ชจ๋ธ ๋ ์ด์ด ๊ทธ๋ฃน์ ๋ค์ค GPU์ ๋ถ์ฐํ๋ ๋ฐฉ์์
๋๋ค. ๋ฉ์ปค๋์ฆ์ ์๋์ ์ผ๋ก ๊ฐ๋จํฉ๋๋ค. ์ํ๋ ๋ ์ด์ด๋ฅผ .to()๋ฅผ ์ฌ์ฉํ์ฌ ์ํ๋ ์ฅ์น๋ก ์ ํํ๋ฉด ๋ฐ์ดํฐ๊ฐ ํด๋น ๋ ์ด์ด๋ก ๋ค์ด์ค๊ณ ๋๊ฐ ๋ ๋ฐ์ดํฐ๋ ๋ ์ด์ด์ ๋์ผํ ์ฅ์น๋ก ์ ํ๋๊ณ ๋๋จธ์ง๋ ์์ ๋์ง ์์ต๋๋ค.
๋๋ถ๋ถ์ ๋ชจ๋ธ์ด ๊ทธ๋ ค์ง๋ ๋ฐฉ์์ด ๋ ์ด์ด๋ฅผ ์ธ๋ก๋ก ์ฌ๋ผ์ด์คํ๊ธฐ ๋๋ฌธ์ ์ด๋ฅผ ์์ง ๋ชจ๋ธ ๋ณ๋ ฌํ๋ผ๊ณ ๋ถ๋ฆ ๋๋ค. ์๋ฅผ ๋ค์ด ๋ค์ ๋ค์ด์ด๊ทธ๋จ์ 8๋ ์ด์ด ๋ชจ๋ธ์ ๋ณด์ฌ์ค๋๋ค:
=================== ===================
| 0 | 1 | 2 | 3 | | 4 | 5 | 6 | 7 |
=================== ===================
gpu0 gpu1
์ฐ๋ฆฌ๋ ๋ชจ๋ธ์ ์์ง์ผ๋ก 2๊ฐ๋ก ๋ถํ ํ์ฌ ๋ ์ด์ด 0-3์ GPU0์ ๋ฐฐ์นํ๊ณ ๋ ์ด์ด 4-7์ GPU1์ ๋ฐฐ์นํ์ต๋๋ค.
์ด์ ๋ฐ์ดํฐ๊ฐ ๋ ์ด์ด 0์์ 1๋ก, 1์์ 2๋ก, 2์์ 3์ผ๋ก ์ด๋ํ๋ ๋์์๋ ์ผ๋ฐ์ ์ธ ๋ชจ๋ธ์ ๋๋ค. ๊ทธ๋ฌ๋ ๋ฐ์ดํฐ๊ฐ ๋ ์ด์ด 3์์ ๋ ์ด์ด 4๋ก ์ ๋ฌ๋์ด์ผ ํ ๋๋ GPU0์์ GPU1๋ก ์ด๋ํด์ผ ํ๋ฏ๋ก ํต์ ์ค๋ฒํค๋๊ฐ ๋ฐ์ํฉ๋๋ค. ์ฐธ์ฌํ๋ GPU๊ฐ ๋์ผํ ์ปดํจํ ๋ ธ๋(์: ๋์ผํ ๋ฌผ๋ฆฌ์ ์ธ ๊ธฐ๊ณ)์ ์๋ ๊ฒฝ์ฐ ์ด ๋ณต์ฌ๋ ๋งค์ฐ ๋น ๋ฆ ๋๋ค. ๊ทธ๋ฌ๋ GPU๊ฐ ์๋ก ๋ค๋ฅธ ์ปดํจํ ๋ ธ๋(์: ์ฌ๋ฌ ๊ธฐ๊ณ)์ ์์นํ ๊ฒฝ์ฐ ํต์ ์ค๋ฒํค๋๋ ์๋นํ ํฌ๊ฒ ๋ ์ ์์ต๋๋ค.
๊ทธ๋ฐ ๋ค์ ๋ ์ด์ด 4๋ถํฐ 5๋ก, 6์ผ๋ก, 7๋ก ์งํ๋๋ ๊ฒ์ ์ผ๋ฐ์ ์ธ ๋ชจ๋ธ๊ณผ ๋์ผํ๊ฒ ์งํ๋๊ณ , 7๋ฒ์งธ ๋ ์ด์ด๊ฐ ์๋ฃ๋๋ฉด ๋ฐ์ดํฐ๋ฅผ ๋ค์ ๋ ์ด์ด 0์ผ๋ก ๋ณด๋ด๊ฑฐ๋ ๋๋ ๋ ์ด๋ธ์ ๋ง์ง๋ง ๋ ์ด์ด๋ก ๋ณด๋ด์ผ ํ ํ์๊ฐ ์์ต๋๋ค. ์ด์ ์์ค์ ๊ณ์ฐํ๊ณ ์ตํฐ๋ง์ด์ ๊ฐ ์๋ํ ์ ์์ต๋๋ค.
๋ฌธ์ ์ :
- ์ด ๋ฐฉ์์ "naive" MP๋ผ๊ณ ๋ถ๋ฅด๋ ์ด์ ๋ ์ฃผ์ด์ง ์ํฉ์ ํ๋์ GPU๋ฅผ ์ ์ธํ ๋ชจ๋ GPU๊ฐ ์ ํด ์ํ๋ผ๋ ์ ์ ๋๋ค. ๋ฐ๋ผ์ 4๊ฐ์ GPU๋ฅผ ์ฌ์ฉํ๋ ๊ฒฝ์ฐ ๋จ์ผ GPU์ ๋ฉ๋ชจ๋ฆฌ ์์ 4๋ฐฐ๋ก ๋๋ฆฌ๊ณ ๋๋จธ์ง ํ๋์จ์ด๋ ๋ฌด์ํ๋ ๊ฒ๊ณผ ๊ฑฐ์ ๋์ผํฉ๋๋ค. ๋ํ ์ฅ์น ๊ฐ ๋ฐ์ดํฐ ๋ณต์ฌ์ ์ค๋ฒํค๋๋ ์์ต๋๋ค. ๋ฐ๋ผ์ 4๊ฐ์ 6GB ์นด๋๋ naive MP๋ฅผ ์ฌ์ฉํ์ฌ 1๊ฐ์ 24GB ์นด๋์ ๋์ผํ ํฌ๊ธฐ๋ฅผ ์์ฉํ ์ ์์ง๋ง, ํ์๋ ๋ฐ์ดํฐ ๋ณต์ฌ์ ์ค๋ฒํค๋๊ฐ ์์ผ๋ฏ๋ก ํ๋ จ์ ๋ ๋นจ๋ฆฌ ์๋ฃํฉ๋๋ค. ๊ทธ๋ฌ๋ ์๋ฅผ ๋ค์ด 40GB ์นด๋๊ฐ ์๊ณ 45GB ๋ชจ๋ธ์ ๋ง์ถ์ด์ผ ํ ๊ฒฝ์ฐ 4๊ฐ์ 40GB ์นด๋๋ก ๋ง์ถ ์ ์์ต๋๋ค (ํ์ง๋ง ๊ทธ๋๋์ธํธ์ ์ตํฐ๋ง์ด์ ์ํ ๋๋ฌธ์ ๊ฐ๊น์ค๋ก ๊ฐ๋ฅํฉ๋๋ค).
- ๊ณต์ ์๋ฒ ๋ฉ์ GPU ๊ฐ์ ๋ณต์ฌํด์ผ ํ ์๋ ์์ต๋๋ค.
ํ์ดํ๋ผ์ธ ๋ณ๋ ฌํ (PP)์ ๊ฑฐ์ naive MP์ ๋์ผํ์ง๋ง GPU ์ ํด ์ํ ๋ฌธ์ ๋ฅผ ํด๊ฒฐํ๊ธฐ ์ํด ๋ค์ด์ค๋ ๋ฐฐ์น๋ฅผ ๋ง์ดํฌ๋ก ๋ฐฐ์น๋ก ๋๋๊ณ ์ธ๊ณต์ ์ผ๋ก ํ์ดํ๋ผ์ธ์ ์์ฑํ์ฌ ์๋ก ๋ค๋ฅธ GPU๊ฐ ๋์์ ๊ณ์ฐ์ ์ฐธ์ฌํ ์ ์๊ฒ ํฉ๋๋ค.
GPipe ๋ ผ๋ฌธ์์ ๊ฐ์ ธ์จ ๊ทธ๋ฆผ์ ์๋จ์ naive MP๋ฅผ, ํ๋จ์๋ PP๋ฅผ ๋ณด์ฌ์ค๋๋ค:
ํ๋จ ๋ค์ด์ด๊ทธ๋จ์์ PP๊ฐ ์ ํด ์์ญ์ด ์ ์ ๊ฒ์ ์ฝ๊ฒ ๋ณผ ์ ์์ต๋๋ค. ์ ํด ๋ถ๋ถ์ "bubble"์ด๋ผ๊ณ ํฉ๋๋ค.
๋ค์ด์ด๊ทธ๋จ์ ์์ชฝ ๋ถ๋ถ์ ์ฐธ์ฌํ๋ GPU๊ฐ 4๊ฐ์ธ ๋ณ๋ ฌ์ฑ์ ๋ณด์ฌ์ค๋๋ค. ์ฆ, 4๊ฐ์ GPU๊ฐ ํ์ดํ๋ผ์ธ์ ์ฐธ์ฌํฉ๋๋ค. ๋ฐ๋ผ์ 4๊ฐ์ ํ์ดํ ๋จ๊ณ F0, F1, F2 ๋ฐ F3์ ์๋ฐฉํฅ ๊ฒฝ๋ก์ B3, B2, B1 ๋ฐ B0์ ์ญ๋ฐฉํฅ ๊ฒฝ๋ก๊ฐ ์์ต๋๋ค.
PP๋ ์กฐ์ ํด์ผ ํ ์๋ก์ด ํ์ดํผํ๋ผ๋ฏธํฐ์ธ chunks๋ฅผ ๋์
ํฉ๋๋ค. ์ด๋ ๋์ผํ ํ์ดํ ๋จ๊ณ๋ฅผ ํตํด ์ผ๋ จ์ ๋ฐ์ดํฐ๋ฅผ ๋ฌถ์ด์ ๋ณด๋ด๋ ๋ฐฉ์์ ์ ์ํฉ๋๋ค. ์๋ฅผ ๋ค์ด, ์๋ ๋ค์ด์ด๊ทธ๋จ์์ chunks=4๋ฅผ ๋ณผ ์ ์์ต๋๋ค. GPU0์ 0, 1, 2 ๋ฐ 3 (F0,0, F0,1, F0,2, F0,3) ๋ฌถ์์์ ๋์ผํ ์๋ฐฉํฅ ๊ฒฝ๋ก๋ฅผ ์ํํ๊ณ , ๋ค๋ฅธ GPU๊ฐ ์์
์ ์ํํ๊ธฐ ์์ํ๊ณ ์๋ฃ๊ฐ ์์๋ ๋๋ง GPU0์ด ๋ฌถ์์ ์ญ์์ผ๋ก 3, 2, 1 ๋ฐ 0 (B0,3, B0,2, B0,1, B0,0) ๊ฒฝ๋ก๋ฅผ ์ํํฉ๋๋ค.
๊ฐ๋
์ ์ผ๋ก ์ด๋ ๊ทธ๋๋์ธํธ ๋์ ๋จ๊ณ (GAS)์ ๋์ผํ ๊ฐ๋
์
๋๋ค. ํ์ดํ ์น์์๋ chunks๋ฅผ ์ฌ์ฉํ๊ณ DeepSpeed์์๋ ๋์ผํ ํ์ดํผํ๋ผ๋ฏธํฐ๋ฅผ GAS๋ก ์ฐธ์กฐํฉ๋๋ค.
๋ฌถ์์ผ๋ก ์ธํด PP๋ ๋ง์ดํฌ๋ก ๋ฐฐ์น (MBS)์ ๊ฐ๋
์ ๋์
ํฉ๋๋ค. DP๋ ์ ์ญ ๋ฐ์ดํฐ ๋ฐฐ์น ํฌ๊ธฐ๋ฅผ ๋ฏธ๋ ๋ฐฐ์น๋ก ๋๋๋๋ค. ๋ฐ๋ผ์ DP ์ฐจ์๊ฐ 4์ด๊ณ ์ ์ญ ๋ฐฐ์น ํฌ๊ธฐ๊ฐ 1024์ด๋ฉด 256์ฉ 4๊ฐ์ ๋ฏธ๋ ๋ฐฐ์น๋ก ๋ถํ ๋ฉ๋๋ค (1024/4). ๊ทธ๋ฆฌ๊ณ chunks (๋๋ GAS)์ ์๊ฐ 32์ด๋ฉด ๋ง์ดํฌ๋ก ๋ฐฐ์น ํฌ๊ธฐ๋ 8์ด ๋ฉ๋๋ค (256/32). ๊ฐ ํ์ดํ๋ผ์ธ ๋จ๊ณ๋ ํ ๋ฒ์ ํ๋์ ๋ง์ดํฌ๋ก ๋ฐฐ์น์ ํจ๊ป ์๋ํฉ๋๋ค.
DP + PP ์ค์ ์ ์ ์ญ ๋ฐฐ์น ํฌ๊ธฐ๋ฅผ ๊ณ์ฐํ๋ ค๋ฉด mbs*chunks*dp_degree (8*32*4=1024)๋ฅผ ์ํํฉ๋๋ค.
๋ค์ด์ด๊ทธ๋จ์ผ๋ก ๋์๊ฐ ๋ณด๊ฒ ์ต๋๋ค.
chunks=1๋ก ์ค์ ํ๋ฉด ๋งค์ฐ ๋นํจ์จ์ ์ธ naive MP๊ฐ ์์ฑ๋๋ฉฐ, ๋งค์ฐ ํฐ chunks ๊ฐ์ผ๋ก ์ค์ ํ๋ฉด ์์ฃผ ์์ ๋ง์ดํฌ๋ก ๋ฐฐ์น ํฌ๊ธฐ๊ฐ ์์ฑ๋์ด ํจ์จ์ ์ด์ง ์์ ์ ์์ต๋๋ค. ๋ฐ๋ผ์ ๊ฐ์ฅ ํจ์จ์ ์ธ GPU ํ์ฉ์ ์ํด ์ด๋ค ๊ฐ์ด ๊ฐ์ฅ ์ ์ ํ์ง ์คํ์ ํด์ผ ํฉ๋๋ค.
๋ค์ด์ด๊ทธ๋จ์์ ๋ณด์ด๋ ๊ฒ์ฒ๋ผ "dead" ์๊ฐ์ ๋ฒ๋ธ์ด ์กด์ฌํ์ฌ ๋ง์ง๋ง forward ๋จ๊ณ๊ฐ backward ๋จ๊ณ๊ฐ ํ์ดํ๋ผ์ธ์ ์๋ฃํ๊ธฐ๋ฅผ ๊ธฐ๋ค๋ ค์ผ ํ๋ ์ํฉ์ด ๋ฐ์ํ์ง๋ง, chunks์ ๊ฐ์ฅ ์ ์ ํ ๊ฐ์ ์ฐพ๋ ๊ฒ์ ๋ชฉ์ ์ ๋ชจ๋ ์ฐธ์ฌํ๋ GPU์์ ๋์์ ๊ณ ๋๋ก ํ์ฉ๋๋ GPU ํ์ฉ์ ๊ฐ๋ฅํ๊ฒ ํ์ฌ ๋ฒ๋ธ์ ํฌ๊ธฐ๋ฅผ ์ต์ํํ๋ ๊ฒ์
๋๋ค.
ํด๊ฒฐ์ฑ ์ ์ ํต์ ์ธ ํ์ดํ๋ผ์ธ API์ ๋ ํ๋์ ์ธ ์๋ฃจ์ ์ผ๋ก ๋๋ฉ๋๋ค. ์ ํต์ ์ธ ํ์ดํ๋ผ์ธ API ์๋ฃจ์ ๊ณผ ํ๋์ ์ธ ์๋ฃจ์ ์ ๋ํด ์์๋ณด๊ฒ ์ต๋๋ค.
์ ํต์ ์ธ ํ์ดํ๋ผ์ธ API ์๋ฃจ์ :
- ํ์ดํ ์น
- FairScale
- DeepSpeed
- Megatron-LM
ํ๋์ ์ธ ์๋ฃจ์ :
- Varuna
- Sagemaker
์ ํต์ ์ธ ํ์ดํ๋ผ์ธ API ์๋ฃจ์ ์ ๋ฌธ์ ์ :
- ๋ชจ๋ธ์ ์๋นํ ์์ ํด์ผ ํ๋ค๋ ์ ์ด ๋ฌธ์ ์
๋๋ค. ํ์ดํ๋ผ์ธ์ ๋ชจ๋์ ์ ์์ ์ธ ํ๋ฆ์
nn.Sequential์ํ์ค๋ก ๋ค์ ์์ฑํด์ผ ํ๋ฏ๋ก ๋ชจ๋ธ์ ์ค๊ณ๋ฅผ ๋ณ๊ฒฝํด์ผ ํ ์ ์์ต๋๋ค. - ํ์ฌ ํ์ดํ๋ผ์ธ API๋ ๋งค์ฐ ์ ํ์ ์ ๋๋ค. ํ์ดํ๋ผ์ธ์ ๋งค์ฐ ์ฒซ ๋ฒ์งธ ๋จ๊ณ์์ ์ ๋ฌ๋๋ ๋ง์ ํ์ด์ฌ ๋ณ์๊ฐ ์๋ ๊ฒฝ์ฐ ์ด๋ฅผ ํด๊ฒฐํด์ผ ํฉ๋๋ค. ํ์ฌ ํ์ดํ๋ผ์ธ ์ธํฐํ์ด์ค๋ ํ๋์ ํ ์ ๋๋ ํ ์์ ํํ์ ์ ์ผํ ์ ๋ ฅ ๋ฐ ์ถ๋ ฅ์ผ๋ก ์๊ตฌํฉ๋๋ค. ์ด๋ฌํ ํ ์๋ ๋ง์ดํฌ๋ก ๋ฐฐ์น๋ก ๋ฏธ๋ ๋ฐฐ์น๋ก ๋ฌถ์ ๊ฒ์ด๋ฏ๋ก ์ฒซ ๋ฒ์งธ ์ฐจ์์ผ๋ก ๋ฐฐ์น ํฌ๊ธฐ๊ฐ ์์ด์ผ ํฉ๋๋ค. ๊ฐ๋ฅํ ๊ฐ์ ์ฌํญ์ ์ฌ๊ธฐ์์ ๋ ผ์๋๊ณ ์์ต๋๋ค. https://github.com/pytorch/pytorch/pull/50693
- ํ์ดํ ๋จ๊ณ ์์ค์์ ์กฐ๊ฑด๋ถ ์ ์ด ํ๋ฆ์ ๋ถ๊ฐ๋ฅํฉ๋๋ค. ์๋ฅผ ๋ค์ด, T5์ ๊ฐ์ ์ธ์ฝ๋-๋์ฝ๋ ๋ชจ๋ธ์ ์กฐ๊ฑด๋ถ ์ธ์ฝ๋ ๋จ๊ณ๋ฅผ ์ฒ๋ฆฌํ๊ธฐ ์ํด ํน๋ณํ ํด๊ฒฐ์ฑ ์ด ํ์ํฉ๋๋ค.
- ๊ฐ ๋ ์ด์ด๋ฅผ ์ ๋ ฌํ์ฌ ํ๋์ ๋ชจ๋ธ์ ์ถ๋ ฅ์ด ๋ค๋ฅธ ๋ชจ๋ธ์ ์ ๋ ฅ์ด ๋๋๋กํด์ผ ํฉ๋๋ค.
์ฐ๋ฆฌ๋ ์์ง Varuna์ SageMaker๋ก ์คํํ์ง ์์์ง๋ง, ํด๋น ๋ ผ๋ฌธ๋ค์ ์์์ ์ธ๊ธํ ๋ฌธ์ ๋ค์ ๋ชฉ๋ก์ ๊ทน๋ณตํ๊ณ ์ฌ์ฉ์์ ๋ชจ๋ธ์ ๋ํ ๋ณ๊ฒฝ ์ฌํญ์ด ํจ์ฌ ์ ๊ฒ ํ์ํ๋ค๊ณ ๋ณด๊ณ ํ๊ณ ์์ต๋๋ค.
๊ตฌํ:
- ํ์ดํ ์น (ํ์ดํ ์น-1.8์์ ์ด๊ธฐ ์ง์, 1.9์์ ์ ์ง์ ์ผ๋ก ๊ฐ์ ๋๊ณ 1.10์์ ๋ ๊ฐ์ ๋จ). ์์ ๋ ์ฐธ๊ณ ํ์ธ์.
- FairScale
- DeepSpeed
- Megatron-LM์ ๋ด๋ถ ๊ตฌํ์ ๊ฐ์ง๊ณ ์์ต๋๋ค - API ์์.
- Varuna
- SageMaker - ์ด๋ AWS์์๋ง ์ฌ์ฉํ ์ ์๋ ์์ ์๋ฃจ์ ์ ๋๋ค.
- OSLO - ์ด๋ Hugging Face Transformers๋ฅผ ๊ธฐ๋ฐ์ผ๋ก ๊ตฌํ๋ ํ์ดํ๋ผ์ธ ๋ณ๋ ฌํ์ ๋๋ค.
๐ค Transformers ์ํ: ์ด ์์ฑ ์์ ์์ ๋ชจ๋ธ ์ค ์ด๋ ๊ฒ๋ ์์ ํ PP๋ฅผ ์ง์ํ์ง ์์ต๋๋ค. GPT2์ T5 ๋ชจ๋ธ์ naive MP๋ฅผ ์ง์ํฉ๋๋ค. ์ฃผ์ ์ฅ์ ๋ฌผ์ ๋ชจ๋ธ์ nn.Sequential๋ก ๋ณํํ๊ณ ๋ชจ๋ ์
๋ ฅ์ ํ
์๋ก ๊ฐ์ ธ์์ผ ํ๋ ๊ฒ์ ์ฒ๋ฆฌํ ์ ์๊ธฐ ๋๋ฌธ์
๋๋ค. ํ์ฌ ๋ชจ๋ธ์๋ ์ด๋ฌํ ๋ณํ์ ๋งค์ฐ ๋ณต์กํ๊ฒ ๋ง๋๋ ๋ง์ ๊ธฐ๋ฅ์ด ํฌํจ๋์ด ์์ด ์ ๊ฑฐํด์ผ ํฉ๋๋ค.
๊ธฐํ ์ ๊ทผ ๋ฐฉ๋ฒ:
DeepSpeed, Varuna ๋ฐ SageMaker๋ ๊ต์ฐจ ํ์ดํ๋ผ์ธ(Interleaved Pipeline) ๊ฐ๋
์ ์ฌ์ฉํฉ๋๋ค.

์ฌ๊ธฐ์๋ ๋ฒ๋ธ(์ ํด ์๊ฐ)์ ์ญ๋ฐฉํฅ ํจ์ค์ ์ฐ์ ์์๋ฅผ ๋ถ์ฌํ์ฌ ์ต์ํํฉ๋๋ค.
Varuna๋ ๊ฐ์ฅ ํจ์จ์ ์ธ ์ค์ผ์ค๋ง์ ์ฐพ๊ธฐ ์ํด ์๋ฎฌ๋ ์ด์ ์ ์ฌ์ฉํ์ฌ ์ค์ผ์ค์ ๊ฐ์ ํ๋ ค๊ณ ํฉ๋๋ค.
OSLO๋ nn.Sequential๋ก ๋ณํํ์ง ์๊ณ Transformers๋ฅผ ๊ธฐ๋ฐ์ผ๋ก ํ ํ์ดํ๋ผ์ธ ๋ณ๋ ฌํ๋ฅผ ๊ตฌํํ์ต๋๋ค.
ํ ์ ๋ณ๋ ฌ ์ฒ๋ฆฌ [[tensor-parallelism]]
ํ ์ ๋ณ๋ ฌ ์ฒ๋ฆฌ์์๋ ๊ฐ GPU๊ฐ ํ ์์ ์ผ๋ถ๋ถ๋ง ์ฒ๋ฆฌํ๊ณ ์ ์ฒด ํ ์๊ฐ ํ์ํ ์ฐ์ฐ์ ๋ํด์๋ง ์ ์ฒด ํ ์๋ฅผ ์ง๊ณํฉ๋๋ค.
์ด ์น์ ์์๋ Megatron-LM ๋ ผ๋ฌธ์ธ Efficient Large-Scale Language Model Training on GPU Clusters์์์ ๊ฐ๋ ๊ณผ ๋ค์ด์ด๊ทธ๋จ์ ์ฌ์ฉํฉ๋๋ค.
Transformer์ ์ฃผ์ ๊ตฌ์ฑ ์์๋ fully connected nn.Linear์ ๋น์ ํ ํ์ฑํ ํจ์์ธ GeLU์
๋๋ค.
Megatron ๋
ผ๋ฌธ์ ํ๊ธฐ๋ฒ์ ๋ฐ๋ผ ํ๋ ฌ์ ์ ๊ณฑ ๋ถ๋ถ์ Y = GeLU(XA)๋ก ํํํ ์ ์์ต๋๋ค. ์ฌ๊ธฐ์ X์ Y๋ ์
๋ ฅ ๋ฐ ์ถ๋ ฅ ๋ฒกํฐ์ด๊ณ A๋ ๊ฐ์ค์น ํ๋ ฌ์
๋๋ค.
ํ๋ ฌ ํํ๋ก ๊ณ์ฐ์ ์ดํด๋ณด๋ฉด, ํ๋ ฌ ๊ณฑ์
์ ๋ค์ค GPU๋ก ๋ถํ ํ ์ ์๋ ๋ฐฉ๋ฒ์ ์ฝ๊ฒ ์ ์ ์์ต๋๋ค:

๊ฐ์ค์น ํ๋ ฌ A๋ฅผ N๊ฐ์ GPU์ ๋ํด ์ด๋ณ๋ก ๋ถํ ํ๊ณ ๋ณ๋ ฌ๋ก ํ๋ ฌ ๊ณฑ์
XA_1์์ XA_n๊น์ง ์ํํ๋ฉด N๊ฐ์ ์ถ๋ ฅ ๋ฒกํฐ Y_1, Y_2, ..., Y_n๊ฐ ์์ฑ๋๋ฉฐ ๋
๋ฆฝ์ ์ผ๋ก GeLU์ ์ ๋ฌ๋ ์ ์์ต๋๋ค:

์ด ์๋ฆฌ๋ฅผ ์ฌ์ฉํ์ฌ ๋๊ธฐํ๊ฐ ํ์ํ์ง ์์ GPU ๊ฐ์ ์์ ๊น์ด์ MLP๋ฅผ ์
๋ฐ์ดํธํ ์ ์์ต๋๋ค. ๊ทธ๋ฌ๋ ๊ฒฐ๊ณผ ๋ฒกํฐ๋ฅผ ์ค๋๋ก๋ถํฐ ์ฌ๊ตฌ์ฑํด์ผ ํ๋ ๋ง์ง๋ง ๋จ๊ณ๊น์ง๋ GPU ๊ฐ์ ๋๊ธฐํ๊ฐ ํ์ํฉ๋๋ค. Megatron-LM ๋
ผ๋ฌธ์ ์ ์๋ค์ ์ด์ ๋ํ ์ ์ฉํ ๊ทธ๋ฆผ์ ์ ๊ณตํฉ๋๋ค:

๋ค์ค ํค๋ ์ดํ
์
๋ ์ด์ด์ ๋ณ๋ ฌํ๋ ๋์ฑ ๊ฐ๋จํฉ๋๋ค. ์ด๋ฏธ ๋
๋ฆฝ์ ์ธ ๋ค์ค ํค๋๋ฅผ ๊ฐ์ง๊ณ ์๊ธฐ ๋๋ฌธ์ ์ด๋ฏธ ๋ณ๋ ฌํ๋์ด ์์ต๋๋ค!

ํน๋ณ ๊ณ ๋ ค์ฌํญ: TP๋ ๋งค์ฐ ๋น ๋ฅธ ๋คํธ์ํฌ๊ฐ ํ์ํ๋ฏ๋ก ํ ๊ฐ ์ด์์ ๋ ธ๋์์ TP๋ฅผ ์ํํ๋ ๊ฒ์ ๊ถ์ฅ๋์ง ์์ต๋๋ค. ์ค์ ๋ก ๋ ธ๋์ 4๊ฐ์ GPU๊ฐ ์๋ ๊ฒฝ์ฐ TP์ ์ต๋ ์ฐจ์๋ 4์ ๋๋ค. TP ์ฐจ์๊ฐ 8์ธ ๊ฒฝ์ฐ ์ต์ํ 8๊ฐ์ GPU๊ฐ ์๋ ๋ ธ๋๋ฅผ ์ฌ์ฉํด์ผ ํฉ๋๋ค.
์ด ์น์ ์ ์๋์ ๋ ์์ธํ TP ๊ฐ์๋ฅผ ๊ธฐ๋ฐ์ผ๋ก ํฉ๋๋ค. ์์ฑ์๋ @anton-l์ ๋๋ค.
SageMaker๋ ๋ ํจ์จ์ ์ธ ์ฒ๋ฆฌ๋ฅผ ์ํด TP์ DP๋ฅผ ๊ฒฐํฉํฉ๋๋ค.
๋์ฒด ์ด๋ฆ:
- DeepSpeed๋ ์ด๋ฅผ ํ ์ ์ฌ๋ผ์ด์ฑ์ด๋ผ๊ณ ๋ถ๋ฆ ๋๋ค.
๊ตฌํ:
- Megatron-LM์ ๋ด๋ถ ๊ตฌํ์ ๊ฐ์ง๊ณ ์์ผ๋ฏ๋ก ๋ชจ๋ธ์ ๋งค์ฐ ํนํ๋์ด ์์ต๋๋ค.
- parallelformers (ํ์ฌ๋ ์ถ๋ก ์๋ง ํด๋น)
- SageMaker - ์ด๋ AWS์์๋ง ์ฌ์ฉํ ์ ์๋ ์์ ์๋ฃจ์ ์ ๋๋ค.
- OSLO์ Transformers๋ฅผ ๊ธฐ๋ฐ์ผ๋ก ํ ํ ์ ๋ณ๋ ฌ ์ฒ๋ฆฌ ๊ตฌํ์ ๊ฐ์ง๊ณ ์์ต๋๋ค.
๐ค Transformers ํํฉ:
- core: ์์ง ํต์ฌ ๋ถ๋ถ์ ๊ตฌํ๋์ง ์์
- ๊ทธ๋ฌ๋ ์ถ๋ก ์ ํ๋ ค๋ฉด parallelformers๊ฐ ๋๋ถ๋ถ์ ๋ชจ๋ธ์ ์ง์ํฉ๋๋ค. ๋ฐ๋ผ์ ํต์ฌ ๋ถ๋ถ์ ๊ตฌํ๋๊ธฐ ์ ๊น์ง ๊ทธ๋ค์ ๊ฒ์ ์ฌ์ฉํ ์ ์์ต๋๋ค. ๊ทธ๋ฆฌ๊ณ ํ๋ จ ๋ชจ๋๋ ์ง์๋ ์์ ์ ๋๋ค.
- Deepspeed-Inference๋ CUDA ์ปค๋์ ๊ธฐ๋ฐ์ผ๋ก ํ๋ ๋งค์ฐ ๋น ๋ฅธ ์ถ๋ก ๋ชจ๋์์ BERT, GPT-2 ๋ฐ GPT-Neo ๋ชจ๋ธ์ ์ง์ํฉ๋๋ค. ์์ธํ ๋ด์ฉ์ ์ฌ๊ธฐ๋ฅผ ์ฐธ์กฐํ์ธ์.
DP+PP [[dppp]]
DeepSpeed pipeline tutorial์์ ๋ค์ ๋ค์ด์ด๊ทธ๋จ์ DP์ PP๋ฅผ ๊ฒฐํฉํ๋ ๋ฐฉ๋ฒ์ ๋ณด์ฌ์ค๋๋ค.
์ฌ๊ธฐ์ DP ๋ญํฌ 0์ GPU2๋ฅผ ๋ณด์ง ๋ชปํ๊ณ , DP ๋ญํฌ 1์ GPU3์ ๋ณด์ง ๋ชปํ๋ ๊ฒ์ด ์ค์ํฉ๋๋ค. DP์๊ฒ๋ ๋ฑ 2๊ฐ์ GPU์ธ ๊ฒ์ฒ๋ผ ๋ฐ์ดํฐ๋ฅผ ๊ณต๊ธํฉ๋๋ค. GPU0์ PP๋ฅผ ์ฌ์ฉํ์ฌ GPU2์๊ฒ ์ผ๋ถ ์์ ์ "๋น๋ฐ๋ฆฌ์" ํ ๋นํฉ๋๋ค. ๊ทธ๋ฆฌ๊ณ GPU1๋ GPU3์ ๋์์ผ๋ก ์ผ์ ๊ฐ์ ๋ฐฉ์์ผ๋ก ์์ ํฉ๋๋ค.
๊ฐ ์ฐจ์๋ง๋ค ์ ์ด๋ 2๊ฐ์ GPU๊ฐ ํ์ํ๋ฏ๋ก ์ต์ํ 4๊ฐ์ GPU๊ฐ ํ์ํฉ๋๋ค.
๊ตฌํ:
๐ค Transformers ํํฉ: ์์ง ๊ตฌํ๋์ง ์์
DP+PP+TP [[dppptp]]
๋ ํจ์จ์ ์ธ ํ๋ จ์ ์ํด PP์ TP ๋ฐ DP๋ฅผ ๊ฒฐํฉํ์ฌ 3D ๋ณ๋ ฌ ์ฒ๋ฆฌ๋ฅผ ์ฌ์ฉํฉ๋๋ค. ๋ค์ ๋ค์ด์ด๊ทธ๋จ์์ ์ด๋ฅผ ํ์ธํ ์ ์์ต๋๋ค.
์ด ๋ค์ด์ด๊ทธ๋จ์ 3D parallelism: Scaling to trillion-parameter models์ด๋ผ๋ ๋ธ๋ก๊ทธ ๊ธ์์ ํ์ธํ ์ ์์ต๋๋ค.
๊ฐ ์ฐจ์๋ง๋ค ์ ์ด๋ 2๊ฐ์ GPU๊ฐ ํ์ํ๋ฏ๋ก ์ต์ํ 8๊ฐ์ GPU๊ฐ ํ์ํฉ๋๋ค.
๊ตฌํ:
- DeepSpeed - DeepSpeed๋ ๋์ฑ ํจ์จ์ ์ธ DP์ธ ZeRO-DP๋ผ๊ณ ๋ ๋ถ๋ฆ ๋๋ค.
- Megatron-LM
- Varuna
- SageMaker
- OSLO
๐ค Transformers ํํฉ: ์์ง ๊ตฌํ๋์ง ์์. PP์ TP๊ฐ ์๊ธฐ ๋๋ฌธ์ ๋๋ค.
ZeRO DP+PP+TP [[zero-dppptp]]
DeepSpeed์ ์ฃผ์ ๊ธฐ๋ฅ ์ค ํ๋๋ DP์ ํ์ฅ์ธ ZeRO์ ๋๋ค. ZeRO-DP์ ๋ํด ์ด๋ฏธ ZeRO Data Parallelism์์ ๋ ผ์๋์์ต๋๋ค. ์ผ๋ฐ์ ์ผ๋ก ์ด๋ PP๋ TP๋ฅผ ํ์๋กํ์ง ์๋ ๋ ๋ฆฝ์ ์ธ ๊ธฐ๋ฅ์ ๋๋ค. ๊ทธ๋ฌ๋ PP์ TP์ ๊ฒฐํฉํ ์๋ ์์ต๋๋ค.
ZeRO-DP๊ฐ PP์ (์ ํ์ ์ผ๋ก TP์) ๊ฒฐํฉ๋๋ฉด ์ผ๋ฐ์ ์ผ๋ก ZeRO ๋จ๊ณ 1(์ตํฐ๋ง์ด์ ๋ถํ )๋ง ํ์ฑํ๋ฉ๋๋ค.
์ด๋ก ์ ์ผ๋ก๋ ZeRO ๋จ๊ณ 2(๊ทธ๋ผ๋์ธํธ ๋ถํ )๋ฅผ ํ์ดํ๋ผ์ธ ๋ณ๋ ฌ ์ฒ๋ฆฌ์ ํจ๊ป ์ฌ์ฉํ ์๋ ์์ง๋ง, ์ด๋ ์ฑ๋ฅ์ ๋์ ์ํฅ์ ๋ฏธ์น ๊ฒ์ ๋๋ค. ๊ฐ ๋ง์ดํฌ๋ก ๋ฐฐ์น๋ง๋ค ๊ทธ๋ผ๋์ธํธ๋ฅผ ์ค๋ฉํ๊ธฐ ์ ์ ์ถ๊ฐ์ ์ธ ๋ฆฌ๋์ค-์ค์บํฐ ์ปฌ๋ ํฐ๋ธ๊ฐ ํ์ํ๋ฉฐ, ์ด๋ ์ ์ฌ์ ์ผ๋ก ์๋นํ ํต์ ์ค๋ฒํค๋๋ฅผ ์ถ๊ฐํฉ๋๋ค. ํ์ดํ๋ผ์ธ ๋ณ๋ ฌ ์ฒ๋ฆฌ์ ํน์ฑ์ ์์ ๋ง์ดํฌ๋ก ๋ฐฐ์น๊ฐ ์ฌ์ฉ๋๋ฉฐ, ์ฐ์ ์ฐ์ฐ ๊ฐ๋(๋ง์ดํฌ๋ก ๋ฐฐ์น ํฌ๊ธฐ)๋ฅผ ๊ท ํ ์๊ฒ ์ ์งํ๋ฉด์ ํ์ดํ๋ผ์ธ ๋ฒ๋ธ(๋ง์ดํฌ๋ก ๋ฐฐ์น ์)์ ์ต์ํํ๋ ๊ฒ์ ์ค์ ์ ๋ก๋๋ค. ๋ฐ๋ผ์ ํด๋น ํต์ ๋น์ฉ์ ๋ฌธ์ ๊ฐ ๋ ๊ฒ์ ๋๋ค.
๋ํ, PP๋ก ์ธํด ์ ์๋ณด๋ค ์ ์ ์์ ๋ ์ด์ด๊ฐ ์์ผ๋ฏ๋ก ๋ฉ๋ชจ๋ฆฌ ์ ์ฝ์ ํฌ์ง ์์ ๊ฒ์
๋๋ค. PP๋ ์ด๋ฏธ ๊ทธ๋๋์ธํธ ํฌ๊ธฐ๋ฅผ 1/PP๋ก ์ค์ด๊ธฐ ๋๋ฌธ์ ๊ทธ๋๋์ธํธ ์ค๋ฉ์ ์ ์ฝ ํจ๊ณผ๋ ์์ DP๋ณด๋ค๋ ๋ฏธ๋ฏธํฉ๋๋ค.
ZeRO ๋จ๊ณ 3๋ ๊ฐ์ ์ด์ ๋ก ์ข์ ์ ํ์ด ์๋๋๋ค - ๋ ๋ง์ ๋ ธ๋ ๊ฐ ํต์ ์ด ํ์ํฉ๋๋ค.
๊ทธ๋ฆฌ๊ณ ZeRO๊ฐ ์๊ธฐ ๋๋ฌธ์ ๋ค๋ฅธ ์ด์ ์ ZeRO-Offload์ ๋๋ค. ์ด๋ ๋จ๊ณ 1์ด๋ฏ๋ก ์ตํฐ๋ง์ด์ ์ํ๋ฅผ CPU๋ก ์คํ๋ก๋ํ ์ ์์ต๋๋ค.
๊ตฌํ:
- Megatron-DeepSpeed ๋ฐ BigScience์ Megatron-Deepspeed, ์ด์ ์ ์ฅ์์ ํฌํฌ์ ๋๋ค.
- OSLO
์ค์ํ ๋ ผ๋ฌธ:
๐ค Transformers ํํฉ: ์์ง ๊ตฌํ๋์ง ์์, PP์ TP๊ฐ ์๊ธฐ ๋๋ฌธ์ ๋๋ค.
FlexFlow [[flexflow]]
FlexFlow๋ ์ฝ๊ฐ ๋ค๋ฅธ ๋ฐฉ์์ผ๋ก ๋ณ๋ ฌํ ๋ฌธ์ ๋ฅผ ํด๊ฒฐํฉ๋๋ค.
์ด๋ Sample-Operator-Attribute-Parameter๋ฅผ ๊ธฐ๋ฐ์ผ๋ก ํ๋ ์ผ์ข ์ 4D ๋ณ๋ ฌํ๋ฅผ ์ํํฉ๋๋ค.
- Sample = ๋ฐ์ดํฐ ๋ณ๋ ฌํ (์ํ๋ณ ๋ณ๋ ฌ)
- Operator = ๋จ์ผ ์ฐ์ฐ์ ์ฌ๋ฌ ํ์ ์ฐ์ฐ์ผ๋ก ๋ณ๋ ฌํ
- Attribute = ๋ฐ์ดํฐ ๋ณ๋ ฌํ (๊ธธ์ด๋ณ ๋ณ๋ ฌ)
- Parameter = ๋ชจ๋ธ ๋ณ๋ ฌํ (์ํ ๋๋ ์์ง๊ณผ ๊ด๊ณ์์ด)
์์:
- Sample
512 ๊ธธ์ด์ 10๊ฐ์ ๋ฐฐ์น๋ฅผ ๊ฐ์ ํด ๋ด ์๋ค. ์ด๋ฅผ sample ์ฐจ์์ผ๋ก 2๊ฐ์ ์ฅ์น์ ๋ณ๋ ฌํํ๋ฉด, 10 x 512๋ 5 x 2 x 512๊ฐ ๋ฉ๋๋ค.
- Operator
๋ ์ด์ด ์ ๊ทํ๋ฅผ ์ํํ๋ค๋ฉด, ์ฐ์ std๋ฅผ ๊ณ์ฐํ๊ณ ๋ ๋ฒ์งธ๋ก mean์ ๊ณ์ฐํ ๋ค์ ๋ฐ์ดํฐ๋ฅผ ์ ๊ทํํ ์ ์์ต๋๋ค. Operator ๋ณ๋ ฌํ๋ std์ mean์ ๋ณ๋ ฌ๋ก ๊ณ์ฐํ ์ ์๋๋ก ํฉ๋๋ค. ๋ฐ๋ผ์ operator ์ฐจ์์ผ๋ก 2๊ฐ์ ์ฅ์น (cuda:0, cuda:1)์ ๋ณ๋ ฌํํ๋ฉด, ๋จผ์ ์ ๋ ฅ ๋ฐ์ดํฐ๋ฅผ ๋ ์ฅ์น๋ก ๋ณต์ฌํ ๋ค์ cuda:0์์ std๋ฅผ ๊ณ์ฐํ๊ณ cuda:1์์ ๋์์ mean์ ๊ณ์ฐํฉ๋๋ค.
- Attribute
512 ๊ธธ์ด์ 10๊ฐ์ ๋ฐฐ์น๊ฐ ์์ต๋๋ค. ์ด๋ฅผ attribute ์ฐจ์์ผ๋ก 2๊ฐ์ ์ฅ์น์ ๋ณ๋ ฌํํ๋ฉด, 10 x 512๋ 10 x 2 x 256์ด ๋ฉ๋๋ค.
- Parameter
์ด๋ tensor ๋ชจ๋ธ ๋ณ๋ ฌํ ๋๋ naive layer-wise ๋ชจ๋ธ ๋ณ๋ ฌํ์ ์ ์ฌํฉ๋๋ค.
์ด ํ๋ ์์ํฌ์ ์ค์ํ ์ ์ (1) GPU/TPU/CPU ๋ (2) RAM/DRAM ๋ (3) ๋น ๋ฅธ ์ธํธ๋ผ-์ปค๋ฅํธ ๋ ๋๋ฆฐ ์ธํฐ-์ปค๋ฅํธ์ ๊ฐ์ ๋ฆฌ์์ค๋ฅผ ๊ณ ๋ คํ์ฌ ์ด๋์์ ์ด๋ค ๋ณ๋ ฌํ๋ฅผ ์ฌ์ฉํ ์ง๋ฅผ ์๊ณ ๋ฆฌ์ฆ์ ์ผ๋ก ์๋์ผ๋ก ์ต์ ํํ๋ค๋ ๊ฒ์ ๋๋ค.
ํ๋ ๋งค์ฐ ์ค์ํ ์ธก๋ฉด์ FlexFlow๊ฐ ์ ์ ์ด๊ณ ๊ณ ์ ๋ ์ํฌ๋ก๋๋ฅผ ๊ฐ์ง ๋ชจ๋ธ์ ๋ํ DNN ๋ณ๋ ฌํ๋ฅผ ์ต์ ํํ๊ธฐ ์ํด ์ค๊ณ๋์๋ค๋ ๊ฒ์ ๋๋ค. ๋์ ์ธ ๋์์ ๊ฐ์ง ๋ชจ๋ธ์ ๋ฐ๋ณต๋ง๋ค ๋ค๋ฅธ ๋ณ๋ ฌํ ์ ๋ต์ ์ ํธํ ์ ์์ต๋๋ค.
๋ฐ๋ผ์ ์ด ํ๋ ์์ํฌ์ ์ฅ์ ์ ์ ํํ ํด๋ฌ์คํฐ์์ 30๋ถ ๋์ ์๋ฎฌ๋ ์ด์ ์ ์คํํ๊ณ ์ด ํน์ ํ๊ฒฝ์ ์ต์ ์ผ๋ก ํ์ฉํ๊ธฐ ์ํ ์ต์์ ์ ๋ต์ ์ ์ํ๋ค๋ ๊ฒ์ ๋๋ค. ๋ถํ์ ์ถ๊ฐ/์ ๊ฑฐ/๊ต์ฒดํ๋ฉด ์คํํ๊ณ ๊ทธ์ ๋ํ ๊ณํ์ ๋ค์ ์ต์ ํํ ํ ํ๋ จํ ์ ์์ต๋๋ค. ๋ค๋ฅธ ์ค์ ์ ์์ฒด์ ์ธ ์ฌ์ฉ์ ์ ์ ์ต์ ํ๋ฅผ ๊ฐ์ง ์ ์์ต๋๋ค.
๐ค Transformers ํํฉ: ์์ง ํตํฉ๋์ง ์์. ์ด๋ฏธ transformers.utils.fx๋ฅผ ํตํด ๋ชจ๋ธ์ FX-์ถ์ ํ ์ ์์ผ๋ฉฐ, ์ด๋ FlexFlow์ ์ ํ ์กฐ๊ฑด์ ๋๋ค. ๋ฐ๋ผ์ ์ด๋ค ์์ ์ ์ํํด์ผ FlexFlow๊ฐ ์ฐ๋ฆฌ์ ๋ชจ๋ธ๊ณผ ํจ๊ป ์๋ํ ์ ์๋์ง ํ์ ํด์ผ ํฉ๋๋ค.
์ด๋ค ์ ๋ต์ ์ฌ์ฉํด์ผ ํ ๊น์? [[which-strategy-to-use-when]]
๋ค์์ ์ด๋ค ๋ณ๋ ฌํ ์ ๋ต์ ์ธ์ ์ฌ์ฉํด์ผ ํ๋์ง์ ๋ํ ๋งค์ฐ ๋๋ต์ ์ธ ๊ฐ์์ ๋๋ค. ๊ฐ ๋ชฉ๋ก์ ์ฒซ ๋ฒ์งธ ์ ๋ต์ด ์ผ๋ฐ์ ์ผ๋ก ๋ ๋น ๋ฆ ๋๋ค.
โจ ๋จ์ผ GPU
๋ชจ๋ธ์ด ๋จ์ผ GPU์ ๋ง๋ ๊ฒฝ์ฐ:
- ์ผ๋ฐ์ ์ธ ์ฌ์ฉ
๋ชจ๋ธ์ด ๋จ์ผ GPU์ ๋ง์ง ์๋ ๊ฒฝ์ฐ:
- ZeRO + CPU ๋ฐ ์ต์ ์ผ๋ก NVMe ์ธ๋ก๋
- ์์ ๋์ผํ๊ฒ ์ฌ์ฉํ๋, ๊ฐ์ฅ ํฐ ๋ ์ด์ด๊ฐ ๋จ์ผ GPU์ ๋ง์ง ์๋ ๊ฒฝ์ฐ Memory Centric Tiling(์์ธํ ๋ด์ฉ์ ์๋ ์ฐธ์กฐ)์ ์ถ๊ฐ์ ์ผ๋ก ์ฌ์ฉ
๊ฐ์ฅ ํฐ ๋ ์ด์ด๊ฐ ๋จ์ผ GPU์ ๋ง์ง ์๋ ๊ฒฝ์ฐ:
- ZeRO - Memory Centric Tiling (MCT) ํ์ฑํ. ์ด๋ฅผ ํตํด ํฌ๊ธฐ๊ฐ ๋งค์ฐ ํฐ ๋ ์ด์ด๋ฅผ ์์๋ก ๋ถํ ํ์ฌ ์์ฐจ์ ์ผ๋ก ์คํํ ์ ์์ต๋๋ค. MCT๋ GPU์ ํ์ฑํ๋ ๋งค๊ฐ๋ณ์์ ์๋ฅผ ์ค์ด์ง๋ง ํ์ฑํ ๋ฉ๋ชจ๋ฆฌ์๋ ์ํฅ์ ์ฃผ์ง ์์ต๋๋ค. ํ์ฌ ์์ฑ ๊ธฐ์ค์ผ๋ก ์ด ์๊ตฌ์ฌํญ์ ๋งค์ฐ ๋๋ฌผ๊ธฐ ๋๋ฌธ์ ์ฌ์ฉ์๊ฐ
torch.nn.Linear๋ฅผ ์๋์ผ๋ก ์์ ํด์ผ ํฉ๋๋ค.
โจ ๋จ์ผ ๋ ธ๋ / ๋ค์ค GPU
๋ชจ๋ธ์ด ๋จ์ผ GPU์ ๋ง๋ ๊ฒฝ์ฐ:
- DDP - ๋ถ์ฐ DP
- ZeRO - ์ํฉ๊ณผ ๊ตฌ์ฑ์ ๋ฐ๋ผ ๋น ๋ฅผ ์๋ ์๊ณ ๊ทธ๋ ์ง ์์ ์๋ ์์ต๋๋ค.
๋ชจ๋ธ์ด ๋จ์ผ GPU์ ๋ง์ง ์๋ ๊ฒฝ์ฐ:
- PP
- ZeRO
- TP
NVLINK ๋๋ NVSwitch๋ฅผ ํตํ ๋งค์ฐ ๋น ๋ฅธ ์ธํธ๋ผ-๋ ธ๋ ์ฐ๊ฒฐ์ด ์๋ ๊ฒฝ์ฐ ์ด ์ธ ๊ฐ์ง ๋ฐฉ๋ฒ์ ๊ฑฐ์ ๋๋ฑํ ๊ฒ์ด๋ฉฐ, ์ด๋ฌํ ์ฐ๊ฒฐ์ด ์๋ ๊ฒฝ์ฐ PP๊ฐ TP๋ ZeRO๋ณด๋ค ๋น ๋ฅผ ๊ฒ์ ๋๋ค. ๋ํ TP์ ์ฐจ์๋ ์ํฅ์ ์ค ์ ์์ต๋๋ค. ํน์ ์ค์ ์์ ์ฐ์น์๋ฅผ ์ฐพ๊ธฐ ์ํด ์คํํ๋ ๊ฒ์ด ๊ฐ์ฅ ์ข์ต๋๋ค.
TP๋ ๊ฑฐ์ ํญ์ ๋จ์ผ ๋ ธ๋ ๋ด์์ ์ฌ์ฉ๋ฉ๋๋ค. ์ฆ, TP ํฌ๊ธฐ <= ๋ ธ๋๋น GPU ์์ ๋๋ค.
๊ฐ์ฅ ํฐ ๋ ์ด์ด๊ฐ ๋จ์ผ GPU์ ๋ง์ง ์๋ ๊ฒฝ์ฐ:
- ZeRO๋ฅผ ์ฌ์ฉํ์ง ์์ ๊ฒฝ์ฐ - PP๋ง ์ฌ์ฉํ ์ ์์ผ๋ฏ๋ก TP๋ฅผ ์ฌ์ฉํด์ผ ํฉ๋๋ค.
- ZeRO๋ฅผ ์ฌ์ฉํ ๊ฒฝ์ฐ, "๋จ์ผ GPU"์ ํญ๋ชฉ๊ณผ ๋์ผํ ํญ๋ชฉ ์ฐธ์กฐ
โจ ๋ค์ค ๋ ธ๋ / ๋ค์ค GPU
๋น ๋ฅธ ๋ ธ๋ ๊ฐ ์ฐ๊ฒฐ์ด ์๋ ๊ฒฝ์ฐ:
- ZeRO - ๋ชจ๋ธ์ ๋ํ ์์ ์ด ๊ฑฐ์ ํ์ํ์ง ์์ต๋๋ค.
- PP+TP+DP - ํต์ ์ด ์ ์ง๋ง ๋ชจ๋ธ์ ๋ํ ๋๊ท๋ชจ ๋ณ๊ฒฝ์ด ํ์ํฉ๋๋ค.
๋๋ฆฐ ๋ ธ๋ ๊ฐ ์ฐ๊ฒฐ ๋ฐ GPU ๋ฉ๋ชจ๋ฆฌ ๋ถ์กฑํ ๊ฒฝ์ฐ:
- DP+PP+TP+ZeRO-1



