| """ |
| Benchmark latency comparison between different all-reduce implementations. |
| |
| Compares: |
| - NCCL all-reduce (may be non-deterministic) |
| - Reduce-scatter + all-gather (RS+AG, deterministic but slower) |
| - Deterministic 1-stage kernel (forces fixed accumulation order, deterministic) |
| |
| Note: The "deterministic kernel" is NOT RS+AG. It uses the 1-stage kernel where |
| each GPU reads all data from all GPUs and reduces locally in a fixed order. |
| |
| Usage: |
| python bench_amd_deterministic_allreduce.py |
| """ |
|
|
| import multiprocessing as mp |
| import os |
| import socket |
| import statistics |
| import sys |
| import time |
|
|
| import torch |
| import torch.distributed as dist |
|
|
| |
| script_dir = os.path.dirname(os.path.abspath(__file__)) |
| python_dir = os.path.join(script_dir, "python") |
| sys.path.insert(0, python_dir) |
|
|
| |
| try: |
| import sglang.srt.distributed.device_communicators.custom_all_reduce_ops as custom_ar_ops |
| from sglang.srt.distributed.device_communicators.custom_all_reduce import ( |
| CustomAllreduce, |
| ) |
| from sglang.srt.distributed.device_communicators.custom_all_reduce_utils import ( |
| is_weak_contiguous, |
| ) |
|
|
| CUSTOM_AR_AVAILABLE = custom_ar_ops.IS_CUSTOM_AR_AVAILABLE |
| except (ImportError, AttributeError): |
| CUSTOM_AR_AVAILABLE = False |
| CustomAllreduce = None |
| is_weak_contiguous = None |
|
|
| |
| |
| SGLANG_AVAILABLE = False |
|
|
|
|
| def get_open_port(): |
| with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as s: |
| s.bind(("127.0.0.1", 0)) |
| return s.getsockname()[1] |
|
|
|
|
| def init_custom_ar_if_available(rank, world_size, device): |
| """Check if custom all-reduce is available and applicable.""" |
| if not CUSTOM_AR_AVAILABLE or CustomAllreduce is None: |
| return False |
|
|
| |
| if world_size <= 8 and world_size % 2 == 0: |
| return True |
|
|
| return False |
|
|
|
|
| def reduce_scatter_then_all_gather(tensor, rank, world_size, custom_ar=None): |
| """ |
| Deterministic all-reduce using reduce-scatter + all-gather. |
| This is deterministic because it uses fixed ordering (no atomics). |
| """ |
| total_size = tensor.numel() |
| if total_size % world_size != 0: |
| |
| gather_list = [torch.empty_like(tensor) for _ in range(world_size)] |
| dist.all_gather(gather_list, tensor) |
| stacked = torch.stack(gather_list, dim=0) |
| tensor.copy_(stacked.sum(dim=0)) |
| return |
|
|
| chunk_size = total_size // world_size |
|
|
| |
| tensor_flat = tensor.view(-1) |
|
|
| |
| output_chunk = torch.empty(chunk_size, dtype=tensor.dtype, device=tensor.device) |
|
|
| |
| input_chunks = [ |
| tensor_flat[i * chunk_size : (i + 1) * chunk_size].clone() |
| for i in range(world_size) |
| ] |
|
|
| dist.reduce_scatter(output_chunk, input_chunks) |
|
|
| |
| output_chunks = [ |
| torch.empty(chunk_size, dtype=tensor.dtype, device=tensor.device) |
| for _ in range(world_size) |
| ] |
| dist.all_gather(output_chunks, output_chunk) |
|
|
| |
| result_flat = torch.cat(output_chunks, dim=0) |
| tensor.copy_(result_flat.view(tensor.shape)) |
|
|
|
|
| def worker(world_size, rank, port, results_queue): |
| device = torch.device(f"cuda:{rank}") |
| torch.cuda.set_device(device) |
|
|
| dist.init_process_group( |
| backend="nccl", |
| init_method=f"tcp://localhost:{port}", |
| rank=rank, |
| world_size=world_size, |
| ) |
|
|
| |
| custom_ar = None |
| use_custom_ar = init_custom_ar_if_available(rank, world_size, device) |
| if use_custom_ar and CUSTOM_AR_AVAILABLE: |
| try: |
| |
| |
| from torch.distributed import new_group |
|
|
| dist.barrier() |
| ar_group = new_group(backend="gloo") |
| dist.barrier() |
| custom_ar = CustomAllreduce(group=ar_group, device=device) |
| if rank == 0: |
| print(" Using custom all-reduce (deterministic)") |
| except Exception as e: |
| if rank == 0: |
| print(f" Custom AR init failed: {e}, using NCCL fallback") |
| custom_ar = None |
| dist.barrier() |
|
|
| |
| batch_sizes = [1, 4, 8, 16, 32, 64, 128, 256, 512] |
| hidden_dim = 16384 |
|
|
| num_trials = 10 |
|
|
| |
| torch.manual_seed(42 + rank) |
|
|
| results = {} |
|
|
| for bs in batch_sizes: |
| |
| base_input = torch.randn(bs, hidden_dim, dtype=torch.bfloat16, device=device) |
|
|
| dist.barrier() |
|
|
| if rank == 0: |
| print(f"\nBatch size {bs:4d}:") |
| print(f" Testing determinism across {num_trials} trials...") |
|
|
| |
| results_ar = [] |
| latencies_ar = [] |
| for trial in range(num_trials): |
| |
| inp_ar = base_input.clone() |
| inp_flat_ar = inp_ar.view(-1) |
|
|
| |
| torch.cuda.synchronize() |
| start = time.perf_counter() |
| dist.all_reduce(inp_flat_ar, op=dist.ReduceOp.SUM) |
| torch.cuda.synchronize() |
| end = time.perf_counter() |
| latencies_ar.append(end - start) |
|
|
| |
| checksum = inp_flat_ar.sum().item() |
| first_vals = inp_flat_ar[:5].clone() |
| results_ar.append((checksum, first_vals)) |
|
|
| |
| results_rs_ag = [] |
| latencies_rs_ag = [] |
| for trial in range(num_trials): |
| |
| inp_rs_ag = base_input.clone() |
| inp_flat_rs_ag = inp_rs_ag.view(-1) |
|
|
| |
| torch.cuda.synchronize() |
| start = time.perf_counter() |
| reduce_scatter_then_all_gather( |
| inp_flat_rs_ag, rank, world_size, custom_ar=None |
| ) |
| torch.cuda.synchronize() |
| end = time.perf_counter() |
| latencies_rs_ag.append(end - start) |
|
|
| |
| checksum = inp_flat_rs_ag.sum().item() |
| first_vals = inp_flat_rs_ag[:5].clone() |
| results_rs_ag.append((checksum, first_vals)) |
|
|
| |
| |
| use_sglang_optimized = False |
| results_optimized_rs_ag = [] |
| latencies_optimized_rs_ag = [] |
|
|
| |
| results_custom_ar = [] |
| latencies_custom_ar = [] |
| if custom_ar is not None: |
| for trial in range(num_trials): |
| |
| inp_custom = base_input.clone() |
| inp_flat_custom = inp_custom.view(-1) |
|
|
| |
| torch.cuda.synchronize() |
| start = time.perf_counter() |
| reduce_scatter_then_all_gather( |
| inp_flat_custom, rank, world_size, custom_ar=custom_ar |
| ) |
| torch.cuda.synchronize() |
| end = time.perf_counter() |
| latencies_custom_ar.append(end - start) |
|
|
| |
| checksum = inp_flat_custom.sum().item() |
| first_vals = inp_flat_custom[:5].clone() |
| results_custom_ar.append((checksum, first_vals)) |
|
|
| |
| results_deterministic_kernel = [] |
| latencies_deterministic_kernel = [] |
| deterministic_kernel_available = False |
| if custom_ar is not None and hasattr(custom_ar, "deterministic_all_reduce"): |
| |
| input_size_bytes = base_input.numel() * base_input.element_size() |
| if input_size_bytes > custom_ar.max_size: |
| if rank == 0: |
| print( |
| f" Deterministic kernel skipped: input size ({input_size_bytes/(1024*1024):.1f} MB) > buffer size ({custom_ar.max_size/(1024*1024):.1f} MB)" |
| ) |
| deterministic_kernel_available = False |
| else: |
| try: |
| deterministic_kernel_available = True |
| for trial in range(num_trials): |
| |
| inp_kernel = base_input.clone() |
|
|
| |
| torch.cuda.synchronize() |
| start = time.perf_counter() |
| result_kernel = custom_ar.deterministic_all_reduce( |
| inp_kernel, registered=False |
| ) |
| torch.cuda.synchronize() |
| end = time.perf_counter() |
| latencies_deterministic_kernel.append(end - start) |
|
|
| |
| result_flat_kernel = result_kernel.view(-1) |
| checksum = result_flat_kernel.sum().item() |
| first_vals = result_flat_kernel[:5].clone() |
| results_deterministic_kernel.append((checksum, first_vals)) |
| except Exception as e: |
| if rank == 0: |
| print( |
| f" Deterministic kernel test failed for batch size {bs}: {e}" |
| ) |
| deterministic_kernel_available = False |
|
|
| dist.barrier() |
|
|
| if rank == 0: |
| |
| ar_deterministic = True |
| ar_ref_sum, ar_ref_vals = results_ar[0] |
| ar_variance = [] |
| for i, (s, vals) in enumerate(results_ar[1:], 1): |
| if abs(ar_ref_sum - s) > 1e-3 or not torch.allclose( |
| ar_ref_vals, vals, rtol=1e-3 |
| ): |
| ar_deterministic = False |
| ar_variance.append(abs(ar_ref_sum - s)) |
|
|
| |
| rs_ag_deterministic = True |
| rs_ag_ref_sum, rs_ag_ref_vals = results_rs_ag[0] |
| rs_ag_variance = [] |
| for i, (s, vals) in enumerate(results_rs_ag[1:], 1): |
| if abs(rs_ag_ref_sum - s) > 1e-3 or not torch.allclose( |
| rs_ag_ref_vals, vals, rtol=1e-3 |
| ): |
| rs_ag_deterministic = False |
| rs_ag_variance.append(abs(rs_ag_ref_sum - s)) |
|
|
| |
| optimized_rs_ag_deterministic = None |
| optimized_rs_ag_max_variance = None |
| lat_optimized_rs_ag_median = None |
| if use_sglang_optimized and results_optimized_rs_ag: |
| optimized_rs_ag_deterministic = True |
| opt_rs_ag_ref_sum, opt_rs_ag_ref_vals = results_optimized_rs_ag[0] |
| opt_rs_ag_variance = [] |
| for i, (s, vals) in enumerate(results_optimized_rs_ag[1:], 1): |
| if abs(opt_rs_ag_ref_sum - s) > 1e-3 or not torch.allclose( |
| opt_rs_ag_ref_vals, vals, rtol=1e-3 |
| ): |
| optimized_rs_ag_deterministic = False |
| opt_rs_ag_variance.append(abs(opt_rs_ag_ref_sum - s)) |
| optimized_rs_ag_max_variance = ( |
| max(opt_rs_ag_variance) if opt_rs_ag_variance else 0.0 |
| ) |
| lat_optimized_rs_ag_median = statistics.median( |
| latencies_optimized_rs_ag |
| ) |
|
|
| |
| custom_ar_deterministic = None |
| custom_ar_max_variance = None |
| lat_custom_ar_median = None |
| if custom_ar is not None and results_custom_ar: |
| custom_ar_deterministic = True |
| custom_ar_ref_sum, custom_ar_ref_vals = results_custom_ar[0] |
| custom_ar_variance = [] |
| for i, (s, vals) in enumerate(results_custom_ar[1:], 1): |
| if abs(custom_ar_ref_sum - s) > 1e-3 or not torch.allclose( |
| custom_ar_ref_vals, vals, rtol=1e-3 |
| ): |
| custom_ar_deterministic = False |
| custom_ar_variance.append(abs(custom_ar_ref_sum - s)) |
| custom_ar_max_variance = ( |
| max(custom_ar_variance) if custom_ar_variance else 0.0 |
| ) |
| lat_custom_ar_median = statistics.median(latencies_custom_ar) |
|
|
| |
| deterministic_kernel_deterministic = None |
| deterministic_kernel_max_variance = None |
| lat_deterministic_kernel_median = None |
| if deterministic_kernel_available and results_deterministic_kernel: |
| deterministic_kernel_deterministic = True |
| kernel_ref_sum, kernel_ref_vals = results_deterministic_kernel[0] |
| kernel_variance = [] |
| for i, (s, vals) in enumerate(results_deterministic_kernel[1:], 1): |
| if abs(kernel_ref_sum - s) > 1e-3 or not torch.allclose( |
| kernel_ref_vals, vals, rtol=1e-3 |
| ): |
| deterministic_kernel_deterministic = False |
| kernel_variance.append(abs(kernel_ref_sum - s)) |
| deterministic_kernel_max_variance = ( |
| max(kernel_variance) if kernel_variance else 0.0 |
| ) |
| lat_deterministic_kernel_median = statistics.median( |
| latencies_deterministic_kernel |
| ) |
|
|
| |
| lat_ar_median = statistics.median(latencies_ar) |
| lat_rs_ag_median = statistics.median(latencies_rs_ag) |
| overhead_rs_ag = ((lat_rs_ag_median - lat_ar_median) / lat_ar_median) * 100 |
|
|
| |
| ar_max_variance = max(ar_variance) if ar_variance else 0.0 |
| rs_ag_max_variance = max(rs_ag_variance) if rs_ag_variance else 0.0 |
|
|
| results[bs] = { |
| "all_reduce": { |
| "latency_median": lat_ar_median, |
| "deterministic": ar_deterministic, |
| "max_variance": ar_max_variance, |
| }, |
| "rs_ag": { |
| "latency_median": lat_rs_ag_median, |
| "deterministic": rs_ag_deterministic, |
| "max_variance": rs_ag_max_variance, |
| }, |
| "custom_ar": ( |
| { |
| "latency_median": lat_custom_ar_median, |
| "deterministic": custom_ar_deterministic, |
| "max_variance": custom_ar_max_variance, |
| } |
| if custom_ar is not None |
| else None |
| ), |
| "deterministic_kernel": ( |
| { |
| "latency_median": lat_deterministic_kernel_median, |
| "deterministic": deterministic_kernel_deterministic, |
| "max_variance": deterministic_kernel_max_variance, |
| } |
| if lat_deterministic_kernel_median is not None |
| else None |
| ), |
| "optimized_rs_ag": ( |
| { |
| "latency_median": lat_optimized_rs_ag_median, |
| "deterministic": optimized_rs_ag_deterministic, |
| "max_variance": optimized_rs_ag_max_variance, |
| } |
| if lat_optimized_rs_ag_median is not None |
| else None |
| ), |
| "overhead_rs_ag_pct": overhead_rs_ag, |
| } |
|
|
| print( |
| f" All-Reduce: {lat_ar_median*1000:.3f}ms, Deterministic: {ar_deterministic}, Max variance: {ar_max_variance:.6f}" |
| ) |
| print( |
| f" RS+All-Gather: {lat_rs_ag_median*1000:.3f}ms, Deterministic: {rs_ag_deterministic}, Max variance: {rs_ag_max_variance:.6f}" |
| ) |
| if custom_ar is not None and lat_custom_ar_median is not None: |
| overhead_custom = ( |
| (lat_custom_ar_median - lat_ar_median) / lat_ar_median |
| ) * 100 |
| print( |
| f" Custom AR: {lat_custom_ar_median*1000:.3f}ms, Deterministic: {custom_ar_deterministic}, Max variance: {custom_ar_max_variance:.6f}, Overhead: {overhead_custom:+.1f}%" |
| ) |
| if lat_deterministic_kernel_median is not None: |
| overhead_kernel = ( |
| (lat_deterministic_kernel_median - lat_ar_median) / lat_ar_median |
| ) * 100 |
| speedup_kernel_vs_rs_ag = ( |
| (lat_rs_ag_median - lat_deterministic_kernel_median) |
| / lat_rs_ag_median |
| ) * 100 |
| print( |
| f" Deterministic Kernel: {lat_deterministic_kernel_median*1000:.3f}ms, Deterministic: {deterministic_kernel_deterministic}, Max variance: {deterministic_kernel_max_variance:.6f}, Overhead: {overhead_kernel:+.1f}%, Speedup vs RS+AG: {speedup_kernel_vs_rs_ag:+.1f}%" |
| ) |
| if lat_optimized_rs_ag_median is not None: |
| overhead_opt = ( |
| (lat_optimized_rs_ag_median - lat_ar_median) / lat_ar_median |
| ) * 100 |
| speedup_vs_rs_ag = ( |
| (lat_rs_ag_median - lat_optimized_rs_ag_median) / lat_rs_ag_median |
| ) * 100 |
| print( |
| f" Optimized RS+AG: {lat_optimized_rs_ag_median*1000:.3f}ms, Deterministic: {optimized_rs_ag_deterministic}, Max variance: {optimized_rs_ag_max_variance:.6f}, Overhead: {overhead_opt:+.1f}%, Speedup vs RS+AG: {speedup_vs_rs_ag:+.1f}%" |
| ) |
| print(f" RS+AG Overhead: {overhead_rs_ag:+.1f}%") |
|
|
| if rank == 0: |
| results_queue.put(results) |
|
|
| dist.destroy_process_group() |
|
|
|
|
| def main(): |
| world_size = 8 |
| available_gpus = torch.cuda.device_count() |
|
|
| print("=" * 80) |
| print("All-Reduce vs Reduce-Scatter + All-Gather Determinism & Latency Benchmark") |
| print("=" * 80) |
| print(f"Available GPUs: {available_gpus}") |
| print(f"Using world_size: {world_size}") |
| print(f"Hidden dimension: 16384") |
| print(f"Tensor dtype: bfloat16") |
| print(f"Trials per batch size: 10 (testing determinism)") |
| print(f"Testing batch sizes: [1, 4, 8, 16, 32, 64, 128, 256, 512]") |
| print("=" * 80) |
|
|
| if available_gpus < world_size: |
| print( |
| f"WARNING: Only {available_gpus} GPUs available, using {available_gpus} instead" |
| ) |
| world_size = available_gpus |
|
|
| if world_size < 2: |
| print("ERROR: Need at least 2 GPUs for this benchmark") |
| return |
|
|
| mp.set_start_method("spawn", force=True) |
| port = get_open_port() |
|
|
| results_queue = mp.Queue() |
| procs = [] |
| for rank in range(world_size): |
| p = mp.Process(target=worker, args=(world_size, rank, port, results_queue)) |
| p.start() |
| procs.append(p) |
|
|
| for p in procs: |
| p.join() |
|
|
| |
| if not results_queue.empty(): |
| results = results_queue.get() |
|
|
| print("\n" + "=" * 80) |
| print("SUMMARY") |
| print("=" * 80) |
| header = f"{'Batch':<8} {'AR (ms)':<12} {'AR Det':<8} {'RS+AG (ms)':<15} {'RS+AG Det':<10} {'RS+AG Ovh':<12}" |
| if any(r.get("custom_ar") is not None for r in results.values()): |
| header += ( |
| f" {'Custom AR (ms)':<18} {'Custom AR Det':<15} {'Custom AR Ovh':<15}" |
| ) |
| if any(r.get("deterministic_kernel") is not None for r in results.values()): |
| header += f" {'Det Kernel (ms)':<18} {'Det Kernel Det':<15} {'Det Kernel Ovh':<15} {'Speedup':<10}" |
| if any(r.get("optimized_rs_ag") is not None for r in results.values()): |
| header += f" {'Opt RS+AG (ms)':<18} {'Opt RS+AG Det':<15} {'Opt RS+AG Ovh':<15} {'Speedup':<10}" |
| print(header) |
| print("-" * 150) |
|
|
| for bs in sorted(results.keys()): |
| r = results[bs] |
| ar_det_str = "✓" if r["all_reduce"]["deterministic"] else "✗" |
| rs_ag_det_str = "✓" if r["rs_ag"]["deterministic"] else "✗" |
| line = ( |
| f"{bs:<8} {r['all_reduce']['latency_median']*1000:<12.3f} {ar_det_str:<8} " |
| f"{r['rs_ag']['latency_median']*1000:<15.3f} {rs_ag_det_str:<10} " |
| f"{r['overhead_rs_ag_pct']:<12.1f}" |
| ) |
| if r.get("custom_ar") is not None: |
| custom_ar = r["custom_ar"] |
| custom_ar_det_str = "✓" if custom_ar["deterministic"] else "✗" |
| custom_ar_overhead = ( |
| (custom_ar["latency_median"] - r["all_reduce"]["latency_median"]) |
| / r["all_reduce"]["latency_median"] |
| ) * 100 |
| line += f" {custom_ar['latency_median']*1000:<18.3f} {custom_ar_det_str:<15} {custom_ar_overhead:<15.1f}" |
| if r.get("deterministic_kernel") is not None: |
| det_kernel = r["deterministic_kernel"] |
| det_kernel_det_str = "✓" if det_kernel["deterministic"] else "✗" |
| det_kernel_overhead = ( |
| (det_kernel["latency_median"] - r["all_reduce"]["latency_median"]) |
| / r["all_reduce"]["latency_median"] |
| ) * 100 |
| speedup_kernel = ( |
| (r["rs_ag"]["latency_median"] - det_kernel["latency_median"]) |
| / r["rs_ag"]["latency_median"] |
| ) * 100 |
| line += f" {det_kernel['latency_median']*1000:<18.3f} {det_kernel_det_str:<15} {det_kernel_overhead:<15.1f} {speedup_kernel:<10.1f}" |
| if r.get("optimized_rs_ag") is not None: |
| opt_rs_ag = r["optimized_rs_ag"] |
| opt_rs_ag_det_str = "✓" if opt_rs_ag["deterministic"] else "✗" |
| opt_rs_ag_overhead = ( |
| (opt_rs_ag["latency_median"] - r["all_reduce"]["latency_median"]) |
| / r["all_reduce"]["latency_median"] |
| ) * 100 |
| speedup = ( |
| (r["rs_ag"]["latency_median"] - opt_rs_ag["latency_median"]) |
| / r["rs_ag"]["latency_median"] |
| ) * 100 |
| line += f" {opt_rs_ag['latency_median']*1000:<18.3f} {opt_rs_ag_det_str:<15} {opt_rs_ag_overhead:<15.1f} {speedup:<10.1f}" |
| print(line) |
|
|
| print("=" * 80) |
|
|
| |
| overheads_rs_ag = [r["overhead_rs_ag_pct"] for r in results.values()] |
| ar_deterministic_count = sum( |
| 1 for r in results.values() if r["all_reduce"]["deterministic"] |
| ) |
| rs_ag_deterministic_count = sum( |
| 1 for r in results.values() if r["rs_ag"]["deterministic"] |
| ) |
| custom_ar_deterministic_count = sum( |
| 1 |
| for r in results.values() |
| if r.get("custom_ar") and r["custom_ar"]["deterministic"] |
| ) |
| custom_ar_total_count = sum( |
| 1 for r in results.values() if r.get("custom_ar") is not None |
| ) |
|
|
| deterministic_kernel_deterministic_count = sum( |
| 1 |
| for r in results.values() |
| if r.get("deterministic_kernel") |
| and r["deterministic_kernel"]["deterministic"] |
| ) |
| deterministic_kernel_total_count = sum( |
| 1 for r in results.values() if r.get("deterministic_kernel") is not None |
| ) |
|
|
| print(f"\nDeterminism Summary:") |
| print( |
| f" All-Reduce deterministic: {ar_deterministic_count}/{len(results)} batch sizes" |
| ) |
| print( |
| f" RS+All-Gather deterministic: {rs_ag_deterministic_count}/{len(results)} batch sizes" |
| ) |
| if custom_ar_total_count > 0: |
| print( |
| f" Custom AR deterministic: {custom_ar_deterministic_count}/{custom_ar_total_count} batch sizes" |
| ) |
| if deterministic_kernel_total_count > 0: |
| print( |
| f" Deterministic Kernel deterministic: {deterministic_kernel_deterministic_count}/{deterministic_kernel_total_count} batch sizes" |
| ) |
|
|
| print(f"\nLatency Overhead Statistics (RS+AG vs All-Reduce):") |
| avg_overhead = statistics.mean(overheads_rs_ag) |
| median_overhead = statistics.median(overheads_rs_ag) |
| min_overhead = min(overheads_rs_ag) |
| max_overhead = max(overheads_rs_ag) |
| print(f" Average: {avg_overhead:.1f}%") |
| print(f" Median: {median_overhead:.1f}%") |
| print(f" Min: {min_overhead:.1f}%") |
| print(f" Max: {max_overhead:.1f}%") |
|
|
| if custom_ar_total_count > 0: |
| overheads_custom = [] |
| for r in results.values(): |
| if r.get("custom_ar") is not None: |
| overhead = ( |
| ( |
| r["custom_ar"]["latency_median"] |
| - r["all_reduce"]["latency_median"] |
| ) |
| / r["all_reduce"]["latency_median"] |
| ) * 100 |
| overheads_custom.append(overhead) |
| print(f"\nLatency Overhead Statistics (Custom AR vs All-Reduce):") |
| print(f" Average: {statistics.mean(overheads_custom):.1f}%") |
| print(f" Median: {statistics.median(overheads_custom):.1f}%") |
| print(f" Min: {min(overheads_custom):.1f}%") |
| print(f" Max: {max(overheads_custom):.1f}%") |
|
|
| if deterministic_kernel_total_count > 0: |
| overheads_kernel = [] |
| speedups_kernel = [] |
| for r in results.values(): |
| if r.get("deterministic_kernel") is not None: |
| overhead = ( |
| ( |
| r["deterministic_kernel"]["latency_median"] |
| - r["all_reduce"]["latency_median"] |
| ) |
| / r["all_reduce"]["latency_median"] |
| ) * 100 |
| overheads_kernel.append(overhead) |
| speedup = ( |
| ( |
| r["rs_ag"]["latency_median"] |
| - r["deterministic_kernel"]["latency_median"] |
| ) |
| / r["rs_ag"]["latency_median"] |
| ) * 100 |
| speedups_kernel.append(speedup) |
| print( |
| f"\nLatency Overhead Statistics (Deterministic Kernel vs All-Reduce):" |
| ) |
| print(f" Average: {statistics.mean(overheads_kernel):.1f}%") |
| print(f" Median: {statistics.median(overheads_kernel):.1f}%") |
| print(f" Min: {min(overheads_kernel):.1f}%") |
| print(f" Max: {max(overheads_kernel):.1f}%") |
| print(f"\nSpeedup Statistics (Deterministic Kernel vs RS+AG):") |
| print(f" Average: {statistics.mean(speedups_kernel):.1f}%") |
| print(f" Median: {statistics.median(speedups_kernel):.1f}%") |
| print(f" Min: {min(speedups_kernel):.1f}%") |
| print(f" Max: {max(speedups_kernel):.1f}%") |
|
|
| |
| print(f"\nVariance Analysis (non-deterministic cases):") |
| for bs in sorted(results.keys()): |
| r = results[bs] |
| if not r["all_reduce"]["deterministic"]: |
| print( |
| f" Batch {bs}: All-Reduce max variance: {r['all_reduce']['max_variance']:.6f}" |
| ) |
| if not r["rs_ag"]["deterministic"]: |
| print( |
| f" Batch {bs}: RS+All-Gather max variance: {r['rs_ag']['max_variance']:.6f}" |
| ) |
| if r.get("custom_ar") is not None and not r["custom_ar"]["deterministic"]: |
| print( |
| f" Batch {bs}: Custom AR max variance: {r['custom_ar']['max_variance']:.6f}" |
| ) |
| if ( |
| r.get("deterministic_kernel") is not None |
| and not r["deterministic_kernel"]["deterministic"] |
| ): |
| print( |
| f" Batch {bs}: Deterministic Kernel max variance: {r['deterministic_kernel']['max_variance']:.6f}" |
| ) |
|
|
|
|
| if __name__ == "__main__": |
| main() |
|
|