| # Copy from deepseek-ai/DeepEP/tests/test_utils.py | |
| import os | |
| import sys | |
| from typing import Optional | |
| import numpy as np | |
| import torch | |
| import torch.distributed as dist | |
| def init_dist(local_rank: int, num_local_ranks: int): | |
| # NOTES: you may rewrite this function with your own cluster settings | |
| ip = os.getenv("MASTER_ADDR", "127.0.0.1") | |
| port = int(os.getenv("MASTER_PORT", "8361")) | |
| num_nodes = int(os.getenv("WORLD_SIZE", 1)) | |
| node_rank = int(os.getenv("RANK", 0)) | |
| assert (num_local_ranks < 8 and num_nodes == 1) or num_local_ranks == 8 | |
| dist.init_process_group( | |
| backend="nccl", | |
| init_method=f"tcp://{ip}:{port}", | |
| world_size=num_nodes * num_local_ranks, | |
| rank=node_rank * num_local_ranks + local_rank, | |
| ) | |
| torch.set_default_dtype(torch.bfloat16) | |
| torch.set_default_device("cuda") | |
| torch.cuda.set_device(local_rank) | |
| return ( | |
| dist.get_rank(), | |
| dist.get_world_size(), | |
| dist.new_group(list(range(num_local_ranks * num_nodes))), | |
| ) | |
| def calc_diff(x: torch.Tensor, y: torch.Tensor): | |
| x, y = x.double() + 1, y.double() + 1 | |
| denominator = (x * x + y * y).sum() | |
| sim = 2 * (x * y).sum() / denominator | |
| return (1 - sim).item() | |
| def per_token_cast_to_fp8(x: torch.Tensor): | |
| assert x.dim() == 2 and x.size(1) % 128 == 0 | |
| m, n = x.shape | |
| x_view = x.view(m, -1, 128) | |
| x_amax = x_view.abs().float().amax(dim=2).view(m, -1).clamp(1e-4) | |
| return (x_view * (448.0 / x_amax.unsqueeze(2))).to(torch.float8_e4m3fn).view( | |
| m, n | |
| ), (x_amax / 448.0).view(m, -1) | |
| def per_token_cast_back(x_fp8: torch.Tensor, x_scales: torch.Tensor): | |
| x_fp32 = x_fp8.to(torch.float32).view(x_fp8.size(0), -1, 128) | |
| x_scales = x_scales.view(x_fp8.size(0), -1, 1) | |
| return (x_fp32 * x_scales).view(x_fp8.shape).to(torch.bfloat16) | |
| def inplace_unique(x: torch.Tensor, num_slots: int): | |
| assert x.dim() == 2 | |
| mask = x < 0 | |
| x_padded = x.masked_fill(mask, num_slots) | |
| bin_count = torch.zeros((x.size(0), num_slots + 1), dtype=x.dtype, device=x.device) | |
| bin_count.scatter_add_(1, x_padded, torch.ones_like(x_padded)) | |
| bin_count = bin_count[:, :num_slots] | |
| sorted_bin_count, sorted_bin_idx = torch.sort(bin_count, dim=-1, descending=True) | |
| sorted_bin_idx.masked_fill_(sorted_bin_count == 0, -1) | |
| sorted_bin_idx = torch.sort(sorted_bin_idx, descending=True, dim=-1).values | |
| x[:, :].fill_(-1) | |
| valid_len = min(num_slots, x.size(1)) | |
| x[:, :valid_len] = sorted_bin_idx[:, :valid_len] | |
| def create_grouped_scores( | |
| scores: torch.Tensor, group_idx: torch.Tensor, num_groups: int | |
| ): | |
| num_tokens, num_experts = scores.shape | |
| scores = scores.view(num_tokens, num_groups, -1) | |
| mask = torch.zeros((num_tokens, num_groups), dtype=torch.bool, device=scores.device) | |
| mask = mask.scatter_(1, group_idx, True).unsqueeze(-1).expand_as(scores) | |
| return (scores * mask).view(num_tokens, num_experts) | |
| def bench(fn, num_warmups: int = 20, num_tests: int = 30, post_fn=None): | |
| # Flush L2 cache with 256 MB data | |
| torch.cuda.synchronize() | |
| cache = torch.empty(int(256e6 // 4), dtype=torch.int, device="cuda") | |
| # Warmup | |
| for _ in range(num_warmups): | |
| fn() | |
| # Flush L2 | |
| cache.zero_() | |
| # Testing | |
| start_events = [torch.cuda.Event(enable_timing=True) for _ in range(num_tests)] | |
| end_events = [torch.cuda.Event(enable_timing=True) for _ in range(num_tests)] | |
| for i in range(num_tests): | |
| # Record | |
| start_events[i].record() | |
| fn() | |
| end_events[i].record() | |
| if post_fn is not None: | |
| post_fn() | |
| torch.cuda.synchronize() | |
| times = np.array( | |
| [s.elapsed_time(e) / 1e3 for s, e in zip(start_events, end_events)] | |
| )[1:] | |
| return np.average(times), np.min(times), np.max(times) | |
| class empty_suppress: | |
| def __enter__(self): | |
| return self | |
| def __exit__(self, *_): | |
| pass | |
| class suppress_stdout_stderr: | |
| def __enter__(self): | |
| self.outnull_file = open(os.devnull, "w") | |
| self.errnull_file = open(os.devnull, "w") | |
| self.old_stdout_fileno_undup = sys.stdout.fileno() | |
| self.old_stderr_fileno_undup = sys.stderr.fileno() | |
| self.old_stdout_fileno = os.dup(sys.stdout.fileno()) | |
| self.old_stderr_fileno = os.dup(sys.stderr.fileno()) | |
| self.old_stdout = sys.stdout | |
| self.old_stderr = sys.stderr | |
| os.dup2(self.outnull_file.fileno(), self.old_stdout_fileno_undup) | |
| os.dup2(self.errnull_file.fileno(), self.old_stderr_fileno_undup) | |
| sys.stdout = self.outnull_file | |
| sys.stderr = self.errnull_file | |
| return self | |
| def __exit__(self, *_): | |
| sys.stdout = self.old_stdout | |
| sys.stderr = self.old_stderr | |
| os.dup2(self.old_stdout_fileno, self.old_stdout_fileno_undup) | |
| os.dup2(self.old_stderr_fileno, self.old_stderr_fileno_undup) | |
| os.close(self.old_stdout_fileno) | |
| os.close(self.old_stderr_fileno) | |
| self.outnull_file.close() | |
| self.errnull_file.close() | |
| def bench_kineto( | |
| fn, | |
| kernel_names, | |
| num_tests: int = 30, | |
| suppress_kineto_output: bool = False, | |
| trace_path: Optional[str] = None, | |
| barrier_comm_profiling: bool = False, | |
| ): | |
| # Profile | |
| suppress = suppress_stdout_stderr if suppress_kineto_output else empty_suppress | |
| with suppress(): | |
| schedule = torch.profiler.schedule(wait=0, warmup=1, active=1, repeat=1) | |
| with torch.profiler.profile( | |
| activities=[torch.profiler.ProfilerActivity.CUDA], schedule=schedule | |
| ) as prof: | |
| for i in range(2): | |
| # NOTES: use a large kernel and a barrier to eliminate the unbalanced CPU launch overhead | |
| if barrier_comm_profiling: | |
| lhs = torch.randn((8192, 8192), dtype=torch.float, device="cuda") | |
| rhs = torch.randn((8192, 8192), dtype=torch.float, device="cuda") | |
| lhs @ rhs | |
| dist.all_reduce(torch.ones(1, dtype=torch.float, device="cuda")) | |
| for _ in range(num_tests): | |
| fn() | |
| prof.step() | |
| # Parse the profiling table | |
| assert isinstance(kernel_names, str) or isinstance(kernel_names, tuple) | |
| is_tupled = isinstance(kernel_names, tuple) | |
| prof_lines = ( | |
| prof.key_averages() | |
| .table(sort_by="cuda_time_total", max_name_column_width=100) | |
| .split("\n") | |
| ) | |
| kernel_names = (kernel_names,) if isinstance(kernel_names, str) else kernel_names | |
| assert all([isinstance(name, str) for name in kernel_names]) | |
| for name in kernel_names: | |
| assert ( | |
| sum([name in line for line in prof_lines]) == 1 | |
| ), f"Errors of the kernel {name} in the profiling table" | |
| # Save chrome traces | |
| if trace_path is not None: | |
| prof.export_chrome_trace(trace_path) | |
| # Return average kernel times | |
| units = {"ms": 1e3, "us": 1e6} | |
| kernel_times = [] | |
| for name in kernel_names: | |
| for line in prof_lines: | |
| if name in line: | |
| time_str = line.split()[-2] | |
| for unit, scale in units.items(): | |
| if unit in time_str: | |
| kernel_times.append(float(time_str.replace(unit, "")) / scale) | |
| break | |
| break | |
| return tuple(kernel_times) if is_tupled else kernel_times[0] | |
| def hash_tensor(t: torch.Tensor): | |
| return t.view(torch.int64).sum().item() | |
Xet Storage Details
- Size:
- 7.45 kB
- Xet hash:
- c449ca3f699952659444e42583ce00ba6ed1ddbab0312adf2de91559bfdebbac
·
Xet efficiently stores files, intelligently splitting them into unique chunks and accelerating uploads and downloads. More info.