File size: 12,166 Bytes
f4cade0 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 270 271 272 273 274 275 276 277 278 279 280 281 282 283 284 285 286 287 288 289 290 291 292 293 294 295 296 297 298 299 300 301 302 303 304 305 306 307 308 |
import random
from typing import Any
import torch
from torch._C._distributed_c10d import (
_resolve_process_group,
FakeWork,
ProcessGroup,
Work,
)
from torch.utils._pytree import tree_map_only
torch.distributed.batch_isend_irecv
c10d = torch.ops.c10d
_c10d_functional = torch.ops._c10d_functional
_c10d_functional_autograd = torch.ops._c10d_functional_autograd
_dtensor = torch.ops._dtensor
used_ids: set[int] = set()
def generate_unique_id() -> int:
while True:
new_id = random.randint(1, 10**9)
if new_id not in used_ids:
used_ids.add(new_id)
return new_id
# Function to create and return FakeWork object
def create_fakework(args, return_first_arg=True): # type: ignore[no-untyped-def]
work = FakeWork()
work.seq_id = generate_unique_id()
fakework_script_obj = work.boxed()
return (args[0], fakework_script_obj) if return_first_arg else fakework_script_obj
# Dictionary mapping collective operations to their meta functions
# All 20 ops from torch.csrc.distributed.c10d.Ops.cpp are included
# _DEPRECATED_META_FUNCTIONS = {
# "allreduce_coalesced_": lambda *args: create_fakework(args, return_first_arg=False),
# "allgather_coalesced_": lambda *args: create_fakework(args, return_first_arg=False),
# "allgather_into_tensor_coalesced_": lambda *args: create_fakework(args, return_first_arg=False),
# "reduce_scatter_tensor_coalesced_": lambda *args: create_fakework(args, return_first_arg=False),
# }
_META_FUNCTIONS = {
"broadcast_": lambda *args: create_fakework(args),
"allreduce_": lambda *args: create_fakework(args),
"allgather_": lambda *args: create_fakework(args),
"_allgather_base_": lambda *args: create_fakework(args),
"reduce_scatter_": lambda *args: create_fakework(args),
"_reduce_scatter_base_": lambda *args: create_fakework(args),
"reduce_": lambda *args: create_fakework(args, return_first_arg=False),
"gather_": lambda *args: create_fakework(args, return_first_arg=False),
"scatter_": lambda *args: create_fakework(args),
"alltoall_": lambda *args: create_fakework(args),
"alltoall_base_": lambda *args: create_fakework(args, return_first_arg=False),
"barrier": lambda *args: create_fakework(args, return_first_arg=False),
"monitored_barrier_": lambda *args: None,
"send": lambda *args: create_fakework(args, return_first_arg=False),
"recv_": lambda *args: create_fakework(args, return_first_arg=False),
"recv_any_source_": lambda *args: create_fakework(args, return_first_arg=False),
}
if not torch._running_with_deploy():
lib_impl = torch.library.Library("c10d", "IMPL") # noqa: TOR901
for op, meta_func in _META_FUNCTIONS.items():
lib_impl.impl(op, meta_func, "Meta")
# List of collective operation functions including functional collectives
# Note: The following collectives might be deprecated soon hence not adding them
# depcreated_non_functional_collectives = [
# c10d.allreduce_coalesced_.default,
# c10d.reduce_scatter_tensor_coalesced_.default,
# c10d.allgather_into_tensor_coalesced_.default,
# c10d.allgather_coalesced_.default,
# ]
non_functional_collectives: set[torch._ops.OpOverload] = {
c10d.broadcast_.default,
c10d.allreduce_.default,
c10d.reduce_.default,
c10d.send.default,
c10d.recv_.default,
c10d.recv_any_source_.default,
c10d.allgather_.default,
c10d.reduce_scatter_.default,
c10d._reduce_scatter_base_.default,
c10d._allgather_base_.default,
c10d.gather_.default,
c10d.scatter_.default,
c10d.alltoall_.default,
c10d.alltoall_base_.default,
c10d.barrier.default,
c10d.monitored_barrier_.default,
}
functional_collectives: set[torch._ops.OpOverload] = {
_c10d_functional.broadcast.default,
_c10d_functional.all_reduce.default,
_c10d_functional.all_gather_into_tensor.default,
_c10d_functional.reduce_scatter_tensor.default,
_c10d_functional.all_to_all_single.default,
_c10d_functional_autograd.all_to_all_single.default,
_c10d_functional.wait_tensor.default,
_c10d_functional.all_reduce_.default,
_c10d_functional.all_reduce_coalesced.default,
_c10d_functional.all_reduce_coalesced_.default,
_c10d_functional.all_gather_into_tensor_out.default,
_c10d_functional.all_gather_into_tensor_coalesced.default,
_c10d_functional_autograd.all_gather_into_tensor.default,
_c10d_functional.reduce_scatter_tensor_coalesced.default,
_c10d_functional_autograd.reduce_scatter_tensor.default,
_c10d_functional.broadcast_.default,
_dtensor.shard_dim_alltoall.default,
}
sync_ops: set[torch._ops.OpOverload] = {
c10d.barrier.default,
c10d.monitored_barrier_.default,
_c10d_functional.wait_tensor.default,
}
collective_ops = set.union(functional_collectives, non_functional_collectives)
class CollectiveOp:
# Static sets for performance optimization
PG_ARG_1 = {
c10d.broadcast_.default,
c10d.allreduce_.default,
c10d.reduce_.default,
c10d.send.default,
c10d.recv_.default,
c10d.recv_any_source_.default,
c10d.barrier.default,
# c10d.allreduce_coalesced_.default
}
PG_ARG_2 = {
c10d.allgather_.default,
c10d._allgather_base_.default,
c10d.reduce_scatter_.default,
c10d._reduce_scatter_base_.default,
c10d.gather_.default,
c10d.scatter_.default,
c10d.alltoall_.default,
c10d.alltoall_base_.default,
# c10d.allgather_coalesced_.default,
# c10d.allgather_into_tensor_coalesced_.default
# c10d.reduce_scatter_tensor_coalesced_.default
}
PG_ARG_3 = {
_c10d_functional.broadcast.default,
_c10d_functional.broadcast_.default,
_c10d_functional.all_reduce.default,
_c10d_functional.all_reduce_.default,
_c10d_functional.all_reduce_coalesced.default,
_c10d_functional.all_reduce_coalesced_.default,
_c10d_functional.all_gather_into_tensor.default,
_c10d_functional.all_gather_into_tensor_out.default,
_c10d_functional_autograd.all_gather_into_tensor.default,
_c10d_functional.all_gather_into_tensor_coalesced.default,
}
PG_ARG_4 = {
_c10d_functional.reduce_scatter_tensor.default,
_c10d_functional.reduce_scatter_tensor_coalesced.default,
_c10d_functional_autograd.reduce_scatter_tensor.default,
_c10d_functional.all_to_all_single.default,
_c10d_functional_autograd.all_to_all_single.default,
_dtensor.shard_dim_alltoall.default,
}
WK_ARG_1 = {
c10d.broadcast_.default,
c10d.allreduce_.default,
c10d.allgather_.default,
c10d.reduce_scatter_.default,
c10d._reduce_scatter_base_.default,
c10d._allgather_base_.default,
c10d.scatter_.default,
c10d.alltoall_.default,
}
WK = {
c10d.send.default,
c10d.recv_.default,
c10d.recv_any_source_.default,
c10d.reduce_.default,
c10d.gather_.default,
c10d.alltoall_base_.default,
c10d.barrier.default,
}
COMM_TENSOR_ARG_0 = {
c10d.allreduce_.default,
c10d.send.default,
c10d.recv_.default,
c10d.recv_any_source_.default,
c10d.allgather_.default,
c10d.gather_.default,
c10d.reduce_.default,
c10d.broadcast_.default,
_c10d_functional.all_reduce_coalesced.default,
_c10d_functional.all_reduce_coalesced_.default,
# c10d.allreduce_coalesced_.default
# c10d.allgather_coalesced_.default
# c10d.allgather_into_tensor_coalesced_.default,
}
COMM_TENSOR_ARG_1 = {
c10d.reduce_scatter_.default,
c10d.scatter_.default,
# c10d.reduce_scatter_tensor_coalesced_.default,
}
COMM_TENSOR_ARG_RES = {
_c10d_functional.all_gather_into_tensor.default,
_c10d_functional_autograd.all_gather_into_tensor.default,
}
COMM_TENSOR_SINGLE_UNTYPED_STORAGE = {
c10d._allgather_base_.default,
_c10d_functional.broadcast.default,
_c10d_functional.broadcast_.default,
_c10d_functional.all_reduce.default,
_c10d_functional.all_reduce_.default,
_c10d_functional.reduce_scatter_tensor.default,
_c10d_functional_autograd.reduce_scatter_tensor.default,
}
COMM_TENSOR_ARG_0_AND_RES = {
_c10d_functional.all_to_all_single.default,
_c10d_functional_autograd.all_to_all_single.default,
_dtensor.shard_dim_alltoall.default,
}
COMM_TENSOR_RES_SUM = {
_c10d_functional.all_gather_into_tensor_coalesced.default,
_c10d_functional.reduce_scatter_tensor_coalesced.default,
}
@staticmethod
def sum_tensors(arg: Any) -> int:
"""Calculate total memory consumed by the tensors in the argument."""
total_memory = 0
def sum_bytes(t: torch.Tensor) -> None:
nonlocal total_memory
total_memory += t.untyped_storage().nbytes()
tree_map_only(torch.Tensor, sum_bytes, arg)
return total_memory
@staticmethod
def get_process_group(func, args) -> ProcessGroup: # type: ignore[no-untyped-def]
"""Retrieve the process group for collective operations, except `wait_tensor`."""
if func in CollectiveOp.PG_ARG_1:
return ProcessGroup.unbox(args[1])
if func in CollectiveOp.PG_ARG_2:
return ProcessGroup.unbox(args[2])
if func in CollectiveOp.PG_ARG_3:
return _resolve_process_group(args[2])
if func in CollectiveOp.PG_ARG_4:
return _resolve_process_group(args[3])
raise TypeError(f"Func {func} not found in {collective_ops}")
@staticmethod
def get_comm_tensor_size(func, res, args, kwargs) -> int: # type: ignore[no-untyped-def]
"""Compute the communication tensor size, except for `wait_tensor`, `barrier`, and `monitored_barrier`."""
if func in CollectiveOp.COMM_TENSOR_ARG_0:
return CollectiveOp.sum_tensors(args[0])
if func in CollectiveOp.COMM_TENSOR_ARG_1:
return CollectiveOp.sum_tensors(args[1])
if func in CollectiveOp.COMM_TENSOR_ARG_RES:
return res.untyped_storage().nbytes()
if func in CollectiveOp.COMM_TENSOR_SINGLE_UNTYPED_STORAGE:
return args[0].untyped_storage().nbytes()
if func == c10d._reduce_scatter_base_.default:
return args[1].untyped_storage().nbytes()
if func == c10d.alltoall_.default:
# TODO(@sanketpurandare) - Confirm size computation
return max(
CollectiveOp.sum_tensors(args[0]), CollectiveOp.sum_tensors(args[1])
)
if func == c10d.alltoall_base_.default:
# TODO(@sanketpurandare) - Confirm size computation
return max(
args[0].untyped_storage().nbytes(), args[1].untyped_storage().nbytes()
)
if func == _c10d_functional.all_gather_into_tensor_out.default:
return args[-1].untyped_storage().nbytes()
if func in CollectiveOp.COMM_TENSOR_RES_SUM:
return CollectiveOp.sum_tensors(res)
if func in CollectiveOp.COMM_TENSOR_ARG_0_AND_RES:
# TODO(@sanketpurandare) - Confirm size computation
return args[0].untyped_storage().nbytes() + res.untyped_storage().nbytes()
raise TypeError(f"Unknown function: {func} in {collective_ops}")
@staticmethod
def get_work(func, res) -> Work: # type: ignore[no-untyped-def]
if func in CollectiveOp.WK:
return FakeWork.unbox(res)
elif func in CollectiveOp.WK_ARG_1:
return FakeWork.unbox(res[1])
raise TypeError(f"Func {func} not found in {collective_ops}")
|