Kernels
ca1207 commited on
Commit
15336dc
·
1 Parent(s): 678578a

delete state in split_func

Browse files
torch-ext/optimizer/matmul_transpose_triton.py CHANGED
@@ -1,17 +1,17 @@
1
  # MIT License
2
- #
3
  # Copyright (c) 2025 Tianyang Lin
4
- #
5
  # Permission is hereby granted, free of charge, to any person obtaining a copy
6
  # of this software and associated documentation files (the "Software"), to deal
7
  # in the Software without restriction, including without limitation the rights
8
  # to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9
  # copies of the Software, and to permit persons to whom the Software is
10
  # furnished to do so, subject to the following conditions:
11
- #
12
  # The above copyright notice and this permission notice shall be included in all
13
  # copies or substantial portions of the Software.
14
- #
15
  # THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16
  # IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17
  # FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
 
1
  # MIT License
2
+ #
3
  # Copyright (c) 2025 Tianyang Lin
4
+ #
5
  # Permission is hereby granted, free of charge, to any person obtaining a copy
6
  # of this software and associated documentation files (the "Software"), to deal
7
  # in the Software without restriction, including without limitation the rights
8
  # to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9
  # copies of the Software, and to permit persons to whom the Software is
10
  # furnished to do so, subject to the following conditions:
11
+ #
12
  # The above copyright notice and this permission notice shall be included in all
13
  # copies or substantial portions of the Software.
14
+ #
15
  # THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16
  # IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17
  # FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
torch-ext/optimizer/muon.py CHANGED
@@ -121,7 +121,7 @@ def _all2all_gather(params, param_to_state, rank, comm_stream, none_grad,
121
  state = param_to_state[id(p)]
122
  dst = state.worker_rank
123
  assert dst < num_ranks
124
- shard_elems = split_elems_for_src(p, state, rank, num_ranks)
125
  g = p.grad
126
  g = g.to_local().to(COMM_DTYPE).contiguous().view(-1)
127
  assert g.numel() == shard_elems
@@ -145,7 +145,7 @@ def _all2all_gather(params, param_to_state, rank, comm_stream, none_grad,
145
  for p in owned_params:
146
  state = param_to_state[id(p)]
147
  assert state.worker_rank == rank
148
- total += split_elems_for_src(p, state, src, num_ranks)
149
  recv_counts[src] = total
150
 
151
  recv_total = sum(recv_counts)
@@ -186,7 +186,7 @@ def _all2all_gather(params, param_to_state, rank, comm_stream, none_grad,
186
  for p in owned_params:
187
  state = param_to_state[id(p)]
188
  assert state.worker_rank == rank
189
- n = split_elems_for_src(p, state, src, num_ranks)
190
  assert n > 0
191
 
192
  sg = recv_buf.narrow(0, off + inner_off, n)
@@ -278,7 +278,7 @@ def _all2all_scatter(params, param_to_state, rank, comm_stream, alloc_event):
278
 
279
  offset = 0
280
  for dst in range(num_ranks):
281
- n = split_elems_for_src(p, state, dst, num_ranks)
282
  assert n > 0
283
 
284
  su = u_full.narrow(0, offset, n)
@@ -304,7 +304,7 @@ def _all2all_scatter(params, param_to_state, rank, comm_stream, alloc_event):
304
  state = param_to_state[id(p)]
305
  if state.worker_rank != src:
306
  continue
307
- total += split_elems_for_src(p, state, rank, num_ranks)
308
  recv_counts[src] = total
309
 
310
  recv_total = sum(recv_counts)
@@ -348,7 +348,7 @@ def _all2all_scatter(params, param_to_state, rank, comm_stream, alloc_event):
348
  state = param_to_state[id(p)]
349
  if state.worker_rank != src:
350
  continue
351
- n = split_elems_for_src(p, state, rank, num_ranks)
352
  assert n > 0
353
 
354
  flat_local = recv_buf.narrow(0, off + inner_off,
 
121
  state = param_to_state[id(p)]
122
  dst = state.worker_rank
123
  assert dst < num_ranks
124
+ shard_elems = split_elems_for_src(p, rank, num_ranks)
125
  g = p.grad
126
  g = g.to_local().to(COMM_DTYPE).contiguous().view(-1)
127
  assert g.numel() == shard_elems
 
145
  for p in owned_params:
146
  state = param_to_state[id(p)]
147
  assert state.worker_rank == rank
148
+ total += split_elems_for_src(p, src, num_ranks)
149
  recv_counts[src] = total
150
 
151
  recv_total = sum(recv_counts)
 
186
  for p in owned_params:
187
  state = param_to_state[id(p)]
188
  assert state.worker_rank == rank
189
+ n = split_elems_for_src(p, src, num_ranks)
190
  assert n > 0
191
 
192
  sg = recv_buf.narrow(0, off + inner_off, n)
 
278
 
279
  offset = 0
280
  for dst in range(num_ranks):
281
+ n = split_elems_for_src(p, dst, num_ranks)
282
  assert n > 0
283
 
284
  su = u_full.narrow(0, offset, n)
 
304
  state = param_to_state[id(p)]
305
  if state.worker_rank != src:
306
  continue
307
+ total += split_elems_for_src(p, rank, num_ranks)
308
  recv_counts[src] = total
309
 
310
  recv_total = sum(recv_counts)
 
348
  state = param_to_state[id(p)]
349
  if state.worker_rank != src:
350
  continue
351
+ n = split_elems_for_src(p, rank, num_ranks)
352
  assert n > 0
353
 
354
  flat_local = recv_buf.narrow(0, off + inner_off,