| <!--Copyright 2022 The HuggingFace Team. All rights reserved. | |
| Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with | |
| the License. You may obtain a copy of the License at | |
| http://www.apache.org/licenses/LICENSE-2.0 | |
| Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on | |
| an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the | |
| specific language governing permissions and limitations under the License. | |
| โ ๏ธ Note that this file is in Markdown but contain specific syntax for our doc-builder (similar to MDX) that may not be | |
| rendered properly in your Markdown viewer. | |
| --> | |
| # ํฐ ๋ชจ๋ธ ์ธ์คํด์คํ [[instantiating-a-big-model]] | |
| ๋งค์ฐ ํฐ ์ฌ์ ํ๋ จ๋ ๋ชจ๋ธ์ ์ฌ์ฉํ๋ ค๋ฉด, RAM ์ฌ์ฉ์ ์ต์ํํด์ผ ํ๋ ๊ณผ์ ๊ฐ ์์ต๋๋ค. ์ผ๋ฐ์ ์ธ PyTorch ์ํฌํ๋ก์ฐ๋ ๋ค์๊ณผ ๊ฐ์ต๋๋ค: | |
| 1. ๋ฌด์์ ๊ฐ์ค์น๋ก ๋ชจ๋ธ์ ์์ฑํฉ๋๋ค. | |
| 2. ์ฌ์ ํ๋ จ๋ ๊ฐ์ค์น๋ฅผ ๋ถ๋ฌ์ต๋๋ค. | |
| 3. ์ฌ์ ํ๋ จ๋ ๊ฐ์ค์น๋ฅผ ๋ฌด์์ ๋ชจ๋ธ์ ์ ์ฉํฉ๋๋ค. | |
| 1๋จ๊ณ์ 2๋จ๊ณ ๋ชจ๋ ๋ชจ๋ธ์ ์ ์ฒด ๋ฒ์ ์ ๋ฉ๋ชจ๋ฆฌ์ ์ ์ฌํด์ผ ํ๋ฉฐ, ๋๋ถ๋ถ ๋ฌธ์ ๊ฐ ์์ง๋ง ๋ชจ๋ธ์ด ๊ธฐ๊ฐ๋ฐ์ดํธ๊ธ์ ์ฉ๋์ ์ฐจ์งํ๊ธฐ ์์ํ๋ฉด ๋ณต์ฌ๋ณธ 2๊ฐ๊ฐ RAM์ ์ด๊ณผํ์ฌ ๋ฉ๋ชจ๋ฆฌ ๋ถ์กฑ ์ด์๋ฅผ ์ผ๊ธฐํ ์ ์์ต๋๋ค. ๋ ์ฌ๊ฐํ ๋ฌธ์ ๋ ๋ถ์ฐ ํ์ต์ ์ํด `torch.distributed`๋ฅผ ์ฌ์ฉํ๋ ๊ฒฝ์ฐ, ํ๋ก์ธ์ค๋ง๋ค ์ฌ์ ํ๋ จ๋ ๋ชจ๋ธ์ ๋ก๋ํ๊ณ ๋ณต์ฌ๋ณธ์ 2๊ฐ์ฉ RAM์ ์ ์ฅํ๋ค๋ ๊ฒ์ ๋๋ค. | |
| <Tip> | |
| ๋ฌด์์๋ก ์์ฑ๋ ๋ชจ๋ธ์ "๋น์ด ์๋" (์ฆ ๊ทธ๋ ๋ฉ๋ชจ๋ฆฌ์ ์๋ ๊ฒ์ผ๋ก ์ด๋ค์ง) ํ ์๋ก ์ด๊ธฐํ๋๋ฉฐ ๋ฉ๋ชจ๋ฆฌ ๊ณต๊ฐ์ ์ฐจ์งํฉ๋๋ค. ์ด๊ธฐํ๋ ๋ชจ๋ธ/ํ๋ผ๋ฏธํฐ์ ์ข ๋ฅ์ ์ ํฉํ ๋ถํฌ(์: ์ ๊ท ๋ถํฌ)์ ๋ฐ๋ฅธ ๋ฌด์์ ์ด๊ธฐํ๋ ๊ฐ๋ฅํ ํ ๋น ๋ฅด๊ฒ ํ๊ธฐ ์ํด ์ด๊ธฐํ๋์ง ์์ ๊ฐ์ค์น์ ๋ํด 3๋จ๊ณ ์ดํ์๋ง ์ํ๋ฉ๋๋ค! | |
| </Tip> | |
| ์ด ์๋ด์์์๋ Transformers๊ฐ ์ด ๋ฌธ์ ๋ฅผ ํด๊ฒฐํ๊ธฐ ์ํด ์ ๊ณตํ๋ ์๋ฃจ์ ์ ์ดํด๋ด ๋๋ค. ์ฃผ์ํ ์ ์ ์์ง ํ๋ฐํ ๊ฐ๋ฐ ์ค์ธ ๋ถ์ผ์ด๋ฏ๋ก ์ฌ๊ธฐ์ ์ค๋ช ํ๋ API๊ฐ ์์ผ๋ก ์ฝ๊ฐ ๋ณ๊ฒฝ๋ ์ ์๋ค๋ ๊ฒ์ ๋๋ค. | |
| ## ์ค๋ฉ๋ ์ฒดํฌํฌ์ธํธ [[sharded-checkpoints]] | |
| 4.18.0 ๋ฒ์ ์ดํ, 10GB ์ด์์ ๊ณต๊ฐ์ ์ฐจ์งํ๋ ๋ชจ๋ธ ์ฒดํฌํฌ์ธํธ๋ ์๋์ผ๋ก ์์ ์กฐ๊ฐ๋ค๋ก ์ค๋ฉ๋ฉ๋๋ค. `model.save_pretrained(save_dir)`๋ฅผ ์คํํ ๋ ํ๋์ ๋จ์ผ ์ฒดํฌํฌ์ธํธ๋ฅผ ๊ฐ์ง๊ฒ ๋ ๋์ , ์ฌ๋ฌ ๋ถ๋ถ ์ฒดํฌํฌ์ธํธ(๊ฐ๊ฐ์ ํฌ๊ธฐ๋ 10GB ๋ฏธ๋ง)์ ๋งค๊ฐ๋ณ์ ์ด๋ฆ์ ํด๋น ํ์ผ์ ๋งคํํ๋ ์ธ๋ฑ์ค๊ฐ ์์ฑ๋ฉ๋๋ค. | |
| `max_shard_size` ๋งค๊ฐ๋ณ์๋ก ์ค๋ฉ ์ ์ต๋ ํฌ๊ธฐ๋ฅผ ์ ์ดํ ์ ์์ผ๋ฏ๋ก, ์ด ์์ ๋ฅผ ์ํด ์ค๋ ํฌ๊ธฐ๊ฐ ์์ ์ผ๋ฐ ํฌ๊ธฐ์ ๋ชจ๋ธ์ ์ฌ์ฉํ๊ฒ ์ต๋๋ค: ์ ํต์ ์ธ BERT ๋ชจ๋ธ์ ์ฌ์ฉํด ๋ด ์๋ค. | |
| ```py | |
| from transformers import AutoModel | |
| model = AutoModel.from_pretrained("google-bert/bert-base-cased") | |
| ``` | |
| [`~PreTrainedModel.save_pretrained`]์ ์ฌ์ฉํ์ฌ ๋ชจ๋ธ์ ์ ์ฅํ๋ฉด, ๋ชจ๋ธ์ ๊ตฌ์ฑ๊ณผ ๊ฐ์ค์น๊ฐ ๋ค์ด์๋ ๋ ๊ฐ์ ํ์ผ์ด ์๋ ์ ํด๋๊ฐ ์์ฑ๋ฉ๋๋ค: | |
| ```py | |
| >>> import os | |
| >>> import tempfile | |
| >>> with tempfile.TemporaryDirectory() as tmp_dir: | |
| ... model.save_pretrained(tmp_dir) | |
| ... print(sorted(os.listdir(tmp_dir))) | |
| ['config.json', 'pytorch_model.bin'] | |
| ``` | |
| ์ด์ ์ต๋ ์ค๋ ํฌ๊ธฐ๋ฅผ 200MB๋ก ์ฌ์ฉํด ๋ด ์๋ค: | |
| ```py | |
| >>> with tempfile.TemporaryDirectory() as tmp_dir: | |
| ... model.save_pretrained(tmp_dir, max_shard_size="200MB") | |
| ... print(sorted(os.listdir(tmp_dir))) | |
| ['config.json', 'pytorch_model-00001-of-00003.bin', 'pytorch_model-00002-of-00003.bin', 'pytorch_model-00003-of-00003.bin', 'pytorch_model.bin.index.json'] | |
| ``` | |
| ๋ชจ๋ธ์ ๊ตฌ์ฑ์ ๋ํด, ์ธ ๊ฐ์ ๋ค๋ฅธ ๊ฐ์ค์น ํ์ผ๊ณผ ํ๋ผ๋ฏธํฐ ์ด๋ฆ๊ณผ ํด๋น ํ์ผ์ ๋งคํ์ด ํฌํจ๋ `index.json` ํ์ผ์ ๋ณผ ์ ์์ต๋๋ค. ์ด๋ฌํ ์ฒดํฌํฌ์ธํธ๋ [`~PreTrainedModel.from_pretrained`] ๋ฉ์๋๋ฅผ ์ฌ์ฉํ์ฌ ์์ ํ ๋ค์ ๋ก๋ํ ์ ์์ต๋๋ค: | |
| ```py | |
| >>> with tempfile.TemporaryDirectory() as tmp_dir: | |
| ... model.save_pretrained(tmp_dir, max_shard_size="200MB") | |
| ... new_model = AutoModel.from_pretrained(tmp_dir) | |
| ``` | |
| ํฐ ๋ชจ๋ธ์ ๊ฒฝ์ฐ ์ด๋ฌํ ๋ฐฉ์์ผ๋ก ์ฒ๋ฆฌํ๋ ์ฃผ๋ ์ฅ์ ์ ์์์ ๋ณด์ฌ์ค ํ๋ฆ์ 2๋จ๊ณ์์, ๊ฐ ์ค๋๊ฐ ์ด์ ์ค๋ ๋ค์์ ๋ก๋๋๋ฏ๋ก ๋ฉ๋ชจ๋ฆฌ ์ฌ์ฉ๋์ด ๋ชจ๋ธ ํฌ๊ธฐ์ ๊ฐ์ฅ ํฐ ์ค๋์ ํฌ๊ธฐ๋ฅผ ์ด๊ณผํ์ง ์๋๋ค๋ ์ ์ ๋๋ค. | |
| ์ด ์ธ๋ฑ์ค ํ์ผ์ ํค๊ฐ ์ฒดํฌํฌ์ธํธ์ ์๋์ง, ๊ทธ๋ฆฌ๊ณ ํด๋น ๊ฐ์ค์น๊ฐ ์ด๋์ ์ ์ฅ๋์ด ์๋์ง๋ฅผ ๊ฒฐ์ ํ๋ ๋ฐ ์ฌ์ฉ๋ฉ๋๋ค. ์ด ์ธ๋ฑ์ค๋ฅผ json๊ณผ ๊ฐ์ด ๋ก๋ํ๊ณ ๋์ ๋๋ฆฌ๋ฅผ ์ป์ ์ ์์ต๋๋ค: | |
| ```py | |
| >>> import json | |
| >>> with tempfile.TemporaryDirectory() as tmp_dir: | |
| ... model.save_pretrained(tmp_dir, max_shard_size="200MB") | |
| ... with open(os.path.join(tmp_dir, "pytorch_model.bin.index.json"), "r") as f: | |
| ... index = json.load(f) | |
| >>> print(index.keys()) | |
| dict_keys(['metadata', 'weight_map']) | |
| ``` | |
| ๋ฉํ๋ฐ์ดํฐ๋ ํ์ฌ ๋ชจ๋ธ์ ์ด ํฌ๊ธฐ๋ง ํฌํจ๋ฉ๋๋ค. ์์ผ๋ก ๋ค๋ฅธ ์ ๋ณด๋ฅผ ์ถ๊ฐํ ๊ณํ์ ๋๋ค: | |
| ```py | |
| >>> index["metadata"] | |
| {'total_size': 433245184} | |
| ``` | |
| ๊ฐ์ค์น ๋งต์ ์ด ์ธ๋ฑ์ค์ ์ฃผ์ ๋ถ๋ถ์ผ๋ก, ๊ฐ ๋งค๊ฐ๋ณ์ ์ด๋ฆ(PyTorch ๋ชจ๋ธ `state_dict`์์ ๋ณดํต ์ฐพ์ ์ ์๋)์ ํด๋น ํ์ผ์ ๋งคํํฉ๋๋ค: | |
| ```py | |
| >>> index["weight_map"] | |
| {'embeddings.LayerNorm.bias': 'pytorch_model-00001-of-00003.bin', | |
| 'embeddings.LayerNorm.weight': 'pytorch_model-00001-of-00003.bin', | |
| ... | |
| ``` | |
| ๋ง์ฝ [`~PreTrainedModel.from_pretrained`]๋ฅผ ์ฌ์ฉํ์ง ์๊ณ ๋ชจ๋ธ ๋ด์์ ์ด๋ฌํ ์ค๋ฉ๋ ์ฒดํฌํฌ์ธํธ๋ฅผ ์ง์ ๊ฐ์ ธ์ค๋ ค๋ฉด (์ ์ฒด ์ฒดํฌํฌ์ธํธ๋ฅผ ์ํด `model.load_state_dict()`๋ฅผ ์ํํ๋ ๊ฒ์ฒ๋ผ), [`~modeling_utils.load_sharded_checkpoint`]๋ฅผ ์ฌ์ฉํด์ผ ํฉ๋๋ค. | |
| ```py | |
| >>> from transformers.modeling_utils import load_sharded_checkpoint | |
| >>> with tempfile.TemporaryDirectory() as tmp_dir: | |
| ... model.save_pretrained(tmp_dir, max_shard_size="200MB") | |
| ... load_sharded_checkpoint(model, tmp_dir) | |
| ``` | |
| ## ์ (ไฝ)๋ฉ๋ชจ๋ฆฌ ๋ก๋ฉ [[low-memory-loading]] | |
| ์ค๋ฉ๋ ์ฒดํฌํฌ์ธํธ๋ ์์์ ์ธ๊ธํ ์์ ํ๋ฆ์ 2๋จ๊ณ์์ ๋ฉ๋ชจ๋ฆฌ ์ฌ์ฉ๋์ ์ค์ด์ง๋ง, ์ (ไฝ)๋ฉ๋ชจ๋ฆฌ ์ค์ ์์ ๋ชจ๋ธ์ ์ฌ์ฉํ๊ธฐ ์ํด ์ฐ๋ฆฌ์ Accelerate ๋ผ์ด๋ธ๋ฌ๋ฆฌ๋ฅผ ๊ธฐ๋ฐ์ผ๋ก ํ ๋๊ตฌ๋ฅผ ํ์ฉํ๋ ๊ฒ์ด ์ข์ต๋๋ค. | |
| ์์ธํ ์ฌํญ์ ๋ค์ ๊ฐ์ด๋๋ฅผ ์ฐธ์กฐํด์ฃผ์ธ์: [Accelerate๋ก ๋๊ท๋ชจ ๋ชจ๋ธ ๊ฐ์ ธ์ค๊ธฐ (์๋ฌธ)](../en/main_classes/model#large-model-loading) |