Kernels
ca1207 commited on
Commit
2a8631f
·
1 Parent(s): ff6d675

use COMM_DTYPE instead of hardcoded dtype

Browse files
torch-ext/optimizer/matmul_transpose_triton.py CHANGED
@@ -28,7 +28,7 @@ def mmt_kernel(x, y, M, K, stride_xm, stride_xk, stride_ym, stride_yn,
28
  GROUP_SIZE_M: tl.constexpr):
29
  """
30
  Core kernel jit function of matmul_transpose that computes y = x @ x.T
31
- The code is a simple adaptation from the triton `matmul` tutorial:
32
  https://triton-lang.org/main/getting-started/tutorials/03-matrix-multiplication.html
33
  """
34
  pid = tl.program_id(axis=0)
 
28
  GROUP_SIZE_M: tl.constexpr):
29
  """
30
  Core kernel jit function of matmul_transpose that computes y = x @ x.T
31
+ The code is a simple adaptation from the triton `matmul` tutorial:
32
  https://triton-lang.org/main/getting-started/tutorials/03-matrix-multiplication.html
33
  """
34
  pid = tl.program_id(axis=0)
torch-ext/optimizer/muon.py CHANGED
@@ -12,6 +12,8 @@ from .matmul_transpose_triton import matmul_transpose_assign
12
 
13
  logger = logging.getLogger(__name__)
14
 
 
 
15
 
16
  # This code snippet is a modified version adapted from the following GitHub repositories:
17
  # https://github.com/KellerJordan/Muon/blob/master/muon.py
@@ -30,8 +32,7 @@ def _zeropower_via_newtonschulz5(G, steps):
30
  performance at all relative to UV^T, where USV^T = G is the SVD.
31
  """
32
  assert len(G.shape) == 2
33
- assert G.dtype == torch.bfloat16
34
- G = G.to(thorch.float32)
35
  X = G # no manual typecast
36
 
37
  if G.size(0) > G.size(1):
@@ -55,7 +56,6 @@ def _zeropower_via_newtonschulz5(G, steps):
55
 
56
  if G.size(0) > G.size(1):
57
  X = X.T
58
- X = X.to(torch.bfloat16)
59
  return X
60
 
61
 
@@ -89,7 +89,7 @@ def _alloc_gathered_grad(params, param_to_state, rank, compute_stream):
89
  if rank == state.worker_rank:
90
  num_ranks = dist.get_world_size(group=state.process_group)
91
  state.gathered_grad = torch.empty(p.grad.numel(),
92
- dtype=torch.bfloat16,
93
  device="cuda")
94
  else:
95
  state.gathered_grad = None
@@ -114,7 +114,7 @@ def _all2all_gather(params, param_to_state, rank, comm_stream, none_grad,
114
  dst = state.worker_rank
115
  shard_elems = split_elems_for_src(p, state, rank, num_ranks)
116
  g = p.grad
117
- g = g.to_local().to(torch.bfloat16).contiguous().view(-1)
118
  assert g.numel() == shard_elems
119
  per_dst[dst].append(g)
120
  send_counts[dst] += shard_elems
@@ -139,7 +139,7 @@ def _all2all_gather(params, param_to_state, rank, comm_stream, none_grad,
139
  recv_counts[src] = total
140
 
141
  recv_total = sum(recv_counts)
142
- recv_buf = torch.empty(recv_total, dtype=torch.bfloat16, device="cuda")
143
  dist.all_to_all_single(
144
  recv_buf,
145
  send_buf,
@@ -225,7 +225,7 @@ def _alloc_scattered_u(params, param_to_state, rank, compute_stream):
225
  for p in params:
226
  state = param_to_state[id(p)]
227
  state.scattered_u = torch.empty_like(p.to_local(),
228
- dtype=torch.bfloat16)
229
 
230
  alloc_event = torch.cuda.Event()
231
  alloc_event.record(compute_stream)
@@ -254,8 +254,7 @@ def _all2all_scatter(params, param_to_state, rank, comm_stream, alloc_event):
254
 
255
  assert state.computed_u is not None
256
 
257
- u_full = state.computed_u.to(
258
- torch.bfloat16).contiguous().view(-1)
259
 
260
  offset = 0
261
  for dst in range(num_ranks):
@@ -274,7 +273,7 @@ def _all2all_scatter(params, param_to_state, rank, comm_stream, alloc_event):
274
  else:
275
  # all_to_all requires participation from all ranks
276
  # Even non-owner ranks must join the collective call
277
- send_buf = torch.empty(0, dtype=torch.bfloat16, device="cuda")
278
 
279
  recv_counts = [0] * num_ranks
280
  for src in range(num_ranks):
@@ -288,7 +287,7 @@ def _all2all_scatter(params, param_to_state, rank, comm_stream, alloc_event):
288
 
289
  recv_total = sum(recv_counts)
290
  assert recv_total > 0
291
- recv_buf = torch.empty(recv_total, dtype=torch.bfloat16, device="cuda")
292
 
293
  dist.all_to_all_single(
294
  recv_buf,
@@ -636,7 +635,7 @@ class Muon(torch.optim.Optimizer):
636
  else:
637
  g = buf
638
 
639
- u = _zeropower_via_newtonschulz5(g.bfloat16(),
640
  steps=group["ns_steps"])
641
 
642
  adjusted_lr = self.adjust_lr_for_muon(lr, p.shape)
 
12
 
13
  logger = logging.getLogger(__name__)
14
 
15
+ COMM_DTYPE = torch.bfloat16
16
+
17
 
18
  # This code snippet is a modified version adapted from the following GitHub repositories:
19
  # https://github.com/KellerJordan/Muon/blob/master/muon.py
 
32
  performance at all relative to UV^T, where USV^T = G is the SVD.
33
  """
34
  assert len(G.shape) == 2
35
+ assert G.dtype == COMM_DTYPE
 
36
  X = G # no manual typecast
37
 
38
  if G.size(0) > G.size(1):
 
56
 
57
  if G.size(0) > G.size(1):
58
  X = X.T
 
59
  return X
60
 
61
 
 
89
  if rank == state.worker_rank:
90
  num_ranks = dist.get_world_size(group=state.process_group)
91
  state.gathered_grad = torch.empty(p.grad.numel(),
92
+ dtype=COMM_DTYPE,
93
  device="cuda")
94
  else:
95
  state.gathered_grad = None
 
114
  dst = state.worker_rank
115
  shard_elems = split_elems_for_src(p, state, rank, num_ranks)
116
  g = p.grad
117
+ g = g.to_local().to(COMM_DTYPE).contiguous().view(-1)
118
  assert g.numel() == shard_elems
119
  per_dst[dst].append(g)
120
  send_counts[dst] += shard_elems
 
139
  recv_counts[src] = total
140
 
141
  recv_total = sum(recv_counts)
142
+ recv_buf = torch.empty(recv_total, dtype=COMM_DTYPE, device="cuda")
143
  dist.all_to_all_single(
144
  recv_buf,
145
  send_buf,
 
225
  for p in params:
226
  state = param_to_state[id(p)]
227
  state.scattered_u = torch.empty_like(p.to_local(),
228
+ dtype=COMM_DTYPE)
229
 
230
  alloc_event = torch.cuda.Event()
231
  alloc_event.record(compute_stream)
 
254
 
255
  assert state.computed_u is not None
256
 
257
+ u_full = state.computed_u.to(COMM_DTYPE).contiguous().view(-1)
 
258
 
259
  offset = 0
260
  for dst in range(num_ranks):
 
273
  else:
274
  # all_to_all requires participation from all ranks
275
  # Even non-owner ranks must join the collective call
276
+ send_buf = torch.empty(0, dtype=COMM_DTYPE, device="cuda")
277
 
278
  recv_counts = [0] * num_ranks
279
  for src in range(num_ranks):
 
287
 
288
  recv_total = sum(recv_counts)
289
  assert recv_total > 0
290
+ recv_buf = torch.empty(recv_total, dtype=COMM_DTYPE, device="cuda")
291
 
292
  dist.all_to_all_single(
293
  recv_buf,
 
635
  else:
636
  g = buf
637
 
638
+ u = _zeropower_via_newtonschulz5(g.to(COMM_DTYPE),
639
  steps=group["ns_steps"])
640
 
641
  adjusted_lr = self.adjust_lr_for_muon(lr, p.shape)