transformers / docs /source /ko /perf_infer_gpu_multi.md
AbdulElahGwaith's picture
Upload folder using huggingface_hub
a9bd396 verified

๋ถ„์‚ฐ ์ถ”๋ก [[distributed-inference]]

๋ชจ๋ธ์ด ๋‹จ์ผ GPU์— ์˜ฌ๋ผ๊ฐ€์ง€ ์•Š๋Š” ๊ฒฝ์šฐ, ํ…์„œ ๋ณ‘๋ ฌ ์ฒ˜๋ฆฌ๋ฅผ ์‚ฌ์šฉํ•œ ๋ถ„์‚ฐ ์ถ”๋ก ์ด ๋„์›€์ด ๋  ์ˆ˜ ์žˆ์Šต๋‹ˆ๋‹ค. ํ…์„œ ๋ณ‘๋ ฌํ™”๋Š” ๋ชจ๋ธ์„ ์—ฌ๋Ÿฌ ๊ฐ€์†๊ธฐ(CUDA GPU, Intel XPU ๋“ฑ)์— ๋ถ„ํ• ํ•˜์—ฌ ํ–‰๋ ฌ ๊ณฑ์…ˆ๊ณผ ๊ฐ™์€ ๊ณ„์‚ฐ์„ ๋ณ‘๋ ฌํ™”ํ•ฉ๋‹ˆ๋‹ค. ์ด๋ฅผ ํ†ตํ•ด ๋” ํฐ ๋ชจ๋ธ์„ ๋ฉ”๋ชจ๋ฆฌ์— ์˜ฌ๋ฆด ์ˆ˜ ์žˆ์œผ๋ฉฐ, ๊ฐ ๊ฐ€์†๊ธฐ๊ฐ€ ํ…์„œ์˜ ์ผ๋ถ€๋ฅผ ์ฒ˜๋ฆฌํ•˜๋ฏ€๋กœ ์ถ”๋ก  ์†๋„๊ฐ€ ํ–ฅ์ƒ๋ฉ๋‹ˆ๋‹ค.

๊ทธ๋Ÿฌ๋‚˜ ํ…์„œ ๋ณ‘๋ ฌํ™”๋Š” ํ†ต์‹  ์˜ค๋ฒ„ํ—ค๋“œ๋ฅผ ๋ฐœ์ƒ์‹œํ‚ค๋ฏ€๋กœ, ๋น ๋ฅธ ๋…ธ๋“œ ๋‚ด ํ†ต์‹ ์„ ํ™œ์šฉํ•  ์ˆ˜ ์žˆ๋Š” ๋‹ค์ค‘ ๊ฐ€์†๊ธฐ ํ™˜๊ฒฝ์—์„œ ์‚ฌ์šฉํ•˜๋Š” ๊ฒƒ์ด ๊ฐ€์žฅ ํšจ๊ณผ์ ์ž…๋‹ˆ๋‹ค. ๋‹ค์ค‘ ๋…ธ๋“œ ํ•™์Šต ํ™˜๊ฒฝ์—์„œ๋Š” ์‚ฌ์šฉ ์‚ฌ๋ก€์— ๋”ฐ๋ผ ํŒŒ์ดํ”„๋ผ์ธ ๋ณ‘๋ ฌํ™”๋‚˜ ๋ฐ์ดํ„ฐ ๋ณ‘๋ ฌํ™”๋ฅผ ์‚ฌ์šฉํ•˜๋Š” ๊ฒƒ์ด ๋” ํšจ์œจ์ ์ผ ์ˆ˜ ์žˆ์Šต๋‹ˆ๋‹ค.

ํ…์„œ ๋ณ‘๋ ฌํ™”์— ๋Œ€ํ•ด ๋” ์ž์„ธํžˆ ์•Œ์•„๋ณด๋ ค๋ฉด Ultra-Scale Playbook์˜ ํ…์„œ ๋ณ‘๋ ฌํ™” ์„น์…˜์„ ์ฐธ์กฐํ•˜์„ธ์š”.

์•„๋ž˜ ๋ชฉ๋ก์—์„œ ํ…์„œ ๋ณ‘๋ ฌ ์ฒ˜๋ฆฌ๋ฅผ ๊ธฐ๋ณธ์ ์œผ๋กœ ์ง€์›ํ•˜๋Š” ๋ชจ๋ธ์„ ํ™•์ธํ•  ์ˆ˜ ์žˆ์Šต๋‹ˆ๋‹ค. ์ƒˆ๋กœ์šด ๋ชจ๋ธ์— ๋Œ€ํ•œ ์ง€์›์„ ์ถ”๊ฐ€ํ•˜๋ ค๋ฉด GitHub ์ด์Šˆ๋‚˜ ํ’€ ๋ฆฌํ€˜์ŠคํŠธ๋ฅผ ์—ด์–ด์ฃผ์„ธ์š”.

์ง€์›๋˜๋Š” ๋ชจ๋ธ ๋ณด๊ธฐ

์ด ๊ฐ€์ด๋“œ๋Š” Transformers์—์„œ ๋‹ค์–‘ํ•œ ๋ถ„ํ•  ์ „๋žต์„ ์‚ฌ์šฉํ•˜์—ฌ ํ…์„œ ๋ณ‘๋ ฌํ™”๋ฅผ ํ™œ์„ฑํ™”ํ•˜๋Š” ๋ฐฉ๋ฒ•์„ ์„ค๋ช…ํ•ฉ๋‹ˆ๋‹ค.

๋ชจ๋ธ ๋ถ„ํ• [[partitioning-a-model]]

Transformers๋Š” tp_plan๋งค๊ฐœ๋ณ€์ˆ˜๋ฅผ ํ™œ์šฉํ•  ์ˆ˜ ์žˆ๋Š” ๋ชจ๋ธ์— ๋Œ€ํ•ด ํ…์„œ ๋ณ‘๋ ฌ ์ฒ˜๋ฆฌ๋ฅผ ์ง€์›ํ•ฉ๋‹ˆ๋‹ค. ๋ชจ๋ธ ๋ถ„ํ•  ๋ฐฉ์‹์€ ๋‘ ๊ฐ€์ง€๊ฐ€ ์žˆ์Šต๋‹ˆ๋‹ค.

  • auto ํ…์„œ ๋ณ‘๋ ฌํ™” ๊ณ„ํš์€ ์‚ฌ์ „ ์ •์˜๋œ ๊ตฌ์„ฑ์„ ๊ธฐ๋ฐ˜์œผ๋กœ ๋ชจ๋ธ(์œ„์— ์–ธ๊ธ‰๋œ ์ง€์› ๋ชจ๋ธ)์„ ์ž๋™์œผ๋กœ ๋ถ„ํ• ํ•ฉ๋‹ˆ๋‹ค.
  • ์‚ฌ์šฉ์ž ์ง€์ • ๋ถ„ํ•  ๊ณ„ํš์„ ์ง์ ‘ ์ •์˜ํ•˜์—ฌ [~PreTrainedModel.from_pretrained] ๋ฉ”์†Œ๋“œ์˜ tp_plan ๋งค๊ฐœ๋ณ€์ˆ˜๋กœ ์ „๋‹ฌํ•  ์ˆ˜ ์žˆ์Šต๋‹ˆ๋‹ค.
import os
import torch
from transformers import AutoModelForCausalLM, AutoTokenizer

# model_id = "meta-llama/Llama-4-Scout-17B-16E-Instruct" # ๋ชจ๋“  ๊ฐ€๋Šฅํ•œ ์ „๋žต์„ ์‹œ๊ฐํ™”ํ•˜๊ธฐ์— ๋” ์ข‹์Œ
model_id = "meta-llama/Meta-Llama-3-8B-Instruct"  # ์ ์€ ์ˆ˜์˜ GPU์— ๋” ์ข‹์Œ

model = AutoModelForCausalLM.from_pretrained(model_id, dtype=torch.bfloat16, tp_plan="auto")
print(model._tp_plan)

tokenizer = AutoTokenizer.from_pretrained("meta-llama/Meta-Llama-3-8B-Instruct")
prompt = "Can I help"
inputs = tokenizer(prompt, return_tensors="pt").input_ids.to(model.device)

# ๋ถ„์‚ฐ ์‹คํ–‰
outputs = model(inputs)

์œ„์˜ ์ถ”๋ก  ์Šคํฌ๋ฆฝํŠธ๋ฅผ GPU๋‹น 4๊ฐœ ํ”„๋กœ์„ธ์Šค๋กœ torchrun์—์„œ ์‹คํ–‰ํ•˜์„ธ์š”.

torchrun --nproc-per-node 4 demo.py

๊ฐ ๋ ˆ์ด์–ด์— ๋Œ€ํ•œ ํ…์„œ ๋ณ‘๋ ฌ ๊ณ„ํš์„ tp_plan์— ์ •์˜ํ•œ ํ›„ [~PreTrainedModel.from_pretrained]์— ์ „๋‹ฌํ•˜์„ธ์š”. ์•„๋ž˜ ์˜ˆ์‹œ๋Š” ์—ด ๋ฐ ํ–‰ ๋ถ„ํ• ์„ ์กฐํ•ฉํ•˜์—ฌ ์‚ฌ์šฉํ•ฉ๋‹ˆ๋‹ค. ์ง€์›๋˜๋Š” ๋‹ค๋ฅธ ๋ถ„ํ•  ์ „๋žต์€ ๋ถ„ํ•  ์ „๋žต ์„น์…˜์„ ์ฐธ๊ณ ํ•˜์„ธ์š”.

์‚ฌ์šฉ์ž ์ง€์ • ๋ถ„ํ•  ๊ณ„ํš์„ ์ˆ˜๋™์œผ๋กœ ์ง€์ •ํ•˜๋ ค๋ฉด ๋ชจ๋ธ ์•„ํ‚คํ…์ฒ˜์™€ ๋ถ„ํ•  ์ „๋žต์ด ํ•จ๊ป˜ ์ƒํ˜ธ ์ž‘์šฉํ•˜๋Š” ๋ฐฉ์‹์— ๋Œ€ํ•œ ์ถฉ๋ถ„ํ•œ ์ดํ•ด๊ฐ€ ํ•„์š”ํ•ฉ๋‹ˆ๋‹ค. ๋ถ„ํ•  ์ „๋žต์„ ์ž˜๋ชป ์„ค์ •ํ•˜๋ฉด ๋ชจ๋ธ์ด ๋งค์šฐ ๋А๋ ค์ง€๊ฑฐ๋‚˜, ์˜ค๋ฅ˜๊ฐ€ ๋ฐœ์ƒํ•˜๊ฑฐ๋‚˜, ๋ถ€์ •ํ™•ํ•œ ๊ฒฐ๊ณผ๋ฅผ ๋‚ผ ์ˆ˜ ์žˆ์Šต๋‹ˆ๋‹ค. ์ž์„ธํžˆ ์•Œ์•„๋ณด๋ ค๋ฉด Ultra-Scale Playbook์„ ์ฐธ๊ณ ํ•˜์„ธ์š”.

from transformers import AutoModelForCausalLM

tp_plan = {
    "model.layers.*.self_attn.q_proj": "colwise",
    "model.layers.*.self_attn.k_proj": "colwise",
    "model.layers.*.self_attn.v_proj": "colwise",
    "model.layers.*.self_attn.o_proj": "rowwise",
    ...
}

model = AutoModelForCausalLM.from_pretrained(model_id, dtype=torch.bfloat16, tp_plan=tp_plan)
print(model._tp_plan)

๋ถ„ํ•  ์ „๋žต[[partitioning-strategies]]

๋ชจ๋“  ๋ถ„ํ•  ์ „๋žต์€ ๋ฌธ์ž์—ด์„ ์ „๋žต ๊ตฌํ˜„์— ๋งคํ•‘ํ•˜๋Š” [ParallelInterface] ํด๋ž˜์Šค์—์„œ ์ •์˜๋ฉ๋‹ˆ๋‹ค. ๋ชจ๋“  ์ „๋žต์€ [~PreTrainedModel.from_pretrained]์˜ tp_plan์„ ํ†ตํ•ด ์„ค์ •๋˜๋ฏ€๋กœ ์ด ํด๋ž˜์Šค์™€ ์ง์ ‘ ์ƒํ˜ธ ์ž‘์šฉํ•  ํ•„์š”๋Š” ์—†์ง€๋งŒ, ์–ด๋–ค ์ „๋žต์„ ์‚ฌ์šฉํ•  ์ˆ˜ ์žˆ๋Š”์ง€ ํ™•์ธํ•  ๋•Œ ์œ ์šฉํ•ฉ๋‹ˆ๋‹ค.

class ParallelInterface(MutableMapping):
    """
    ํ—ˆ์šฉ๋œ ์–ดํ…์…˜ ํ•จ์ˆ˜๋ฅผ ์ถ”์ ํ•˜๋Š” ๋”•์…”๋„ˆ๋ฆฌ ๊ฐ™์€ ๊ฐ์ฒด์ž…๋‹ˆ๋‹ค. `register()` ํ˜ธ์ถœ๋กœ ์ƒˆ๋กœ์šด ์–ดํ…์…˜ ํ•จ์ˆ˜๋ฅผ ์‰ฝ๊ฒŒ ์ถ”๊ฐ€ํ•  ์ˆ˜ ์žˆ์Šต๋‹ˆ๋‹ค. 
    ๋ชจ๋ธ์ด ๊ธฐ์กด ์–ดํ…์…˜ ํ•จ์ˆ˜(์˜ˆ: `sdpa`)๋ฅผ ๋กœ์ปฌ์—์„œ ๋ฎ์–ด์“ฐ๋ ค๋ฉด `modeling_<model>.py` ๋‚ด๋ถ€์—์„œ ์ด ํด๋ž˜์Šค์˜ ์ƒˆ ์ธ์Šคํ„ด์Šค๋ฅผ ์„ ์–ธํ•˜๊ณ  
    ํ•ด๋‹น ์ธ์Šคํ„ด์Šค์—์„œ ์„ ์–ธํ•ด์•ผ ํ•ฉ๋‹ˆ๋‹ค.
    """
    _global_mapping = {
        "colwise": ColwiseParallel(),
        "rowwise": RowwiseParallel(),
        "colwise_rep": ColwiseParallel(output_layouts=Replicate()),
        "rowwise_rep": RowwiseParallel(input_layouts=Replicate()),
        "local_colwise": ColwiseParallel(use_dtensor=False),
        "local_rowwise": RowwiseParallel(use_dtensor=False),
        "local": IsolatedParallel(),
        "gather": GatherParallel(),
        "local_packed_rowwise": PackedRowwiseParallel(use_dtensor=False),
        "sequence_parallel": SequenceParallel(),
        "replicate": ReplicateParallel(),
    }

๊ฐ ์ „๋žต์— ๋Œ€ํ•ด ์ž์„ธํžˆ ์•Œ์•„๋ณด๋ ค๋ฉด ์•„๋ž˜ ํ‘œ๋ฅผ ์ฐธ๊ณ ํ•˜์„ธ์š”.

์ „๋žต ์„ค๋ช…
ColwiseParallel ๊ฐ€์ค‘์น˜์™€ ํŽธํ–ฅ์˜ ์—ด ๋ฐฉํ–ฅ ๋ถ„ํ• .
RowwiseParallel ๊ฐ€์ค‘์น˜์™€ ํŽธํ–ฅ์˜ ํ–‰ ๋ฐฉํ–ฅ ๋ถ„ํ• . nn.Embedding ๋ชจ๋“ˆ ๋ถ„ํ• ๋„ ์ง€์›.
SequenceParallel LayerNorm๊ณผ Dropout ๋ ˆ์ด์–ด๋ฅผ ์ง€์›ํ•˜๋Š” ์‹œํ€€์Šค ๋ณ‘๋ ฌ ๊ตฌํ˜„. RMSNorm์˜ Python ๊ตฌํ˜„๋„ ์ง€์›.
PackedColwiseParallel ํŒจํ‚น๋œ ๊ฐ€์ค‘์น˜๋ฅผ ์ง€์›ํ•˜๋Š” ColwiseParallel์˜ ๋ณ€ํ˜•(์˜ˆ: up_proj์™€ gate_proj๋ฅผ ํ•จ๊ป˜ ํŒจํ‚น). ์ž์„ธํ•œ ๋‚ด์šฉ์€ ์ฝ”๋“œ๋ฅผ ์ฐธ์กฐํ•˜์„ธ์š”.
PackedRowwiseParallel ํŒจํ‚น๋œ ๊ฐ€์ค‘์น˜๋ฅผ ์ง€์›ํ•˜๋Š” RowwiseParallel์˜ ๋ณ€ํ˜•(์ฝ”๋“œ ์ฐธ์กฐ).
GatherParallel ๊ธฐ๊ธฐ ๊ฐ„ ๋ชจ๋“ˆ์˜ ์ถœ๋ ฅ์„ ์ˆ˜์ง‘.
IsolatedParallel Mixture-of-Experts(MoE) ๋ ˆ์ด์–ด์˜ ์ „๋ฌธ๊ฐ€์— ์‚ฌ์šฉ๋˜์–ด ๋‹ค๋ฅธ ๊ธฐ๊ธฐ๋กœ๋ถ€ํ„ฐ ๋ชจ๋“ˆ์„ ๊ฒฉ๋ฆฌ.
ReplicateParallel ๋ถ€๋ถ„์ ์œผ๋กœ ๋ถ„ํ• ๋œ ๋ชจ๋ธ๋กœ ์ธํ•ด torch.distributed API๊ฐ€ ์ค‘๋‹จ๋˜๋Š” ๊ฒƒ์„ ๋ฐฉ์ง€ํ•˜๊ธฐ ์œ„ํ•ด ๋ชจ๋“  ๊ธฐ๊ธฐ์— ๋ชจ๋“ˆ์„ ๋ณต์ œ.

ํŒจํ‚น๋œ ์ „๋žต[[packed-strategies]]

๊ฐ€์ค‘์น˜ ํŒจํ‚น์€ ์—ฌ๋Ÿฌ ์„ ํ˜• ๋ ˆ์ด์–ด๋ฅผ ํ•˜๋‚˜์˜ ๋” ํฐ ๋ ˆ์ด์–ด๋กœ ํ•ฉ์น˜๋Š” ๊ธฐ๋ฒ•์ž…๋‹ˆ๋‹ค. ํŒจํ‚น๋œ ์ „๋žต์ธ PackedColwiseParallel๊ณผ PackedRowwiseParallel์€ ํŒจํ‚น๋œ ๊ฐ€์ค‘์น˜๋ฅผ ๋ถ„ํ• ํ•˜๋Š” ๋ฐ ์‚ฌ์šฉ๋ฉ๋‹ˆ๋‹ค. ๊ธฐ๋ณธ์ ์ธ ColwiseParallel์ด๋‚˜ RowwiseParallel์€ ํŒจํ‚น๋œ ๊ฐ€์ค‘์น˜๋ฅผ ์˜ฌ๋ฐ”๋ฅด๊ฒŒ ๋ถ„ํ• ํ•˜์ง€ ๋ชปํ•ฉ๋‹ˆ๋‹ค.

์•„๋ž˜ ์˜ˆ์‹œ๋Š” up_proj์™€ gate_proj๋ฅผ ๋‹จ์ผ gate_up_proj ๋ชจ๋“ˆ๋กœ ํŒจํ‚นํ•˜๊ณ  gate_up_proj๋ฅผ ๋ถ„ํ• ํ•˜๊ธฐ ์œ„ํ•ด PackedRowwiseParallel ์ „๋žต์ด ํ•„์š”ํ•ฉ๋‹ˆ๋‹ค.

class Llama4TextExperts(nn.Module):
    ...
    self.gate_up_proj = nn.Parameter(torch.zeros(self.num_experts, self.hidden_size, 2 * self.expert_dim))

๋ฐฐ์น˜ ํ–‰๋ ฌ ๊ณฑ์…ˆ์„ forward ํŒจ์Šค์—์„œ ์‚ฌ์šฉํ•˜์—ฌ gate_up_proj ๋ชจ๋“ˆ์˜ ์ถœ๋ ฅ์„ ๊ณ„์‚ฐํ•  ์ˆ˜ ์žˆ์Šต๋‹ˆ๋‹ค.

def forward(self, hidden_states):
    ...
    gate_up = torch.bmm(hidden_states, self.gate_up_proj) # gate_up_proj ๋ชจ๋“ˆ์˜ ์ถœ๋ ฅ ๊ณ„์‚ฐ
    gate, up = gate_up.chunk(2, dim=-1) # ์ถœ๋ ฅ์„ gate์™€ up์œผ๋กœ ๋ถ„ํ• 

Packed*๋ฅผ ์‚ฌ์šฉํ•ด์•ผ ํ•˜๋Š” ์ด์œ ์— ๋Œ€ํ•œ ์‹œ๊ฐ์  ํ‘œํ˜„์€ ์ด ์ฃผ์„์„ ์ฐธ๊ณ ํ•˜์„ธ์š”.

๋กœ์ปฌ ์ „๋žต[[local-strategies]]

๋กœ์ปฌ ์ „๋žต(local_colwise, local_rowwise, local_packed_rowwise)์€ torch.chunk์™€ ๊ฐ™์€ ์ผ๋ถ€ ์—ฐ์‚ฐ์—์„œ ์ง€์›๋˜์ง€ ์•Š๊ธฐ ๋•Œ๋ฌธ์— DTensor๋ฅผ ์‚ฌ์šฉํ•˜์ง€ ์•Š์Šต๋‹ˆ๋‹ค. ๋Œ€์‹  ๋กœ์ปฌ ์ „๋žต์€ ๊ธฐ๋ณธ torch.Tensor๋ฅผ ์‚ฌ์šฉํ•˜๊ณ  ์ผ๋ถ€ ๋ถ„์‚ฐ ๋กœ์ง์„ ์ˆ˜๋™์œผ๋กœ ์ˆ˜ํ–‰ํ•ฉ๋‹ˆ๋‹ค.

์‚ฌ์šฉ์ž ์ •์˜ ๋ถ„ํ•  ์ „๋žต[[custom-partitioning-strategies]]

์‚ฌ์šฉ์ž ์ •์˜ ๋ถ„ํ•  ์ „๋žต์€ TensorParallelLayer๋ฅผ ์ƒ์†ํ•˜๊ณ  partition_tensor, _prepare_input_fn, _prepare_output_fn์„ ๊ตฌํ˜„ํ•ด์•ผ ํ•ฉ๋‹ˆ๋‹ค.

๊ทธ๋Ÿฐ ๋‹ค์Œ tp_plan์—์„œ ํ•ด๋‹น ์ „๋žต์„ ์ง€์ •ํ–ˆ์„ ๋•Œ ๋””์ŠคํŒจ์นญ ๋กœ์ง์ด ์ฐพ์„ ์ˆ˜ ์žˆ๋„๋ก ParallelInterface ๋งคํ•‘์— ๋“ฑ๋กํ•ด์•ผ ํ•ฉ๋‹ˆ๋‹ค.

์•„๋ž˜ ์˜ˆ์‹œ๋Š” ์ด ์›Œํฌํ”Œ๋กœ์šฐ๋กœ ColwiseParallel์„ ๊ตฌํ˜„ํ•˜๋Š” ๋ฐฉ๋ฒ•์„ ๋ณด์—ฌ์ค๋‹ˆ๋‹ค.

  1. TensorParallelLayer๋ฅผ ์ƒ์†ํ•ฉ๋‹ˆ๋‹ค. __init__ ๋ฉ”์†Œ๋“œ์—์„œ ์ž…๋ ฅ ๋ฐ ์ถœ๋ ฅ ํ…์„œ๊ฐ€ ๊ธฐ๊ธฐ์— ์–ด๋–ป๊ฒŒ ๋ฐฐ์น˜๋˜์–ด์•ผ ํ•˜๋Š”์ง€ ์„ค๋ช…ํ•˜๋Š” input_layouts๊ณผ output_layouts์„ ์ •์˜ํ•ฉ๋‹ˆ๋‹ค. desired_input_layouts ์†์„ฑ์€ ์ž…๋ ฅ์ด ๊ธฐ๊ธฐ์— ์–ด๋–ป๊ฒŒ ๋ฐฐ์น˜๋˜์–ด์•ผ๋งŒ ํ•˜๋Š”์ง€๋ฅผ ๋ช…์‹œํ•˜๋Š” ๋ฐ ์‚ฌ์šฉ๋ฉ๋‹ˆ๋‹ค.

    class ColwiseParallel(TensorParallelLayer):
        def __init__(
            self,
            *,
            input_layouts: Optional[Placement] = None, # ์ด์ „ ๋ ˆ์ด์–ด์—์„œ ์˜ค๋Š” ์ž…๋ ฅ ๋ ˆ์ด์•„์›ƒ
            output_layouts: Optional[Placement] = None, # ๋‹ฌ์„ฑํ•˜๊ณ ์ž ํ•˜๋Š” ์ถœ๋ ฅ ๋ ˆ์ด์•„์›ƒ
            use_local_output: bool = True, # ๋กœ์ปฌ ์ถœ๋ ฅ ์‚ฌ์šฉ ์—ฌ๋ถ€
            use_dtensor=True, # DTensor ์‚ฌ์šฉ ์—ฌ๋ถ€
        ):
            self.input_layouts = (input_layouts or Replicate(),) # ์ด์ „ ๋ ˆ์ด์–ด์—์„œ ์˜ค๋Š” ์ž…๋ ฅ ๋ถ„ํ• 
            self.output_layouts = (output_layouts or Shard(-1),) # ์›ํ•˜๋Š” ์ถœ๋ ฅ ๋ถ„ํ• 
            self.desired_input_layouts = (Replicate(),) # ์›ํ•˜๋Š” ์ž…๋ ฅ ๋ถ„ํ• , ์ž…๋ ฅ์€ GPU ๊ฐ„์— ๋ณต์ œ๋˜์–ด์•ผ ํ•จ
            self.use_local_output = use_local_output
            self.use_dtensor = use_dtensor
    
  2. partition_tensor, _prepare_input_fn, _prepare_output_fn ๋ฉ”์„œ๋“œ๋ฅผ ๊ตฌํ˜„ํ•ฉ๋‹ˆ๋‹ค.

    partition_tensor ๋ฉ”์†Œ๋“œ๋Š” ํ…์„œ๋ฅผ ๋ถ„ํ• ํ•˜๊ณ  ๋ถ„ํ• ๋œ ํ…์„œ๋กœ empty_param์„ ์ฑ„์›๋‹ˆ๋‹ค. ์œ ํ‹ธ๋ฆฌํ‹ฐ ํ•จ์ˆ˜ get_tensor_shard๋ฅผ ์‚ฌ์šฉํ•˜์—ฌ ์ฃผ์–ด์ง„ ๋žญํฌ์— ๋Œ€ํ•œ ์›๋ณธ ๋งค๊ฐœ๋ณ€์ˆ˜์˜ ์˜ฌ๋ฐ”๋ฅธ ๋ถ„ํ• ์„ ์–ป๊ณ , ํŒจํ‚น๋œ ๊ฐ€์ค‘์น˜์— ๋Œ€ํ•ด์„œ๋Š” get_packed_weights๋ฅผ ์‚ฌ์šฉํ•˜์„ธ์š”.

    def partition_tensor(
        self,
        param, # ๋งค๊ฐœ๋ณ€์ˆ˜์˜ ์ „์ฒด ํ…์„œ
        empty_param, # ๋งค๊ฐœ๋ณ€์ˆ˜์˜ ๋นˆ ํ…์„œ, ๋ถ„ํ• ๋œ ํ…์„œ๋กœ ์ฑ„์›Œ์ง
        param_type, # ๋งค๊ฐœ๋ณ€์ˆ˜ ์œ ํ˜•, `bias` ๋˜๋Š” `weight`
        param_casting_dtype, # ๋งค๊ฐœ๋ณ€์ˆ˜๋ฅผ ์บ์ŠคํŒ…ํ•  ์œ ํ˜•
        to_contiguous, # ํ…์„œ๋ฅผ ์—ฐ์†์ ์ธ ๋ฉ”๋ชจ๋ฆฌ ๋ ˆ์ด์•„์›ƒ์œผ๋กœ ๋ณ€ํ™˜ํ• ์ง€ ์—ฌ๋ถ€
        rank, # ํ˜„์žฌ ๊ธฐ๊ธฐ์˜ ๋žญํฌ
        device_mesh, # ๊ธฐ๊ธฐ ๋ฉ”์‹œ
    ) -> nn.Parameter: # ๋ถ„ํ• ๋œ ๋งค๊ฐœ๋ณ€์ˆ˜ ๋ฐ˜ํ™˜
        ...
    

    _prepare_input_fn๊ณผ _prepare_output_fn ๋ฉ”์†Œ๋“œ๋Š” ์‚ฌ์ „ ํฌ์›Œ๋“œ ๋ฐ ํฌ์›Œ๋“œ ํ›…์—์„œ ์‚ฌ์šฉ๋ฉ๋‹ˆ๋‹ค. __init__์—์„œ ์ง€์ •๋œ ๋Œ€๋กœ ์ž…๋ ฅ๊ณผ ์ถœ๋ ฅ์„ ์›ํ•˜๋Š” ๋ ˆ์ด์•„์›ƒ์œผ๋กœ ์žฌ๋ถ„๋ฐฐํ•ฉ๋‹ˆ๋‹ค.

    def _prepare_input_fn(input_layouts, desired_input_layouts, mod, inputs, device_mesh):
        ...
        # ์‚ฌ์šฉ์ž ์ •์˜ ๋กœ์ง ์ˆ˜ํ–‰, DTensor๋กœ ์บ์ŠคํŒ… ๋“ฑ.
        ...
        return inputs.redistribute(placements=desired_input_layouts, device_mesh=device_mesh)
    def _prepare_output_fn(output_layouts, use_local_output, mod, outputs, device_mesh):
        ...
        # ์‚ฌ์šฉ์ž ์ •์˜ ๋กœ์ง ์ˆ˜ํ–‰, DTensor๋กœ ์บ์ŠคํŒ… ๋“ฑ.
        ...
        return outputs.redistribute(placements=output_layouts, device_mesh=device_mesh)
    
  3. tp_plan๊ณผ ํ•จ๊ป˜ ์‚ฌ์šฉํ•  ์ˆ˜ ์žˆ๋„๋ก ์ „๋žต์„ [ParallelInterface]์— ๋“ฑ๋กํ•ฉ๋‹ˆ๋‹ค.

    from transformers.integrations.tensor_parallel import ParallelInterface
    
    ParallelInterface.register_strategy("colwise_custom", ColwiseParallel)
    tp_plan = {
        "model.layers.*.self_attn.q_proj": "colwise_custom",
        ...
    }
    model = AutoModelForCausalLM.from_pretrained(model_id, dtype=torch.bfloat16, tp_plan=tp_plan)
    

๋ฒค์น˜๋งˆํฌ[[benchmarks]]

ํ…์„œ ๋ณ‘๋ ฌํ™”๋Š” ํŠนํžˆ ํฐ ๋ฐฐ์น˜ ํฌ๊ธฐ๋‚˜ ๊ธด ์‹œํ€€์Šค๋ฅผ ๊ฐ€์ง„ ์ž…๋ ฅ์— ๋Œ€ํ•œ ์ถ”๋ก  ์†๋„๋ฅผ ํฌ๊ฒŒ ํ–ฅ์ƒ์‹œํ‚ฌ ์ˆ˜ ์žˆ์Šต๋‹ˆ๋‹ค.

์‹œํ€€์Šค ๊ธธ์ด๊ฐ€ 512์ธ Llama์—์„œ ๋‹จ์ผ ํฌ์›Œ๋“œ ํŒจ์Šค์— ๋Œ€ํ•œ ์˜ˆ์ƒ ์†๋„ ํ–ฅ์ƒ ์ˆ˜์น˜๋Š” ์•„๋ž˜ ์ฐจํŠธ๋ฅผ ์ฐธ์กฐํ•˜์„ธ์š”.

์„ค๊ณ„ ๊ตฌํ˜„[[design-implementation]]

Transformers ํ…์„œ ๋ณ‘๋ ฌํ™” ๊ตฌํ˜„์€ ํ”„๋ ˆ์ž„์›Œํฌ์— ๊ตฌ์• ๋ฐ›์ง€ ์•Š์ง€๋งŒ, ๊ตฌ์ฒด์ ์ธ ๊ตฌํ˜„์„ ์œ„ํ•ด์„œ๋Š” DeviceMesh์™€ torch.distributed์˜ DTensor์— ์˜์กดํ•˜์—ฌ ๊ฐ„๋‹จํ•˜๊ณ  ํ™•์žฅ ๊ฐ€๋Šฅํ•œ ์ธํ„ฐํŽ˜์ด์Šค๋ฅผ ์ œ๊ณตํ•ฉ๋‹ˆ๋‹ค.

DeviceMesh[[devicemesh]]

DeviceMesh๋ฅผ ํ•จ๊ป˜ ํ†ต์‹ ํ•˜๋Š” ๊ธฐ๊ธฐ๋“ค์˜ ๋‹ค์ฐจ์› ๊ทธ๋ฆฌ๋“œ๋กœ ์ƒ์ƒํ•ด๋ณด์„ธ์š”. ๋ณ‘๋ ฌ ์ฒ˜๋ฆฌ ์ „๋žต๋งˆ๋‹ค ๊ฐ๊ธฐ ๋‹ค๋ฅธ ํ†ต์‹  ํŒจํ„ด์ด ํ•„์š”ํ•˜๋ฏ€๋กœ, ์—ฌ๋Ÿฌ ํ•˜์œ„ ๋ฉ”์‹œ๋ฅผ ๊ฐ€์ง„ DeviceMesh๋ฅผ ๋งŒ๋“ค ์ˆ˜ ์žˆ์Šต๋‹ˆ๋‹ค.

from torch.distributed.device_mesh import init_device_mesh

# 4๊ฐœ GPU์˜ 1D ๋ฉ”์‹œ ์ƒ์„ฑ
device_mesh = init_device_mesh("cuda", (4,), mesh_dim_names=["tp"])

torch.distributed์—์„œ ์ •์˜๋œ ๋Œ€๋ถ€๋ถ„์˜ ๋ณ‘๋ ฌํ™” ์ „๋žต์€ ๋ฉ”์‹œ ์ž์ฒด๋‚˜ ํ•˜์œ„ ๋ฉ”์‹œ์— ์ ์šฉํ•  ์ˆ˜ ์žˆ์œผ๋ฉฐ, ์ž๋™์œผ๋กœ ํ†ต์‹  ํŒจํ„ด์„ ์ฒ˜๋ฆฌํ•ฉ๋‹ˆ๋‹ค.

DTensor[[dtensor]]

DTensor(๋ถ„์‚ฐ ํ…์„œ)๋Š” ์ผ๋ฐ˜์ ์ธ ํ…์„œ ์—ฐ์‚ฐ ์œ„์— ๋ถ„์‚ฐ ๋กœ์ง์„ ์ฒ˜๋ฆฌํ•˜๋Š” ํ…์„œ ํ•˜์œ„ ํด๋ž˜์Šค์ž…๋‹ˆ๋‹ค. ํ…์„œ ๋ณ‘๋ ฌํ™”์˜ ๋Œ€๋ถ€๋ถ„์˜ ๋ชจ๋ธ ๊ฐ€์ค‘์น˜๋Š” DTensor ํ˜•ํƒœ๋กœ ์ €์žฅ๋ฉ๋‹ˆ๋‹ค.

DTensor์˜ ๊ฐ€์žฅ ์ค‘์š”ํ•œ ๋ถ€๋ถ„์€ placement ์†์„ฑ์ž…๋‹ˆ๋‹ค. ์ด๋Š” PyTorch์—๊ฒŒ ํ…์„œ๊ฐ€ DeviceMesh์˜ ๊ธฐ๊ธฐ์— ์–ด๋–ป๊ฒŒ ๋ฐฐ์น˜๋˜๋Š”์ง€ ์•Œ๋ ค์ฃผ๊ธฐ ๋•Œ๋ฌธ์ž…๋‹ˆ๋‹ค. placement ์†์„ฑ์€ ๋‹ค์Œ ๊ฐ’์„ ๊ฐ€์งˆ ์ˆ˜ ์žˆ์Šต๋‹ˆ๋‹ค.

  • Shard(dimension) - DTensor๊ฐ€ ๊ตฌ์„ฑ๋œ DeviceMesh์—์„œ ์ฃผ์–ด์ง„ ์ฐจ์›์— ๊ฑธ์ณ ์–ด๋–ป๊ฒŒ ๋ถ„ํ• ๋˜๋Š”์ง€ ๋‚˜ํƒ€๋ƒ…๋‹ˆ๋‹ค. ์•„๋ž˜ ์˜ˆ์‹œ๋Š” ์—ด ๋ฐฉํ–ฅ ๋ถ„ํ• ์„ ์œ„ํ•ด ๋‹ค์–‘ํ•œ ์ฐจ์›์— ๊ฑธ์ณ ๊ฐ€์ค‘์น˜๋ฅผ ๋ถ„ํ• ํ•˜๋Š” ๋ฐฉ๋ฒ•์„ ๋ณด์—ฌ์ค๋‹ˆ๋‹ค.

    weight = ...
    weight = DTensor.from_local(weight, device_mesh["tp"], placements=[Shard(0)]) # ์ฒซ ๋ฒˆ์งธ(์—ด ๋ฐฉํ–ฅ) ์ฐจ์›์— ๊ฑธ์ณ ๋ถ„ํ• 
    bias = ...
    bias = DTensor.from_local(bias, device_mesh["tp"], placements=[Shard(-1)]) # ์œ ์ผํ•œ ์ฐจ์›์— ๊ฑธ์ณ ๋ถ„ํ• 
    

    ์ด ์˜ˆ์‹œ๋Š” ํ–‰ ๋ฐฉํ–ฅ ๋ถ„ํ• ์„ ์œ„ํ•ด ์—ฌ๋Ÿฌ ์ฐจ์›์— ๊ฑธ์ณ ๊ฐ€์ค‘์น˜๋ฅผ ๋ถ„ํ• ํ•˜๋Š” ๋ฐฉ๋ฒ•์„ ๋ณด์—ฌ์ค๋‹ˆ๋‹ค.

    weight = ...
    weight = DTensor.from_local(weight, device_mesh["tp"], placements=[Shard(1)]) # ๋‘ ๋ฒˆ์งธ(ํ–‰ ๋ฐฉํ–ฅ) ์ฐจ์›์— ๊ฑธ์ณ ๋ถ„ํ• 
    bias = ...
    bias = DTensor.from_local(bias, device_mesh["tp"], placements=[Replicate()]) # ๋ชจ๋“  GPU์— ํŽธํ–ฅ ๋ณต์ œ
    
  • Replicate() - DTensor๊ฐ€ DeviceMesh์— ๊ฑธ์ณ ๋ณต์ œ๋จ์„ ๋‚˜ํƒ€๋ƒ…๋‹ˆ๋‹ค. ๊ฐ ๊ธฐ๊ธฐ์— ํ…์„œ์˜ ์ „์ฒด ์‚ฌ๋ณธ๋งŒ ์ƒ์„ฑํ•ฉ๋‹ˆ๋‹ค.

    bias = ...
    bias = DTensor.from_local(bias, device_mesh["tp"], placements=[Replicate()]) # ๋ชจ๋“  GPU์— ํŽธํ–ฅ ๋ณต์ œ
    
  • Partial() - ํ…์„œ๊ฐ€ ๊ฐ์†Œ ์—ฐ์‚ฐ์„ ๊ธฐ๋‹ค๋ฆฌ๊ณ  ์žˆ๋Š” ์ƒํƒœ์ž„์„ ๋‚˜ํƒ€๋ƒ…๋‹ˆ๋‹ค (์ผ๋ฐ˜์ ์œผ๋กœ Transformers์—์„œ์˜ ์‚ฌ์šฉ ์‚ฌ๋ก€์™€๋Š” ์ง์ ‘์ ์ธ ๊ด€๋ จ์ด ์ ์Šต๋‹ˆ๋‹ค).