๋ชจ๋ธ ๋ก๋ํ๊ธฐ[[loading-models]]
Transformers๋ ํ ์ค์ ์ฝ๋๋ก ์ฌ์ฉํ ์ ์๋ ๋ง์ ์ฌ์ ํ๋ จ๋ ๋ชจ๋ธ์ ์ ๊ณตํฉ๋๋ค. ๋ชจ๋ธ ํด๋์ค์ [~PreTrainedModel.from_pretrained] ๋ฉ์๋๊ฐ ํ์ํฉ๋๋ค.
[~PreTrainedModel.from_pretrained]๋ฅผ ํธ์ถํ์ฌ Hugging Face Hub์ ์ ์ฅ๋ ๋ชจ๋ธ์ ๊ฐ์ค์น์ ๊ตฌ์ฑ์ ๋ค์ด๋ก๋ํ๊ณ ๋ก๋ํ์ธ์.
[
~PreTrainedModel.from_pretrained] ๋ฉ์๋๋ safetensors ํ์ผ ํ์์ผ๋ก ์ ์ฅ๋ ๊ฐ์ค์น๊ฐ ์์ผ๋ฉด ์ด๋ฅผ ๋ก๋ํฉ๋๋ค. ์ ํต์ ์ผ๋ก PyTorch ๋ชจ๋ธ ๊ฐ์ค์น๋ ๋ณด์์ ์ทจ์ฝํ ๊ฒ์ผ๋ก ์๋ ค์ง pickle ์ ํธ๋ฆฌํฐ๋ก ์ง๋ ฌํ๋ฉ๋๋ค. Safetensor ํ์ผ์ ๋ ์์ ํ๊ณ ๋ก๋ ์๋๊ฐ ๋น ๋ฆ ๋๋ค.
from transformers import AutoModelForCausalLM
model = AutoModelForCausalLM.from_pretrained("meta-llama/Llama-2-7b-hf", torch_dtype="auto", device_map="auto")
์ด ๊ฐ์ด๋๋ ๋ชจ๋ธ์ ๋ถ๋ฌ์ค๋ ๋ฐฉ๋ฒ, ๋ค์ํ ๋ก๋ฉ ๋ฐฉ์, ๋งค์ฐ ํฐ ๋ชจ๋ธ์์ ๋ฐ์ํ ์ ์๋ ๋ฉ๋ชจ๋ฆฌ ๋ฌธ์ ๋ฅผ ํด๊ฒฐํ๋ ๋ฐฉ๋ฒ, ๊ทธ๋ฆฌ๊ณ ์ฌ์ฉ์ ์ ์ ๋ชจ๋ธ์ ๋ถ๋ฌ์ค๋ ๋ฐฉ๋ฒ์ ์ค๋ช ํฉ๋๋ค.
๋ชจ๋ธ๊ณผ ๊ตฌ์ฑ[[models-and-configurations]]
๋ชจ๋ ๋ชจ๋ธ์๋ ์๋ ๋ ์ด์ด ์, ์ดํ ์ฌ์ ํฌ๊ธฐ, ํ์ฑํ ํจ์ ๋ฑ๊ณผ ๊ฐ์ ํน์ ์์ฑ์ด ํฌํจ๋ configuration.py ํ์ผ์ด ์์ต๋๋ค. ๋ํ ๊ฐ ๋ ์ด์ด์ ์ ์์ ๊ฐ๊ฐ์ ๋ ์ด์ด ์์์ ์ผ์ด๋๋ ์ํ์ ์ฐ์ฐ์ ์ ์ํ๋ modeling.py ํ์ผ๋ ์์ต๋๋ค. modeling.py ํ์ผ์ configuration.py์ ์ ์๋ ๋ชจ๋ธ ์์ฑ์ ๋ฐํ์ผ๋ก ๋ชจ๋ธ์ ๊ตฌ์ถํฉ๋๋ค. ์ด ๋จ๊ณ์์๋ ์์ง ํ์ต๋์ง ์์ ๋ฌด์์ ๊ฐ์ค์น๋ฅผ ๊ฐ์ง ์ํ์ด๊ธฐ ๋๋ฌธ์, ์๋ฏธ ์๋ ์ถ๋ ฅ์ ์ป๊ธฐ ์ํด์๋ ํ์ต์ด ํ์ํฉ๋๋ค.
*์ํคํ ์ฒ(Architecture)*๋ ๋ชจ๋ธ์ ๊ณจ๊ฒฉ์ ์๋ฏธํ๊ณ *์ฒดํฌํฌ์ธํธ(checkpoint)*๋ ์ฃผ์ด์ง ์ํคํ ์ฒ์ ๋ํ ๋ชจ๋ธ์ ๊ฐ์ค์น๋ฅผ ์๋ฏธํฉ๋๋ค. ์๋ฅผ ๋ค์ด, BERT๋ ์ํคํ ์ฒ์ด๊ณ google-bert/bert-base-uncased๋ ํด๋น ์ํคํ ์ฒ์ ์ฒดํฌํฌ์ธํธ(checkpoint)์ ๋๋ค. ๋ชจ๋ธ์ด๋ผ๋ ์ฉ์ด๋ ์ํคํ ์ฒ ๋ฐ ์ฒดํฌํฌ์ธํธ(checkpoint)์ ํผ์ฉํ์ฌ ์ฌ์ฉ๋๋ ๊ฒ์ ๋ณผ ์ ์์ต๋๋ค.
๋ก๋ํ ์ ์๋ ๋ชจ๋ธ์ ์ผ๋ฐ์ ์ผ๋ก ๋ ๊ฐ์ง ํ์ ์ด ์์ต๋๋ค.
- ์๋ ์ํ๋ฅผ ์ถ๋ ฅํ๋ [
AutoModel] ๋๋ [LlamaModel]๊ณผ ๊ฐ์ ๊ธฐ๋ณธ ๋ชจ๋ธ์ ๋๋ค. - ํน์ ์์
์ ์ํํ๊ธฐ ์ํด ํน์ ํค๋๊ฐ ๋ถ์ [
AutoModelForCausalLM] ๋๋ [LlamaForCausalLM]๊ณผ ๊ฐ์ ๋ชจ๋ธ์ ๋๋ค.
๊ฐ ๋ชจ๋ธ ํ์ ๋ง๋ค, ๊ฐ๊ฐ์ ๊ธฐ๊ณํ์ต ํ๋ ์์ํฌ(PyTorch, TensorFlow, Flax)๋ฅผ ์ํ ๋ณ๋์ ํด๋์ค๊ฐ ์์ต๋๋ค. ์ฌ์ฉ ์ค์ธ ํ๋ ์์ํฌ์ ํด๋นํ๋ ์ ๋์ด(prefix)๋ฅผ ์ ํํ์ธ์.
from transformers import AutoModelForCausalLM, MistralForCausalLM
# AutoClass ๋๋ ๋ชจ๋ธ๋ณ ํด๋์ค(model-specific class) ๋ฅผ ์ด์ฉํด ๋ก๋
model = AutoModelForCausalLM.from_pretrained("mistralai/Mistral-7B-v0.1", torch_dtype="auto", device_map="auto")
model = MistralForCausalLM.from_pretrained("mistralai/Mistral-7B-v0.1", torch_dtype="auto", device_map="auto")
from transformers import TFAutoModelForCausalLM, TFMistralForCausalLM
# AutoClass ๋๋ ๋ชจ๋ธ๋ณ ํด๋์ค(model-specific class) ๋ฅผ ์ด์ฉํด ๋ก๋
model = TFAutoModelForCausalLM.from_pretrained("mistralai/Mistral-7B-v0.1")
model = TFMistralForCausalLM.from_pretrained("mistralai/Mistral-7B-v0.1")
from transformers import FlaxAutoModelForCausalLM, FlaxMistralForCausalLM
# AutoClass ๋๋ ๋ชจ๋ธ๋ณ ํด๋์ค(model-specific class) ๋ฅผ ์ด์ฉํด ๋ก๋
model = FlaxAutoModelForCausalLM.from_pretrained("mistralai/Mistral-7B-v0.1")
model = FlaxMistralForCausalLM.from_pretrained("mistralai/Mistral-7B-v0.1")
๋ชจ๋ธ ํด๋์ค[[model-classes]]
์ฌ์ ํ๋ จ๋ ๋ชจ๋ธ์ ๊ฐ์ ธ์ค๋ ค๋ฉด ๋ชจ๋ธ์ ๊ฐ์ค์น๋ฅผ ๋ก๋ํด์ผ ํฉ๋๋ค. ์ด๋ Hugging Face Hub ๋๋ ๋ก์ปฌ ๋๋ ํฐ๋ฆฌ์์ ๊ฐ์ค์น๋ฅผ ๋ฐ์๋ค์ด๋ [~PreTrainedModel.from_pretrained]๋ฅผ ํธ์ถํ์ฌ ์ํ๋ฉ๋๋ค.
๋ ๊ฐ์ง ๋ชจ๋ธ ํด๋์ค๋ก AutoModel ํด๋์ค์ ๋ชจ๋ธ๋ณ ํด๋์ค๊ฐ ์์ต๋๋ค.
AutoModel ํด๋์ค๋ ์ ํํ ๋ชจ๋ธ ํด๋์ค ์ด๋ฆ์ ๋ชฐ๋ผ๋ ์ํคํ ์ฒ๋ฅผ ๋ถ๋ฌ์ฌ ์ ์๋ ํธ๋ฆฌํ ๋ฐฉ๋ฒ์ ๋๋ค. ๋ง์ ๋ชจ๋ธ์ด ์ ๊ณต๋๊ธฐ ๋๋ฌธ์, ์ด ํด๋์ค๋ ๊ตฌ์ฑ ํ์ผ์ ๊ธฐ๋ฐ์ผ๋ก ์ฌ๋ฐ๋ฅธ ๋ชจ๋ธ ํด๋์ค๋ฅผ ์๋์ผ๋ก ์ ํํด ์ค๋๋ค. ์ํ๋ ์์ ๊ณผ ์ฌ์ฉํ๋ ค๋ ์ฒดํฌํฌ์ธํธ๋ง ์๊ณ ์์ผ๋ฉด ๋ฉ๋๋ค.
์ฃผ์ด์ง ์์ ์ ์ํคํ ์ฒ๊ฐ ์ง์ํ๋ ํ, ๋ชจ๋ธ์ด๋ ์์ ์ ์ฝ๊ฒ ์ ํํ ์ ์์ต๋๋ค.
์๋ฅผ ๋ค์ด, ๋์ผํ ๋ชจ๋ธ์ ์๋ก ๋ค๋ฅธ ์์ ์ ์ฌ์ฉํ ์ ์์ต๋๋ค.
from transformers import AutoModelForCausalLM, AutoModelForSequenceClassification, AutoModelForQuestionAnswering
# ๋์ผํ API๋ฅผ 3๊ฐ์ง ๋ค๋ฅธ ์์
์ ์ฌ์ฉ
model = AutoModelForCausalLM.from_pretrained("meta-llama/Llama-2-7b-hf")
model = AutoModelForSequenceClassification.from_pretrained("meta-llama/Llama-2-7b-hf")
model = AutoModelForQuestionAnswering.from_pretrained("meta-llama/Llama-2-7b-hf")
๋ค๋ฅธ ๊ฒฝ์ฐ์๋, ํ๋์ ์์ ์ ๋ํด ์ฌ๋ฌ ๊ฐ์ง ๋ชจ๋ธ์ ๋น ๋ฅด๊ฒ ์ํํด๋ณด๊ณ ์ถ์ ์๋ ์์ต๋๋ค.
from transformers import AutoModelForCausalLM
# ๋์ผํ API๋ฅผ ์ฌ์ฉํ์ฌ 3๊ฐ์ง ๋ค๋ฅธ ๋ชจ๋ธ ๋ก๋
model = AutoModelForCausalLM.from_pretrained("meta-llama/Llama-2-7b-hf")
model = AutoModelForCausalLM.from_pretrained("mistralai/Mistral-7B-v0.1")
model = AutoModelForCausalLM.from_pretrained("google/gemma-7b")
AutoModel ํด๋์ค๋ ๋ชจ๋ธ๋ณ ํด๋์ค๋ค์ ๊ธฐ๋ฐ์ผ๋ก ๊ตฌ์ถ๋ฉ๋๋ค. ํน์ ์์
์ ์ง์ํ๋ ๋ชจ๋ ๋ชจ๋ธ ํด๋์ค๋ค์ ๊ฐ๊ฐ์ AutoModelFor ์์
ํด๋์ค์ ๋งคํ๋ฉ๋๋ค.
์ด๋ฏธ ์ฌ์ฉํ๋ ค๋ ๋ชจ๋ธ ํด๋์ค๋ฅผ ์๊ณ ์๋ค๋ฉด ํด๋น ๋ชจ๋ธ๋ณ ํด๋์ค๋ฅผ ์ง์ ์ฌ์ฉํ ์ ์์ต๋๋ค.
from transformers import LlamaModel, LlamaForCausalLM
model = LlamaForCausalLM.from_pretrained("meta-llama/Llama-2-7b-hf")
๋๊ท๋ชจ ๋ชจ๋ธ[[large-models]]
๋๊ท๋ชจ ์ฌ์ ํ๋ จ๋ ๋ชจ๋ธ์ ๋ก๋ํ๋ ๋ฐ ๋ง์ ๋ฉ๋ชจ๋ฆฌ๊ฐ ํ์ํฉ๋๋ค. ๋ก๋ ๊ณผ์ ์ ๋ค์์ ํฌํจํฉ๋๋ค:
- ๋ฌด์์ ๊ฐ์ค์น๋ก ๋ชจ๋ธ ์์ฑ
- ์ฌ์ ํ๋ จ๋ ๊ฐ์ค์น ๋ก๋
- ์ฌ์ ํ๋ จ๋ ๊ฐ์ค์น๋ฅผ ๋ชจ๋ธ์ ์ ์ฉ
๋ชจ๋ธ ๊ฐ์ค์น์ ๋ณต์ฌ๋ณธ ๋ ๊ฐ์ง(๋ฌด์์ ๊ฐ์ค์น์ ์ฌ์ ํ๋ จ๋ ๊ฐ์ค์น)๋ฅผ ๋ณด๊ดํ ์ ์๋ ์ถฉ๋ถํ ๋ฉ๋ชจ๋ฆฌ๊ฐ ํ์ํ๋ฉฐ, ์ด๋ ๋ณด์ ํ ํ๋์จ์ด์ ๋ฐ๋ผ ๋ถ๊ฐ๋ฅํ ์ ์์ต๋๋ค. ๋ถ์ฐ ํ์ต ํ๊ฒฝ์์๋ ๊ฐ ํ๋ก์ธ์ค๊ฐ ์ฌ์ ํ๋ จ๋ ๋ชจ๋ธ์ ๋ก๋ํ๊ธฐ ๋๋ฌธ์ ์ด๋ ๋์ฑ ์ด๋ ค์ด ๊ณผ์ ์ ๋๋ค.
transformers๋ ๋น ๋ฅธ ์ด๊ธฐํ, ๋ถํ ๋ ์ฒดํฌํฌ์ธํธ, Accelerate์ Big Model Inference ๊ธฐ๋ฅ, ๊ทธ๋ฆฌ๊ณ ๋ ๋ฎ์ ๋นํธ ๋ฐ์ดํฐ ํ์ ์ง์์ ํตํด ์ด๋ฌํ ๋ฉ๋ชจ๋ฆฌ ๊ด๋ จ ๋ฌธ์ ๋ค์ ์ผ๋ถ ์ค์ฌ์ค๋๋ค.
๋ถํ ๋ ์ฒดํฌํฌ์ธํธ[[sharded-checkpoints]]
[~PreTrainedModel.save_pretrained] ๋ฉ์๋๋ 10GB๋ณด๋ค ํฐ ์ฒดํฌํฌ์ธํธ๋ฅผ ์๋์ผ๋ก ์ค๋ํฉ๋๋ค.
๊ฐ ์ค๋(shard)๋ ์ด์ ์ค๋๊ฐ ๋ก๋๋ ํ ์์ฐจ์ ์ผ๋ก ๋ก๋๋์ด, ๋ฉ๋ชจ๋ฆฌ ์ฌ์ฉ๋์ ๋ชจ๋ธ ํฌ๊ธฐ์ ๊ฐ์ฅ ํฐ ์ค๋ ํฌ๊ธฐ๋ก๋ง ์ ํํฉ๋๋ค.
max_shard_size ๋งค๊ฐ๋ณ์๋ ๊ฐ ์ค๋์ ๋ํด ๊ธฐ๋ณธ์ ์ผ๋ก 5GB๋ก ์ค์ ๋์ด ์๋๋ฐ, ์ด๋ ๋ฉ๋ชจ๋ฆฌ ๋ถ์กฑ ์์ด ๋ฌด๋ฃ ๋ฑ๊ธ GPU ์ธ์คํด์ค์์ ๋ ์ฝ๊ฒ ์คํํ ์ ์๊ธฐ ๋๋ฌธ์
๋๋ค.
์๋ฅผ ๋ค์ด, [~PreTrainedModel.save_pretrained]์์ BioMistral/BioMistral-7B์ ๋ํ ๋ถํ ๋ ์ฒดํฌํฌ์ธํธ๋ฅผ ์์ฑํด๋ณด๊ฒ ์ต๋๋ค.
from transformers import AutoModel
import tempfile
import os
model = AutoModel.from_pretrained("biomistral/biomistral-7b")
with tempfile.TemporaryDirectory() as tmp_dir:
model.save_pretrained(tmp_dir, max_shard_size="5GB")
print(sorted(os.listdir(tmp_dir)))
[~PreTrainedModel.from_pretrained]๋ก ๋ถํ ๋ ์ฒดํฌํฌ์ธํธ๋ฅผ ๋ค์ ๋ก๋ํฉ๋๋ค.
with tempfile.TemporaryDirectory() as tmp_dir:
model.save_pretrained(tmp_dir)
new_model = AutoModel.from_pretrained(tmp_dir)
๋ถํ ๋ ์ฒดํฌํฌ์ธํธ๋ [~transformers.trainer_utils.load_sharded_checkpoint]๋ก๋ ์ง์ ๋ถ๋ฌ์ฌ ์ ์์ต๋๋ค.
from transformers.trainer_utils import load_sharded_checkpoint
with tempfile.TemporaryDirectory() as tmp_dir:
model.save_pretrained(tmp_dir, max_shard_size="5GB")
load_sharded_checkpoint(model, tmp_dir)
[~PreTrainedModel.save_pretrained] ๋ฉ์๋๋ ๋งค๊ฐ๋ณ์ ์ด๋ฆ์ ์ ์ฅ๋ ํ์ผ์ ๋งคํํ๋ ์ธ๋ฑ์ค ํ์ผ์ ์์ฑํฉ๋๋ค. ์ธ๋ฑ์ค ํ์ผ์๋ metadata์ weight_map์ด๋ผ๋ ๋ ๊ฐ์ ํค๊ฐ ์์ต๋๋ค.
import json
with tempfile.TemporaryDirectory() as tmp_dir:
model.save_pretrained(tmp_dir, max_shard_size="5GB")
with open(os.path.join(tmp_dir, "model.safetensors.index.json"), "r") as f:
index = json.load(f)
print(index.keys())
metadata ํค๋ ์ ์ฒด ๋ชจ๋ธ ํฌ๊ธฐ๋ฅผ ์ ๊ณตํฉ๋๋ค.
index["metadata"]
{'total_size': 28966928384}
weight_map ํค๋ ๊ฐ ๋งค๊ฐ๋ณ์๋ฅผ ์ ์ฅ๋ ์ค๋์ ๋งคํํฉ๋๋ค.
index["weight_map"]
{'lm_head.weight': 'model-00006-of-00006.safetensors',
'model.embed_tokens.weight': 'model-00001-of-00006.safetensors',
'model.layers.0.input_layernorm.weight': 'model-00001-of-00006.safetensors',
'model.layers.0.mlp.down_proj.weight': 'model-00001-of-00006.safetensors',
...
}
๋ํ ๋ชจ๋ธ ์ถ๋ก [[big-model-inference]]
์ด ๊ธฐ๋ฅ์ ์ฌ์ฉํ๋ ค๋ฉด Accelerate v0.9.0 ๋ฐ PyTorch v1.9.0 ์ด์์ด ์ค์น๋์ด ์๋์ง ํ์ธํ์ธ์!
[~PreTrainedModel.from_pretrained]๋ Accelerate์ ๋ํ ๋ชจ๋ธ ์ถ๋ก ๊ธฐ๋ฅ์ผ๋ก ๊ฐํ๋์์ต๋๋ค.
๋ํ ๋ชจ๋ธ ์ถ๋ก ์ PyTorch meta ์ฅ์น์์ ๋ชจ๋ธ ์ค์ผ๋ ํค์ ์์ฑํฉ๋๋ค. meta ์ฅ์น๋ ์ค์ ๋ฐ์ดํฐ๋ฅผ ์ ์ฅํ์ง ์๊ณ ๋ฉํ๋ฐ์ดํฐ๋ง ์ ์ฅํฉ๋๋ค.
๋ฌด์์๋ก ์ด๊ธฐํ๋ ๊ฐ์ค์น๋ ์ฌ์ ํ๋ จ๋ ๊ฐ์ค์น๊ฐ ๋ก๋๋ ๋๋ง ์์ฑ๋์ด ๋ฉ๋ชจ๋ฆฌ์ ๋์์ ๋ชจ๋ธ์ ๋ ๋ณต์ฌ๋ณธ์ ์ ์งํ๋ ๊ฒ์ ๋ฐฉ์งํฉ๋๋ค. ์ต๋ ๋ฉ๋ชจ๋ฆฌ ์ฌ์ฉ๋์ ๋ชจ๋ธ ํฌ๊ธฐ๋งํผ์ ๋๋ค.
์ฅ์น ํ ๋น์ ๋ํ ์์ธํ ๋ด์ฉ์ ์ฅ์น ๋งต ์ค๊ณํ๊ธฐ๋ฅผ ์ฐธ์กฐํ์ธ์.
๋ํ ๋ชจ๋ธ ์ถ๋ก ์ ๋ ๋ฒ์งธ ๊ธฐ๋ฅ์ ๋ถ๋ฌ์จ ๊ฐ์ค์น๊ฐ ๋ชจ๋ธ ์ค์ผ๋ ํค์ ํ ๋น๋๋ ๋ฐฉ์๊ณผ ๊ด๋ จ์ด ์์ต๋๋ค. ๋ชจ๋ธ ๊ฐ์ค์น๋ ์ฌ์ฉ ๊ฐ๋ฅํ ๋ชจ๋ ๋๋ฐ์ด์ค์ ๋ถ์ฐ๋๋ฉฐ, ๊ฐ์ฅ ๋น ๋ฅธ ๋๋ฐ์ด์ค(๋ณดํต GPU)๋ถํฐ ์์ํด ๋๋จธ์ง ๊ฐ์ค์น๋ ๋๋ฆฐ ๋๋ฐ์ด์ค(CPU ๋ฐ ํ๋ ๋์คํฌ)๋ก ์์ฐจ์ ์ผ๋ก ํ ๋น๋ฉ๋๋ค.
๋ ๊ธฐ๋ฅ์ ๊ฒฐํฉํ๋ฉด ๋ํ ์ฌ์ ํ๋ จ๋ ๋ชจ๋ธ์ ๋ฉ๋ชจ๋ฆฌ ์ฌ์ฉ๋๊ณผ ๋ก๋ฉ ์๊ฐ์ด ์ค์ด๋ญ๋๋ค.
๋ํ ๋ชจ๋ธ ์ถ๋ก ์ ํ์ฑํํ๋ ค๋ฉด device_map์ "auto"๋ก ์ค์ ํฉ๋๋ค.
from transformers import AutoModelForCausalLM
model = AutoModelForCausalLM.from_pretrained("google/gemma-7b", device_map="auto")
device_map์ ์ฌ์ฉํ์ฌ ๋ ์ด์ด๋ฅผ ์ํ๋ ๋๋ฐ์ด์ค์ ์๋์ผ๋ก ํ ๋นํ ์๋ ์์ต๋๋ค. ๋ ์ด์ด ์ ์ฒด๊ฐ ๋์ผํ ๋๋ฐ์ด์ค์ ํ ๋น๋์ด ์๋ค๋ฉด, ํด๋น ๋ ์ด์ด์ ๋ชจ๋ ์๋ธ๋ชจ๋์ด ์ด๋์ ๋ฐฐ์น๋๋์ง ์ผ์ผ์ด ์ง์ ํ ํ์๋ ์์ต๋๋ค.
๋ชจ๋ธ์ด ๊ฐ ๋๋ฐ์ด์ค์ ์ด๋ป๊ฒ ๋ถ์ฐ๋์ด ์๋์ง ํ์ธํ๋ ค๋ฉด hf_device_map ์์ฑ์ ์กฐํํฉ๋๋ค.
device_map = {"model.layers.1": 0, "model.layers.14": 1, "model.layers.31": "cpu", "lm_head": "disk"}
model.hf_device_map
๋ชจ๋ธ ๋ฐ์ดํฐ ํ์ [[model-data-type]]
PyTorch ๋ชจ๋ธ ๊ฐ์ค์น๋ ๊ธฐ๋ณธ์ ์ผ๋ก torch.float32๋ก ์ด๊ธฐํ๋ฉ๋๋ค. torch.float16๊ณผ ๊ฐ์ด ๋ค๋ฅธ ๋ฐ์ดํฐ ํ์
์ผ๋ก ๋ชจ๋ธ์ ๋ก๋ํ๋ฉด ๋ชจ๋ธ์ด ์ํ๋ ๋ฐ์ดํฐ ํ์
์ผ๋ก ๋ค์ ๋ก๋๋๊ธฐ ๋๋ฌธ์ ์ถ๊ฐ ๋ฉ๋ชจ๋ฆฌ๊ฐ ํ์ํฉ๋๋ค.
torch_dtype ๋งค๊ฐ๋ณ์๋ฅผ ๋ช
์์ ์ผ๋ก ์ค์ ํ์ฌ ๊ฐ์ค์น๋ฅผ ๋ ๋ฒ ๋ก๋ํ๋ ๋์ (torch.float32 ํ torch.float16) ์ํ๋ ๋ฐ์ดํฐ ํ์
์ผ๋ก ๋ชจ๋ธ์ ์ง์ ์ด๊ธฐํํ์ธ์. ๋ํ torch_dtype="auto"๋ฅผ ์ค์ ํ์ฌ ๊ฐ์ค์น๋ฅผ ์ ์ฅ๋ ๋ฐ์ดํฐ ํ์
์ผ๋ก ์๋์ผ๋ก ๋ก๋ํ ์๋ ์์ต๋๋ค.
import torch
from transformers import AutoModelForCausalLM
gemma = AutoModelForCausalLM.from_pretrained("google/gemma-7b", torch_dtype=torch.float16)
from transformers import AutoModelForCausalLM
gemma = AutoModelForCausalLM.from_pretrained("google/gemma-7b", torch_dtype="auto")
๋ชจ๋ธ์ ์ฒ์๋ถํฐ ์ธ์คํด์คํํ๋ ๊ฒฝ์ฐ, torch_dtype ํ๋ผ๋ฏธํฐ๋ [AutoConfig]์์ ์ค์ ํ ์๋ ์์ต๋๋ค.
import torch
from transformers import AutoConfig, AutoModel
my_config = AutoConfig.from_pretrained("google/gemma-2b", torch_dtype=torch.float16)
model = AutoModel.from_config(my_config)
์ปค์คํ ๋ชจ๋ธ[[custom-models]]
์ปค์คํ
๋ชจ๋ธ์ ํธ๋์คํฌ๋จธ์ ๊ตฌ์ฑ ๋ฐ ๋ชจ๋ธ๋ง ํด๋์ค๋ฅผ ๊ธฐ๋ฐ์ผ๋ก ๊ตฌ์ถ๋๋ฉฐ, AutoClass API๋ฅผ ์ง์ํ๊ณ [~PreTrainedModel.from_pretrained]๋ก ๋ก๋๋ฉ๋๋ค. ์ฐจ์ด์ ์ ๋ชจ๋ธ๋ง ์ฝ๋๊ฐ ํธ๋์คํฌ๋จธ์์ ์ ๊ณต๋๋ ๊ฒ์ด ์๋๋ผ๋ ์ ์
๋๋ค.
์ปค์คํ ๋ชจ๋ธ์ ๋ก๋ํ ๋๋ ํน๋ณํ ์ฃผ์ํด์ผ ํฉ๋๋ค. Hub์๋ ๋ชจ๋ ์ ์ฅ์์ ๋ํ ์ ์ฑ์ฝ๋ ์ค์บ์ด ํฌํจ๋์ด ์์ง๋ง, ์ฌ์ ํ ์ค์๋ก ์ ์ฑ์ฝ๋๋ฅผ ์คํํ์ง ์๋๋ก ์ฃผ์ํด์ผ ํฉ๋๋ค.
์ปค์คํ
๋ชจ๋ธ์ ๋ก๋ํ๋ ค๋ฉด [~PreTrainedModel.from_pretrained]์์ trust_remote_code=True๋ฅผ ์ค์ ํ์ธ์.
from transformers import AutoModelForImageClassification
model = AutoModelForImageClassification.from_pretrained("sgugger/custom-resnet50d", trust_remote_code=True)
์ถ๊ฐ์ ์ธ ๋ณด์ ์กฐ์น๋ก, ๋ณ๊ฒฝ๋์์ ์๋ ์๋ ๋ชจ๋ธ ์ฝ๋๋ฅผ ๋ก๋ํ๋ ๊ฒ์ ํผํ๊ธฐ ์ํด ํน์ ๋ฆฌ๋น์ ์์ ์ปค์คํ ๋ชจ๋ธ์ ๋ก๋ํฉ๋๋ค. ์ปค๋ฐ ํด์๋ ๋ชจ๋ธ์ ์ปค๋ฐ ๊ธฐ๋ก์์ ๋ณต์ฌํ ์ ์์ต๋๋ค.
commit_hash = "ed94a7c6247d8aedce4647f00f20de6875b5b292"
model = AutoModelForImageClassification.from_pretrained(
"sgugger/custom-resnet50d", trust_remote_code=True, revision=commit_hash
)
์์ธํ ๋ด์ฉ์ ์ฌ์ฉ์ ์ ์ ๋ชจ๋ธ ๊ณต์ ํ๊ธฐ ๊ฐ์ด๋๋ฅผ ์ฐธ์กฐํ์ธ์.