AbdulElahGwaith's picture
Upload folder using huggingface_hub
a9bd396 verified

๋ชจ๋ธ ๋กœ๋“œํ•˜๊ธฐ[[loading-models]]

Transformers๋Š” ํ•œ ์ค„์˜ ์ฝ”๋“œ๋กœ ์‚ฌ์šฉํ•  ์ˆ˜ ์žˆ๋Š” ๋งŽ์€ ์‚ฌ์ „ ํ›ˆ๋ จ๋œ ๋ชจ๋ธ์„ ์ œ๊ณตํ•ฉ๋‹ˆ๋‹ค. ๋ชจ๋ธ ํด๋ž˜์Šค์™€ [~PreTrainedModel.from_pretrained] ๋ฉ”์†Œ๋“œ๊ฐ€ ํ•„์š”ํ•ฉ๋‹ˆ๋‹ค.

[~PreTrainedModel.from_pretrained]๋ฅผ ํ˜ธ์ถœํ•˜์—ฌ Hugging Face Hub์— ์ €์žฅ๋œ ๋ชจ๋ธ์˜ ๊ฐ€์ค‘์น˜์™€ ๊ตฌ์„ฑ์„ ๋‹ค์šด๋กœ๋“œํ•˜๊ณ  ๋กœ๋“œํ•˜์„ธ์š”.

[~PreTrainedModel.from_pretrained] ๋ฉ”์†Œ๋“œ๋Š” safetensors ํŒŒ์ผ ํ˜•์‹์œผ๋กœ ์ €์žฅ๋œ ๊ฐ€์ค‘์น˜๊ฐ€ ์žˆ์œผ๋ฉด ์ด๋ฅผ ๋กœ๋“œํ•ฉ๋‹ˆ๋‹ค. ์ „ํ†ต์ ์œผ๋กœ PyTorch ๋ชจ๋ธ ๊ฐ€์ค‘์น˜๋Š” ๋ณด์•ˆ์— ์ทจ์•ฝํ•œ ๊ฒƒ์œผ๋กœ ์•Œ๋ ค์ง„ pickle ์œ ํ‹ธ๋ฆฌํ‹ฐ๋กœ ์ง๋ ฌํ™”๋ฉ๋‹ˆ๋‹ค. Safetensor ํŒŒ์ผ์€ ๋” ์•ˆ์ „ํ•˜๊ณ  ๋กœ๋“œ ์†๋„๊ฐ€ ๋น ๋ฆ…๋‹ˆ๋‹ค.

from transformers import AutoModelForCausalLM

model = AutoModelForCausalLM.from_pretrained("meta-llama/Llama-2-7b-hf", torch_dtype="auto", device_map="auto")

์ด ๊ฐ€์ด๋“œ๋Š” ๋ชจ๋ธ์„ ๋ถˆ๋Ÿฌ์˜ค๋Š” ๋ฐฉ๋ฒ•, ๋‹ค์–‘ํ•œ ๋กœ๋”ฉ ๋ฐฉ์‹, ๋งค์šฐ ํฐ ๋ชจ๋ธ์—์„œ ๋ฐœ์ƒํ•  ์ˆ˜ ์žˆ๋Š” ๋ฉ”๋ชจ๋ฆฌ ๋ฌธ์ œ๋ฅผ ํ•ด๊ฒฐํ•˜๋Š” ๋ฐฉ๋ฒ•, ๊ทธ๋ฆฌ๊ณ  ์‚ฌ์šฉ์ž ์ •์˜ ๋ชจ๋ธ์„ ๋ถˆ๋Ÿฌ์˜ค๋Š” ๋ฐฉ๋ฒ•์„ ์„ค๋ช…ํ•ฉ๋‹ˆ๋‹ค.

๋ชจ๋ธ๊ณผ ๊ตฌ์„ฑ[[models-and-configurations]]

๋ชจ๋“  ๋ชจ๋ธ์—๋Š” ์€๋‹‰ ๋ ˆ์ด์–ด ์ˆ˜, ์–ดํœ˜ ์‚ฌ์ „ ํฌ๊ธฐ, ํ™œ์„ฑํ™” ํ•จ์ˆ˜ ๋“ฑ๊ณผ ๊ฐ™์€ ํŠน์ • ์†์„ฑ์ด ํฌํ•จ๋œ configuration.py ํŒŒ์ผ์ด ์žˆ์Šต๋‹ˆ๋‹ค. ๋˜ํ•œ ๊ฐ ๋ ˆ์ด์–ด์˜ ์ •์˜์™€ ๊ฐ๊ฐ์˜ ๋ ˆ์ด์–ด ์•ˆ์—์„œ ์ผ์–ด๋‚˜๋Š” ์ˆ˜ํ•™์  ์—ฐ์‚ฐ์„ ์ •์˜ํ•˜๋Š” modeling.py ํŒŒ์ผ๋„ ์žˆ์Šต๋‹ˆ๋‹ค. modeling.py ํŒŒ์ผ์€ configuration.py์— ์ •์˜๋œ ๋ชจ๋ธ ์†์„ฑ์„ ๋ฐ”ํƒ•์œผ๋กœ ๋ชจ๋ธ์„ ๊ตฌ์ถ•ํ•ฉ๋‹ˆ๋‹ค. ์ด ๋‹จ๊ณ„์—์„œ๋Š” ์•„์ง ํ•™์Šต๋˜์ง€ ์•Š์€ ๋ฌด์ž‘์œ„ ๊ฐ€์ค‘์น˜๋ฅผ ๊ฐ€์ง„ ์ƒํƒœ์ด๊ธฐ ๋•Œ๋ฌธ์—, ์˜๋ฏธ ์žˆ๋Š” ์ถœ๋ ฅ์„ ์–ป๊ธฐ ์œ„ํ•ด์„œ๋Š” ํ•™์Šต์ด ํ•„์š”ํ•ฉ๋‹ˆ๋‹ค.

*์•„ํ‚คํ…์ฒ˜(Architecture)*๋Š” ๋ชจ๋ธ์˜ ๊ณจ๊ฒฉ์„ ์˜๋ฏธํ•˜๊ณ  *์ฒดํฌํฌ์ธํŠธ(checkpoint)*๋Š” ์ฃผ์–ด์ง„ ์•„ํ‚คํ…์ฒ˜์— ๋Œ€ํ•œ ๋ชจ๋ธ์˜ ๊ฐ€์ค‘์น˜๋ฅผ ์˜๋ฏธํ•ฉ๋‹ˆ๋‹ค. ์˜ˆ๋ฅผ ๋“ค์–ด, BERT๋Š” ์•„ํ‚คํ…์ฒ˜์ด๊ณ  google-bert/bert-base-uncased๋Š” ํ•ด๋‹น ์•„ํ‚คํ…์ฒ˜์˜ ์ฒดํฌํฌ์ธํŠธ(checkpoint)์ž…๋‹ˆ๋‹ค. ๋ชจ๋ธ์ด๋ผ๋Š” ์šฉ์–ด๋Š” ์•„ํ‚คํ…์ฒ˜ ๋ฐ ์ฒดํฌํฌ์ธํŠธ(checkpoint)์™€ ํ˜ผ์šฉํ•˜์—ฌ ์‚ฌ์šฉ๋˜๋Š” ๊ฒƒ์„ ๋ณผ ์ˆ˜ ์žˆ์Šต๋‹ˆ๋‹ค.

๋กœ๋“œํ•  ์ˆ˜ ์žˆ๋Š” ๋ชจ๋ธ์€ ์ผ๋ฐ˜์ ์œผ๋กœ ๋‘ ๊ฐ€์ง€ ํƒ€์ž…์ด ์žˆ์Šต๋‹ˆ๋‹ค.

  1. ์€๋‹‰ ์ƒํƒœ๋ฅผ ์ถœ๋ ฅํ•˜๋Š” [AutoModel] ๋˜๋Š” [LlamaModel]๊ณผ ๊ฐ™์€ ๊ธฐ๋ณธ ๋ชจ๋ธ์ž…๋‹ˆ๋‹ค.
  2. ํŠน์ • ์ž‘์—…์„ ์ˆ˜ํ–‰ํ•˜๊ธฐ ์œ„ํ•ด ํŠน์ • ํ—ค๋“œ๊ฐ€ ๋ถ™์€ [AutoModelForCausalLM] ๋˜๋Š” [LlamaForCausalLM]๊ณผ ๊ฐ™์€ ๋ชจ๋ธ์ž…๋‹ˆ๋‹ค.

๊ฐ ๋ชจ๋ธ ํƒ€์ž…๋งˆ๋‹ค, ๊ฐ๊ฐ์˜ ๊ธฐ๊ณ„ํ•™์Šต ํ”„๋ ˆ์ž„์›Œํฌ(PyTorch, TensorFlow, Flax)๋ฅผ ์œ„ํ•œ ๋ณ„๋„์˜ ํด๋ž˜์Šค๊ฐ€ ์žˆ์Šต๋‹ˆ๋‹ค. ์‚ฌ์šฉ ์ค‘์ธ ํ”„๋ ˆ์ž„์›Œํฌ์— ํ•ด๋‹นํ•˜๋Š” ์ ‘๋‘์–ด(prefix)๋ฅผ ์„ ํƒํ•˜์„ธ์š”.

from transformers import AutoModelForCausalLM, MistralForCausalLM

# AutoClass ๋˜๋Š” ๋ชจ๋ธ๋ณ„ ํด๋ž˜์Šค(model-specific class) ๋ฅผ ์ด์šฉํ•ด ๋กœ๋“œ
model = AutoModelForCausalLM.from_pretrained("mistralai/Mistral-7B-v0.1", torch_dtype="auto", device_map="auto")
model = MistralForCausalLM.from_pretrained("mistralai/Mistral-7B-v0.1", torch_dtype="auto", device_map="auto")
from transformers import TFAutoModelForCausalLM, TFMistralForCausalLM

# AutoClass ๋˜๋Š” ๋ชจ๋ธ๋ณ„ ํด๋ž˜์Šค(model-specific class) ๋ฅผ ์ด์šฉํ•ด ๋กœ๋“œ
model = TFAutoModelForCausalLM.from_pretrained("mistralai/Mistral-7B-v0.1")
model = TFMistralForCausalLM.from_pretrained("mistralai/Mistral-7B-v0.1")
from transformers import FlaxAutoModelForCausalLM, FlaxMistralForCausalLM

# AutoClass ๋˜๋Š” ๋ชจ๋ธ๋ณ„ ํด๋ž˜์Šค(model-specific class) ๋ฅผ ์ด์šฉํ•ด ๋กœ๋“œ
model = FlaxAutoModelForCausalLM.from_pretrained("mistralai/Mistral-7B-v0.1")
model = FlaxMistralForCausalLM.from_pretrained("mistralai/Mistral-7B-v0.1")

๋ชจ๋ธ ํด๋ž˜์Šค[[model-classes]]

์‚ฌ์ „ ํ›ˆ๋ จ๋œ ๋ชจ๋ธ์„ ๊ฐ€์ ธ์˜ค๋ ค๋ฉด ๋ชจ๋ธ์— ๊ฐ€์ค‘์น˜๋ฅผ ๋กœ๋“œํ•ด์•ผ ํ•ฉ๋‹ˆ๋‹ค. ์ด๋Š” Hugging Face Hub ๋˜๋Š” ๋กœ์ปฌ ๋””๋ ‰ํ„ฐ๋ฆฌ์—์„œ ๊ฐ€์ค‘์น˜๋ฅผ ๋ฐ›์•„๋“ค์ด๋Š” [~PreTrainedModel.from_pretrained]๋ฅผ ํ˜ธ์ถœํ•˜์—ฌ ์ˆ˜ํ–‰๋ฉ๋‹ˆ๋‹ค.

๋‘ ๊ฐ€์ง€ ๋ชจ๋ธ ํด๋ž˜์Šค๋กœ AutoModel ํด๋ž˜์Šค์™€ ๋ชจ๋ธ๋ณ„ ํด๋ž˜์Šค๊ฐ€ ์žˆ์Šต๋‹ˆ๋‹ค.

AutoModel ํด๋ž˜์Šค๋Š” ์ •ํ™•ํ•œ ๋ชจ๋ธ ํด๋ž˜์Šค ์ด๋ฆ„์„ ๋ชฐ๋ผ๋„ ์•„ํ‚คํ…์ฒ˜๋ฅผ ๋ถˆ๋Ÿฌ์˜ฌ ์ˆ˜ ์žˆ๋Š” ํŽธ๋ฆฌํ•œ ๋ฐฉ๋ฒ•์ž…๋‹ˆ๋‹ค. ๋งŽ์€ ๋ชจ๋ธ์ด ์ œ๊ณต๋˜๊ธฐ ๋•Œ๋ฌธ์—, ์ด ํด๋ž˜์Šค๋Š” ๊ตฌ์„ฑ ํŒŒ์ผ์„ ๊ธฐ๋ฐ˜์œผ๋กœ ์˜ฌ๋ฐ”๋ฅธ ๋ชจ๋ธ ํด๋ž˜์Šค๋ฅผ ์ž๋™์œผ๋กœ ์„ ํƒํ•ด ์ค๋‹ˆ๋‹ค. ์›ํ•˜๋Š” ์ž‘์—…๊ณผ ์‚ฌ์šฉํ•˜๋ ค๋Š” ์ฒดํฌํฌ์ธํŠธ๋งŒ ์•Œ๊ณ  ์žˆ์œผ๋ฉด ๋ฉ๋‹ˆ๋‹ค.

์ฃผ์–ด์ง„ ์ž‘์—…์„ ์•„ํ‚คํ…์ฒ˜๊ฐ€ ์ง€์›ํ•˜๋Š” ํ•œ, ๋ชจ๋ธ์ด๋‚˜ ์ž‘์—…์„ ์‰ฝ๊ฒŒ ์ „ํ™˜ํ•  ์ˆ˜ ์žˆ์Šต๋‹ˆ๋‹ค.

์˜ˆ๋ฅผ ๋“ค์–ด, ๋™์ผํ•œ ๋ชจ๋ธ์„ ์„œ๋กœ ๋‹ค๋ฅธ ์ž‘์—…์— ์‚ฌ์šฉํ•  ์ˆ˜ ์žˆ์Šต๋‹ˆ๋‹ค.

from transformers import AutoModelForCausalLM, AutoModelForSequenceClassification, AutoModelForQuestionAnswering

# ๋™์ผํ•œ API๋ฅผ 3๊ฐ€์ง€ ๋‹ค๋ฅธ ์ž‘์—…์— ์‚ฌ์šฉ
model = AutoModelForCausalLM.from_pretrained("meta-llama/Llama-2-7b-hf")
model = AutoModelForSequenceClassification.from_pretrained("meta-llama/Llama-2-7b-hf")
model = AutoModelForQuestionAnswering.from_pretrained("meta-llama/Llama-2-7b-hf")

๋‹ค๋ฅธ ๊ฒฝ์šฐ์—๋Š”, ํ•˜๋‚˜์˜ ์ž‘์—…์— ๋Œ€ํ•ด ์—ฌ๋Ÿฌ ๊ฐ€์ง€ ๋ชจ๋ธ์„ ๋น ๋ฅด๊ฒŒ ์‹œํ—˜ํ•ด๋ณด๊ณ  ์‹ถ์„ ์ˆ˜๋„ ์žˆ์Šต๋‹ˆ๋‹ค.

from transformers import AutoModelForCausalLM

# ๋™์ผํ•œ API๋ฅผ ์‚ฌ์šฉํ•˜์—ฌ 3๊ฐ€์ง€ ๋‹ค๋ฅธ ๋ชจ๋ธ ๋กœ๋“œ
model = AutoModelForCausalLM.from_pretrained("meta-llama/Llama-2-7b-hf")
model = AutoModelForCausalLM.from_pretrained("mistralai/Mistral-7B-v0.1")
model = AutoModelForCausalLM.from_pretrained("google/gemma-7b")

AutoModel ํด๋ž˜์Šค๋Š” ๋ชจ๋ธ๋ณ„ ํด๋ž˜์Šค๋“ค์„ ๊ธฐ๋ฐ˜์œผ๋กœ ๊ตฌ์ถ•๋ฉ๋‹ˆ๋‹ค. ํŠน์ • ์ž‘์—…์„ ์ง€์›ํ•˜๋Š” ๋ชจ๋“  ๋ชจ๋ธ ํด๋ž˜์Šค๋“ค์€ ๊ฐ๊ฐ์˜ AutoModelFor ์ž‘์—… ํด๋ž˜์Šค์— ๋งคํ•‘๋ฉ๋‹ˆ๋‹ค.

์ด๋ฏธ ์‚ฌ์šฉํ•˜๋ ค๋Š” ๋ชจ๋ธ ํด๋ž˜์Šค๋ฅผ ์•Œ๊ณ  ์žˆ๋‹ค๋ฉด ํ•ด๋‹น ๋ชจ๋ธ๋ณ„ ํด๋ž˜์Šค๋ฅผ ์ง์ ‘ ์‚ฌ์šฉํ•  ์ˆ˜ ์žˆ์Šต๋‹ˆ๋‹ค.

from transformers import LlamaModel, LlamaForCausalLM

model = LlamaForCausalLM.from_pretrained("meta-llama/Llama-2-7b-hf")

๋Œ€๊ทœ๋ชจ ๋ชจ๋ธ[[large-models]]

๋Œ€๊ทœ๋ชจ ์‚ฌ์ „ ํ›ˆ๋ จ๋œ ๋ชจ๋ธ์€ ๋กœ๋“œํ•˜๋Š” ๋ฐ ๋งŽ์€ ๋ฉ”๋ชจ๋ฆฌ๊ฐ€ ํ•„์š”ํ•ฉ๋‹ˆ๋‹ค. ๋กœ๋“œ ๊ณผ์ •์€ ๋‹ค์Œ์„ ํฌํ•จํ•ฉ๋‹ˆ๋‹ค:

  1. ๋ฌด์ž‘์œ„ ๊ฐ€์ค‘์น˜๋กœ ๋ชจ๋ธ ์ƒ์„ฑ
  2. ์‚ฌ์ „ ํ›ˆ๋ จ๋œ ๊ฐ€์ค‘์น˜ ๋กœ๋“œ
  3. ์‚ฌ์ „ ํ›ˆ๋ จ๋œ ๊ฐ€์ค‘์น˜๋ฅผ ๋ชจ๋ธ์— ์ ์šฉ

๋ชจ๋ธ ๊ฐ€์ค‘์น˜์˜ ๋ณต์‚ฌ๋ณธ ๋‘ ๊ฐ€์ง€(๋ฌด์ž‘์œ„ ๊ฐ€์ค‘์น˜์™€ ์‚ฌ์ „ ํ›ˆ๋ จ๋œ ๊ฐ€์ค‘์น˜)๋ฅผ ๋ณด๊ด€ํ•  ์ˆ˜ ์žˆ๋Š” ์ถฉ๋ถ„ํ•œ ๋ฉ”๋ชจ๋ฆฌ๊ฐ€ ํ•„์š”ํ•˜๋ฉฐ, ์ด๋Š” ๋ณด์œ ํ•œ ํ•˜๋“œ์›จ์–ด์— ๋”ฐ๋ผ ๋ถˆ๊ฐ€๋Šฅํ•  ์ˆ˜ ์žˆ์Šต๋‹ˆ๋‹ค. ๋ถ„์‚ฐ ํ•™์Šต ํ™˜๊ฒฝ์—์„œ๋Š” ๊ฐ ํ”„๋กœ์„ธ์Šค๊ฐ€ ์‚ฌ์ „ ํ›ˆ๋ จ๋œ ๋ชจ๋ธ์„ ๋กœ๋“œํ•˜๊ธฐ ๋•Œ๋ฌธ์— ์ด๋Š” ๋”์šฑ ์–ด๋ ค์šด ๊ณผ์ œ์ž…๋‹ˆ๋‹ค.

transformers๋Š” ๋น ๋ฅธ ์ดˆ๊ธฐํ™”, ๋ถ„ํ• ๋œ ์ฒดํฌํฌ์ธํŠธ, Accelerate์˜ Big Model Inference ๊ธฐ๋Šฅ, ๊ทธ๋ฆฌ๊ณ  ๋” ๋‚ฎ์€ ๋น„ํŠธ ๋ฐ์ดํ„ฐ ํƒ€์ž… ์ง€์›์„ ํ†ตํ•ด ์ด๋Ÿฌํ•œ ๋ฉ”๋ชจ๋ฆฌ ๊ด€๋ จ ๋ฌธ์ œ๋“ค์„ ์ผ๋ถ€ ์ค„์—ฌ์ค๋‹ˆ๋‹ค.

๋ถ„ํ• ๋œ ์ฒดํฌํฌ์ธํŠธ[[sharded-checkpoints]]

[~PreTrainedModel.save_pretrained] ๋ฉ”์†Œ๋“œ๋Š” 10GB๋ณด๋‹ค ํฐ ์ฒดํฌํฌ์ธํŠธ๋ฅผ ์ž๋™์œผ๋กœ ์ƒค๋“œํ•ฉ๋‹ˆ๋‹ค.

๊ฐ ์ƒค๋“œ(shard)๋Š” ์ด์ „ ์ƒค๋“œ๊ฐ€ ๋กœ๋“œ๋œ ํ›„ ์ˆœ์ฐจ์ ์œผ๋กœ ๋กœ๋“œ๋˜์–ด, ๋ฉ”๋ชจ๋ฆฌ ์‚ฌ์šฉ๋Ÿ‰์„ ๋ชจ๋ธ ํฌ๊ธฐ์™€ ๊ฐ€์žฅ ํฐ ์ƒค๋“œ ํฌ๊ธฐ๋กœ๋งŒ ์ œํ•œํ•ฉ๋‹ˆ๋‹ค.

max_shard_size ๋งค๊ฐœ๋ณ€์ˆ˜๋Š” ๊ฐ ์ƒค๋“œ์— ๋Œ€ํ•ด ๊ธฐ๋ณธ์ ์œผ๋กœ 5GB๋กœ ์„ค์ •๋˜์–ด ์žˆ๋Š”๋ฐ, ์ด๋Š” ๋ฉ”๋ชจ๋ฆฌ ๋ถ€์กฑ ์—†์ด ๋ฌด๋ฃŒ ๋“ฑ๊ธ‰ GPU ์ธ์Šคํ„ด์Šค์—์„œ ๋” ์‰ฝ๊ฒŒ ์‹คํ–‰ํ•  ์ˆ˜ ์žˆ๊ธฐ ๋•Œ๋ฌธ์ž…๋‹ˆ๋‹ค.

์˜ˆ๋ฅผ ๋“ค์–ด, [~PreTrainedModel.save_pretrained]์—์„œ BioMistral/BioMistral-7B์— ๋Œ€ํ•œ ๋ถ„ํ• ๋œ ์ฒดํฌํฌ์ธํŠธ๋ฅผ ์ƒ์„ฑํ•ด๋ณด๊ฒ ์Šต๋‹ˆ๋‹ค.

from transformers import AutoModel
import tempfile
import os

model = AutoModel.from_pretrained("biomistral/biomistral-7b")
with tempfile.TemporaryDirectory() as tmp_dir:
    model.save_pretrained(tmp_dir, max_shard_size="5GB")
    print(sorted(os.listdir(tmp_dir)))

[~PreTrainedModel.from_pretrained]๋กœ ๋ถ„ํ• ๋œ ์ฒดํฌํฌ์ธํŠธ๋ฅผ ๋‹ค์‹œ ๋กœ๋“œํ•ฉ๋‹ˆ๋‹ค.

with tempfile.TemporaryDirectory() as tmp_dir:
    model.save_pretrained(tmp_dir)
    new_model = AutoModel.from_pretrained(tmp_dir)

๋ถ„ํ• ๋œ ์ฒดํฌํฌ์ธํŠธ๋Š” [~transformers.trainer_utils.load_sharded_checkpoint]๋กœ๋„ ์ง์ ‘ ๋ถˆ๋Ÿฌ์˜ฌ ์ˆ˜ ์žˆ์Šต๋‹ˆ๋‹ค.

from transformers.trainer_utils import load_sharded_checkpoint

with tempfile.TemporaryDirectory() as tmp_dir:
    model.save_pretrained(tmp_dir, max_shard_size="5GB")
    load_sharded_checkpoint(model, tmp_dir)

[~PreTrainedModel.save_pretrained] ๋ฉ”์†Œ๋“œ๋Š” ๋งค๊ฐœ๋ณ€์ˆ˜ ์ด๋ฆ„์„ ์ €์žฅ๋œ ํŒŒ์ผ์— ๋งคํ•‘ํ•˜๋Š” ์ธ๋ฑ์Šค ํŒŒ์ผ์„ ์ƒ์„ฑํ•ฉ๋‹ˆ๋‹ค. ์ธ๋ฑ์Šค ํŒŒ์ผ์—๋Š” metadata์™€ weight_map์ด๋ผ๋Š” ๋‘ ๊ฐœ์˜ ํ‚ค๊ฐ€ ์žˆ์Šต๋‹ˆ๋‹ค.

import json

with tempfile.TemporaryDirectory() as tmp_dir:
    model.save_pretrained(tmp_dir, max_shard_size="5GB")
    with open(os.path.join(tmp_dir, "model.safetensors.index.json"), "r") as f:
        index = json.load(f)

print(index.keys())

metadata ํ‚ค๋Š” ์ „์ฒด ๋ชจ๋ธ ํฌ๊ธฐ๋ฅผ ์ œ๊ณตํ•ฉ๋‹ˆ๋‹ค.

index["metadata"]
{'total_size': 28966928384}

weight_map ํ‚ค๋Š” ๊ฐ ๋งค๊ฐœ๋ณ€์ˆ˜๋ฅผ ์ €์žฅ๋œ ์ƒค๋“œ์— ๋งคํ•‘ํ•ฉ๋‹ˆ๋‹ค.

index["weight_map"]
{'lm_head.weight': 'model-00006-of-00006.safetensors',
 'model.embed_tokens.weight': 'model-00001-of-00006.safetensors',
 'model.layers.0.input_layernorm.weight': 'model-00001-of-00006.safetensors',
 'model.layers.0.mlp.down_proj.weight': 'model-00001-of-00006.safetensors',
 ...
}

๋Œ€ํ˜• ๋ชจ๋ธ ์ถ”๋ก [[big-model-inference]]

์ด ๊ธฐ๋Šฅ์„ ์‚ฌ์šฉํ•˜๋ ค๋ฉด Accelerate v0.9.0 ๋ฐ PyTorch v1.9.0 ์ด์ƒ์ด ์„ค์น˜๋˜์–ด ์žˆ๋Š”์ง€ ํ™•์ธํ•˜์„ธ์š”!

[~PreTrainedModel.from_pretrained]๋Š” Accelerate์˜ ๋Œ€ํ˜• ๋ชจ๋ธ ์ถ”๋ก  ๊ธฐ๋Šฅ์œผ๋กœ ๊ฐ•ํ™”๋˜์—ˆ์Šต๋‹ˆ๋‹ค.

๋Œ€ํ˜• ๋ชจ๋ธ ์ถ”๋ก ์€ PyTorch meta ์žฅ์น˜์—์„œ ๋ชจ๋ธ ์Šค์ผˆ๋ ˆํ†ค์„ ์ƒ์„ฑํ•ฉ๋‹ˆ๋‹ค. meta ์žฅ์น˜๋Š” ์‹ค์ œ ๋ฐ์ดํ„ฐ๋ฅผ ์ €์žฅํ•˜์ง€ ์•Š๊ณ  ๋ฉ”ํƒ€๋ฐ์ดํ„ฐ๋งŒ ์ €์žฅํ•ฉ๋‹ˆ๋‹ค.

๋ฌด์ž‘์œ„๋กœ ์ดˆ๊ธฐํ™”๋œ ๊ฐ€์ค‘์น˜๋Š” ์‚ฌ์ „ ํ›ˆ๋ จ๋œ ๊ฐ€์ค‘์น˜๊ฐ€ ๋กœ๋“œ๋  ๋•Œ๋งŒ ์ƒ์„ฑ๋˜์–ด ๋ฉ”๋ชจ๋ฆฌ์— ๋™์‹œ์— ๋ชจ๋ธ์˜ ๋‘ ๋ณต์‚ฌ๋ณธ์„ ์œ ์ง€ํ•˜๋Š” ๊ฒƒ์„ ๋ฐฉ์ง€ํ•ฉ๋‹ˆ๋‹ค. ์ตœ๋Œ€ ๋ฉ”๋ชจ๋ฆฌ ์‚ฌ์šฉ๋Ÿ‰์€ ๋ชจ๋ธ ํฌ๊ธฐ๋งŒํผ์ž…๋‹ˆ๋‹ค.

์žฅ์น˜ ํ• ๋‹น์— ๋Œ€ํ•œ ์ž์„ธํ•œ ๋‚ด์šฉ์€ ์žฅ์น˜ ๋งต ์„ค๊ณ„ํ•˜๊ธฐ๋ฅผ ์ฐธ์กฐํ•˜์„ธ์š”.

๋Œ€ํ˜• ๋ชจ๋ธ ์ถ”๋ก ์˜ ๋‘ ๋ฒˆ์งธ ๊ธฐ๋Šฅ์€ ๋ถˆ๋Ÿฌ์˜จ ๊ฐ€์ค‘์น˜๊ฐ€ ๋ชจ๋ธ ์Šค์ผˆ๋ ˆํ†ค์— ํ• ๋‹น๋˜๋Š” ๋ฐฉ์‹๊ณผ ๊ด€๋ จ์ด ์žˆ์Šต๋‹ˆ๋‹ค. ๋ชจ๋ธ ๊ฐ€์ค‘์น˜๋Š” ์‚ฌ์šฉ ๊ฐ€๋Šฅํ•œ ๋ชจ๋“  ๋””๋ฐ”์ด์Šค์— ๋ถ„์‚ฐ๋˜๋ฉฐ, ๊ฐ€์žฅ ๋น ๋ฅธ ๋””๋ฐ”์ด์Šค(๋ณดํ†ต GPU)๋ถ€ํ„ฐ ์‹œ์ž‘ํ•ด ๋‚˜๋จธ์ง€ ๊ฐ€์ค‘์น˜๋Š” ๋А๋ฆฐ ๋””๋ฐ”์ด์Šค(CPU ๋ฐ ํ•˜๋“œ ๋””์Šคํฌ)๋กœ ์ˆœ์ฐจ์ ์œผ๋กœ ํ• ๋‹น๋ฉ๋‹ˆ๋‹ค.

๋‘ ๊ธฐ๋Šฅ์„ ๊ฒฐํ•ฉํ•˜๋ฉด ๋Œ€ํ˜• ์‚ฌ์ „ ํ›ˆ๋ จ๋œ ๋ชจ๋ธ์˜ ๋ฉ”๋ชจ๋ฆฌ ์‚ฌ์šฉ๋Ÿ‰๊ณผ ๋กœ๋”ฉ ์‹œ๊ฐ„์ด ์ค„์–ด๋“ญ๋‹ˆ๋‹ค.

๋Œ€ํ˜• ๋ชจ๋ธ ์ถ”๋ก ์„ ํ™œ์„ฑํ™”ํ•˜๋ ค๋ฉด device_map์„ "auto"๋กœ ์„ค์ •ํ•ฉ๋‹ˆ๋‹ค.

from transformers import AutoModelForCausalLM

model = AutoModelForCausalLM.from_pretrained("google/gemma-7b", device_map="auto")

device_map์„ ์‚ฌ์šฉํ•˜์—ฌ ๋ ˆ์ด์–ด๋ฅผ ์›ํ•˜๋Š” ๋””๋ฐ”์ด์Šค์— ์ˆ˜๋™์œผ๋กœ ํ• ๋‹นํ•  ์ˆ˜๋„ ์žˆ์Šต๋‹ˆ๋‹ค. ๋ ˆ์ด์–ด ์ „์ฒด๊ฐ€ ๋™์ผํ•œ ๋””๋ฐ”์ด์Šค์— ํ• ๋‹น๋˜์–ด ์žˆ๋‹ค๋ฉด, ํ•ด๋‹น ๋ ˆ์ด์–ด์˜ ๋ชจ๋“  ์„œ๋ธŒ๋ชจ๋“ˆ์ด ์–ด๋””์— ๋ฐฐ์น˜๋˜๋Š”์ง€ ์ผ์ผ์ด ์ง€์ •ํ•  ํ•„์š”๋Š” ์—†์Šต๋‹ˆ๋‹ค.

๋ชจ๋ธ์ด ๊ฐ ๋””๋ฐ”์ด์Šค์— ์–ด๋–ป๊ฒŒ ๋ถ„์‚ฐ๋˜์–ด ์žˆ๋Š”์ง€ ํ™•์ธํ•˜๋ ค๋ฉด hf_device_map ์†์„ฑ์„ ์กฐํšŒํ•ฉ๋‹ˆ๋‹ค.

device_map = {"model.layers.1": 0, "model.layers.14": 1, "model.layers.31": "cpu", "lm_head": "disk"}
model.hf_device_map

๋ชจ๋ธ ๋ฐ์ดํ„ฐ ํƒ€์ž…[[model-data-type]]

PyTorch ๋ชจ๋ธ ๊ฐ€์ค‘์น˜๋Š” ๊ธฐ๋ณธ์ ์œผ๋กœ torch.float32๋กœ ์ดˆ๊ธฐํ™”๋ฉ๋‹ˆ๋‹ค. torch.float16๊ณผ ๊ฐ™์ด ๋‹ค๋ฅธ ๋ฐ์ดํ„ฐ ํƒ€์ž…์œผ๋กœ ๋ชจ๋ธ์„ ๋กœ๋“œํ•˜๋ฉด ๋ชจ๋ธ์ด ์›ํ•˜๋Š” ๋ฐ์ดํ„ฐ ํƒ€์ž…์œผ๋กœ ๋‹ค์‹œ ๋กœ๋“œ๋˜๊ธฐ ๋•Œ๋ฌธ์— ์ถ”๊ฐ€ ๋ฉ”๋ชจ๋ฆฌ๊ฐ€ ํ•„์š”ํ•ฉ๋‹ˆ๋‹ค.

torch_dtype ๋งค๊ฐœ๋ณ€์ˆ˜๋ฅผ ๋ช…์‹œ์ ์œผ๋กœ ์„ค์ •ํ•˜์—ฌ ๊ฐ€์ค‘์น˜๋ฅผ ๋‘ ๋ฒˆ ๋กœ๋“œํ•˜๋Š” ๋Œ€์‹ (torch.float32 ํ›„ torch.float16) ์›ํ•˜๋Š” ๋ฐ์ดํ„ฐ ํƒ€์ž…์œผ๋กœ ๋ชจ๋ธ์„ ์ง์ ‘ ์ดˆ๊ธฐํ™”ํ•˜์„ธ์š”. ๋˜ํ•œ torch_dtype="auto"๋ฅผ ์„ค์ •ํ•˜์—ฌ ๊ฐ€์ค‘์น˜๋ฅผ ์ €์žฅ๋œ ๋ฐ์ดํ„ฐ ํƒ€์ž…์œผ๋กœ ์ž๋™์œผ๋กœ ๋กœ๋“œํ•  ์ˆ˜๋„ ์žˆ์Šต๋‹ˆ๋‹ค.

import torch
from transformers import AutoModelForCausalLM

gemma = AutoModelForCausalLM.from_pretrained("google/gemma-7b", torch_dtype=torch.float16)
from transformers import AutoModelForCausalLM

gemma = AutoModelForCausalLM.from_pretrained("google/gemma-7b", torch_dtype="auto")

๋ชจ๋ธ์„ ์ฒ˜์Œ๋ถ€ํ„ฐ ์ธ์Šคํ„ด์Šคํ™”ํ•˜๋Š” ๊ฒฝ์šฐ, torch_dtype ํŒŒ๋ผ๋ฏธํ„ฐ๋Š” [AutoConfig]์—์„œ ์„ค์ •ํ•  ์ˆ˜๋„ ์žˆ์Šต๋‹ˆ๋‹ค.

import torch
from transformers import AutoConfig, AutoModel

my_config = AutoConfig.from_pretrained("google/gemma-2b", torch_dtype=torch.float16)
model = AutoModel.from_config(my_config)

์ปค์Šคํ…€ ๋ชจ๋ธ[[custom-models]]

์ปค์Šคํ…€ ๋ชจ๋ธ์€ ํŠธ๋žœ์Šคํฌ๋จธ์˜ ๊ตฌ์„ฑ ๋ฐ ๋ชจ๋ธ๋ง ํด๋ž˜์Šค๋ฅผ ๊ธฐ๋ฐ˜์œผ๋กœ ๊ตฌ์ถ•๋˜๋ฉฐ, AutoClass API๋ฅผ ์ง€์›ํ•˜๊ณ  [~PreTrainedModel.from_pretrained]๋กœ ๋กœ๋“œ๋ฉ๋‹ˆ๋‹ค. ์ฐจ์ด์ ์€ ๋ชจ๋ธ๋ง ์ฝ”๋“œ๊ฐ€ ํŠธ๋žœ์Šคํฌ๋จธ์—์„œ ์ œ๊ณต๋˜๋Š” ๊ฒƒ์ด ์•„๋‹ˆ๋ผ๋Š” ์ ์ž…๋‹ˆ๋‹ค.

์ปค์Šคํ…€ ๋ชจ๋ธ์„ ๋กœ๋“œํ•  ๋•Œ๋Š” ํŠน๋ณ„ํžˆ ์ฃผ์˜ํ•ด์•ผ ํ•ฉ๋‹ˆ๋‹ค. Hub์—๋Š” ๋ชจ๋“  ์ €์žฅ์†Œ์— ๋Œ€ํ•œ ์•…์„ฑ์ฝ”๋“œ ์Šค์บ”์ด ํฌํ•จ๋˜์–ด ์žˆ์ง€๋งŒ, ์—ฌ์ „ํžˆ ์‹ค์ˆ˜๋กœ ์•…์„ฑ์ฝ”๋“œ๋ฅผ ์‹คํ–‰ํ•˜์ง€ ์•Š๋„๋ก ์ฃผ์˜ํ•ด์•ผ ํ•ฉ๋‹ˆ๋‹ค.

์ปค์Šคํ…€ ๋ชจ๋ธ์„ ๋กœ๋“œํ•˜๋ ค๋ฉด [~PreTrainedModel.from_pretrained]์—์„œ trust_remote_code=True๋ฅผ ์„ค์ •ํ•˜์„ธ์š”.

from transformers import AutoModelForImageClassification

model = AutoModelForImageClassification.from_pretrained("sgugger/custom-resnet50d", trust_remote_code=True)

์ถ”๊ฐ€์ ์ธ ๋ณด์•ˆ ์กฐ์น˜๋กœ, ๋ณ€๊ฒฝ๋˜์—ˆ์„ ์ˆ˜๋„ ์žˆ๋Š” ๋ชจ๋ธ ์ฝ”๋“œ๋ฅผ ๋กœ๋“œํ•˜๋Š” ๊ฒƒ์„ ํ”ผํ•˜๊ธฐ ์œ„ํ•ด ํŠน์ • ๋ฆฌ๋น„์ „์—์„œ ์ปค์Šคํ…€ ๋ชจ๋ธ์„ ๋กœ๋“œํ•ฉ๋‹ˆ๋‹ค. ์ปค๋ฐ‹ ํ•ด์‹œ๋Š” ๋ชจ๋ธ์˜ ์ปค๋ฐ‹ ๊ธฐ๋ก์—์„œ ๋ณต์‚ฌํ•  ์ˆ˜ ์žˆ์Šต๋‹ˆ๋‹ค.

commit_hash = "ed94a7c6247d8aedce4647f00f20de6875b5b292"
model = AutoModelForImageClassification.from_pretrained(
    "sgugger/custom-resnet50d", trust_remote_code=True, revision=commit_hash
)

์ž์„ธํ•œ ๋‚ด์šฉ์€ ์‚ฌ์šฉ์ž ์ •์˜ ๋ชจ๋ธ ๊ณต์œ ํ•˜๊ธฐ ๊ฐ€์ด๋“œ๋ฅผ ์ฐธ์กฐํ•˜์„ธ์š”.