File size: 5,225 Bytes
54c5666 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 |
#!/usr/bin/env python3
"""
Distributed Training Script for Multi-GPU/Multi-Node Training
Supports FSDP, DeepSpeed, and DDP
"""
import os
import sys
import argparse
import yaml
import torch
import torch.distributed as dist
from torch.distributed.fsdp import FullyShardedDataParallel as FSDP
from torch.distributed.fsdp.wrap import transformer_auto_wrap_policy
sys.path.append(os.path.join(os.path.dirname(__file__), '..', 'src'))
from models.architecture import AdvancedGPTModel, ModelConfig, TransformerBlock
try:
import deepspeed
DEEPSPEED_AVAILABLE = True
except ImportError:
DEEPSPEED_AVAILABLE = False
def setup_distributed():
"""Initialize distributed training"""
if 'RANK' in os.environ and 'WORLD_SIZE' in os.environ:
rank = int(os.environ['RANK'])
world_size = int(os.environ['WORLD_SIZE'])
local_rank = int(os.environ['LOCAL_RANK'])
else:
print("Not running in distributed mode")
return 0, 1, 0
# Initialize process group
dist.init_process_group(backend='nccl')
torch.cuda.set_device(local_rank)
return rank, world_size, local_rank
def cleanup_distributed():
"""Cleanup distributed training"""
if dist.is_initialized():
dist.destroy_process_group()
def setup_fsdp_model(model, config):
"""Setup FSDP wrapped model"""
# Auto wrap policy for transformer blocks
auto_wrap_policy = transformer_auto_wrap_policy(
transformer_layer_cls={TransformerBlock},
)
# Mixed precision policy
from torch.distributed.fsdp import MixedPrecision
if config['training']['mixed_precision'] == 'bf16':
mp_policy = MixedPrecision(
param_dtype=torch.bfloat16,
reduce_dtype=torch.bfloat16,
buffer_dtype=torch.bfloat16,
)
elif config['training']['mixed_precision'] == 'fp16':
mp_policy = MixedPrecision(
param_dtype=torch.float16,
reduce_dtype=torch.float16,
buffer_dtype=torch.float16,
)
else:
mp_policy = None
# Wrap model with FSDP
model = FSDP(
model,
auto_wrap_policy=auto_wrap_policy,
mixed_precision=mp_policy,
device_id=torch.cuda.current_device(),
sync_module_states=True,
param_init_fn=None,
)
return model
def setup_deepspeed_model(model, config, optimizer=None):
"""Setup DeepSpeed model"""
if not DEEPSPEED_AVAILABLE:
raise ImportError("DeepSpeed not available")
deepspeed_config = config['training'].get('deepspeed_config')
if deepspeed_config and os.path.exists(deepspeed_config):
with open(deepspeed_config, 'r') as f:
ds_config = yaml.safe_load(f)
else:
# Default DeepSpeed config
ds_config = {
"train_batch_size": config['training']['batch_size'],
"train_micro_batch_size_per_gpu": config['training'].get('micro_batch_size', 1),
"gradient_accumulation_steps": config['training']['gradient_accumulation_steps'],
"zero_optimization": {
"stage": 3,
"offload_optimizer": {"device": "cpu"},
"offload_param": {"device": "cpu"},
},
"fp16": {"enabled": config['training']['mixed_precision'] == 'fp16'},
"bf16": {"enabled": config['training']['mixed_precision'] == 'bf16'},
}
model_engine, optimizer, _, _ = deepspeed.initialize(
model=model,
optimizer=optimizer,
config=ds_config
)
return model_engine, optimizer
def main():
parser = argparse.ArgumentParser(description="Distributed Training")
parser.add_argument("--config", type=str, required=True, help="Config file path")
parser.add_argument("--backend", type=str, choices=['fsdp', 'deepspeed', 'ddp'], default='fsdp')
parser.add_argument("--local_rank", type=int, default=-1, help="Local rank for distributed training")
args = parser.parse_args()
# Setup distributed
rank, world_size, local_rank = setup_distributed()
print(f"Rank {rank}/{world_size}, Local rank: {local_rank}")
# Load config
with open(args.config, 'r') as f:
config = yaml.safe_load(f)
# Create model
model_config = ModelConfig(**config['model_config_dict'])
model = AdvancedGPTModel(model_config)
model = model.cuda(local_rank)
# Setup distributed model
if args.backend == 'fsdp':
model = setup_fsdp_model(model, config)
print("Using FSDP")
elif args.backend == 'deepspeed':
# Note: DeepSpeed initialization happens in the training script
print("Using DeepSpeed")
elif args.backend == 'ddp':
model = torch.nn.parallel.DistributedDataParallel(
model, device_ids=[local_rank], output_device=local_rank
)
print("Using DDP")
print(f"Model setup complete on rank {rank}")
# Cleanup
cleanup_distributed()
if __name__ == "__main__":
main()
|