delete state in split_func
Browse files
torch-ext/optimizer/matmul_transpose_triton.py
CHANGED
|
@@ -1,17 +1,17 @@
|
|
| 1 |
# MIT License
|
| 2 |
-
#
|
| 3 |
# Copyright (c) 2025 Tianyang Lin
|
| 4 |
-
#
|
| 5 |
# Permission is hereby granted, free of charge, to any person obtaining a copy
|
| 6 |
# of this software and associated documentation files (the "Software"), to deal
|
| 7 |
# in the Software without restriction, including without limitation the rights
|
| 8 |
# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
|
| 9 |
# copies of the Software, and to permit persons to whom the Software is
|
| 10 |
# furnished to do so, subject to the following conditions:
|
| 11 |
-
#
|
| 12 |
# The above copyright notice and this permission notice shall be included in all
|
| 13 |
# copies or substantial portions of the Software.
|
| 14 |
-
#
|
| 15 |
# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
|
| 16 |
# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
|
| 17 |
# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
|
|
|
|
| 1 |
# MIT License
|
| 2 |
+
#
|
| 3 |
# Copyright (c) 2025 Tianyang Lin
|
| 4 |
+
#
|
| 5 |
# Permission is hereby granted, free of charge, to any person obtaining a copy
|
| 6 |
# of this software and associated documentation files (the "Software"), to deal
|
| 7 |
# in the Software without restriction, including without limitation the rights
|
| 8 |
# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
|
| 9 |
# copies of the Software, and to permit persons to whom the Software is
|
| 10 |
# furnished to do so, subject to the following conditions:
|
| 11 |
+
#
|
| 12 |
# The above copyright notice and this permission notice shall be included in all
|
| 13 |
# copies or substantial portions of the Software.
|
| 14 |
+
#
|
| 15 |
# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
|
| 16 |
# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
|
| 17 |
# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
|
torch-ext/optimizer/muon.py
CHANGED
|
@@ -121,7 +121,7 @@ def _all2all_gather(params, param_to_state, rank, comm_stream, none_grad,
|
|
| 121 |
state = param_to_state[id(p)]
|
| 122 |
dst = state.worker_rank
|
| 123 |
assert dst < num_ranks
|
| 124 |
-
shard_elems = split_elems_for_src(p,
|
| 125 |
g = p.grad
|
| 126 |
g = g.to_local().to(COMM_DTYPE).contiguous().view(-1)
|
| 127 |
assert g.numel() == shard_elems
|
|
@@ -145,7 +145,7 @@ def _all2all_gather(params, param_to_state, rank, comm_stream, none_grad,
|
|
| 145 |
for p in owned_params:
|
| 146 |
state = param_to_state[id(p)]
|
| 147 |
assert state.worker_rank == rank
|
| 148 |
-
total += split_elems_for_src(p,
|
| 149 |
recv_counts[src] = total
|
| 150 |
|
| 151 |
recv_total = sum(recv_counts)
|
|
@@ -186,7 +186,7 @@ def _all2all_gather(params, param_to_state, rank, comm_stream, none_grad,
|
|
| 186 |
for p in owned_params:
|
| 187 |
state = param_to_state[id(p)]
|
| 188 |
assert state.worker_rank == rank
|
| 189 |
-
n = split_elems_for_src(p,
|
| 190 |
assert n > 0
|
| 191 |
|
| 192 |
sg = recv_buf.narrow(0, off + inner_off, n)
|
|
@@ -278,7 +278,7 @@ def _all2all_scatter(params, param_to_state, rank, comm_stream, alloc_event):
|
|
| 278 |
|
| 279 |
offset = 0
|
| 280 |
for dst in range(num_ranks):
|
| 281 |
-
n = split_elems_for_src(p,
|
| 282 |
assert n > 0
|
| 283 |
|
| 284 |
su = u_full.narrow(0, offset, n)
|
|
@@ -304,7 +304,7 @@ def _all2all_scatter(params, param_to_state, rank, comm_stream, alloc_event):
|
|
| 304 |
state = param_to_state[id(p)]
|
| 305 |
if state.worker_rank != src:
|
| 306 |
continue
|
| 307 |
-
total += split_elems_for_src(p,
|
| 308 |
recv_counts[src] = total
|
| 309 |
|
| 310 |
recv_total = sum(recv_counts)
|
|
@@ -348,7 +348,7 @@ def _all2all_scatter(params, param_to_state, rank, comm_stream, alloc_event):
|
|
| 348 |
state = param_to_state[id(p)]
|
| 349 |
if state.worker_rank != src:
|
| 350 |
continue
|
| 351 |
-
n = split_elems_for_src(p,
|
| 352 |
assert n > 0
|
| 353 |
|
| 354 |
flat_local = recv_buf.narrow(0, off + inner_off,
|
|
|
|
| 121 |
state = param_to_state[id(p)]
|
| 122 |
dst = state.worker_rank
|
| 123 |
assert dst < num_ranks
|
| 124 |
+
shard_elems = split_elems_for_src(p, rank, num_ranks)
|
| 125 |
g = p.grad
|
| 126 |
g = g.to_local().to(COMM_DTYPE).contiguous().view(-1)
|
| 127 |
assert g.numel() == shard_elems
|
|
|
|
| 145 |
for p in owned_params:
|
| 146 |
state = param_to_state[id(p)]
|
| 147 |
assert state.worker_rank == rank
|
| 148 |
+
total += split_elems_for_src(p, src, num_ranks)
|
| 149 |
recv_counts[src] = total
|
| 150 |
|
| 151 |
recv_total = sum(recv_counts)
|
|
|
|
| 186 |
for p in owned_params:
|
| 187 |
state = param_to_state[id(p)]
|
| 188 |
assert state.worker_rank == rank
|
| 189 |
+
n = split_elems_for_src(p, src, num_ranks)
|
| 190 |
assert n > 0
|
| 191 |
|
| 192 |
sg = recv_buf.narrow(0, off + inner_off, n)
|
|
|
|
| 278 |
|
| 279 |
offset = 0
|
| 280 |
for dst in range(num_ranks):
|
| 281 |
+
n = split_elems_for_src(p, dst, num_ranks)
|
| 282 |
assert n > 0
|
| 283 |
|
| 284 |
su = u_full.narrow(0, offset, n)
|
|
|
|
| 304 |
state = param_to_state[id(p)]
|
| 305 |
if state.worker_rank != src:
|
| 306 |
continue
|
| 307 |
+
total += split_elems_for_src(p, rank, num_ranks)
|
| 308 |
recv_counts[src] = total
|
| 309 |
|
| 310 |
recv_total = sum(recv_counts)
|
|
|
|
| 348 |
state = param_to_state[id(p)]
|
| 349 |
if state.worker_rank != src:
|
| 350 |
continue
|
| 351 |
+
n = split_elems_for_src(p, rank, num_ranks)
|
| 352 |
assert n > 0
|
| 353 |
|
| 354 |
flat_local = recv_buf.narrow(0, off + inner_off,
|