DrDavis's picture
Upload folder using huggingface_hub
17c6d62 verified

ํฐ ๋ชจ๋ธ ์ธ์Šคํ„ด์Šคํ™” [[instantiating-a-big-model]]

๋งค์šฐ ํฐ ์‚ฌ์ „ํ›ˆ๋ จ๋œ ๋ชจ๋ธ์„ ์‚ฌ์šฉํ•˜๋ ค๋ฉด, RAM ์‚ฌ์šฉ์„ ์ตœ์†Œํ™”ํ•ด์•ผ ํ•˜๋Š” ๊ณผ์ œ๊ฐ€ ์žˆ์Šต๋‹ˆ๋‹ค. ์ผ๋ฐ˜์ ์ธ PyTorch ์›Œํฌํ”Œ๋กœ์šฐ๋Š” ๋‹ค์Œ๊ณผ ๊ฐ™์Šต๋‹ˆ๋‹ค:

  1. ๋ฌด์ž‘์œ„ ๊ฐ€์ค‘์น˜๋กœ ๋ชจ๋ธ์„ ์ƒ์„ฑํ•ฉ๋‹ˆ๋‹ค.
  2. ์‚ฌ์ „ํ›ˆ๋ จ๋œ ๊ฐ€์ค‘์น˜๋ฅผ ๋ถˆ๋Ÿฌ์˜ต๋‹ˆ๋‹ค.
  3. ์‚ฌ์ „ํ›ˆ๋ จ๋œ ๊ฐ€์ค‘์น˜๋ฅผ ๋ฌด์ž‘์œ„ ๋ชจ๋ธ์— ์ ์šฉํ•ฉ๋‹ˆ๋‹ค.

1๋‹จ๊ณ„์™€ 2๋‹จ๊ณ„ ๋ชจ๋‘ ๋ชจ๋ธ์˜ ์ „์ฒด ๋ฒ„์ „์„ ๋ฉ”๋ชจ๋ฆฌ์— ์ ์žฌํ•ด์•ผ ํ•˜๋ฉฐ, ๋Œ€๋ถ€๋ถ„ ๋ฌธ์ œ๊ฐ€ ์—†์ง€๋งŒ ๋ชจ๋ธ์ด ๊ธฐ๊ฐ€๋ฐ”์ดํŠธ๊ธ‰์˜ ์šฉ๋Ÿ‰์„ ์ฐจ์ง€ํ•˜๊ธฐ ์‹œ์ž‘ํ•˜๋ฉด ๋ณต์‚ฌ๋ณธ 2๊ฐœ๊ฐ€ RAM์„ ์ดˆ๊ณผํ•˜์—ฌ ๋ฉ”๋ชจ๋ฆฌ ๋ถ€์กฑ ์ด์Šˆ๋ฅผ ์•ผ๊ธฐํ•  ์ˆ˜ ์žˆ์Šต๋‹ˆ๋‹ค. ๋” ์‹ฌ๊ฐํ•œ ๋ฌธ์ œ๋Š” ๋ถ„์‚ฐ ํ•™์Šต์„ ์œ„ํ•ด torch.distributed๋ฅผ ์‚ฌ์šฉํ•˜๋Š” ๊ฒฝ์šฐ, ํ”„๋กœ์„ธ์Šค๋งˆ๋‹ค ์‚ฌ์ „ํ›ˆ๋ จ๋œ ๋ชจ๋ธ์„ ๋กœ๋“œํ•˜๊ณ  ๋ณต์‚ฌ๋ณธ์„ 2๊ฐœ์”ฉ RAM์— ์ €์žฅํ•œ๋‹ค๋Š” ๊ฒƒ์ž…๋‹ˆ๋‹ค.

๋ฌด์ž‘์œ„๋กœ ์ƒ์„ฑ๋œ ๋ชจ๋ธ์€ "๋น„์–ด ์žˆ๋Š”" (์ฆ‰ ๊ทธ๋•Œ ๋ฉ”๋ชจ๋ฆฌ์— ์žˆ๋˜ ๊ฒƒ์œผ๋กœ ์ด๋ค„์ง„) ํ…์„œ๋กœ ์ดˆ๊ธฐํ™”๋˜๋ฉฐ ๋ฉ”๋ชจ๋ฆฌ ๊ณต๊ฐ„์„ ์ฐจ์ง€ํ•ฉ๋‹ˆ๋‹ค. ์ดˆ๊ธฐํ™”๋œ ๋ชจ๋ธ/ํŒŒ๋ผ๋ฏธํ„ฐ์˜ ์ข…๋ฅ˜์— ์ ํ•ฉํ•œ ๋ถ„ํฌ(์˜ˆ: ์ •๊ทœ ๋ถ„ํฌ)์— ๋”ฐ๋ฅธ ๋ฌด์ž‘์œ„ ์ดˆ๊ธฐํ™”๋Š” ๊ฐ€๋Šฅํ•œ ํ•œ ๋น ๋ฅด๊ฒŒ ํ•˜๊ธฐ ์œ„ํ•ด ์ดˆ๊ธฐํ™”๋˜์ง€ ์•Š์€ ๊ฐ€์ค‘์น˜์— ๋Œ€ํ•ด 3๋‹จ๊ณ„ ์ดํ›„์—๋งŒ ์ˆ˜ํ–‰๋ฉ๋‹ˆ๋‹ค!

์ด ์•ˆ๋‚ด์„œ์—์„œ๋Š” Transformers๊ฐ€ ์ด ๋ฌธ์ œ๋ฅผ ํ•ด๊ฒฐํ•˜๊ธฐ ์œ„ํ•ด ์ œ๊ณตํ•˜๋Š” ์†”๋ฃจ์…˜์„ ์‚ดํŽด๋ด…๋‹ˆ๋‹ค. ์ฃผ์˜ํ•  ์ ์€ ์•„์ง ํ™œ๋ฐœํžˆ ๊ฐœ๋ฐœ ์ค‘์ธ ๋ถ„์•ผ์ด๋ฏ€๋กœ ์—ฌ๊ธฐ์„œ ์„ค๋ช…ํ•˜๋Š” API๊ฐ€ ์•ž์œผ๋กœ ์•ฝ๊ฐ„ ๋ณ€๊ฒฝ๋  ์ˆ˜ ์žˆ๋‹ค๋Š” ๊ฒƒ์ž…๋‹ˆ๋‹ค.

์ƒค๋”ฉ๋œ ์ฒดํฌํฌ์ธํŠธ [[sharded-checkpoints]]

4.18.0 ๋ฒ„์ „ ์ดํ›„, 10GB ์ด์ƒ์˜ ๊ณต๊ฐ„์„ ์ฐจ์ง€ํ•˜๋Š” ๋ชจ๋ธ ์ฒดํฌํฌ์ธํŠธ๋Š” ์ž๋™์œผ๋กœ ์ž‘์€ ์กฐ๊ฐ๋“ค๋กœ ์ƒค๋”ฉ๋ฉ๋‹ˆ๋‹ค. model.save_pretrained(save_dir)๋ฅผ ์‹คํ–‰ํ•  ๋•Œ ํ•˜๋‚˜์˜ ๋‹จ์ผ ์ฒดํฌํฌ์ธํŠธ๋ฅผ ๊ฐ€์ง€๊ฒŒ ๋  ๋Œ€์‹ , ์—ฌ๋Ÿฌ ๋ถ€๋ถ„ ์ฒดํฌํฌ์ธํŠธ(๊ฐ๊ฐ์˜ ํฌ๊ธฐ๋Š” 10GB ๋ฏธ๋งŒ)์™€ ๋งค๊ฐœ๋ณ€์ˆ˜ ์ด๋ฆ„์„ ํ•ด๋‹น ํŒŒ์ผ์— ๋งคํ•‘ํ•˜๋Š” ์ธ๋ฑ์Šค๊ฐ€ ์ƒ์„ฑ๋ฉ๋‹ˆ๋‹ค.

max_shard_size ๋งค๊ฐœ๋ณ€์ˆ˜๋กœ ์ƒค๋”ฉ ์ „ ์ตœ๋Œ€ ํฌ๊ธฐ๋ฅผ ์ œ์–ดํ•  ์ˆ˜ ์žˆ์œผ๋ฏ€๋กœ, ์ด ์˜ˆ์ œ๋ฅผ ์œ„ํ•ด ์ƒค๋“œ ํฌ๊ธฐ๊ฐ€ ์ž‘์€ ์ผ๋ฐ˜ ํฌ๊ธฐ์˜ ๋ชจ๋ธ์„ ์‚ฌ์šฉํ•˜๊ฒ ์Šต๋‹ˆ๋‹ค: ์ „ํ†ต์ ์ธ BERT ๋ชจ๋ธ์„ ์‚ฌ์šฉํ•ด ๋ด…์‹œ๋‹ค.

from transformers import AutoModel

model = AutoModel.from_pretrained("google-bert/bert-base-cased")

[~PreTrainedModel.save_pretrained]์„ ์‚ฌ์šฉํ•˜์—ฌ ๋ชจ๋ธ์„ ์ €์žฅํ•˜๋ฉด, ๋ชจ๋ธ์˜ ๊ตฌ์„ฑ๊ณผ ๊ฐ€์ค‘์น˜๊ฐ€ ๋“ค์–ด์žˆ๋Š” ๋‘ ๊ฐœ์˜ ํŒŒ์ผ์ด ์žˆ๋Š” ์ƒˆ ํด๋”๊ฐ€ ์ƒ์„ฑ๋ฉ๋‹ˆ๋‹ค:

>>> import os
>>> import tempfile

>>> with tempfile.TemporaryDirectory() as tmp_dir:
...     model.save_pretrained(tmp_dir)
...     print(sorted(os.listdir(tmp_dir)))
['config.json', 'pytorch_model.bin']

์ด์ œ ์ตœ๋Œ€ ์ƒค๋“œ ํฌ๊ธฐ๋ฅผ 200MB๋กœ ์‚ฌ์šฉํ•ด ๋ด…์‹œ๋‹ค:

>>> with tempfile.TemporaryDirectory() as tmp_dir:
...     model.save_pretrained(tmp_dir, max_shard_size="200MB")
...     print(sorted(os.listdir(tmp_dir)))
['config.json', 'pytorch_model-00001-of-00003.bin', 'pytorch_model-00002-of-00003.bin', 'pytorch_model-00003-of-00003.bin', 'pytorch_model.bin.index.json']

๋ชจ๋ธ์˜ ๊ตฌ์„ฑ์— ๋”ํ•ด, ์„ธ ๊ฐœ์˜ ๋‹ค๋ฅธ ๊ฐ€์ค‘์น˜ ํŒŒ์ผ๊ณผ ํŒŒ๋ผ๋ฏธํ„ฐ ์ด๋ฆ„๊ณผ ํ•ด๋‹น ํŒŒ์ผ์˜ ๋งคํ•‘์ด ํฌํ•จ๋œ index.json ํŒŒ์ผ์„ ๋ณผ ์ˆ˜ ์žˆ์Šต๋‹ˆ๋‹ค. ์ด๋Ÿฌํ•œ ์ฒดํฌํฌ์ธํŠธ๋Š” [~PreTrainedModel.from_pretrained] ๋ฉ”์„œ๋“œ๋ฅผ ์‚ฌ์šฉํ•˜์—ฌ ์™„์ „ํžˆ ๋‹ค์‹œ ๋กœ๋“œํ•  ์ˆ˜ ์žˆ์Šต๋‹ˆ๋‹ค:

>>> with tempfile.TemporaryDirectory() as tmp_dir:
...     model.save_pretrained(tmp_dir, max_shard_size="200MB")
...     new_model = AutoModel.from_pretrained(tmp_dir)

ํฐ ๋ชจ๋ธ์˜ ๊ฒฝ์šฐ ์ด๋Ÿฌํ•œ ๋ฐฉ์‹์œผ๋กœ ์ฒ˜๋ฆฌํ•˜๋Š” ์ฃผ๋œ ์žฅ์ ์€ ์œ„์—์„œ ๋ณด์—ฌ์ค€ ํ๋ฆ„์˜ 2๋‹จ๊ณ„์—์„œ, ๊ฐ ์ƒค๋“œ๊ฐ€ ์ด์ „ ์ƒค๋“œ ๋‹ค์Œ์— ๋กœ๋“œ๋˜๋ฏ€๋กœ ๋ฉ”๋ชจ๋ฆฌ ์‚ฌ์šฉ๋Ÿ‰์ด ๋ชจ๋ธ ํฌ๊ธฐ์™€ ๊ฐ€์žฅ ํฐ ์ƒค๋“œ์˜ ํฌ๊ธฐ๋ฅผ ์ดˆ๊ณผํ•˜์ง€ ์•Š๋Š”๋‹ค๋Š” ์ ์ž…๋‹ˆ๋‹ค.

์ด ์ธ๋ฑ์Šค ํŒŒ์ผ์€ ํ‚ค๊ฐ€ ์ฒดํฌํฌ์ธํŠธ์— ์žˆ๋Š”์ง€, ๊ทธ๋ฆฌ๊ณ  ํ•ด๋‹น ๊ฐ€์ค‘์น˜๊ฐ€ ์–ด๋””์— ์ €์žฅ๋˜์–ด ์žˆ๋Š”์ง€๋ฅผ ๊ฒฐ์ •ํ•˜๋Š” ๋ฐ ์‚ฌ์šฉ๋ฉ๋‹ˆ๋‹ค. ์ด ์ธ๋ฑ์Šค๋ฅผ json๊ณผ ๊ฐ™์ด ๋กœ๋“œํ•˜๊ณ  ๋”•์…”๋„ˆ๋ฆฌ๋ฅผ ์–ป์„ ์ˆ˜ ์žˆ์Šต๋‹ˆ๋‹ค:

>>> import json

>>> with tempfile.TemporaryDirectory() as tmp_dir:
...     model.save_pretrained(tmp_dir, max_shard_size="200MB")
...     with open(os.path.join(tmp_dir, "pytorch_model.bin.index.json"), "r") as f:
...         index = json.load(f)

>>> print(index.keys())
dict_keys(['metadata', 'weight_map'])

๋ฉ”ํƒ€๋ฐ์ดํ„ฐ๋Š” ํ˜„์žฌ ๋ชจ๋ธ์˜ ์ด ํฌ๊ธฐ๋งŒ ํฌํ•จ๋ฉ๋‹ˆ๋‹ค. ์•ž์œผ๋กœ ๋‹ค๋ฅธ ์ •๋ณด๋ฅผ ์ถ”๊ฐ€ํ•  ๊ณ„ํš์ž…๋‹ˆ๋‹ค:

>>> index["metadata"]
{'total_size': 433245184}

๊ฐ€์ค‘์น˜ ๋งต์€ ์ด ์ธ๋ฑ์Šค์˜ ์ฃผ์š” ๋ถ€๋ถ„์œผ๋กœ, ๊ฐ ๋งค๊ฐœ๋ณ€์ˆ˜ ์ด๋ฆ„(PyTorch ๋ชจ๋ธ state_dict์—์„œ ๋ณดํ†ต ์ฐพ์„ ์ˆ˜ ์žˆ๋Š”)์„ ํ•ด๋‹น ํŒŒ์ผ์— ๋งคํ•‘ํ•ฉ๋‹ˆ๋‹ค:

>>> index["weight_map"]
{'embeddings.LayerNorm.bias': 'pytorch_model-00001-of-00003.bin',
 'embeddings.LayerNorm.weight': 'pytorch_model-00001-of-00003.bin',
 ...

๋งŒ์•ฝ [~PreTrainedModel.from_pretrained]๋ฅผ ์‚ฌ์šฉํ•˜์ง€ ์•Š๊ณ  ๋ชจ๋ธ ๋‚ด์—์„œ ์ด๋Ÿฌํ•œ ์ƒค๋”ฉ๋œ ์ฒดํฌํฌ์ธํŠธ๋ฅผ ์ง์ ‘ ๊ฐ€์ ธ์˜ค๋ ค๋ฉด (์ „์ฒด ์ฒดํฌํฌ์ธํŠธ๋ฅผ ์œ„ํ•ด model.load_state_dict()๋ฅผ ์ˆ˜ํ–‰ํ•˜๋Š” ๊ฒƒ์ฒ˜๋Ÿผ), [~modeling_utils.load_sharded_checkpoint]๋ฅผ ์‚ฌ์šฉํ•ด์•ผ ํ•ฉ๋‹ˆ๋‹ค.

>>> from transformers.modeling_utils import load_sharded_checkpoint

>>> with tempfile.TemporaryDirectory() as tmp_dir:
...     model.save_pretrained(tmp_dir, max_shard_size="200MB")
...     load_sharded_checkpoint(model, tmp_dir)

์ €(ไฝŽ)๋ฉ”๋ชจ๋ฆฌ ๋กœ๋”ฉ [[low-memory-loading]]

์ƒค๋”ฉ๋œ ์ฒดํฌํฌ์ธํŠธ๋Š” ์œ„์—์„œ ์–ธ๊ธ‰ํ•œ ์ž‘์—… ํ๋ฆ„์˜ 2๋‹จ๊ณ„์—์„œ ๋ฉ”๋ชจ๋ฆฌ ์‚ฌ์šฉ๋Ÿ‰์„ ์ค„์ด์ง€๋งŒ, ์ €(ไฝŽ)๋ฉ”๋ชจ๋ฆฌ ์„ค์ •์—์„œ ๋ชจ๋ธ์„ ์‚ฌ์šฉํ•˜๊ธฐ ์œ„ํ•ด ์šฐ๋ฆฌ์˜ Accelerate ๋ผ์ด๋ธŒ๋Ÿฌ๋ฆฌ๋ฅผ ๊ธฐ๋ฐ˜์œผ๋กœ ํ•œ ๋„๊ตฌ๋ฅผ ํ™œ์šฉํ•˜๋Š” ๊ฒƒ์ด ์ข‹์Šต๋‹ˆ๋‹ค.

์ž์„ธํ•œ ์‚ฌํ•ญ์€ ๋‹ค์Œ ๊ฐ€์ด๋“œ๋ฅผ ์ฐธ์กฐํ•ด์ฃผ์„ธ์š”: Accelerate๋กœ ๋Œ€๊ทœ๋ชจ ๋ชจ๋ธ ๊ฐ€์ ธ์˜ค๊ธฐ (์˜๋ฌธ)