Hanrui / sglang /sgl-kernel /benchmark /bench_amd_deterministic_allreduce.py
Lekr0's picture
Add files using upload-large-folder tool
a402b9b verified
"""
Benchmark latency comparison between different all-reduce implementations.
Compares:
- NCCL all-reduce (may be non-deterministic)
- Reduce-scatter + all-gather (RS+AG, deterministic but slower)
- Deterministic 1-stage kernel (forces fixed accumulation order, deterministic)
Note: The "deterministic kernel" is NOT RS+AG. It uses the 1-stage kernel where
each GPU reads all data from all GPUs and reduces locally in a fixed order.
Usage:
python bench_amd_deterministic_allreduce.py
"""
import multiprocessing as mp
import os
import socket
import statistics
import sys
import time
import torch
import torch.distributed as dist
# Add python directory to path to import sglang modules
script_dir = os.path.dirname(os.path.abspath(__file__))
python_dir = os.path.join(script_dir, "python")
sys.path.insert(0, python_dir)
# Try to import custom all-reduce if available
try:
import sglang.srt.distributed.device_communicators.custom_all_reduce_ops as custom_ar_ops
from sglang.srt.distributed.device_communicators.custom_all_reduce import (
CustomAllreduce,
)
from sglang.srt.distributed.device_communicators.custom_all_reduce_utils import (
is_weak_contiguous,
)
CUSTOM_AR_AVAILABLE = custom_ar_ops.IS_CUSTOM_AR_AVAILABLE
except (ImportError, AttributeError):
CUSTOM_AR_AVAILABLE = False
CustomAllreduce = None
is_weak_contiguous = None
# Note: sglang's optimized all-reduce requires full runtime initialization
# and won't work in standalone benchmarks, so we skip it
SGLANG_AVAILABLE = False
def get_open_port():
with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as s:
s.bind(("127.0.0.1", 0))
return s.getsockname()[1]
def init_custom_ar_if_available(rank, world_size, device):
"""Check if custom all-reduce is available and applicable."""
if not CUSTOM_AR_AVAILABLE or CustomAllreduce is None:
return False
# Custom AR works best for single-node, even number of GPUs, world_size <= 8
if world_size <= 8 and world_size % 2 == 0:
return True
return False
def reduce_scatter_then_all_gather(tensor, rank, world_size, custom_ar=None):
"""
Deterministic all-reduce using reduce-scatter + all-gather.
This is deterministic because it uses fixed ordering (no atomics).
"""
total_size = tensor.numel()
if total_size % world_size != 0:
# Fallback to all-gather + local reduce if not divisible
gather_list = [torch.empty_like(tensor) for _ in range(world_size)]
dist.all_gather(gather_list, tensor)
stacked = torch.stack(gather_list, dim=0)
tensor.copy_(stacked.sum(dim=0))
return
chunk_size = total_size // world_size
# Flatten to 1D
tensor_flat = tensor.view(-1)
# Reduce-scatter: each rank gets its chunk of the reduced result
output_chunk = torch.empty(chunk_size, dtype=tensor.dtype, device=tensor.device)
# Split input into chunks for reduce-scatter
input_chunks = [
tensor_flat[i * chunk_size : (i + 1) * chunk_size].clone()
for i in range(world_size)
]
dist.reduce_scatter(output_chunk, input_chunks)
# All-gather: broadcast each rank's chunk to all ranks
output_chunks = [
torch.empty(chunk_size, dtype=tensor.dtype, device=tensor.device)
for _ in range(world_size)
]
dist.all_gather(output_chunks, output_chunk)
# Concatenate results back
result_flat = torch.cat(output_chunks, dim=0)
tensor.copy_(result_flat.view(tensor.shape))
def worker(world_size, rank, port, results_queue):
device = torch.device(f"cuda:{rank}")
torch.cuda.set_device(device)
dist.init_process_group(
backend="nccl",
init_method=f"tcp://localhost:{port}",
rank=rank,
world_size=world_size,
)
# Try to initialize custom all-reduce if available
custom_ar = None
use_custom_ar = init_custom_ar_if_available(rank, world_size, device)
if use_custom_ar and CUSTOM_AR_AVAILABLE:
try:
# Create a gloo group for custom AR (it requires non-NCCL backend)
# All ranks must call new_group with the same parameters
from torch.distributed import new_group
dist.barrier() # Ensure all ranks are ready
ar_group = new_group(backend="gloo")
dist.barrier() # Ensure group creation is complete
custom_ar = CustomAllreduce(group=ar_group, device=device)
if rank == 0:
print(" Using custom all-reduce (deterministic)")
except Exception as e:
if rank == 0:
print(f" Custom AR init failed: {e}, using NCCL fallback")
custom_ar = None
dist.barrier() # Ensure all ranks continue even if one fails
# Test different batch sizes - similar to test_ar.py
batch_sizes = [1, 4, 8, 16, 32, 64, 128, 256, 512]
hidden_dim = 16384 # Fixed hidden dimension
num_trials = 10 # Same as test_ar.py
# Different seed per rank - each GPU has DIFFERENT input (like test_ar.py)
torch.manual_seed(42 + rank)
results = {}
for bs in batch_sizes:
# Create fixed input for all trials (like test_ar.py)
base_input = torch.randn(bs, hidden_dim, dtype=torch.bfloat16, device=device)
dist.barrier()
if rank == 0:
print(f"\nBatch size {bs:4d}:")
print(f" Testing determinism across {num_trials} trials...")
# Test all-reduce determinism
results_ar = []
latencies_ar = []
for trial in range(num_trials):
# Clone the same input for each trial
inp_ar = base_input.clone()
inp_flat_ar = inp_ar.view(-1)
# Measure latency
torch.cuda.synchronize()
start = time.perf_counter()
dist.all_reduce(inp_flat_ar, op=dist.ReduceOp.SUM)
torch.cuda.synchronize()
end = time.perf_counter()
latencies_ar.append(end - start)
# Store checksum and first values (like test_ar.py)
checksum = inp_flat_ar.sum().item()
first_vals = inp_flat_ar[:5].clone()
results_ar.append((checksum, first_vals))
# Test reduce-scatter + all-gather determinism
results_rs_ag = []
latencies_rs_ag = []
for trial in range(num_trials):
# Clone the same input for each trial
inp_rs_ag = base_input.clone()
inp_flat_rs_ag = inp_rs_ag.view(-1)
# Measure latency
torch.cuda.synchronize()
start = time.perf_counter()
reduce_scatter_then_all_gather(
inp_flat_rs_ag, rank, world_size, custom_ar=None
)
torch.cuda.synchronize()
end = time.perf_counter()
latencies_rs_ag.append(end - start)
# Store checksum and first values (like test_ar.py)
checksum = inp_flat_rs_ag.sum().item()
first_vals = inp_flat_rs_ag[:5].clone()
results_rs_ag.append((checksum, first_vals))
# Note: sglang's optimized all-reduce requires full runtime initialization
# and is not tested in this standalone benchmark
use_sglang_optimized = False
results_optimized_rs_ag = []
latencies_optimized_rs_ag = []
# Test custom all-reduce determinism (if available)
results_custom_ar = []
latencies_custom_ar = []
if custom_ar is not None:
for trial in range(num_trials):
# Clone the same input for each trial
inp_custom = base_input.clone()
inp_flat_custom = inp_custom.view(-1)
# Measure latency
torch.cuda.synchronize()
start = time.perf_counter()
reduce_scatter_then_all_gather(
inp_flat_custom, rank, world_size, custom_ar=custom_ar
)
torch.cuda.synchronize()
end = time.perf_counter()
latencies_custom_ar.append(end - start)
# Store checksum and first values (like test_ar.py)
checksum = inp_flat_custom.sum().item()
first_vals = inp_flat_custom[:5].clone()
results_custom_ar.append((checksum, first_vals))
# Test deterministic kernel (if available)
results_deterministic_kernel = []
latencies_deterministic_kernel = []
deterministic_kernel_available = False
if custom_ar is not None and hasattr(custom_ar, "deterministic_all_reduce"):
# Check if input size fits in buffer
input_size_bytes = base_input.numel() * base_input.element_size()
if input_size_bytes > custom_ar.max_size:
if rank == 0:
print(
f" Deterministic kernel skipped: input size ({input_size_bytes/(1024*1024):.1f} MB) > buffer size ({custom_ar.max_size/(1024*1024):.1f} MB)"
)
deterministic_kernel_available = False
else:
try:
deterministic_kernel_available = True
for trial in range(num_trials):
# Clone the same input for each trial
inp_kernel = base_input.clone()
# Measure latency
torch.cuda.synchronize()
start = time.perf_counter()
result_kernel = custom_ar.deterministic_all_reduce(
inp_kernel, registered=False
)
torch.cuda.synchronize()
end = time.perf_counter()
latencies_deterministic_kernel.append(end - start)
# Store checksum and first values
result_flat_kernel = result_kernel.view(-1)
checksum = result_flat_kernel.sum().item()
first_vals = result_flat_kernel[:5].clone()
results_deterministic_kernel.append((checksum, first_vals))
except Exception as e:
if rank == 0:
print(
f" Deterministic kernel test failed for batch size {bs}: {e}"
)
deterministic_kernel_available = False
dist.barrier()
if rank == 0:
# Check determinism for all-reduce
ar_deterministic = True
ar_ref_sum, ar_ref_vals = results_ar[0]
ar_variance = []
for i, (s, vals) in enumerate(results_ar[1:], 1):
if abs(ar_ref_sum - s) > 1e-3 or not torch.allclose(
ar_ref_vals, vals, rtol=1e-3
):
ar_deterministic = False
ar_variance.append(abs(ar_ref_sum - s))
# Check determinism for reduce-scatter + all-gather
rs_ag_deterministic = True
rs_ag_ref_sum, rs_ag_ref_vals = results_rs_ag[0]
rs_ag_variance = []
for i, (s, vals) in enumerate(results_rs_ag[1:], 1):
if abs(rs_ag_ref_sum - s) > 1e-3 or not torch.allclose(
rs_ag_ref_vals, vals, rtol=1e-3
):
rs_ag_deterministic = False
rs_ag_variance.append(abs(rs_ag_ref_sum - s))
# Check determinism for optimized RS+AG (if available)
optimized_rs_ag_deterministic = None
optimized_rs_ag_max_variance = None
lat_optimized_rs_ag_median = None
if use_sglang_optimized and results_optimized_rs_ag:
optimized_rs_ag_deterministic = True
opt_rs_ag_ref_sum, opt_rs_ag_ref_vals = results_optimized_rs_ag[0]
opt_rs_ag_variance = []
for i, (s, vals) in enumerate(results_optimized_rs_ag[1:], 1):
if abs(opt_rs_ag_ref_sum - s) > 1e-3 or not torch.allclose(
opt_rs_ag_ref_vals, vals, rtol=1e-3
):
optimized_rs_ag_deterministic = False
opt_rs_ag_variance.append(abs(opt_rs_ag_ref_sum - s))
optimized_rs_ag_max_variance = (
max(opt_rs_ag_variance) if opt_rs_ag_variance else 0.0
)
lat_optimized_rs_ag_median = statistics.median(
latencies_optimized_rs_ag
)
# Check determinism for custom all-reduce (if available)
custom_ar_deterministic = None
custom_ar_max_variance = None
lat_custom_ar_median = None
if custom_ar is not None and results_custom_ar:
custom_ar_deterministic = True
custom_ar_ref_sum, custom_ar_ref_vals = results_custom_ar[0]
custom_ar_variance = []
for i, (s, vals) in enumerate(results_custom_ar[1:], 1):
if abs(custom_ar_ref_sum - s) > 1e-3 or not torch.allclose(
custom_ar_ref_vals, vals, rtol=1e-3
):
custom_ar_deterministic = False
custom_ar_variance.append(abs(custom_ar_ref_sum - s))
custom_ar_max_variance = (
max(custom_ar_variance) if custom_ar_variance else 0.0
)
lat_custom_ar_median = statistics.median(latencies_custom_ar)
# Check determinism for deterministic kernel (if available)
deterministic_kernel_deterministic = None
deterministic_kernel_max_variance = None
lat_deterministic_kernel_median = None
if deterministic_kernel_available and results_deterministic_kernel:
deterministic_kernel_deterministic = True
kernel_ref_sum, kernel_ref_vals = results_deterministic_kernel[0]
kernel_variance = []
for i, (s, vals) in enumerate(results_deterministic_kernel[1:], 1):
if abs(kernel_ref_sum - s) > 1e-3 or not torch.allclose(
kernel_ref_vals, vals, rtol=1e-3
):
deterministic_kernel_deterministic = False
kernel_variance.append(abs(kernel_ref_sum - s))
deterministic_kernel_max_variance = (
max(kernel_variance) if kernel_variance else 0.0
)
lat_deterministic_kernel_median = statistics.median(
latencies_deterministic_kernel
)
# Calculate latency statistics
lat_ar_median = statistics.median(latencies_ar)
lat_rs_ag_median = statistics.median(latencies_rs_ag)
overhead_rs_ag = ((lat_rs_ag_median - lat_ar_median) / lat_ar_median) * 100
# Calculate variance statistics
ar_max_variance = max(ar_variance) if ar_variance else 0.0
rs_ag_max_variance = max(rs_ag_variance) if rs_ag_variance else 0.0
results[bs] = {
"all_reduce": {
"latency_median": lat_ar_median,
"deterministic": ar_deterministic,
"max_variance": ar_max_variance,
},
"rs_ag": {
"latency_median": lat_rs_ag_median,
"deterministic": rs_ag_deterministic,
"max_variance": rs_ag_max_variance,
},
"custom_ar": (
{
"latency_median": lat_custom_ar_median,
"deterministic": custom_ar_deterministic,
"max_variance": custom_ar_max_variance,
}
if custom_ar is not None
else None
),
"deterministic_kernel": (
{
"latency_median": lat_deterministic_kernel_median,
"deterministic": deterministic_kernel_deterministic,
"max_variance": deterministic_kernel_max_variance,
}
if lat_deterministic_kernel_median is not None
else None
),
"optimized_rs_ag": (
{
"latency_median": lat_optimized_rs_ag_median,
"deterministic": optimized_rs_ag_deterministic,
"max_variance": optimized_rs_ag_max_variance,
}
if lat_optimized_rs_ag_median is not None
else None
),
"overhead_rs_ag_pct": overhead_rs_ag,
}
print(
f" All-Reduce: {lat_ar_median*1000:.3f}ms, Deterministic: {ar_deterministic}, Max variance: {ar_max_variance:.6f}"
)
print(
f" RS+All-Gather: {lat_rs_ag_median*1000:.3f}ms, Deterministic: {rs_ag_deterministic}, Max variance: {rs_ag_max_variance:.6f}"
)
if custom_ar is not None and lat_custom_ar_median is not None:
overhead_custom = (
(lat_custom_ar_median - lat_ar_median) / lat_ar_median
) * 100
print(
f" Custom AR: {lat_custom_ar_median*1000:.3f}ms, Deterministic: {custom_ar_deterministic}, Max variance: {custom_ar_max_variance:.6f}, Overhead: {overhead_custom:+.1f}%"
)
if lat_deterministic_kernel_median is not None:
overhead_kernel = (
(lat_deterministic_kernel_median - lat_ar_median) / lat_ar_median
) * 100
speedup_kernel_vs_rs_ag = (
(lat_rs_ag_median - lat_deterministic_kernel_median)
/ lat_rs_ag_median
) * 100
print(
f" Deterministic Kernel: {lat_deterministic_kernel_median*1000:.3f}ms, Deterministic: {deterministic_kernel_deterministic}, Max variance: {deterministic_kernel_max_variance:.6f}, Overhead: {overhead_kernel:+.1f}%, Speedup vs RS+AG: {speedup_kernel_vs_rs_ag:+.1f}%"
)
if lat_optimized_rs_ag_median is not None:
overhead_opt = (
(lat_optimized_rs_ag_median - lat_ar_median) / lat_ar_median
) * 100
speedup_vs_rs_ag = (
(lat_rs_ag_median - lat_optimized_rs_ag_median) / lat_rs_ag_median
) * 100
print(
f" Optimized RS+AG: {lat_optimized_rs_ag_median*1000:.3f}ms, Deterministic: {optimized_rs_ag_deterministic}, Max variance: {optimized_rs_ag_max_variance:.6f}, Overhead: {overhead_opt:+.1f}%, Speedup vs RS+AG: {speedup_vs_rs_ag:+.1f}%"
)
print(f" RS+AG Overhead: {overhead_rs_ag:+.1f}%")
if rank == 0:
results_queue.put(results)
dist.destroy_process_group()
def main():
world_size = 8
available_gpus = torch.cuda.device_count()
print("=" * 80)
print("All-Reduce vs Reduce-Scatter + All-Gather Determinism & Latency Benchmark")
print("=" * 80)
print(f"Available GPUs: {available_gpus}")
print(f"Using world_size: {world_size}")
print(f"Hidden dimension: 16384")
print(f"Tensor dtype: bfloat16")
print(f"Trials per batch size: 10 (testing determinism)")
print(f"Testing batch sizes: [1, 4, 8, 16, 32, 64, 128, 256, 512]")
print("=" * 80)
if available_gpus < world_size:
print(
f"WARNING: Only {available_gpus} GPUs available, using {available_gpus} instead"
)
world_size = available_gpus
if world_size < 2:
print("ERROR: Need at least 2 GPUs for this benchmark")
return
mp.set_start_method("spawn", force=True)
port = get_open_port()
results_queue = mp.Queue()
procs = []
for rank in range(world_size):
p = mp.Process(target=worker, args=(world_size, rank, port, results_queue))
p.start()
procs.append(p)
for p in procs:
p.join()
# Collect results
if not results_queue.empty():
results = results_queue.get()
print("\n" + "=" * 80)
print("SUMMARY")
print("=" * 80)
header = f"{'Batch':<8} {'AR (ms)':<12} {'AR Det':<8} {'RS+AG (ms)':<15} {'RS+AG Det':<10} {'RS+AG Ovh':<12}"
if any(r.get("custom_ar") is not None for r in results.values()):
header += (
f" {'Custom AR (ms)':<18} {'Custom AR Det':<15} {'Custom AR Ovh':<15}"
)
if any(r.get("deterministic_kernel") is not None for r in results.values()):
header += f" {'Det Kernel (ms)':<18} {'Det Kernel Det':<15} {'Det Kernel Ovh':<15} {'Speedup':<10}"
if any(r.get("optimized_rs_ag") is not None for r in results.values()):
header += f" {'Opt RS+AG (ms)':<18} {'Opt RS+AG Det':<15} {'Opt RS+AG Ovh':<15} {'Speedup':<10}"
print(header)
print("-" * 150)
for bs in sorted(results.keys()):
r = results[bs]
ar_det_str = "✓" if r["all_reduce"]["deterministic"] else "✗"
rs_ag_det_str = "✓" if r["rs_ag"]["deterministic"] else "✗"
line = (
f"{bs:<8} {r['all_reduce']['latency_median']*1000:<12.3f} {ar_det_str:<8} "
f"{r['rs_ag']['latency_median']*1000:<15.3f} {rs_ag_det_str:<10} "
f"{r['overhead_rs_ag_pct']:<12.1f}"
)
if r.get("custom_ar") is not None:
custom_ar = r["custom_ar"]
custom_ar_det_str = "✓" if custom_ar["deterministic"] else "✗"
custom_ar_overhead = (
(custom_ar["latency_median"] - r["all_reduce"]["latency_median"])
/ r["all_reduce"]["latency_median"]
) * 100
line += f" {custom_ar['latency_median']*1000:<18.3f} {custom_ar_det_str:<15} {custom_ar_overhead:<15.1f}"
if r.get("deterministic_kernel") is not None:
det_kernel = r["deterministic_kernel"]
det_kernel_det_str = "✓" if det_kernel["deterministic"] else "✗"
det_kernel_overhead = (
(det_kernel["latency_median"] - r["all_reduce"]["latency_median"])
/ r["all_reduce"]["latency_median"]
) * 100
speedup_kernel = (
(r["rs_ag"]["latency_median"] - det_kernel["latency_median"])
/ r["rs_ag"]["latency_median"]
) * 100
line += f" {det_kernel['latency_median']*1000:<18.3f} {det_kernel_det_str:<15} {det_kernel_overhead:<15.1f} {speedup_kernel:<10.1f}"
if r.get("optimized_rs_ag") is not None:
opt_rs_ag = r["optimized_rs_ag"]
opt_rs_ag_det_str = "✓" if opt_rs_ag["deterministic"] else "✗"
opt_rs_ag_overhead = (
(opt_rs_ag["latency_median"] - r["all_reduce"]["latency_median"])
/ r["all_reduce"]["latency_median"]
) * 100
speedup = (
(r["rs_ag"]["latency_median"] - opt_rs_ag["latency_median"])
/ r["rs_ag"]["latency_median"]
) * 100
line += f" {opt_rs_ag['latency_median']*1000:<18.3f} {opt_rs_ag_det_str:<15} {opt_rs_ag_overhead:<15.1f} {speedup:<10.1f}"
print(line)
print("=" * 80)
# Calculate statistics
overheads_rs_ag = [r["overhead_rs_ag_pct"] for r in results.values()]
ar_deterministic_count = sum(
1 for r in results.values() if r["all_reduce"]["deterministic"]
)
rs_ag_deterministic_count = sum(
1 for r in results.values() if r["rs_ag"]["deterministic"]
)
custom_ar_deterministic_count = sum(
1
for r in results.values()
if r.get("custom_ar") and r["custom_ar"]["deterministic"]
)
custom_ar_total_count = sum(
1 for r in results.values() if r.get("custom_ar") is not None
)
deterministic_kernel_deterministic_count = sum(
1
for r in results.values()
if r.get("deterministic_kernel")
and r["deterministic_kernel"]["deterministic"]
)
deterministic_kernel_total_count = sum(
1 for r in results.values() if r.get("deterministic_kernel") is not None
)
print(f"\nDeterminism Summary:")
print(
f" All-Reduce deterministic: {ar_deterministic_count}/{len(results)} batch sizes"
)
print(
f" RS+All-Gather deterministic: {rs_ag_deterministic_count}/{len(results)} batch sizes"
)
if custom_ar_total_count > 0:
print(
f" Custom AR deterministic: {custom_ar_deterministic_count}/{custom_ar_total_count} batch sizes"
)
if deterministic_kernel_total_count > 0:
print(
f" Deterministic Kernel deterministic: {deterministic_kernel_deterministic_count}/{deterministic_kernel_total_count} batch sizes"
)
print(f"\nLatency Overhead Statistics (RS+AG vs All-Reduce):")
avg_overhead = statistics.mean(overheads_rs_ag)
median_overhead = statistics.median(overheads_rs_ag)
min_overhead = min(overheads_rs_ag)
max_overhead = max(overheads_rs_ag)
print(f" Average: {avg_overhead:.1f}%")
print(f" Median: {median_overhead:.1f}%")
print(f" Min: {min_overhead:.1f}%")
print(f" Max: {max_overhead:.1f}%")
if custom_ar_total_count > 0:
overheads_custom = []
for r in results.values():
if r.get("custom_ar") is not None:
overhead = (
(
r["custom_ar"]["latency_median"]
- r["all_reduce"]["latency_median"]
)
/ r["all_reduce"]["latency_median"]
) * 100
overheads_custom.append(overhead)
print(f"\nLatency Overhead Statistics (Custom AR vs All-Reduce):")
print(f" Average: {statistics.mean(overheads_custom):.1f}%")
print(f" Median: {statistics.median(overheads_custom):.1f}%")
print(f" Min: {min(overheads_custom):.1f}%")
print(f" Max: {max(overheads_custom):.1f}%")
if deterministic_kernel_total_count > 0:
overheads_kernel = []
speedups_kernel = []
for r in results.values():
if r.get("deterministic_kernel") is not None:
overhead = (
(
r["deterministic_kernel"]["latency_median"]
- r["all_reduce"]["latency_median"]
)
/ r["all_reduce"]["latency_median"]
) * 100
overheads_kernel.append(overhead)
speedup = (
(
r["rs_ag"]["latency_median"]
- r["deterministic_kernel"]["latency_median"]
)
/ r["rs_ag"]["latency_median"]
) * 100
speedups_kernel.append(speedup)
print(
f"\nLatency Overhead Statistics (Deterministic Kernel vs All-Reduce):"
)
print(f" Average: {statistics.mean(overheads_kernel):.1f}%")
print(f" Median: {statistics.median(overheads_kernel):.1f}%")
print(f" Min: {min(overheads_kernel):.1f}%")
print(f" Max: {max(overheads_kernel):.1f}%")
print(f"\nSpeedup Statistics (Deterministic Kernel vs RS+AG):")
print(f" Average: {statistics.mean(speedups_kernel):.1f}%")
print(f" Median: {statistics.median(speedups_kernel):.1f}%")
print(f" Min: {min(speedups_kernel):.1f}%")
print(f" Max: {max(speedups_kernel):.1f}%")
# Show variance for non-deterministic cases
print(f"\nVariance Analysis (non-deterministic cases):")
for bs in sorted(results.keys()):
r = results[bs]
if not r["all_reduce"]["deterministic"]:
print(
f" Batch {bs}: All-Reduce max variance: {r['all_reduce']['max_variance']:.6f}"
)
if not r["rs_ag"]["deterministic"]:
print(
f" Batch {bs}: RS+All-Gather max variance: {r['rs_ag']['max_variance']:.6f}"
)
if r.get("custom_ar") is not None and not r["custom_ar"]["deterministic"]:
print(
f" Batch {bs}: Custom AR max variance: {r['custom_ar']['max_variance']:.6f}"
)
if (
r.get("deterministic_kernel") is not None
and not r["deterministic_kernel"]["deterministic"]
):
print(
f" Batch {bs}: Deterministic Kernel max variance: {r['deterministic_kernel']['max_variance']:.6f}"
)
if __name__ == "__main__":
main()