ํฐ ๋ชจ๋ธ ์ธ์คํด์คํ [[instantiating-a-big-model]]
๋งค์ฐ ํฐ ์ฌ์ ํ๋ จ๋ ๋ชจ๋ธ์ ์ฌ์ฉํ๋ ค๋ฉด, RAM ์ฌ์ฉ์ ์ต์ํํด์ผ ํ๋ ๊ณผ์ ๊ฐ ์์ต๋๋ค. ์ผ๋ฐ์ ์ธ PyTorch ์ํฌํ๋ก์ฐ๋ ๋ค์๊ณผ ๊ฐ์ต๋๋ค:
- ๋ฌด์์ ๊ฐ์ค์น๋ก ๋ชจ๋ธ์ ์์ฑํฉ๋๋ค.
- ์ฌ์ ํ๋ จ๋ ๊ฐ์ค์น๋ฅผ ๋ถ๋ฌ์ต๋๋ค.
- ์ฌ์ ํ๋ จ๋ ๊ฐ์ค์น๋ฅผ ๋ฌด์์ ๋ชจ๋ธ์ ์ ์ฉํฉ๋๋ค.
1๋จ๊ณ์ 2๋จ๊ณ ๋ชจ๋ ๋ชจ๋ธ์ ์ ์ฒด ๋ฒ์ ์ ๋ฉ๋ชจ๋ฆฌ์ ์ ์ฌํด์ผ ํ๋ฉฐ, ๋๋ถ๋ถ ๋ฌธ์ ๊ฐ ์์ง๋ง ๋ชจ๋ธ์ด ๊ธฐ๊ฐ๋ฐ์ดํธ๊ธ์ ์ฉ๋์ ์ฐจ์งํ๊ธฐ ์์ํ๋ฉด ๋ณต์ฌ๋ณธ 2๊ฐ๊ฐ RAM์ ์ด๊ณผํ์ฌ ๋ฉ๋ชจ๋ฆฌ ๋ถ์กฑ ์ด์๋ฅผ ์ผ๊ธฐํ ์ ์์ต๋๋ค. ๋ ์ฌ๊ฐํ ๋ฌธ์ ๋ ๋ถ์ฐ ํ์ต์ ์ํด torch.distributed๋ฅผ ์ฌ์ฉํ๋ ๊ฒฝ์ฐ, ํ๋ก์ธ์ค๋ง๋ค ์ฌ์ ํ๋ จ๋ ๋ชจ๋ธ์ ๋ก๋ํ๊ณ ๋ณต์ฌ๋ณธ์ 2๊ฐ์ฉ RAM์ ์ ์ฅํ๋ค๋ ๊ฒ์
๋๋ค.
๋ฌด์์๋ก ์์ฑ๋ ๋ชจ๋ธ์ "๋น์ด ์๋" (์ฆ ๊ทธ๋ ๋ฉ๋ชจ๋ฆฌ์ ์๋ ๊ฒ์ผ๋ก ์ด๋ค์ง) ํ ์๋ก ์ด๊ธฐํ๋๋ฉฐ ๋ฉ๋ชจ๋ฆฌ ๊ณต๊ฐ์ ์ฐจ์งํฉ๋๋ค. ์ด๊ธฐํ๋ ๋ชจ๋ธ/ํ๋ผ๋ฏธํฐ์ ์ข ๋ฅ์ ์ ํฉํ ๋ถํฌ(์: ์ ๊ท ๋ถํฌ)์ ๋ฐ๋ฅธ ๋ฌด์์ ์ด๊ธฐํ๋ ๊ฐ๋ฅํ ํ ๋น ๋ฅด๊ฒ ํ๊ธฐ ์ํด ์ด๊ธฐํ๋์ง ์์ ๊ฐ์ค์น์ ๋ํด 3๋จ๊ณ ์ดํ์๋ง ์ํ๋ฉ๋๋ค!
์ด ์๋ด์์์๋ Transformers๊ฐ ์ด ๋ฌธ์ ๋ฅผ ํด๊ฒฐํ๊ธฐ ์ํด ์ ๊ณตํ๋ ์๋ฃจ์ ์ ์ดํด๋ด ๋๋ค. ์ฃผ์ํ ์ ์ ์์ง ํ๋ฐํ ๊ฐ๋ฐ ์ค์ธ ๋ถ์ผ์ด๋ฏ๋ก ์ฌ๊ธฐ์ ์ค๋ช ํ๋ API๊ฐ ์์ผ๋ก ์ฝ๊ฐ ๋ณ๊ฒฝ๋ ์ ์๋ค๋ ๊ฒ์ ๋๋ค.
์ค๋ฉ๋ ์ฒดํฌํฌ์ธํธ [[sharded-checkpoints]]
4.18.0 ๋ฒ์ ์ดํ, 10GB ์ด์์ ๊ณต๊ฐ์ ์ฐจ์งํ๋ ๋ชจ๋ธ ์ฒดํฌํฌ์ธํธ๋ ์๋์ผ๋ก ์์ ์กฐ๊ฐ๋ค๋ก ์ค๋ฉ๋ฉ๋๋ค. model.save_pretrained(save_dir)๋ฅผ ์คํํ ๋ ํ๋์ ๋จ์ผ ์ฒดํฌํฌ์ธํธ๋ฅผ ๊ฐ์ง๊ฒ ๋ ๋์ , ์ฌ๋ฌ ๋ถ๋ถ ์ฒดํฌํฌ์ธํธ(๊ฐ๊ฐ์ ํฌ๊ธฐ๋ 10GB ๋ฏธ๋ง)์ ๋งค๊ฐ๋ณ์ ์ด๋ฆ์ ํด๋น ํ์ผ์ ๋งคํํ๋ ์ธ๋ฑ์ค๊ฐ ์์ฑ๋ฉ๋๋ค.
max_shard_size ๋งค๊ฐ๋ณ์๋ก ์ค๋ฉ ์ ์ต๋ ํฌ๊ธฐ๋ฅผ ์ ์ดํ ์ ์์ผ๋ฏ๋ก, ์ด ์์ ๋ฅผ ์ํด ์ค๋ ํฌ๊ธฐ๊ฐ ์์ ์ผ๋ฐ ํฌ๊ธฐ์ ๋ชจ๋ธ์ ์ฌ์ฉํ๊ฒ ์ต๋๋ค: ์ ํต์ ์ธ BERT ๋ชจ๋ธ์ ์ฌ์ฉํด ๋ด
์๋ค.
from transformers import AutoModel
model = AutoModel.from_pretrained("google-bert/bert-base-cased")
[~PreTrainedModel.save_pretrained]์ ์ฌ์ฉํ์ฌ ๋ชจ๋ธ์ ์ ์ฅํ๋ฉด, ๋ชจ๋ธ์ ๊ตฌ์ฑ๊ณผ ๊ฐ์ค์น๊ฐ ๋ค์ด์๋ ๋ ๊ฐ์ ํ์ผ์ด ์๋ ์ ํด๋๊ฐ ์์ฑ๋ฉ๋๋ค:
>>> import os
>>> import tempfile
>>> with tempfile.TemporaryDirectory() as tmp_dir:
... model.save_pretrained(tmp_dir)
... print(sorted(os.listdir(tmp_dir)))
['config.json', 'pytorch_model.bin']
์ด์ ์ต๋ ์ค๋ ํฌ๊ธฐ๋ฅผ 200MB๋ก ์ฌ์ฉํด ๋ด ์๋ค:
>>> with tempfile.TemporaryDirectory() as tmp_dir:
... model.save_pretrained(tmp_dir, max_shard_size="200MB")
... print(sorted(os.listdir(tmp_dir)))
['config.json', 'pytorch_model-00001-of-00003.bin', 'pytorch_model-00002-of-00003.bin', 'pytorch_model-00003-of-00003.bin', 'pytorch_model.bin.index.json']
๋ชจ๋ธ์ ๊ตฌ์ฑ์ ๋ํด, ์ธ ๊ฐ์ ๋ค๋ฅธ ๊ฐ์ค์น ํ์ผ๊ณผ ํ๋ผ๋ฏธํฐ ์ด๋ฆ๊ณผ ํด๋น ํ์ผ์ ๋งคํ์ด ํฌํจ๋ index.json ํ์ผ์ ๋ณผ ์ ์์ต๋๋ค. ์ด๋ฌํ ์ฒดํฌํฌ์ธํธ๋ [~PreTrainedModel.from_pretrained] ๋ฉ์๋๋ฅผ ์ฌ์ฉํ์ฌ ์์ ํ ๋ค์ ๋ก๋ํ ์ ์์ต๋๋ค:
>>> with tempfile.TemporaryDirectory() as tmp_dir:
... model.save_pretrained(tmp_dir, max_shard_size="200MB")
... new_model = AutoModel.from_pretrained(tmp_dir)
ํฐ ๋ชจ๋ธ์ ๊ฒฝ์ฐ ์ด๋ฌํ ๋ฐฉ์์ผ๋ก ์ฒ๋ฆฌํ๋ ์ฃผ๋ ์ฅ์ ์ ์์์ ๋ณด์ฌ์ค ํ๋ฆ์ 2๋จ๊ณ์์, ๊ฐ ์ค๋๊ฐ ์ด์ ์ค๋ ๋ค์์ ๋ก๋๋๋ฏ๋ก ๋ฉ๋ชจ๋ฆฌ ์ฌ์ฉ๋์ด ๋ชจ๋ธ ํฌ๊ธฐ์ ๊ฐ์ฅ ํฐ ์ค๋์ ํฌ๊ธฐ๋ฅผ ์ด๊ณผํ์ง ์๋๋ค๋ ์ ์ ๋๋ค.
์ด ์ธ๋ฑ์ค ํ์ผ์ ํค๊ฐ ์ฒดํฌํฌ์ธํธ์ ์๋์ง, ๊ทธ๋ฆฌ๊ณ ํด๋น ๊ฐ์ค์น๊ฐ ์ด๋์ ์ ์ฅ๋์ด ์๋์ง๋ฅผ ๊ฒฐ์ ํ๋ ๋ฐ ์ฌ์ฉ๋ฉ๋๋ค. ์ด ์ธ๋ฑ์ค๋ฅผ json๊ณผ ๊ฐ์ด ๋ก๋ํ๊ณ ๋์ ๋๋ฆฌ๋ฅผ ์ป์ ์ ์์ต๋๋ค:
>>> import json
>>> with tempfile.TemporaryDirectory() as tmp_dir:
... model.save_pretrained(tmp_dir, max_shard_size="200MB")
... with open(os.path.join(tmp_dir, "pytorch_model.bin.index.json"), "r") as f:
... index = json.load(f)
>>> print(index.keys())
dict_keys(['metadata', 'weight_map'])
๋ฉํ๋ฐ์ดํฐ๋ ํ์ฌ ๋ชจ๋ธ์ ์ด ํฌ๊ธฐ๋ง ํฌํจ๋ฉ๋๋ค. ์์ผ๋ก ๋ค๋ฅธ ์ ๋ณด๋ฅผ ์ถ๊ฐํ ๊ณํ์ ๋๋ค:
>>> index["metadata"]
{'total_size': 433245184}
๊ฐ์ค์น ๋งต์ ์ด ์ธ๋ฑ์ค์ ์ฃผ์ ๋ถ๋ถ์ผ๋ก, ๊ฐ ๋งค๊ฐ๋ณ์ ์ด๋ฆ(PyTorch ๋ชจ๋ธ state_dict์์ ๋ณดํต ์ฐพ์ ์ ์๋)์ ํด๋น ํ์ผ์ ๋งคํํฉ๋๋ค:
>>> index["weight_map"]
{'embeddings.LayerNorm.bias': 'pytorch_model-00001-of-00003.bin',
'embeddings.LayerNorm.weight': 'pytorch_model-00001-of-00003.bin',
...
๋ง์ฝ [~PreTrainedModel.from_pretrained]๋ฅผ ์ฌ์ฉํ์ง ์๊ณ ๋ชจ๋ธ ๋ด์์ ์ด๋ฌํ ์ค๋ฉ๋ ์ฒดํฌํฌ์ธํธ๋ฅผ ์ง์ ๊ฐ์ ธ์ค๋ ค๋ฉด (์ ์ฒด ์ฒดํฌํฌ์ธํธ๋ฅผ ์ํด model.load_state_dict()๋ฅผ ์ํํ๋ ๊ฒ์ฒ๋ผ), [~modeling_utils.load_sharded_checkpoint]๋ฅผ ์ฌ์ฉํด์ผ ํฉ๋๋ค.
>>> from transformers.modeling_utils import load_sharded_checkpoint
>>> with tempfile.TemporaryDirectory() as tmp_dir:
... model.save_pretrained(tmp_dir, max_shard_size="200MB")
... load_sharded_checkpoint(model, tmp_dir)
์ (ไฝ)๋ฉ๋ชจ๋ฆฌ ๋ก๋ฉ [[low-memory-loading]]
์ค๋ฉ๋ ์ฒดํฌํฌ์ธํธ๋ ์์์ ์ธ๊ธํ ์์ ํ๋ฆ์ 2๋จ๊ณ์์ ๋ฉ๋ชจ๋ฆฌ ์ฌ์ฉ๋์ ์ค์ด์ง๋ง, ์ (ไฝ)๋ฉ๋ชจ๋ฆฌ ์ค์ ์์ ๋ชจ๋ธ์ ์ฌ์ฉํ๊ธฐ ์ํด ์ฐ๋ฆฌ์ Accelerate ๋ผ์ด๋ธ๋ฌ๋ฆฌ๋ฅผ ๊ธฐ๋ฐ์ผ๋ก ํ ๋๊ตฌ๋ฅผ ํ์ฉํ๋ ๊ฒ์ด ์ข์ต๋๋ค.
์์ธํ ์ฌํญ์ ๋ค์ ๊ฐ์ด๋๋ฅผ ์ฐธ์กฐํด์ฃผ์ธ์: Accelerate๋ก ๋๊ท๋ชจ ๋ชจ๋ธ ๊ฐ์ ธ์ค๊ธฐ (์๋ฌธ)