use COMM_DTYPE instead of hardcoded dtype
Browse files
torch-ext/optimizer/matmul_transpose_triton.py
CHANGED
|
@@ -28,7 +28,7 @@ def mmt_kernel(x, y, M, K, stride_xm, stride_xk, stride_ym, stride_yn,
|
|
| 28 |
GROUP_SIZE_M: tl.constexpr):
|
| 29 |
"""
|
| 30 |
Core kernel jit function of matmul_transpose that computes y = x @ x.T
|
| 31 |
-
The code is a simple adaptation from the triton `matmul` tutorial:
|
| 32 |
https://triton-lang.org/main/getting-started/tutorials/03-matrix-multiplication.html
|
| 33 |
"""
|
| 34 |
pid = tl.program_id(axis=0)
|
|
|
|
| 28 |
GROUP_SIZE_M: tl.constexpr):
|
| 29 |
"""
|
| 30 |
Core kernel jit function of matmul_transpose that computes y = x @ x.T
|
| 31 |
+
The code is a simple adaptation from the triton `matmul` tutorial:
|
| 32 |
https://triton-lang.org/main/getting-started/tutorials/03-matrix-multiplication.html
|
| 33 |
"""
|
| 34 |
pid = tl.program_id(axis=0)
|
torch-ext/optimizer/muon.py
CHANGED
|
@@ -12,6 +12,8 @@ from .matmul_transpose_triton import matmul_transpose_assign
|
|
| 12 |
|
| 13 |
logger = logging.getLogger(__name__)
|
| 14 |
|
|
|
|
|
|
|
| 15 |
|
| 16 |
# This code snippet is a modified version adapted from the following GitHub repositories:
|
| 17 |
# https://github.com/KellerJordan/Muon/blob/master/muon.py
|
|
@@ -30,8 +32,7 @@ def _zeropower_via_newtonschulz5(G, steps):
|
|
| 30 |
performance at all relative to UV^T, where USV^T = G is the SVD.
|
| 31 |
"""
|
| 32 |
assert len(G.shape) == 2
|
| 33 |
-
assert G.dtype ==
|
| 34 |
-
G = G.to(thorch.float32)
|
| 35 |
X = G # no manual typecast
|
| 36 |
|
| 37 |
if G.size(0) > G.size(1):
|
|
@@ -55,7 +56,6 @@ def _zeropower_via_newtonschulz5(G, steps):
|
|
| 55 |
|
| 56 |
if G.size(0) > G.size(1):
|
| 57 |
X = X.T
|
| 58 |
-
X = X.to(torch.bfloat16)
|
| 59 |
return X
|
| 60 |
|
| 61 |
|
|
@@ -89,7 +89,7 @@ def _alloc_gathered_grad(params, param_to_state, rank, compute_stream):
|
|
| 89 |
if rank == state.worker_rank:
|
| 90 |
num_ranks = dist.get_world_size(group=state.process_group)
|
| 91 |
state.gathered_grad = torch.empty(p.grad.numel(),
|
| 92 |
-
dtype=
|
| 93 |
device="cuda")
|
| 94 |
else:
|
| 95 |
state.gathered_grad = None
|
|
@@ -114,7 +114,7 @@ def _all2all_gather(params, param_to_state, rank, comm_stream, none_grad,
|
|
| 114 |
dst = state.worker_rank
|
| 115 |
shard_elems = split_elems_for_src(p, state, rank, num_ranks)
|
| 116 |
g = p.grad
|
| 117 |
-
g = g.to_local().to(
|
| 118 |
assert g.numel() == shard_elems
|
| 119 |
per_dst[dst].append(g)
|
| 120 |
send_counts[dst] += shard_elems
|
|
@@ -139,7 +139,7 @@ def _all2all_gather(params, param_to_state, rank, comm_stream, none_grad,
|
|
| 139 |
recv_counts[src] = total
|
| 140 |
|
| 141 |
recv_total = sum(recv_counts)
|
| 142 |
-
recv_buf = torch.empty(recv_total, dtype=
|
| 143 |
dist.all_to_all_single(
|
| 144 |
recv_buf,
|
| 145 |
send_buf,
|
|
@@ -225,7 +225,7 @@ def _alloc_scattered_u(params, param_to_state, rank, compute_stream):
|
|
| 225 |
for p in params:
|
| 226 |
state = param_to_state[id(p)]
|
| 227 |
state.scattered_u = torch.empty_like(p.to_local(),
|
| 228 |
-
dtype=
|
| 229 |
|
| 230 |
alloc_event = torch.cuda.Event()
|
| 231 |
alloc_event.record(compute_stream)
|
|
@@ -254,8 +254,7 @@ def _all2all_scatter(params, param_to_state, rank, comm_stream, alloc_event):
|
|
| 254 |
|
| 255 |
assert state.computed_u is not None
|
| 256 |
|
| 257 |
-
u_full = state.computed_u.to(
|
| 258 |
-
torch.bfloat16).contiguous().view(-1)
|
| 259 |
|
| 260 |
offset = 0
|
| 261 |
for dst in range(num_ranks):
|
|
@@ -274,7 +273,7 @@ def _all2all_scatter(params, param_to_state, rank, comm_stream, alloc_event):
|
|
| 274 |
else:
|
| 275 |
# all_to_all requires participation from all ranks
|
| 276 |
# Even non-owner ranks must join the collective call
|
| 277 |
-
send_buf = torch.empty(0, dtype=
|
| 278 |
|
| 279 |
recv_counts = [0] * num_ranks
|
| 280 |
for src in range(num_ranks):
|
|
@@ -288,7 +287,7 @@ def _all2all_scatter(params, param_to_state, rank, comm_stream, alloc_event):
|
|
| 288 |
|
| 289 |
recv_total = sum(recv_counts)
|
| 290 |
assert recv_total > 0
|
| 291 |
-
recv_buf = torch.empty(recv_total, dtype=
|
| 292 |
|
| 293 |
dist.all_to_all_single(
|
| 294 |
recv_buf,
|
|
@@ -636,7 +635,7 @@ class Muon(torch.optim.Optimizer):
|
|
| 636 |
else:
|
| 637 |
g = buf
|
| 638 |
|
| 639 |
-
u = _zeropower_via_newtonschulz5(g.
|
| 640 |
steps=group["ns_steps"])
|
| 641 |
|
| 642 |
adjusted_lr = self.adjust_lr_for_muon(lr, p.shape)
|
|
|
|
| 12 |
|
| 13 |
logger = logging.getLogger(__name__)
|
| 14 |
|
| 15 |
+
COMM_DTYPE = torch.bfloat16
|
| 16 |
+
|
| 17 |
|
| 18 |
# This code snippet is a modified version adapted from the following GitHub repositories:
|
| 19 |
# https://github.com/KellerJordan/Muon/blob/master/muon.py
|
|
|
|
| 32 |
performance at all relative to UV^T, where USV^T = G is the SVD.
|
| 33 |
"""
|
| 34 |
assert len(G.shape) == 2
|
| 35 |
+
assert G.dtype == COMM_DTYPE
|
|
|
|
| 36 |
X = G # no manual typecast
|
| 37 |
|
| 38 |
if G.size(0) > G.size(1):
|
|
|
|
| 56 |
|
| 57 |
if G.size(0) > G.size(1):
|
| 58 |
X = X.T
|
|
|
|
| 59 |
return X
|
| 60 |
|
| 61 |
|
|
|
|
| 89 |
if rank == state.worker_rank:
|
| 90 |
num_ranks = dist.get_world_size(group=state.process_group)
|
| 91 |
state.gathered_grad = torch.empty(p.grad.numel(),
|
| 92 |
+
dtype=COMM_DTYPE,
|
| 93 |
device="cuda")
|
| 94 |
else:
|
| 95 |
state.gathered_grad = None
|
|
|
|
| 114 |
dst = state.worker_rank
|
| 115 |
shard_elems = split_elems_for_src(p, state, rank, num_ranks)
|
| 116 |
g = p.grad
|
| 117 |
+
g = g.to_local().to(COMM_DTYPE).contiguous().view(-1)
|
| 118 |
assert g.numel() == shard_elems
|
| 119 |
per_dst[dst].append(g)
|
| 120 |
send_counts[dst] += shard_elems
|
|
|
|
| 139 |
recv_counts[src] = total
|
| 140 |
|
| 141 |
recv_total = sum(recv_counts)
|
| 142 |
+
recv_buf = torch.empty(recv_total, dtype=COMM_DTYPE, device="cuda")
|
| 143 |
dist.all_to_all_single(
|
| 144 |
recv_buf,
|
| 145 |
send_buf,
|
|
|
|
| 225 |
for p in params:
|
| 226 |
state = param_to_state[id(p)]
|
| 227 |
state.scattered_u = torch.empty_like(p.to_local(),
|
| 228 |
+
dtype=COMM_DTYPE)
|
| 229 |
|
| 230 |
alloc_event = torch.cuda.Event()
|
| 231 |
alloc_event.record(compute_stream)
|
|
|
|
| 254 |
|
| 255 |
assert state.computed_u is not None
|
| 256 |
|
| 257 |
+
u_full = state.computed_u.to(COMM_DTYPE).contiguous().view(-1)
|
|
|
|
| 258 |
|
| 259 |
offset = 0
|
| 260 |
for dst in range(num_ranks):
|
|
|
|
| 273 |
else:
|
| 274 |
# all_to_all requires participation from all ranks
|
| 275 |
# Even non-owner ranks must join the collective call
|
| 276 |
+
send_buf = torch.empty(0, dtype=COMM_DTYPE, device="cuda")
|
| 277 |
|
| 278 |
recv_counts = [0] * num_ranks
|
| 279 |
for src in range(num_ranks):
|
|
|
|
| 287 |
|
| 288 |
recv_total = sum(recv_counts)
|
| 289 |
assert recv_total > 0
|
| 290 |
+
recv_buf = torch.empty(recv_total, dtype=COMM_DTYPE, device="cuda")
|
| 291 |
|
| 292 |
dist.all_to_all_single(
|
| 293 |
recv_buf,
|
|
|
|
| 635 |
else:
|
| 636 |
g = buf
|
| 637 |
|
| 638 |
+
u = _zeropower_via_newtonschulz5(g.to(COMM_DTYPE),
|
| 639 |
steps=group["ns_steps"])
|
| 640 |
|
| 641 |
adjusted_lr = self.adjust_lr_for_muon(lr, p.shape)
|