Omini3D / tests /test_ccl_stress.py
maxmo2009's picture
Sync from local: code + epoch-110 checkpoint, clean README
2af0e94 verified
#!/usr/bin/env python
"""
CCL Stress Test: Diagnose epoch-boundary hangs in DDP training on Intel XPU.
Hypothesis: After ~200 CCL collective operations (one epoch), CCL's internal
state (IPC handles, Level Zero resources) gets corrupted, causing the next
collective to deadlock.
This test isolates which factor triggers the hang:
Phase 1: 200 DDP forward+backward passes (simulating one epoch of collectives)
Phase 2: Save a ~2.9 GB checkpoint to /tmp (memory pressure from large file I/O)
Phase 3: 10 more DDP forward+backward passes (does the CCL hang?)
Phase 4: Reinit DataLoader with new DistributedSampler shuffle, 10 more passes
Phase 5: Explicit dist.broadcast on small tensors (simulating NaN sync)
If any phase hangs, the srun timeout kills the process and the last logged
START message reveals the failing phase.
Launch: srun (see bash_test_ccl.sh) or torchrun for single-node testing.
Environment variables (set by srun/SLURM):
RANK, LOCAL_RANK, WORLD_SIZE, MASTER_ADDR, MASTER_PORT
"""
import os
import sys
import time
import argparse
import datetime
import tempfile
import torch
import torch.nn as nn
import torch.distributed as dist
from torch.nn.parallel import DistributedDataParallel as DDP
from torch.utils.data import Dataset, DataLoader
from torch.utils.data.distributed import DistributedSampler
# XPU support
try:
import intel_extension_for_pytorch as ipex
except ImportError:
ipex = None
try:
import oneccl_bindings_for_pytorch
except (ImportError, Exception):
pass
# ---------------------------------------------------------------------------
# Logging
# ---------------------------------------------------------------------------
def log(rank, msg):
ts = datetime.datetime.now().strftime("%Y-%m-%d %H:%M:%S.%f")[:-3]
print(f"[{ts}] [rank {rank}] {msg}", flush=True)
# ---------------------------------------------------------------------------
# Dummy model: ~400M parameters using 3D convolutions
# ---------------------------------------------------------------------------
class DummyConv3DModel(nn.Module):
"""A model with ~400M parameters built from Conv3d layers.
Architecture: a sequence of Conv3d blocks that keep the spatial dimensions
fixed (kernel=3, stride=1, padding=1) with large channel counts to hit the
parameter target. Final layer produces 3 output channels (DVF-like).
With channels = [1, 64, 128, 256, 256, 256, 256, 128, 64, 3]:
Total params ~ 400M (dominated by the 256->256 blocks).
"""
def __init__(self, ndims=3):
super().__init__()
# Channel progression designed to produce ~400M params total
channels = [1, 64, 128, 256, 512, 256, 128, 64, 3]
layers = []
for i in range(len(channels) - 1):
layers.append(nn.Conv3d(channels[i], channels[i + 1], 3, 1, 1))
if i < len(channels) - 2:
layers.append(nn.GroupNorm(min(32, channels[i + 1]), channels[i + 1]))
layers.append(nn.ReLU(inplace=True))
self.net = nn.Sequential(*layers)
def forward(self, x):
return self.net(x)
# ---------------------------------------------------------------------------
# Dummy dataset
# ---------------------------------------------------------------------------
class DummyVolumeDataset(Dataset):
"""Random 3D volumes. Returns (volume, label) where label is an integer."""
def __init__(self, num_samples, spatial_size=128):
self.num_samples = num_samples
self.spatial_size = spatial_size
def __len__(self):
return self.num_samples
def __getitem__(self, idx):
vol = torch.randn(1, self.spatial_size, self.spatial_size, self.spatial_size)
label = idx % 10
return vol, label
# ---------------------------------------------------------------------------
# Core: forward + backward pass
# ---------------------------------------------------------------------------
def do_forward_backward(model, optimizer, data, device, rank, step_label=""):
"""Single DDP forward+backward+step. Returns loss value."""
x, _ = data
x = x.to(device)
pred = model(x)
loss = pred.mean()
optimizer.zero_grad(set_to_none=True)
loss.backward()
optimizer.step()
return loss.item()
# ---------------------------------------------------------------------------
# Phases
# ---------------------------------------------------------------------------
def phase_1(rank, model, optimizer, dataloader, device, num_steps=200):
"""Phase 1: N DDP forward+backward passes (simulating one epoch)."""
log(rank, f"PHASE 1 START: {num_steps} DDP forward+backward passes")
t0 = time.time()
data_iter = iter(dataloader)
for step in range(num_steps):
try:
data = next(data_iter)
except StopIteration:
data_iter = iter(dataloader)
data = next(data_iter)
loss_val = do_forward_backward(model, optimizer, data, device, rank,
step_label=f"P1 step {step}")
if step % 50 == 0 or step == num_steps - 1:
log(rank, f" Phase 1 step {step}/{num_steps}, loss={loss_val:.6f}")
elapsed = time.time() - t0
log(rank, f"PHASE 1 END: completed {num_steps} steps in {elapsed:.1f}s")
dist.barrier()
log(rank, f"PHASE 1 BARRIER passed")
def phase_2(rank, model, device):
"""Phase 2: Save a ~2.9 GB checkpoint to /tmp (simulating end-of-epoch save)."""
log(rank, "PHASE 2 START: Checkpoint save")
t0 = time.time()
# Only rank 0 saves (matches training script behavior)
if rank == 0:
state = {
"model_state_dict": model.module.state_dict(),
"dummy_optimizer": {f"key_{i}": torch.randn(1000, 1000) for i in range(5)},
"epoch": 1,
}
ckpt_path = os.path.join(tempfile.gettempdir(), "ccl_stress_test_ckpt.pth")
torch.save(state, ckpt_path)
ckpt_size_gb = os.path.getsize(ckpt_path) / (1024 ** 3)
log(rank, f" Saved checkpoint: {ckpt_path} ({ckpt_size_gb:.2f} GB)")
# Clean up
os.remove(ckpt_path)
log(rank, f" Cleaned up checkpoint file")
else:
log(rank, " (non-rank-0, waiting at barrier)")
elapsed = time.time() - t0
log(rank, f"PHASE 2 END: checkpoint save completed in {elapsed:.1f}s")
dist.barrier()
log(rank, f"PHASE 2 BARRIER passed")
def phase_3(rank, model, optimizer, dataloader, device, num_steps=10):
"""Phase 3: Post-checkpoint DDP passes. Does this hang?"""
log(rank, f"PHASE 3 START: {num_steps} DDP forward+backward passes AFTER checkpoint save")
t0 = time.time()
data_iter = iter(dataloader)
for step in range(num_steps):
try:
data = next(data_iter)
except StopIteration:
data_iter = iter(dataloader)
data = next(data_iter)
loss_val = do_forward_backward(model, optimizer, data, device, rank,
step_label=f"P3 step {step}")
log(rank, f" Phase 3 step {step}/{num_steps}, loss={loss_val:.6f}")
elapsed = time.time() - t0
log(rank, f"PHASE 3 END: completed {num_steps} steps in {elapsed:.1f}s")
dist.barrier()
log(rank, f"PHASE 3 BARRIER passed")
def phase_4(rank, world_size, model, optimizer, device, num_steps=10, spatial_size=128, batch_size=2):
"""Phase 4: Reinit DataLoader with new DistributedSampler, then do passes."""
log(rank, f"PHASE 4 START: DataLoader reinit + {num_steps} DDP passes")
t0 = time.time()
# Create a new dataset and sampler (simulates epoch boundary DataLoader reinit)
new_dataset = DummyVolumeDataset(num_samples=200, spatial_size=spatial_size)
new_sampler = DistributedSampler(new_dataset, num_replicas=world_size, rank=rank, shuffle=True)
new_sampler.set_epoch(2) # Different epoch = different shuffle
new_dataloader = DataLoader(new_dataset, batch_size=batch_size, sampler=new_sampler,
num_workers=2, pin_memory=False, drop_last=True)
log(rank, f" New DataLoader created with {len(new_dataset)} samples, sampler epoch=2")
data_iter = iter(new_dataloader)
for step in range(num_steps):
try:
data = next(data_iter)
except StopIteration:
data_iter = iter(new_dataloader)
data = next(data_iter)
loss_val = do_forward_backward(model, optimizer, data, device, rank,
step_label=f"P4 step {step}")
log(rank, f" Phase 4 step {step}/{num_steps}, loss={loss_val:.6f}")
elapsed = time.time() - t0
log(rank, f"PHASE 4 END: completed {num_steps} steps in {elapsed:.1f}s")
dist.barrier()
log(rank, f"PHASE 4 BARRIER passed")
def phase_5(rank, device, num_broadcasts=50):
"""Phase 5: Explicit dist.broadcast on small tensors (simulating NaN sync)."""
log(rank, f"PHASE 5 START: {num_broadcasts} dist.broadcast calls")
t0 = time.time()
for i in range(num_broadcasts):
# Simulate NaN/Inf check synchronization (all_reduce with MAX)
flag = torch.tensor([0.0], device=device)
dist.all_reduce(flag, op=dist.ReduceOp.MAX)
# Simulate parameter broadcast (like checkpoint sync)
param = torch.randn(1024, device=device)
dist.broadcast(param, src=0)
if i % 10 == 0 or i == num_broadcasts - 1:
log(rank, f" Phase 5 broadcast {i}/{num_broadcasts} completed")
elapsed = time.time() - t0
log(rank, f"PHASE 5 END: completed {num_broadcasts} broadcasts in {elapsed:.1f}s")
dist.barrier()
log(rank, f"PHASE 5 BARRIER passed")
# ---------------------------------------------------------------------------
# Main
# ---------------------------------------------------------------------------
def main():
parser = argparse.ArgumentParser(description="CCL stress test for XPU DDP epoch-boundary hangs")
parser.add_argument("--spatial-size", type=int, default=64,
help="Spatial size of 3D volumes (default: 64, use 128 for full-scale test)")
parser.add_argument("--batch-size", type=int, default=2,
help="Batch size per rank (default: 2)")
parser.add_argument("--phase1-steps", type=int, default=200,
help="Number of steps in Phase 1 (default: 200)")
parser.add_argument("--phase3-steps", type=int, default=10,
help="Number of steps in Phase 3 (default: 10)")
parser.add_argument("--phase4-steps", type=int, default=10,
help="Number of steps in Phase 4 (default: 10)")
parser.add_argument("--phase5-broadcasts", type=int, default=50,
help="Number of broadcasts in Phase 5 (default: 50)")
parser.add_argument("--skip-phases", type=str, default="",
help="Comma-separated phases to skip, e.g. '2,4'")
args = parser.parse_args()
skip_phases = set()
if args.skip_phases:
skip_phases = {int(p.strip()) for p in args.skip_phases.split(",")}
# -----------------------------------------------------------------------
# Device detection
# -----------------------------------------------------------------------
if hasattr(torch, "xpu") and torch.xpu.is_available():
device_type = "xpu"
backend = "ccl"
elif torch.cuda.is_available():
device_type = "cuda"
backend = "nccl"
else:
print("ERROR: No XPU or CUDA device found. This test requires GPU/XPU.", flush=True)
sys.exit(1)
# -----------------------------------------------------------------------
# DDP init (expects RANK, LOCAL_RANK, WORLD_SIZE, MASTER_ADDR, MASTER_PORT from srun)
# -----------------------------------------------------------------------
rank = int(os.environ.get("RANK", 0))
local_rank = int(os.environ.get("LOCAL_RANK", 0))
world_size = int(os.environ.get("WORLD_SIZE", 1))
if world_size < 2:
print("WARNING: WORLD_SIZE < 2. DDP collectives are trivial with 1 rank. "
"Use srun with multiple tasks for meaningful CCL stress testing.", flush=True)
log(rank, f"Initializing DDP: backend={backend}, device={device_type}, "
f"rank={rank}, local_rank={local_rank}, world_size={world_size}")
log(rank, f"MASTER_ADDR={os.environ.get('MASTER_ADDR', 'unset')}, "
f"MASTER_PORT={os.environ.get('MASTER_PORT', 'unset')}")
log(rank, f"CCL_ZE_CACHE_OPEN_IPC_HANDLES_THRESHOLD="
f"{os.environ.get('CCL_ZE_CACHE_OPEN_IPC_HANDLES_THRESHOLD', 'unset')}")
dist.init_process_group(backend=backend, rank=rank, world_size=world_size)
if device_type == "xpu":
torch.xpu.set_device(local_rank)
device = torch.device(f"xpu:{local_rank}")
else:
torch.cuda.set_device(local_rank)
device = torch.device(f"cuda:{local_rank}")
log(rank, f"DDP initialized on {device}")
# -----------------------------------------------------------------------
# Model + DDP wrapper
# -----------------------------------------------------------------------
log(rank, "Creating model...")
model = DummyConv3DModel(ndims=3).to(device)
num_params = sum(p.numel() for p in model.parameters())
param_size_gb = num_params * 4 / (1024 ** 3) # float32
log(rank, f"Model created: {num_params:,} parameters ({param_size_gb:.2f} GB in float32)")
model = DDP(model, device_ids=[local_rank])
optimizer = torch.optim.Adam(model.parameters(), lr=1e-4)
log(rank, "DDP wrapper and optimizer created")
# -----------------------------------------------------------------------
# DataLoader (initial)
# -----------------------------------------------------------------------
dataset = DummyVolumeDataset(num_samples=max(500, args.phase1_steps * args.batch_size),
spatial_size=args.spatial_size)
sampler = DistributedSampler(dataset, num_replicas=world_size, rank=rank, shuffle=True)
sampler.set_epoch(1)
dataloader = DataLoader(dataset, batch_size=args.batch_size, sampler=sampler,
num_workers=2, pin_memory=False, drop_last=True)
log(rank, f"DataLoader created: {len(dataset)} samples, batch_size={args.batch_size}, "
f"spatial_size={args.spatial_size}")
# -----------------------------------------------------------------------
# Warmup: 5 steps to stabilize memory allocator
# -----------------------------------------------------------------------
log(rank, "WARMUP START: 5 steps")
data_iter = iter(dataloader)
for i in range(5):
try:
data = next(data_iter)
except StopIteration:
data_iter = iter(dataloader)
data = next(data_iter)
do_forward_backward(model, optimizer, data, device, rank)
dist.barrier()
log(rank, "WARMUP END")
# Report memory after warmup
if device_type == "xpu" and hasattr(torch.xpu, "memory_allocated"):
alloc = torch.xpu.memory_allocated(device) / (1024 ** 3)
reserved = torch.xpu.memory_reserved(device) / (1024 ** 3)
log(rank, f"XPU memory after warmup: allocated={alloc:.2f} GB, reserved={reserved:.2f} GB")
elif device_type == "cuda":
alloc = torch.cuda.memory_allocated(device) / (1024 ** 3)
reserved = torch.cuda.memory_reserved(device) / (1024 ** 3)
log(rank, f"CUDA memory after warmup: allocated={alloc:.2f} GB, reserved={reserved:.2f} GB")
# -----------------------------------------------------------------------
# Run phases
# -----------------------------------------------------------------------
total_t0 = time.time()
if 1 not in skip_phases:
phase_1(rank, model, optimizer, dataloader, device, num_steps=args.phase1_steps)
else:
log(rank, "PHASE 1 SKIPPED")
if 2 not in skip_phases:
phase_2(rank, model, device)
else:
log(rank, "PHASE 2 SKIPPED")
if 3 not in skip_phases:
phase_3(rank, model, optimizer, dataloader, device, num_steps=args.phase3_steps)
else:
log(rank, "PHASE 3 SKIPPED")
if 4 not in skip_phases:
phase_4(rank, world_size, model, optimizer, device, num_steps=args.phase4_steps,
spatial_size=args.spatial_size, batch_size=args.batch_size)
else:
log(rank, "PHASE 4 SKIPPED")
if 5 not in skip_phases:
phase_5(rank, device, num_broadcasts=args.phase5_broadcasts)
else:
log(rank, "PHASE 5 SKIPPED")
total_elapsed = time.time() - total_t0
log(rank, f"ALL PHASES COMPLETE in {total_elapsed:.1f}s")
# Report final memory
if device_type == "xpu" and hasattr(torch.xpu, "memory_allocated"):
alloc = torch.xpu.memory_allocated(device) / (1024 ** 3)
reserved = torch.xpu.memory_reserved(device) / (1024 ** 3)
log(rank, f"XPU memory final: allocated={alloc:.2f} GB, reserved={reserved:.2f} GB")
elif device_type == "cuda":
alloc = torch.cuda.memory_allocated(device) / (1024 ** 3)
reserved = torch.cuda.memory_reserved(device) / (1024 ** 3)
log(rank, f"CUDA memory final: allocated={alloc:.2f} GB, reserved={reserved:.2f} GB")
dist.barrier()
dist.destroy_process_group()
log(rank, "Process group destroyed. Test PASSED.")
if __name__ == "__main__":
main()