๋ถ์ฐ ์ถ๋ก [[distributed-inference]]
๋ชจ๋ธ์ด ๋จ์ผ GPU์ ์ฌ๋ผ๊ฐ์ง ์๋ ๊ฒฝ์ฐ, ํ ์ ๋ณ๋ ฌ ์ฒ๋ฆฌ๋ฅผ ์ฌ์ฉํ ๋ถ์ฐ ์ถ๋ก ์ด ๋์์ด ๋ ์ ์์ต๋๋ค. ํ ์ ๋ณ๋ ฌํ๋ ๋ชจ๋ธ์ ์ฌ๋ฌ ๊ฐ์๊ธฐ(CUDA GPU, Intel XPU ๋ฑ)์ ๋ถํ ํ์ฌ ํ๋ ฌ ๊ณฑ์ ๊ณผ ๊ฐ์ ๊ณ์ฐ์ ๋ณ๋ ฌํํฉ๋๋ค. ์ด๋ฅผ ํตํด ๋ ํฐ ๋ชจ๋ธ์ ๋ฉ๋ชจ๋ฆฌ์ ์ฌ๋ฆด ์ ์์ผ๋ฉฐ, ๊ฐ ๊ฐ์๊ธฐ๊ฐ ํ ์์ ์ผ๋ถ๋ฅผ ์ฒ๋ฆฌํ๋ฏ๋ก ์ถ๋ก ์๋๊ฐ ํฅ์๋ฉ๋๋ค.
๊ทธ๋ฌ๋ ํ ์ ๋ณ๋ ฌํ๋ ํต์ ์ค๋ฒํค๋๋ฅผ ๋ฐ์์ํค๋ฏ๋ก, ๋น ๋ฅธ ๋ ธ๋ ๋ด ํต์ ์ ํ์ฉํ ์ ์๋ ๋ค์ค ๊ฐ์๊ธฐ ํ๊ฒฝ์์ ์ฌ์ฉํ๋ ๊ฒ์ด ๊ฐ์ฅ ํจ๊ณผ์ ์ ๋๋ค. ๋ค์ค ๋ ธ๋ ํ์ต ํ๊ฒฝ์์๋ ์ฌ์ฉ ์ฌ๋ก์ ๋ฐ๋ผ ํ์ดํ๋ผ์ธ ๋ณ๋ ฌํ๋ ๋ฐ์ดํฐ ๋ณ๋ ฌํ๋ฅผ ์ฌ์ฉํ๋ ๊ฒ์ด ๋ ํจ์จ์ ์ผ ์ ์์ต๋๋ค.
ํ ์ ๋ณ๋ ฌํ์ ๋ํด ๋ ์์ธํ ์์๋ณด๋ ค๋ฉด Ultra-Scale Playbook์ ํ ์ ๋ณ๋ ฌํ ์น์ ์ ์ฐธ์กฐํ์ธ์.
์๋ ๋ชฉ๋ก์์ ํ ์ ๋ณ๋ ฌ ์ฒ๋ฆฌ๋ฅผ ๊ธฐ๋ณธ์ ์ผ๋ก ์ง์ํ๋ ๋ชจ๋ธ์ ํ์ธํ ์ ์์ต๋๋ค. ์๋ก์ด ๋ชจ๋ธ์ ๋ํ ์ง์์ ์ถ๊ฐํ๋ ค๋ฉด GitHub ์ด์๋ ํ ๋ฆฌํ์คํธ๋ฅผ ์ด์ด์ฃผ์ธ์.
์ง์๋๋ ๋ชจ๋ธ ๋ณด๊ธฐ
์ด ๊ฐ์ด๋๋ Transformers์์ ๋ค์ํ ๋ถํ ์ ๋ต์ ์ฌ์ฉํ์ฌ ํ ์ ๋ณ๋ ฌํ๋ฅผ ํ์ฑํํ๋ ๋ฐฉ๋ฒ์ ์ค๋ช ํฉ๋๋ค.
๋ชจ๋ธ ๋ถํ [[partitioning-a-model]]
Transformers๋ tp_plan๋งค๊ฐ๋ณ์๋ฅผ ํ์ฉํ ์ ์๋ ๋ชจ๋ธ์ ๋ํด ํ
์ ๋ณ๋ ฌ ์ฒ๋ฆฌ๋ฅผ ์ง์ํฉ๋๋ค. ๋ชจ๋ธ ๋ถํ ๋ฐฉ์์ ๋ ๊ฐ์ง๊ฐ ์์ต๋๋ค.
autoํ ์ ๋ณ๋ ฌํ ๊ณํ์ ์ฌ์ ์ ์๋ ๊ตฌ์ฑ์ ๊ธฐ๋ฐ์ผ๋ก ๋ชจ๋ธ(์์ ์ธ๊ธ๋ ์ง์ ๋ชจ๋ธ)์ ์๋์ผ๋ก ๋ถํ ํฉ๋๋ค.- ์ฌ์ฉ์ ์ง์ ๋ถํ ๊ณํ์ ์ง์ ์ ์ํ์ฌ [~PreTrainedModel.from_pretrained] ๋ฉ์๋์
tp_plan๋งค๊ฐ๋ณ์๋ก ์ ๋ฌํ ์ ์์ต๋๋ค.
import os
import torch
from transformers import AutoModelForCausalLM, AutoTokenizer
# model_id = "meta-llama/Llama-4-Scout-17B-16E-Instruct" # ๋ชจ๋ ๊ฐ๋ฅํ ์ ๋ต์ ์๊ฐํํ๊ธฐ์ ๋ ์ข์
model_id = "meta-llama/Meta-Llama-3-8B-Instruct" # ์ ์ ์์ GPU์ ๋ ์ข์
model = AutoModelForCausalLM.from_pretrained(model_id, dtype=torch.bfloat16, tp_plan="auto")
print(model._tp_plan)
tokenizer = AutoTokenizer.from_pretrained("meta-llama/Meta-Llama-3-8B-Instruct")
prompt = "Can I help"
inputs = tokenizer(prompt, return_tensors="pt").input_ids.to(model.device)
# ๋ถ์ฐ ์คํ
outputs = model(inputs)
์์ ์ถ๋ก ์คํฌ๋ฆฝํธ๋ฅผ GPU๋น 4๊ฐ ํ๋ก์ธ์ค๋ก torchrun์์ ์คํํ์ธ์.
torchrun --nproc-per-node 4 demo.py
๊ฐ ๋ ์ด์ด์ ๋ํ ํ
์ ๋ณ๋ ฌ ๊ณํ์ tp_plan์ ์ ์ํ ํ [~PreTrainedModel.from_pretrained]์ ์ ๋ฌํ์ธ์. ์๋ ์์๋ ์ด ๋ฐ ํ ๋ถํ ์ ์กฐํฉํ์ฌ ์ฌ์ฉํฉ๋๋ค. ์ง์๋๋ ๋ค๋ฅธ ๋ถํ ์ ๋ต์ ๋ถํ ์ ๋ต ์น์
์ ์ฐธ๊ณ ํ์ธ์.
์ฌ์ฉ์ ์ง์ ๋ถํ ๊ณํ์ ์๋์ผ๋ก ์ง์ ํ๋ ค๋ฉด ๋ชจ๋ธ ์ํคํ ์ฒ์ ๋ถํ ์ ๋ต์ด ํจ๊ป ์ํธ ์์ฉํ๋ ๋ฐฉ์์ ๋ํ ์ถฉ๋ถํ ์ดํด๊ฐ ํ์ํฉ๋๋ค. ๋ถํ ์ ๋ต์ ์๋ชป ์ค์ ํ๋ฉด ๋ชจ๋ธ์ด ๋งค์ฐ ๋๋ ค์ง๊ฑฐ๋, ์ค๋ฅ๊ฐ ๋ฐ์ํ๊ฑฐ๋, ๋ถ์ ํํ ๊ฒฐ๊ณผ๋ฅผ ๋ผ ์ ์์ต๋๋ค. ์์ธํ ์์๋ณด๋ ค๋ฉด Ultra-Scale Playbook์ ์ฐธ๊ณ ํ์ธ์.
from transformers import AutoModelForCausalLM
tp_plan = {
"model.layers.*.self_attn.q_proj": "colwise",
"model.layers.*.self_attn.k_proj": "colwise",
"model.layers.*.self_attn.v_proj": "colwise",
"model.layers.*.self_attn.o_proj": "rowwise",
...
}
model = AutoModelForCausalLM.from_pretrained(model_id, dtype=torch.bfloat16, tp_plan=tp_plan)
print(model._tp_plan)
๋ถํ ์ ๋ต[[partitioning-strategies]]
๋ชจ๋ ๋ถํ ์ ๋ต์ ๋ฌธ์์ด์ ์ ๋ต ๊ตฌํ์ ๋งคํํ๋ [ParallelInterface] ํด๋์ค์์ ์ ์๋ฉ๋๋ค. ๋ชจ๋ ์ ๋ต์ [~PreTrainedModel.from_pretrained]์ tp_plan์ ํตํด ์ค์ ๋๋ฏ๋ก ์ด ํด๋์ค์ ์ง์ ์ํธ ์์ฉํ ํ์๋ ์์ง๋ง, ์ด๋ค ์ ๋ต์ ์ฌ์ฉํ ์ ์๋์ง ํ์ธํ ๋ ์ ์ฉํฉ๋๋ค.
class ParallelInterface(MutableMapping):
"""
ํ์ฉ๋ ์ดํ
์
ํจ์๋ฅผ ์ถ์ ํ๋ ๋์
๋๋ฆฌ ๊ฐ์ ๊ฐ์ฒด์
๋๋ค. `register()` ํธ์ถ๋ก ์๋ก์ด ์ดํ
์
ํจ์๋ฅผ ์ฝ๊ฒ ์ถ๊ฐํ ์ ์์ต๋๋ค.
๋ชจ๋ธ์ด ๊ธฐ์กด ์ดํ
์
ํจ์(์: `sdpa`)๋ฅผ ๋ก์ปฌ์์ ๋ฎ์ด์ฐ๋ ค๋ฉด `modeling_<model>.py` ๋ด๋ถ์์ ์ด ํด๋์ค์ ์ ์ธ์คํด์ค๋ฅผ ์ ์ธํ๊ณ
ํด๋น ์ธ์คํด์ค์์ ์ ์ธํด์ผ ํฉ๋๋ค.
"""
_global_mapping = {
"colwise": ColwiseParallel(),
"rowwise": RowwiseParallel(),
"colwise_rep": ColwiseParallel(output_layouts=Replicate()),
"rowwise_rep": RowwiseParallel(input_layouts=Replicate()),
"local_colwise": ColwiseParallel(use_dtensor=False),
"local_rowwise": RowwiseParallel(use_dtensor=False),
"local": IsolatedParallel(),
"gather": GatherParallel(),
"local_packed_rowwise": PackedRowwiseParallel(use_dtensor=False),
"sequence_parallel": SequenceParallel(),
"replicate": ReplicateParallel(),
}
๊ฐ ์ ๋ต์ ๋ํด ์์ธํ ์์๋ณด๋ ค๋ฉด ์๋ ํ๋ฅผ ์ฐธ๊ณ ํ์ธ์.
| ์ ๋ต | ์ค๋ช |
|---|---|
ColwiseParallel |
๊ฐ์ค์น์ ํธํฅ์ ์ด ๋ฐฉํฅ ๋ถํ . |
RowwiseParallel |
๊ฐ์ค์น์ ํธํฅ์ ํ ๋ฐฉํฅ ๋ถํ . nn.Embedding ๋ชจ๋ ๋ถํ ๋ ์ง์. |
SequenceParallel |
LayerNorm๊ณผ Dropout ๋ ์ด์ด๋ฅผ ์ง์ํ๋ ์ํ์ค ๋ณ๋ ฌ ๊ตฌํ. RMSNorm์ Python ๊ตฌํ๋ ์ง์. |
PackedColwiseParallel |
ํจํน๋ ๊ฐ์ค์น๋ฅผ ์ง์ํ๋ ColwiseParallel์ ๋ณํ(์: up_proj์ gate_proj๋ฅผ ํจ๊ป ํจํน). ์์ธํ ๋ด์ฉ์ ์ฝ๋๋ฅผ ์ฐธ์กฐํ์ธ์. |
PackedRowwiseParallel |
ํจํน๋ ๊ฐ์ค์น๋ฅผ ์ง์ํ๋ RowwiseParallel์ ๋ณํ(์ฝ๋ ์ฐธ์กฐ). |
GatherParallel |
๊ธฐ๊ธฐ ๊ฐ ๋ชจ๋์ ์ถ๋ ฅ์ ์์ง. |
IsolatedParallel |
Mixture-of-Experts(MoE) ๋ ์ด์ด์ ์ ๋ฌธ๊ฐ์ ์ฌ์ฉ๋์ด ๋ค๋ฅธ ๊ธฐ๊ธฐ๋ก๋ถํฐ ๋ชจ๋์ ๊ฒฉ๋ฆฌ. |
ReplicateParallel |
๋ถ๋ถ์ ์ผ๋ก ๋ถํ ๋ ๋ชจ๋ธ๋ก ์ธํด torch.distributed API๊ฐ ์ค๋จ๋๋ ๊ฒ์ ๋ฐฉ์งํ๊ธฐ ์ํด ๋ชจ๋ ๊ธฐ๊ธฐ์ ๋ชจ๋์ ๋ณต์ . |
ํจํน๋ ์ ๋ต[[packed-strategies]]
๊ฐ์ค์น ํจํน์ ์ฌ๋ฌ ์ ํ ๋ ์ด์ด๋ฅผ ํ๋์ ๋ ํฐ ๋ ์ด์ด๋ก ํฉ์น๋ ๊ธฐ๋ฒ์
๋๋ค. ํจํน๋ ์ ๋ต์ธ PackedColwiseParallel๊ณผ PackedRowwiseParallel์ ํจํน๋ ๊ฐ์ค์น๋ฅผ ๋ถํ ํ๋ ๋ฐ ์ฌ์ฉ๋ฉ๋๋ค. ๊ธฐ๋ณธ์ ์ธ ColwiseParallel์ด๋ RowwiseParallel์ ํจํน๋ ๊ฐ์ค์น๋ฅผ ์ฌ๋ฐ๋ฅด๊ฒ ๋ถํ ํ์ง ๋ชปํฉ๋๋ค.
์๋ ์์๋ up_proj์ gate_proj๋ฅผ ๋จ์ผ gate_up_proj ๋ชจ๋๋ก ํจํนํ๊ณ gate_up_proj๋ฅผ ๋ถํ ํ๊ธฐ ์ํด PackedRowwiseParallel ์ ๋ต์ด ํ์ํฉ๋๋ค.
class Llama4TextExperts(nn.Module):
...
self.gate_up_proj = nn.Parameter(torch.zeros(self.num_experts, self.hidden_size, 2 * self.expert_dim))
๋ฐฐ์น ํ๋ ฌ ๊ณฑ์
์ forward ํจ์ค์์ ์ฌ์ฉํ์ฌ gate_up_proj ๋ชจ๋์ ์ถ๋ ฅ์ ๊ณ์ฐํ ์ ์์ต๋๋ค.
def forward(self, hidden_states):
...
gate_up = torch.bmm(hidden_states, self.gate_up_proj) # gate_up_proj ๋ชจ๋์ ์ถ๋ ฅ ๊ณ์ฐ
gate, up = gate_up.chunk(2, dim=-1) # ์ถ๋ ฅ์ gate์ up์ผ๋ก ๋ถํ
Packed*๋ฅผ ์ฌ์ฉํด์ผ ํ๋ ์ด์ ์ ๋ํ ์๊ฐ์ ํํ์ ์ด ์ฃผ์์ ์ฐธ๊ณ ํ์ธ์.
๋ก์ปฌ ์ ๋ต[[local-strategies]]
๋ก์ปฌ ์ ๋ต(local_colwise, local_rowwise, local_packed_rowwise)์ torch.chunk์ ๊ฐ์ ์ผ๋ถ ์ฐ์ฐ์์ ์ง์๋์ง ์๊ธฐ ๋๋ฌธ์ DTensor๋ฅผ ์ฌ์ฉํ์ง ์์ต๋๋ค. ๋์ ๋ก์ปฌ ์ ๋ต์ ๊ธฐ๋ณธ torch.Tensor๋ฅผ ์ฌ์ฉํ๊ณ ์ผ๋ถ ๋ถ์ฐ ๋ก์ง์ ์๋์ผ๋ก ์ํํฉ๋๋ค.
์ฌ์ฉ์ ์ ์ ๋ถํ ์ ๋ต[[custom-partitioning-strategies]]
์ฌ์ฉ์ ์ ์ ๋ถํ ์ ๋ต์ TensorParallelLayer๋ฅผ ์์ํ๊ณ partition_tensor, _prepare_input_fn, _prepare_output_fn์ ๊ตฌํํด์ผ ํฉ๋๋ค.
๊ทธ๋ฐ ๋ค์ tp_plan์์ ํด๋น ์ ๋ต์ ์ง์ ํ์ ๋ ๋์คํจ์นญ ๋ก์ง์ด ์ฐพ์ ์ ์๋๋ก ParallelInterface ๋งคํ์ ๋ฑ๋กํด์ผ ํฉ๋๋ค.
์๋ ์์๋ ์ด ์ํฌํ๋ก์ฐ๋ก ColwiseParallel์ ๊ตฌํํ๋ ๋ฐฉ๋ฒ์ ๋ณด์ฌ์ค๋๋ค.
TensorParallelLayer๋ฅผ ์์ํฉ๋๋ค.__init__๋ฉ์๋์์ ์ ๋ ฅ ๋ฐ ์ถ๋ ฅ ํ ์๊ฐ ๊ธฐ๊ธฐ์ ์ด๋ป๊ฒ ๋ฐฐ์น๋์ด์ผ ํ๋์ง ์ค๋ช ํ๋input_layouts๊ณผoutput_layouts์ ์ ์ํฉ๋๋ค.desired_input_layouts์์ฑ์ ์ ๋ ฅ์ด ๊ธฐ๊ธฐ์ ์ด๋ป๊ฒ ๋ฐฐ์น๋์ด์ผ๋ง ํ๋์ง๋ฅผ ๋ช ์ํ๋ ๋ฐ ์ฌ์ฉ๋ฉ๋๋ค.class ColwiseParallel(TensorParallelLayer): def __init__( self, *, input_layouts: Optional[Placement] = None, # ์ด์ ๋ ์ด์ด์์ ์ค๋ ์ ๋ ฅ ๋ ์ด์์ output_layouts: Optional[Placement] = None, # ๋ฌ์ฑํ๊ณ ์ ํ๋ ์ถ๋ ฅ ๋ ์ด์์ use_local_output: bool = True, # ๋ก์ปฌ ์ถ๋ ฅ ์ฌ์ฉ ์ฌ๋ถ use_dtensor=True, # DTensor ์ฌ์ฉ ์ฌ๋ถ ): self.input_layouts = (input_layouts or Replicate(),) # ์ด์ ๋ ์ด์ด์์ ์ค๋ ์ ๋ ฅ ๋ถํ self.output_layouts = (output_layouts or Shard(-1),) # ์ํ๋ ์ถ๋ ฅ ๋ถํ self.desired_input_layouts = (Replicate(),) # ์ํ๋ ์ ๋ ฅ ๋ถํ , ์ ๋ ฅ์ GPU ๊ฐ์ ๋ณต์ ๋์ด์ผ ํจ self.use_local_output = use_local_output self.use_dtensor = use_dtensorpartition_tensor,_prepare_input_fn,_prepare_output_fn๋ฉ์๋๋ฅผ ๊ตฌํํฉ๋๋ค.partition_tensor๋ฉ์๋๋ ํ ์๋ฅผ ๋ถํ ํ๊ณ ๋ถํ ๋ ํ ์๋กempty_param์ ์ฑ์๋๋ค. ์ ํธ๋ฆฌํฐ ํจ์get_tensor_shard๋ฅผ ์ฌ์ฉํ์ฌ ์ฃผ์ด์ง ๋ญํฌ์ ๋ํ ์๋ณธ ๋งค๊ฐ๋ณ์์ ์ฌ๋ฐ๋ฅธ ๋ถํ ์ ์ป๊ณ , ํจํน๋ ๊ฐ์ค์น์ ๋ํด์๋get_packed_weights๋ฅผ ์ฌ์ฉํ์ธ์.def partition_tensor( self, param, # ๋งค๊ฐ๋ณ์์ ์ ์ฒด ํ ์ empty_param, # ๋งค๊ฐ๋ณ์์ ๋น ํ ์, ๋ถํ ๋ ํ ์๋ก ์ฑ์์ง param_type, # ๋งค๊ฐ๋ณ์ ์ ํ, `bias` ๋๋ `weight` param_casting_dtype, # ๋งค๊ฐ๋ณ์๋ฅผ ์บ์คํ ํ ์ ํ to_contiguous, # ํ ์๋ฅผ ์ฐ์์ ์ธ ๋ฉ๋ชจ๋ฆฌ ๋ ์ด์์์ผ๋ก ๋ณํํ ์ง ์ฌ๋ถ rank, # ํ์ฌ ๊ธฐ๊ธฐ์ ๋ญํฌ device_mesh, # ๊ธฐ๊ธฐ ๋ฉ์ ) -> nn.Parameter: # ๋ถํ ๋ ๋งค๊ฐ๋ณ์ ๋ฐํ ..._prepare_input_fn๊ณผ_prepare_output_fn๋ฉ์๋๋ ์ฌ์ ํฌ์๋ ๋ฐ ํฌ์๋ ํ ์์ ์ฌ์ฉ๋ฉ๋๋ค.__init__์์ ์ง์ ๋ ๋๋ก ์ ๋ ฅ๊ณผ ์ถ๋ ฅ์ ์ํ๋ ๋ ์ด์์์ผ๋ก ์ฌ๋ถ๋ฐฐํฉ๋๋ค.def _prepare_input_fn(input_layouts, desired_input_layouts, mod, inputs, device_mesh): ... # ์ฌ์ฉ์ ์ ์ ๋ก์ง ์ํ, DTensor๋ก ์บ์คํ ๋ฑ. ... return inputs.redistribute(placements=desired_input_layouts, device_mesh=device_mesh) def _prepare_output_fn(output_layouts, use_local_output, mod, outputs, device_mesh): ... # ์ฌ์ฉ์ ์ ์ ๋ก์ง ์ํ, DTensor๋ก ์บ์คํ ๋ฑ. ... return outputs.redistribute(placements=output_layouts, device_mesh=device_mesh)tp_plan๊ณผ ํจ๊ป ์ฌ์ฉํ ์ ์๋๋ก ์ ๋ต์ [ParallelInterface]์ ๋ฑ๋กํฉ๋๋ค.from transformers.integrations.tensor_parallel import ParallelInterface ParallelInterface.register_strategy("colwise_custom", ColwiseParallel) tp_plan = { "model.layers.*.self_attn.q_proj": "colwise_custom", ... } model = AutoModelForCausalLM.from_pretrained(model_id, dtype=torch.bfloat16, tp_plan=tp_plan)
๋ฒค์น๋งํฌ[[benchmarks]]
ํ ์ ๋ณ๋ ฌํ๋ ํนํ ํฐ ๋ฐฐ์น ํฌ๊ธฐ๋ ๊ธด ์ํ์ค๋ฅผ ๊ฐ์ง ์ ๋ ฅ์ ๋ํ ์ถ๋ก ์๋๋ฅผ ํฌ๊ฒ ํฅ์์ํฌ ์ ์์ต๋๋ค.
์ํ์ค ๊ธธ์ด๊ฐ 512์ธ Llama์์ ๋จ์ผ ํฌ์๋ ํจ์ค์ ๋ํ ์์ ์๋ ํฅ์ ์์น๋ ์๋ ์ฐจํธ๋ฅผ ์ฐธ์กฐํ์ธ์.
์ค๊ณ ๊ตฌํ[[design-implementation]]
Transformers ํ ์ ๋ณ๋ ฌํ ๊ตฌํ์ ํ๋ ์์ํฌ์ ๊ตฌ์ ๋ฐ์ง ์์ง๋ง, ๊ตฌ์ฒด์ ์ธ ๊ตฌํ์ ์ํด์๋ DeviceMesh์ torch.distributed์ DTensor์ ์์กดํ์ฌ ๊ฐ๋จํ๊ณ ํ์ฅ ๊ฐ๋ฅํ ์ธํฐํ์ด์ค๋ฅผ ์ ๊ณตํฉ๋๋ค.
DeviceMesh[[devicemesh]]
DeviceMesh๋ฅผ ํจ๊ป ํต์ ํ๋ ๊ธฐ๊ธฐ๋ค์ ๋ค์ฐจ์ ๊ทธ๋ฆฌ๋๋ก ์์ํด๋ณด์ธ์. ๋ณ๋ ฌ ์ฒ๋ฆฌ ์ ๋ต๋ง๋ค ๊ฐ๊ธฐ ๋ค๋ฅธ ํต์ ํจํด์ด ํ์ํ๋ฏ๋ก, ์ฌ๋ฌ ํ์ ๋ฉ์๋ฅผ ๊ฐ์ง DeviceMesh๋ฅผ ๋ง๋ค ์ ์์ต๋๋ค.
from torch.distributed.device_mesh import init_device_mesh
# 4๊ฐ GPU์ 1D ๋ฉ์ ์์ฑ
device_mesh = init_device_mesh("cuda", (4,), mesh_dim_names=["tp"])
torch.distributed์์ ์ ์๋ ๋๋ถ๋ถ์ ๋ณ๋ ฌํ ์ ๋ต์ ๋ฉ์ ์์ฒด๋ ํ์ ๋ฉ์์ ์ ์ฉํ ์ ์์ผ๋ฉฐ, ์๋์ผ๋ก ํต์ ํจํด์ ์ฒ๋ฆฌํฉ๋๋ค.
DTensor[[dtensor]]
DTensor(๋ถ์ฐ ํ
์)๋ ์ผ๋ฐ์ ์ธ ํ
์ ์ฐ์ฐ ์์ ๋ถ์ฐ ๋ก์ง์ ์ฒ๋ฆฌํ๋ ํ
์ ํ์ ํด๋์ค์
๋๋ค. ํ
์ ๋ณ๋ ฌํ์ ๋๋ถ๋ถ์ ๋ชจ๋ธ ๊ฐ์ค์น๋ DTensor ํํ๋ก ์ ์ฅ๋ฉ๋๋ค.
DTensor์ ๊ฐ์ฅ ์ค์ํ ๋ถ๋ถ์ placement ์์ฑ์
๋๋ค. ์ด๋ PyTorch์๊ฒ ํ
์๊ฐ DeviceMesh์ ๊ธฐ๊ธฐ์ ์ด๋ป๊ฒ ๋ฐฐ์น๋๋์ง ์๋ ค์ฃผ๊ธฐ ๋๋ฌธ์
๋๋ค. placement ์์ฑ์ ๋ค์ ๊ฐ์ ๊ฐ์ง ์ ์์ต๋๋ค.
Shard(dimension)-DTensor๊ฐ ๊ตฌ์ฑ๋DeviceMesh์์ ์ฃผ์ด์ง ์ฐจ์์ ๊ฑธ์ณ ์ด๋ป๊ฒ ๋ถํ ๋๋์ง ๋ํ๋ ๋๋ค. ์๋ ์์๋ ์ด ๋ฐฉํฅ ๋ถํ ์ ์ํด ๋ค์ํ ์ฐจ์์ ๊ฑธ์ณ ๊ฐ์ค์น๋ฅผ ๋ถํ ํ๋ ๋ฐฉ๋ฒ์ ๋ณด์ฌ์ค๋๋ค.weight = ... weight = DTensor.from_local(weight, device_mesh["tp"], placements=[Shard(0)]) # ์ฒซ ๋ฒ์งธ(์ด ๋ฐฉํฅ) ์ฐจ์์ ๊ฑธ์ณ ๋ถํ bias = ... bias = DTensor.from_local(bias, device_mesh["tp"], placements=[Shard(-1)]) # ์ ์ผํ ์ฐจ์์ ๊ฑธ์ณ ๋ถํ์ด ์์๋ ํ ๋ฐฉํฅ ๋ถํ ์ ์ํด ์ฌ๋ฌ ์ฐจ์์ ๊ฑธ์ณ ๊ฐ์ค์น๋ฅผ ๋ถํ ํ๋ ๋ฐฉ๋ฒ์ ๋ณด์ฌ์ค๋๋ค.
weight = ... weight = DTensor.from_local(weight, device_mesh["tp"], placements=[Shard(1)]) # ๋ ๋ฒ์งธ(ํ ๋ฐฉํฅ) ์ฐจ์์ ๊ฑธ์ณ ๋ถํ bias = ... bias = DTensor.from_local(bias, device_mesh["tp"], placements=[Replicate()]) # ๋ชจ๋ GPU์ ํธํฅ ๋ณต์ Replicate()-DTensor๊ฐDeviceMesh์ ๊ฑธ์ณ ๋ณต์ ๋จ์ ๋ํ๋ ๋๋ค. ๊ฐ ๊ธฐ๊ธฐ์ ํ ์์ ์ ์ฒด ์ฌ๋ณธ๋ง ์์ฑํฉ๋๋ค.bias = ... bias = DTensor.from_local(bias, device_mesh["tp"], placements=[Replicate()]) # ๋ชจ๋ GPU์ ํธํฅ ๋ณต์ Partial()- ํ ์๊ฐ ๊ฐ์ ์ฐ์ฐ์ ๊ธฐ๋ค๋ฆฌ๊ณ ์๋ ์ํ์์ ๋ํ๋ ๋๋ค (์ผ๋ฐ์ ์ผ๋ก Transformers์์์ ์ฌ์ฉ ์ฌ๋ก์๋ ์ง์ ์ ์ธ ๊ด๋ จ์ด ์ ์ต๋๋ค).