μμ λΆν λ°μ΄ν° λ³λ ¬ μ²λ¦¬(FSDP) [[fully-sharded-data-parallel]]
Fully Sharded Data Parallel (FSDP)μ λͺ¨λΈμ λ§€κ°λ³μ, κ·Έλ μ΄λμΈνΈ λ° μ΅ν°λ§μ΄μ μνλ₯Ό μ¬μ© κ°λ₯ν GPU(μμ
μ λλ λν¬λΌκ³ λ ν¨) μμ λ°λΌ λΆν νλ λ°μ΄ν° λ³λ ¬ μ²λ¦¬ λ°©μμ
λλ€. DistributedDataParallel (DDP)μ λ¬λ¦¬, FSDPλ κ° GPUμ λͺ¨λΈμ 볡μ νκΈ° λλ¬Έμ λ©λͺ¨λ¦¬ μ¬μ©λμ μ€μ
λλ€. μ΄λ GPU λ©λͺ¨λ¦¬ ν¨μ¨μ±μ ν₯μμν€λ©° μ μ μμ GPUλ‘ ν¨μ¬ λ ν° λͺ¨λΈμ νλ ¨ν μ μκ² ν©λλ€. FSDPλ λΆμ° νκ²½μμμ νλ ¨μ μ½κ² κ΄λ¦¬ν μ μλ λΌμ΄λΈλ¬λ¦¬μΈ Accelerateμ ν΅ν©λμ΄ μμΌλ©°, λ°λΌμ [Trainer] ν΄λμ€μμ μ¬μ©ν μ μμ΅λλ€.
μμνκΈ° μ μ Accelerateκ° μ€μΉλμ΄ μκ³ μ΅μ PyTorch 2.1.0 μ΄μμ λ²μ μ΄ μ€μΉλμ΄ μλμ§ νμΈνμΈμ.
pip install accelerate
FSDP κ΅¬μ± [[fsdp-configuration]]
μμνλ €λ©΄ accelerate config λͺ
λ Ήμ μ€ννμ¬ νλ ¨ νκ²½μ λν κ΅¬μ± νμΌμ μμ±νμΈμ. Accelerateλ μ΄ κ΅¬μ± νμΌμ μ¬μ©νμ¬ accelerate configμμ μ νν νλ ¨ μ΅μ
μ λ°λΌ μλμΌλ‘ μ¬λ°λ₯Έ νλ ¨ νκ²½μ μ€μ ν©λλ€.
accelerate config
accelerate configλ₯Ό μ€ννλ©΄ νλ ¨ νκ²½μ ꡬμ±νκΈ° μν μΌλ ¨μ μ΅μ
λ€μ΄ λνλ©λλ€. μ΄ μΉμ
μμλ κ°μ₯ μ€μν FSDP μ΅μ
μ€ μΌλΆλ₯Ό λ€λ£Ήλλ€. λ€λ₯Έ μ¬μ© κ°λ₯ν FSDP μ΅μ
μ λν΄ λ μμλ³΄κ³ μΆλ€λ©΄ fsdp_config λ§€κ°λ³μλ₯Ό μ°Έμ‘°νμΈμ.
λΆν μ λ΅ [[sharding-strategy]]
FSDPλ μ¬λ¬ κ°μ§ λΆν μ λ΅μ μ 곡ν©λλ€:
FULL_SHARD- λͺ¨λΈ λ§€κ°λ³μ, κ·Έλ μ΄λμΈνΈ λ° μ΅ν°λ§μ΄μ μνλ₯Ό μμ μ κ°μ λΆν ; μ΄ μ΅μ μ μ ννλ €λ©΄1μ μ ννμΈμSHARD_GRAD_OP- κ·Έλ μ΄λμΈνΈ λ° μ΅ν°λ§μ΄μ μνλ₯Ό μμ μ κ°μ λΆν ; μ΄ μ΅μ μ μ ννλ €λ©΄2λ₯Ό μ ννμΈμNO_SHARD- μ무 κ²λ λΆν νμ§ μμ (DDPμ λμΌ); μ΄ μ΅μ μ μ ννλ €λ©΄3μ μ ννμΈμHYBRID_SHARD- κ° μμ μκ° μ 체 볡μ¬λ³Έμ κ°μ§κ³ μλ μνμμ λͺ¨λΈ λ§€κ°λ³μ, κ·Έλ μ΄λμΈνΈ λ° μ΅ν°λ§μ΄μ μνλ₯Ό μμ μ λ΄μμ λΆν ; μ΄ μ΅μ μ μ ννλ €λ©΄4λ₯Ό μ ννμΈμHYBRID_SHARD_ZERO2- κ° μμ μκ° μ 체 볡μ¬λ³Έμ κ°μ§κ³ μλ μνμμ κ·Έλ μ΄λμΈνΈ λ° μ΅ν°λ§μ΄μ μνλ₯Ό μμ μ λ΄μμ λΆν ; μ΄ μ΅μ μ μ ννλ €λ©΄5λ₯Ό μ ννμΈμ
μ΄κ²μ fsdp_sharding_strategy νλκ·Έλ‘ νμ±νλ©λλ€.
CPU μ€νλ‘λ [[cpu-offload]]
μ¬μ©νμ§ μλ λ§€κ°λ³μμ κ·Έλ μ΄λμΈνΈλ₯Ό CPUλ‘ μ€νλ‘λνμ¬ λ λ§μ GPU λ©λͺ¨λ¦¬λ₯Ό μ μ½νκ³ FSDPλ‘λ μΆ©λΆνμ§ μμ ν° λͺ¨λΈμ GPUμ μ μ¬ν μ μλλ‘ ν μ μμ΅λλ€. μ΄λ accelerate configλ₯Ό μ€νν λ fsdp_offload_params: trueλ‘ μ€μ νμ¬ νμ±νλ©λλ€.
λν μ μ± [[wrapping-policy]]
FSDPλ λ€νΈμν¬μ κ° λ μ΄μ΄λ₯Ό λννμ¬ μ μ©λ©λλ€. λνμ μΌλ°μ μΌλ‘ μ€μ²© λ°©μμΌλ‘ μ μ©λλ©° κ°κ° μλ°©ν₯μΌλ‘ μ§λκ° ν μ 체 κ°μ€μΉλ₯Ό μμ νμ¬ λ€μ λ μ΄μ΄μμ μ¬μ©ν λ©λͺ¨λ¦¬λ₯Ό μ μ½ν©λλ€. μλ λν μ μ±
μ μ΄λ₯Ό ꡬννλ κ°μ₯ κ°λ¨ν λ°©λ²μ΄λ©° μ½λλ₯Ό λ³κ²½ν νμκ° μμ΅λλ€. Transformer λ μ΄μ΄λ₯Ό λννλ €λ©΄ fsdp_auto_wrap_policy: TRANSFORMER_BASED_WRAPλ₯Ό μ ννκ³ λνν λ μ΄μ΄λ₯Ό μ§μ νλ €λ©΄ fsdp_transformer_layer_cls_to_wrapλ₯Ό μ ννμΈμ (μ: BertLayer).
λλ νΉμ λ§€κ°λ³μ μλ₯Ό μ΄κ³Όν κ²½μ° FSDPκ° λ μ΄μ΄μ μ μ©λλ ν¬κΈ° κΈ°λ° λν μ μ±
μ μ νν μ μμ΅λλ€. μ΄λ fsdp_wrap_policy: SIZE_BASED_WRAP λ° min_num_paramμ μνλ ν¬κΈ°μ μκ³κ°μΌλ‘ μ€μ νμ¬ νμ±νλ©λλ€.
체ν¬ν¬μΈνΈ [[checkpointing]]
μ€κ° 체ν¬ν¬μΈνΈλ fsdp_state_dict_type: SHARDED_STATE_DICTλ‘ μ μ₯ν΄μΌ ν©λλ€. CPU μ€νλ‘λκ° νμ±νλ λν¬ 0μμ μ 체 μν λμ
λ리λ₯Ό μ μ₯νλ λ° μκ°μ΄ λ§μ΄ κ±Έλ¦¬κ³ , λΈλ‘λμΊμ€ν
μ€ λ¬΄κΈ°ν λκΈ°νμ¬ NCCL Timeout μ€λ₯κ° λ°μν μ μκΈ° λλ¬Έμ
λλ€. [~accelerate.Accelerator.load_state] λ©μλλ₯Ό μ¬μ©νμ¬ λΆν λ μν λμ
λλ¦¬λ‘ νλ ¨μ μ¬κ°ν μ μμ΅λλ€.
# κ²½λ‘κ° λ΄μ¬λ 체ν¬ν¬μΈνΈ
accelerator.load_state("ckpt")
κ·Έλ¬λ νλ ¨μ΄ λλλ©΄ μ 체 μν λμ λ리λ₯Ό μ μ₯ν΄μΌ ν©λλ€. λΆν λ μν λμ λ리λ FSDPμλ§ νΈνλκΈ° λλ¬Έμ λλ€.
if trainer.is_fsdp_enabled:
trainer.accelerator.state.fsdp_plugin.set_state_dict_type("FULL_STATE_DICT")
trainer.save_model(script_args.output_dir)
TPU [[tpu]]
PyTorch XLAλ TPUμ λν FSDP νλ ¨μ μ§μνλ©° accelerate configλ‘ μμ±λ FSDP κ΅¬μ± νμΌμ μμ νμ¬ νμ±νν μ μμ΅λλ€. μμμ μ§μ ν λΆν μ λ΅ λ° λν μ΅μ
μΈμλ μλμ νμλ λ§€κ°λ³μλ₯Ό νμΌμ μΆκ°ν μ μμ΅λλ€.
xla: True # PyTorch/XLAλ₯Ό νμ±ννλ €λ©΄ Trueλ‘ μ€μ ν΄μΌ ν©λλ€
xla_fsdp_settings: # XLA νΉμ FSDP λ§€κ°λ³μ
xla_fsdp_grad_ckpt: True # gradient checkpointingμ μ¬μ©ν©λλ€
xla_fsdp_settingsλ FSDPμ λν μΆκ°μ μΈ XLA νΉμ λ§€κ°λ³μλ₯Ό ꡬμ±ν μ μκ² ν©λλ€.
νλ ¨ μμ [[launch-training]]
μμ FSDP κ΅¬μ± νμΌμ λ€μκ³Ό κ°μ μ μμ΅λλ€:
compute_environment: LOCAL_MACHINE
debug: false
distributed_type: FSDP
downcast_bf16: 'no'
fsdp_config:
fsdp_auto_wrap_policy: TRANSFORMER_BASED_WRAP
fsdp_backward_prefetch_policy: BACKWARD_PRE
fsdp_cpu_ram_efficient_loading: true
fsdp_forward_prefetch: false
fsdp_offload_params: true
fsdp_sharding_strategy: 1
fsdp_state_dict_type: SHARDED_STATE_DICT
fsdp_sync_module_states: true
fsdp_transformer_layer_cls_to_wrap: BertLayer
fsdp_use_orig_params: true
machine_rank: 0
main_training_function: main
mixed_precision: bf16
num_machines: 1
num_processes: 2
rdzv_backend: static
same_network: true
tpu_env: []
tpu_use_cluster: false
tpu_use_sudo: false
use_cpu: false
νλ ¨μ μμνλ €λ©΄ accelerate launch λͺ
λ Ήμ μ€ννμΈμ. μ΄ λ μ μ accelerate configλ‘ μμ±ν κ΅¬μ± νμΌμ μλμΌλ‘ μ¬μ©ν©λλ€.
accelerate launch my-trainer-script.py
accelerate launch --fsdp="full shard" --fsdp_config="path/to/fsdp_config/ my-trainer-script.py
λ€μ λ¨κ³ [[next-steps]]
FSDPλ λ§€μ° ν° λͺ¨λΈμ νλ ¨ν λ κ°λ ₯ν λκ΅¬κ° λ μ μμΌλ©°, μ¬λ¬ κ°μ GPUλ TPUλ₯Ό μ¬μ©ν μ μμ΅λλ€. λͺ¨λΈ λ§€κ°λ³μ, μ΅ν°λ§μ΄μ λ° κ·Έλ μ΄λμΈνΈ μνλ₯Ό λΆν νκ³ λΉνμ± μνμΌ λ, CPUλ‘ μ€νλ‘λνλ©΄ FSDPλ λκ·λͺ¨ νλ ¨μ λμ μ°μ° λΉμ©μ μ€μΌ μ μμ΅λλ€. λ μμλ³΄κ³ μΆλ€λ©΄ λ€μ μλ£κ° λμμ΄ λ μ μμ΅λλ€:
- FSDPμ λν λ κΉμ΄ μλ Accelerate κ°μ΄λλ₯Ό λ°λΌκ° 보μΈμ.
- PyTorchμ μμ λΆν λ°μ΄ν° λ³λ ¬ μ²λ¦¬ (FSDP) APIλ₯Ό μκ°ν©λλ€ λΈλ‘κ·Έ κΈμ μ½μ΄λ³΄μΈμ.
- FSDPλ₯Ό μ¬μ©νμ¬ ν΄λΌμ°λ TPUμμ PyTorch λͺ¨λΈ ν¬κΈ° μ‘°μ νκΈ° λΈλ‘κ·Έ κΈμ μ½μ΄λ³΄μΈμ.