| |
| """ |
| CCL Stress Test: Diagnose epoch-boundary hangs in DDP training on Intel XPU. |
| |
| Hypothesis: After ~200 CCL collective operations (one epoch), CCL's internal |
| state (IPC handles, Level Zero resources) gets corrupted, causing the next |
| collective to deadlock. |
| |
| This test isolates which factor triggers the hang: |
| Phase 1: 200 DDP forward+backward passes (simulating one epoch of collectives) |
| Phase 2: Save a ~2.9 GB checkpoint to /tmp (memory pressure from large file I/O) |
| Phase 3: 10 more DDP forward+backward passes (does the CCL hang?) |
| Phase 4: Reinit DataLoader with new DistributedSampler shuffle, 10 more passes |
| Phase 5: Explicit dist.broadcast on small tensors (simulating NaN sync) |
| |
| If any phase hangs, the srun timeout kills the process and the last logged |
| START message reveals the failing phase. |
| |
| Launch: srun (see bash_test_ccl.sh) or torchrun for single-node testing. |
| |
| Environment variables (set by srun/SLURM): |
| RANK, LOCAL_RANK, WORLD_SIZE, MASTER_ADDR, MASTER_PORT |
| """ |
|
|
| import os |
| import sys |
| import time |
| import argparse |
| import datetime |
| import tempfile |
|
|
| import torch |
| import torch.nn as nn |
| import torch.distributed as dist |
| from torch.nn.parallel import DistributedDataParallel as DDP |
| from torch.utils.data import Dataset, DataLoader |
| from torch.utils.data.distributed import DistributedSampler |
|
|
| |
| try: |
| import intel_extension_for_pytorch as ipex |
| except ImportError: |
| ipex = None |
| try: |
| import oneccl_bindings_for_pytorch |
| except (ImportError, Exception): |
| pass |
|
|
|
|
| |
| |
| |
|
|
| def log(rank, msg): |
| ts = datetime.datetime.now().strftime("%Y-%m-%d %H:%M:%S.%f")[:-3] |
| print(f"[{ts}] [rank {rank}] {msg}", flush=True) |
|
|
|
|
| |
| |
| |
|
|
| class DummyConv3DModel(nn.Module): |
| """A model with ~400M parameters built from Conv3d layers. |
| |
| Architecture: a sequence of Conv3d blocks that keep the spatial dimensions |
| fixed (kernel=3, stride=1, padding=1) with large channel counts to hit the |
| parameter target. Final layer produces 3 output channels (DVF-like). |
| |
| With channels = [1, 64, 128, 256, 256, 256, 256, 128, 64, 3]: |
| Total params ~ 400M (dominated by the 256->256 blocks). |
| """ |
|
|
| def __init__(self, ndims=3): |
| super().__init__() |
| |
| channels = [1, 64, 128, 256, 512, 256, 128, 64, 3] |
| layers = [] |
| for i in range(len(channels) - 1): |
| layers.append(nn.Conv3d(channels[i], channels[i + 1], 3, 1, 1)) |
| if i < len(channels) - 2: |
| layers.append(nn.GroupNorm(min(32, channels[i + 1]), channels[i + 1])) |
| layers.append(nn.ReLU(inplace=True)) |
| self.net = nn.Sequential(*layers) |
|
|
| def forward(self, x): |
| return self.net(x) |
|
|
|
|
| |
| |
| |
|
|
| class DummyVolumeDataset(Dataset): |
| """Random 3D volumes. Returns (volume, label) where label is an integer.""" |
|
|
| def __init__(self, num_samples, spatial_size=128): |
| self.num_samples = num_samples |
| self.spatial_size = spatial_size |
|
|
| def __len__(self): |
| return self.num_samples |
|
|
| def __getitem__(self, idx): |
| vol = torch.randn(1, self.spatial_size, self.spatial_size, self.spatial_size) |
| label = idx % 10 |
| return vol, label |
|
|
|
|
| |
| |
| |
|
|
| def do_forward_backward(model, optimizer, data, device, rank, step_label=""): |
| """Single DDP forward+backward+step. Returns loss value.""" |
| x, _ = data |
| x = x.to(device) |
| pred = model(x) |
| loss = pred.mean() |
| optimizer.zero_grad(set_to_none=True) |
| loss.backward() |
| optimizer.step() |
| return loss.item() |
|
|
|
|
| |
| |
| |
|
|
| def phase_1(rank, model, optimizer, dataloader, device, num_steps=200): |
| """Phase 1: N DDP forward+backward passes (simulating one epoch).""" |
| log(rank, f"PHASE 1 START: {num_steps} DDP forward+backward passes") |
| t0 = time.time() |
| data_iter = iter(dataloader) |
| for step in range(num_steps): |
| try: |
| data = next(data_iter) |
| except StopIteration: |
| data_iter = iter(dataloader) |
| data = next(data_iter) |
| loss_val = do_forward_backward(model, optimizer, data, device, rank, |
| step_label=f"P1 step {step}") |
| if step % 50 == 0 or step == num_steps - 1: |
| log(rank, f" Phase 1 step {step}/{num_steps}, loss={loss_val:.6f}") |
| elapsed = time.time() - t0 |
| log(rank, f"PHASE 1 END: completed {num_steps} steps in {elapsed:.1f}s") |
| dist.barrier() |
| log(rank, f"PHASE 1 BARRIER passed") |
|
|
|
|
| def phase_2(rank, model, device): |
| """Phase 2: Save a ~2.9 GB checkpoint to /tmp (simulating end-of-epoch save).""" |
| log(rank, "PHASE 2 START: Checkpoint save") |
| t0 = time.time() |
|
|
| |
| if rank == 0: |
| state = { |
| "model_state_dict": model.module.state_dict(), |
| "dummy_optimizer": {f"key_{i}": torch.randn(1000, 1000) for i in range(5)}, |
| "epoch": 1, |
| } |
| ckpt_path = os.path.join(tempfile.gettempdir(), "ccl_stress_test_ckpt.pth") |
| torch.save(state, ckpt_path) |
| ckpt_size_gb = os.path.getsize(ckpt_path) / (1024 ** 3) |
| log(rank, f" Saved checkpoint: {ckpt_path} ({ckpt_size_gb:.2f} GB)") |
| |
| os.remove(ckpt_path) |
| log(rank, f" Cleaned up checkpoint file") |
| else: |
| log(rank, " (non-rank-0, waiting at barrier)") |
|
|
| elapsed = time.time() - t0 |
| log(rank, f"PHASE 2 END: checkpoint save completed in {elapsed:.1f}s") |
| dist.barrier() |
| log(rank, f"PHASE 2 BARRIER passed") |
|
|
|
|
| def phase_3(rank, model, optimizer, dataloader, device, num_steps=10): |
| """Phase 3: Post-checkpoint DDP passes. Does this hang?""" |
| log(rank, f"PHASE 3 START: {num_steps} DDP forward+backward passes AFTER checkpoint save") |
| t0 = time.time() |
| data_iter = iter(dataloader) |
| for step in range(num_steps): |
| try: |
| data = next(data_iter) |
| except StopIteration: |
| data_iter = iter(dataloader) |
| data = next(data_iter) |
| loss_val = do_forward_backward(model, optimizer, data, device, rank, |
| step_label=f"P3 step {step}") |
| log(rank, f" Phase 3 step {step}/{num_steps}, loss={loss_val:.6f}") |
| elapsed = time.time() - t0 |
| log(rank, f"PHASE 3 END: completed {num_steps} steps in {elapsed:.1f}s") |
| dist.barrier() |
| log(rank, f"PHASE 3 BARRIER passed") |
|
|
|
|
| def phase_4(rank, world_size, model, optimizer, device, num_steps=10, spatial_size=128, batch_size=2): |
| """Phase 4: Reinit DataLoader with new DistributedSampler, then do passes.""" |
| log(rank, f"PHASE 4 START: DataLoader reinit + {num_steps} DDP passes") |
| t0 = time.time() |
|
|
| |
| new_dataset = DummyVolumeDataset(num_samples=200, spatial_size=spatial_size) |
| new_sampler = DistributedSampler(new_dataset, num_replicas=world_size, rank=rank, shuffle=True) |
| new_sampler.set_epoch(2) |
| new_dataloader = DataLoader(new_dataset, batch_size=batch_size, sampler=new_sampler, |
| num_workers=2, pin_memory=False, drop_last=True) |
| log(rank, f" New DataLoader created with {len(new_dataset)} samples, sampler epoch=2") |
|
|
| data_iter = iter(new_dataloader) |
| for step in range(num_steps): |
| try: |
| data = next(data_iter) |
| except StopIteration: |
| data_iter = iter(new_dataloader) |
| data = next(data_iter) |
| loss_val = do_forward_backward(model, optimizer, data, device, rank, |
| step_label=f"P4 step {step}") |
| log(rank, f" Phase 4 step {step}/{num_steps}, loss={loss_val:.6f}") |
| elapsed = time.time() - t0 |
| log(rank, f"PHASE 4 END: completed {num_steps} steps in {elapsed:.1f}s") |
| dist.barrier() |
| log(rank, f"PHASE 4 BARRIER passed") |
|
|
|
|
| def phase_5(rank, device, num_broadcasts=50): |
| """Phase 5: Explicit dist.broadcast on small tensors (simulating NaN sync).""" |
| log(rank, f"PHASE 5 START: {num_broadcasts} dist.broadcast calls") |
| t0 = time.time() |
|
|
| for i in range(num_broadcasts): |
| |
| flag = torch.tensor([0.0], device=device) |
| dist.all_reduce(flag, op=dist.ReduceOp.MAX) |
|
|
| |
| param = torch.randn(1024, device=device) |
| dist.broadcast(param, src=0) |
|
|
| if i % 10 == 0 or i == num_broadcasts - 1: |
| log(rank, f" Phase 5 broadcast {i}/{num_broadcasts} completed") |
|
|
| elapsed = time.time() - t0 |
| log(rank, f"PHASE 5 END: completed {num_broadcasts} broadcasts in {elapsed:.1f}s") |
| dist.barrier() |
| log(rank, f"PHASE 5 BARRIER passed") |
|
|
|
|
| |
| |
| |
|
|
| def main(): |
| parser = argparse.ArgumentParser(description="CCL stress test for XPU DDP epoch-boundary hangs") |
| parser.add_argument("--spatial-size", type=int, default=64, |
| help="Spatial size of 3D volumes (default: 64, use 128 for full-scale test)") |
| parser.add_argument("--batch-size", type=int, default=2, |
| help="Batch size per rank (default: 2)") |
| parser.add_argument("--phase1-steps", type=int, default=200, |
| help="Number of steps in Phase 1 (default: 200)") |
| parser.add_argument("--phase3-steps", type=int, default=10, |
| help="Number of steps in Phase 3 (default: 10)") |
| parser.add_argument("--phase4-steps", type=int, default=10, |
| help="Number of steps in Phase 4 (default: 10)") |
| parser.add_argument("--phase5-broadcasts", type=int, default=50, |
| help="Number of broadcasts in Phase 5 (default: 50)") |
| parser.add_argument("--skip-phases", type=str, default="", |
| help="Comma-separated phases to skip, e.g. '2,4'") |
| args = parser.parse_args() |
|
|
| skip_phases = set() |
| if args.skip_phases: |
| skip_phases = {int(p.strip()) for p in args.skip_phases.split(",")} |
|
|
| |
| |
| |
| if hasattr(torch, "xpu") and torch.xpu.is_available(): |
| device_type = "xpu" |
| backend = "ccl" |
| elif torch.cuda.is_available(): |
| device_type = "cuda" |
| backend = "nccl" |
| else: |
| print("ERROR: No XPU or CUDA device found. This test requires GPU/XPU.", flush=True) |
| sys.exit(1) |
|
|
| |
| |
| |
| rank = int(os.environ.get("RANK", 0)) |
| local_rank = int(os.environ.get("LOCAL_RANK", 0)) |
| world_size = int(os.environ.get("WORLD_SIZE", 1)) |
|
|
| if world_size < 2: |
| print("WARNING: WORLD_SIZE < 2. DDP collectives are trivial with 1 rank. " |
| "Use srun with multiple tasks for meaningful CCL stress testing.", flush=True) |
|
|
| log(rank, f"Initializing DDP: backend={backend}, device={device_type}, " |
| f"rank={rank}, local_rank={local_rank}, world_size={world_size}") |
| log(rank, f"MASTER_ADDR={os.environ.get('MASTER_ADDR', 'unset')}, " |
| f"MASTER_PORT={os.environ.get('MASTER_PORT', 'unset')}") |
| log(rank, f"CCL_ZE_CACHE_OPEN_IPC_HANDLES_THRESHOLD=" |
| f"{os.environ.get('CCL_ZE_CACHE_OPEN_IPC_HANDLES_THRESHOLD', 'unset')}") |
|
|
| dist.init_process_group(backend=backend, rank=rank, world_size=world_size) |
|
|
| if device_type == "xpu": |
| torch.xpu.set_device(local_rank) |
| device = torch.device(f"xpu:{local_rank}") |
| else: |
| torch.cuda.set_device(local_rank) |
| device = torch.device(f"cuda:{local_rank}") |
|
|
| log(rank, f"DDP initialized on {device}") |
|
|
| |
| |
| |
| log(rank, "Creating model...") |
| model = DummyConv3DModel(ndims=3).to(device) |
| num_params = sum(p.numel() for p in model.parameters()) |
| param_size_gb = num_params * 4 / (1024 ** 3) |
| log(rank, f"Model created: {num_params:,} parameters ({param_size_gb:.2f} GB in float32)") |
|
|
| model = DDP(model, device_ids=[local_rank]) |
| optimizer = torch.optim.Adam(model.parameters(), lr=1e-4) |
| log(rank, "DDP wrapper and optimizer created") |
|
|
| |
| |
| |
| dataset = DummyVolumeDataset(num_samples=max(500, args.phase1_steps * args.batch_size), |
| spatial_size=args.spatial_size) |
| sampler = DistributedSampler(dataset, num_replicas=world_size, rank=rank, shuffle=True) |
| sampler.set_epoch(1) |
| dataloader = DataLoader(dataset, batch_size=args.batch_size, sampler=sampler, |
| num_workers=2, pin_memory=False, drop_last=True) |
| log(rank, f"DataLoader created: {len(dataset)} samples, batch_size={args.batch_size}, " |
| f"spatial_size={args.spatial_size}") |
|
|
| |
| |
| |
| log(rank, "WARMUP START: 5 steps") |
| data_iter = iter(dataloader) |
| for i in range(5): |
| try: |
| data = next(data_iter) |
| except StopIteration: |
| data_iter = iter(dataloader) |
| data = next(data_iter) |
| do_forward_backward(model, optimizer, data, device, rank) |
| dist.barrier() |
| log(rank, "WARMUP END") |
|
|
| |
| if device_type == "xpu" and hasattr(torch.xpu, "memory_allocated"): |
| alloc = torch.xpu.memory_allocated(device) / (1024 ** 3) |
| reserved = torch.xpu.memory_reserved(device) / (1024 ** 3) |
| log(rank, f"XPU memory after warmup: allocated={alloc:.2f} GB, reserved={reserved:.2f} GB") |
| elif device_type == "cuda": |
| alloc = torch.cuda.memory_allocated(device) / (1024 ** 3) |
| reserved = torch.cuda.memory_reserved(device) / (1024 ** 3) |
| log(rank, f"CUDA memory after warmup: allocated={alloc:.2f} GB, reserved={reserved:.2f} GB") |
|
|
| |
| |
| |
| total_t0 = time.time() |
|
|
| if 1 not in skip_phases: |
| phase_1(rank, model, optimizer, dataloader, device, num_steps=args.phase1_steps) |
| else: |
| log(rank, "PHASE 1 SKIPPED") |
|
|
| if 2 not in skip_phases: |
| phase_2(rank, model, device) |
| else: |
| log(rank, "PHASE 2 SKIPPED") |
|
|
| if 3 not in skip_phases: |
| phase_3(rank, model, optimizer, dataloader, device, num_steps=args.phase3_steps) |
| else: |
| log(rank, "PHASE 3 SKIPPED") |
|
|
| if 4 not in skip_phases: |
| phase_4(rank, world_size, model, optimizer, device, num_steps=args.phase4_steps, |
| spatial_size=args.spatial_size, batch_size=args.batch_size) |
| else: |
| log(rank, "PHASE 4 SKIPPED") |
|
|
| if 5 not in skip_phases: |
| phase_5(rank, device, num_broadcasts=args.phase5_broadcasts) |
| else: |
| log(rank, "PHASE 5 SKIPPED") |
|
|
| total_elapsed = time.time() - total_t0 |
| log(rank, f"ALL PHASES COMPLETE in {total_elapsed:.1f}s") |
|
|
| |
| if device_type == "xpu" and hasattr(torch.xpu, "memory_allocated"): |
| alloc = torch.xpu.memory_allocated(device) / (1024 ** 3) |
| reserved = torch.xpu.memory_reserved(device) / (1024 ** 3) |
| log(rank, f"XPU memory final: allocated={alloc:.2f} GB, reserved={reserved:.2f} GB") |
| elif device_type == "cuda": |
| alloc = torch.cuda.memory_allocated(device) / (1024 ** 3) |
| reserved = torch.cuda.memory_reserved(device) / (1024 ** 3) |
| log(rank, f"CUDA memory final: allocated={alloc:.2f} GB, reserved={reserved:.2f} GB") |
|
|
| dist.barrier() |
| dist.destroy_process_group() |
| log(rank, "Process group destroyed. Test PASSED.") |
|
|
|
|
| if __name__ == "__main__": |
| main() |
|
|