|
|
|
|
|
|
|
|
import argparse |
|
|
import os |
|
|
import time |
|
|
|
|
|
import torch |
|
|
import torch.cuda |
|
|
import torch.mps |
|
|
|
|
|
|
|
|
def int_or_list(x): |
|
|
try: |
|
|
return int(x) |
|
|
except ValueError: |
|
|
return [int(xi) for xi in x.split(",")] |
|
|
|
|
|
|
|
|
def none_or_list(x): |
|
|
if x == "": |
|
|
return None |
|
|
else: |
|
|
return [int(xi) for xi in x.split(",")] |
|
|
|
|
|
|
|
|
def dtype_from_str(x): |
|
|
if x == "": |
|
|
return torch.float32 |
|
|
else: |
|
|
dt = getattr(torch, x) |
|
|
if not isinstance(dt, torch.dtype): |
|
|
raise ValueError(f"{x} is not a torch dtype") |
|
|
return dt |
|
|
|
|
|
|
|
|
def bench(f, *args): |
|
|
for i in range(10): |
|
|
f(*args) |
|
|
|
|
|
s = time.time() |
|
|
for i in range(100): |
|
|
f(*args) |
|
|
e = time.time() |
|
|
return e - s |
|
|
|
|
|
|
|
|
def sync_if_needed(x): |
|
|
if x.device == torch.device("mps"): |
|
|
torch.mps.synchronize() |
|
|
elif x.device == torch.device("cuda"): |
|
|
torch.cuda.synchronize() |
|
|
|
|
|
|
|
|
@torch.no_grad() |
|
|
def matmul_square(x): |
|
|
y = x |
|
|
for i in range(10): |
|
|
y = y @ x |
|
|
sync_if_needed(x) |
|
|
|
|
|
|
|
|
@torch.no_grad() |
|
|
def matmul(x, y): |
|
|
ys = [] |
|
|
for i in range(10): |
|
|
ys.append(x @ y) |
|
|
sync_if_needed(x) |
|
|
|
|
|
|
|
|
@torch.no_grad() |
|
|
def conv1d(x, y): |
|
|
x = torch.transpose(x, -1, -2) |
|
|
y = torch.transpose(y, -1, -2) |
|
|
ys = [] |
|
|
for i in range(10): |
|
|
ys.append(torch.nn.functional.conv1d(x, y)) |
|
|
sync_if_needed(x) |
|
|
|
|
|
|
|
|
@torch.no_grad() |
|
|
def conv2d(x, y): |
|
|
x = torch.permute(x, (0, 3, 1, 2)) |
|
|
y = torch.permute(y, (0, 3, 1, 2)) |
|
|
ys = [] |
|
|
for i in range(10): |
|
|
ys.append(torch.nn.functional.conv2d(x, y)) |
|
|
sync_if_needed(x) |
|
|
|
|
|
|
|
|
@torch.no_grad() |
|
|
def binary(op, x, y): |
|
|
for i in range(100): |
|
|
y = getattr(torch, op)(x, y) |
|
|
sync_if_needed(x) |
|
|
|
|
|
|
|
|
@torch.no_grad() |
|
|
def reduction(op, axis, x): |
|
|
ys = [] |
|
|
for i in range(100): |
|
|
ys.append(getattr(x, op)(axis)) |
|
|
sync_if_needed(x) |
|
|
|
|
|
|
|
|
@torch.no_grad() |
|
|
def sum_and_add(axis, x, y): |
|
|
z = x.sum(axis=axis, keepdims=True) |
|
|
for i in range(50): |
|
|
z = (z + y).sum(axis=axis, keepdims=True) |
|
|
sync_if_needed(x) |
|
|
|
|
|
|
|
|
@torch.no_grad() |
|
|
def softmax(axis, x): |
|
|
ys = [] |
|
|
for i in range(100): |
|
|
ex = torch.exp(x - torch.max(x, dim=axis, keepdims=True).values) |
|
|
y = ex / torch.sum(ex, dim=axis, keepdims=True) |
|
|
ys.append(y) |
|
|
sync_if_needed(x) |
|
|
|
|
|
|
|
|
@torch.no_grad() |
|
|
def softmax_fused(axis, x): |
|
|
ys = [] |
|
|
for i in range(100): |
|
|
ys.append(torch.nn.functional.softmax(x, dim=axis)) |
|
|
sync_if_needed(x) |
|
|
|
|
|
|
|
|
@torch.no_grad() |
|
|
def relu(x): |
|
|
y = x |
|
|
for i in range(100): |
|
|
y = torch.nn.functional.relu(y) |
|
|
sync_if_needed(x) |
|
|
|
|
|
|
|
|
@torch.no_grad() |
|
|
def leaky_relu(x): |
|
|
y = x |
|
|
for i in range(100): |
|
|
y = torch.nn.functional.leaky_relu(y) |
|
|
sync_if_needed(x) |
|
|
|
|
|
|
|
|
@torch.no_grad() |
|
|
def elu(x): |
|
|
y = x |
|
|
for i in range(100): |
|
|
y = torch.nn.functional.elu(y) |
|
|
sync_if_needed(x) |
|
|
|
|
|
|
|
|
@torch.no_grad() |
|
|
def celu(x): |
|
|
y = x |
|
|
for i in range(100): |
|
|
y = torch.nn.functional.celu(y) |
|
|
sync_if_needed(x) |
|
|
|
|
|
|
|
|
@torch.no_grad() |
|
|
def relu6(x): |
|
|
y = x |
|
|
for i in range(100): |
|
|
y = torch.nn.functional.relu6(y) |
|
|
sync_if_needed(x) |
|
|
|
|
|
|
|
|
@torch.no_grad() |
|
|
def softplus(x): |
|
|
y = x |
|
|
for i in range(100): |
|
|
y = torch.nn.functional.softplus(y) |
|
|
sync_if_needed(x) |
|
|
|
|
|
|
|
|
@torch.no_grad() |
|
|
def log_sigmoid(x): |
|
|
y = x |
|
|
for i in range(100): |
|
|
y = torch.nn.functional.logsigmoid(y) |
|
|
sync_if_needed(x) |
|
|
|
|
|
|
|
|
@torch.no_grad() |
|
|
def prelu(x: torch.Tensor) -> torch.Tensor: |
|
|
y = x |
|
|
for _ in range(100): |
|
|
y = torch.nn.functional.prelu(y, torch.ones(1).to(y.device)) |
|
|
sync_if_needed(x) |
|
|
|
|
|
|
|
|
@torch.no_grad() |
|
|
def mish(x: torch.Tensor) -> torch.Tensor: |
|
|
y = x |
|
|
for _ in range(100): |
|
|
y = torch.nn.functional.mish(y) |
|
|
sync_if_needed(x) |
|
|
|
|
|
|
|
|
@torch.no_grad() |
|
|
def scalar_mult(x): |
|
|
y = x |
|
|
for i in range(100): |
|
|
y = y * (1.0 / (1 + i)) |
|
|
sync_if_needed(x) |
|
|
|
|
|
|
|
|
@torch.no_grad() |
|
|
def cross_entropy(targets, x): |
|
|
ys = [] |
|
|
for i in range(100): |
|
|
ys.append(torch.nn.functional.cross_entropy(x, targets)) |
|
|
sync_if_needed(x) |
|
|
|
|
|
|
|
|
@torch.no_grad() |
|
|
def logsumexp(axis, x): |
|
|
ys = [] |
|
|
for i in range(100): |
|
|
ys.append(torch.logsumexp(x, dim=axis)) |
|
|
sync_if_needed(x) |
|
|
|
|
|
|
|
|
@torch.no_grad() |
|
|
def linear_fused(w, b, x): |
|
|
ys = [] |
|
|
for i in range(10): |
|
|
ys.append(torch.nn.functional.linear(x, w, b)) |
|
|
sync_if_needed(x) |
|
|
|
|
|
|
|
|
@torch.no_grad() |
|
|
def linear(w, b, x): |
|
|
ys = [] |
|
|
for i in range(10): |
|
|
ys.append((x @ torch.transpose(w, -2, -1)) + b) |
|
|
sync_if_needed(x) |
|
|
|
|
|
|
|
|
@torch.no_grad() |
|
|
def rope(x): |
|
|
*_, N, D = x.shape |
|
|
ys = [] |
|
|
for i in range(10): |
|
|
x = x.view(-1, N, D) |
|
|
positions = torch.arange(N, device=x.device) |
|
|
freqs = 10000 ** torch.linspace(0, 1, D // 2, device=x.device) |
|
|
theta = positions[:, None] * freqs[None] |
|
|
costheta = torch.cos(theta) |
|
|
sintheta = torch.sin(theta) |
|
|
x1 = x[..., ::2] |
|
|
x2 = x[..., 1::2] |
|
|
rx1 = x1 * costheta - x2 * sintheta |
|
|
rx2 = x1 * sintheta + x2 * costheta |
|
|
y = torch.cat([rx1[..., None], rx2[..., None]], dim=-1) |
|
|
y = y.reshape(-1, N, D) |
|
|
ys.append(y) |
|
|
sync_if_needed(x) |
|
|
|
|
|
|
|
|
@torch.no_grad() |
|
|
def concatenate(axis, x, y): |
|
|
ys = [] |
|
|
for i in range(10): |
|
|
ys.append(torch.cat([x, y], dim=axis)) |
|
|
sync_if_needed(x) |
|
|
|
|
|
|
|
|
@torch.no_grad() |
|
|
def cumsum(axis, x): |
|
|
ys = [] |
|
|
for i in range(10): |
|
|
ys.append(x.cumsum(axis)) |
|
|
sync_if_needed(x) |
|
|
|
|
|
|
|
|
@torch.no_grad() |
|
|
def sort(axis, x): |
|
|
ys = [] |
|
|
for i in range(10): |
|
|
ys.append(torch.sort(x, dim=axis)[0]) |
|
|
sync_if_needed(x) |
|
|
|
|
|
|
|
|
@torch.no_grad() |
|
|
def topk(axis, x): |
|
|
k = x.shape[axis] // 3 |
|
|
ys = [] |
|
|
for i in range(10): |
|
|
ys.append(torch.topk(x, k, dim=axis)[0]) |
|
|
sync_if_needed(x) |
|
|
|
|
|
|
|
|
@torch.no_grad() |
|
|
def step_function(x): |
|
|
y = x |
|
|
for i in range(100): |
|
|
y = torch.where(y < 0, 0, 1) |
|
|
sync_if_needed(x) |
|
|
|
|
|
|
|
|
@torch.no_grad() |
|
|
def selu(x): |
|
|
y = x |
|
|
for i in range(100): |
|
|
y = torch.nn.functional.selu(y) |
|
|
sync_if_needed(x) |
|
|
|
|
|
|
|
|
if __name__ == "__main__": |
|
|
parser = argparse.ArgumentParser() |
|
|
parser.add_argument("benchmark", help="Choose the benchmark to run") |
|
|
parser.add_argument( |
|
|
"--size", |
|
|
default=[(1024, 1024)], |
|
|
type=lambda x: list(map(int, x.split("x"))), |
|
|
help="Set the matrix size", |
|
|
action="append", |
|
|
) |
|
|
parser.add_argument( |
|
|
"--axis", |
|
|
default=[1], |
|
|
type=int_or_list, |
|
|
help="Set a reduction axis", |
|
|
action="append", |
|
|
) |
|
|
parser.add_argument( |
|
|
"--transpose", |
|
|
type=none_or_list, |
|
|
default=[], |
|
|
help="Permute the matrix", |
|
|
action="append", |
|
|
) |
|
|
parser.add_argument( |
|
|
"--print-pid", action="store_true", help="Print the PID and pause" |
|
|
) |
|
|
parser.add_argument("--cpu", action="store_true", help="Use the CPU") |
|
|
parser.add_argument( |
|
|
"--fused", action="store_true", help="Use fused functions where possible" |
|
|
) |
|
|
parser.add_argument("--dtype", type=dtype_from_str, default=[], action="append") |
|
|
|
|
|
args = parser.parse_args() |
|
|
|
|
|
if len(args.size) > 1: |
|
|
args.size.pop(0) |
|
|
if len(args.axis) > 1: |
|
|
args.axis.pop(0) |
|
|
|
|
|
torch.set_num_threads(1) |
|
|
device = "mps" |
|
|
if torch.cuda.is_available(): |
|
|
device = "cuda" |
|
|
if args.cpu: |
|
|
device = "cpu" |
|
|
|
|
|
types = args.dtype |
|
|
if not types: |
|
|
types = [torch.float32] |
|
|
if len(types) < len(args.size): |
|
|
types = types + [types[0]] * (len(args.size) - len(types)) |
|
|
|
|
|
xs = [] |
|
|
for size, dtype in zip(args.size, types): |
|
|
xs.append(torch.randn(*size).to(device).to(dtype)) |
|
|
for i, t in enumerate(args.transpose): |
|
|
if t is None: |
|
|
continue |
|
|
xs[i] = xs[i].permute(*t) |
|
|
x = xs[0] |
|
|
axis = args.axis[0] |
|
|
|
|
|
if args.print_pid: |
|
|
print(os.getpid()) |
|
|
input("Press enter to run") |
|
|
|
|
|
if args.benchmark == "matmul_square": |
|
|
print(bench(matmul_square, x)) |
|
|
|
|
|
elif args.benchmark == "matmul": |
|
|
print(bench(matmul, *xs)) |
|
|
|
|
|
elif args.benchmark == "linear": |
|
|
if args.fused: |
|
|
print(bench(linear_fused, *xs)) |
|
|
else: |
|
|
print(bench(linear, *xs)) |
|
|
|
|
|
elif args.benchmark == "sum_axis": |
|
|
print(bench(reduction, "sum", axis, x)) |
|
|
|
|
|
elif args.benchmark == "sum_all": |
|
|
print(bench(reduction, "sum", None, x)) |
|
|
|
|
|
elif args.benchmark == "argmax": |
|
|
print(bench(reduction, "argmax", axis, x)) |
|
|
|
|
|
elif args.benchmark == "add": |
|
|
print(bench(binary, "add", *xs)) |
|
|
|
|
|
elif args.benchmark == "mul": |
|
|
print(bench(binary, "mul", *xs)) |
|
|
|
|
|
elif args.benchmark == "softmax": |
|
|
if args.fused: |
|
|
print(bench(softmax_fused, axis, x)) |
|
|
else: |
|
|
print(bench(softmax, axis, x)) |
|
|
|
|
|
elif args.benchmark == "relu": |
|
|
print(bench(relu, x)) |
|
|
|
|
|
elif args.benchmark == "leaky_relu": |
|
|
print(bench(leaky_relu, x)) |
|
|
|
|
|
elif args.benchmark == "elu": |
|
|
print(bench(elu, x)) |
|
|
|
|
|
elif args.benchmark == "relu6": |
|
|
print(bench(relu6, x)) |
|
|
|
|
|
elif args.benchmark == "softplus": |
|
|
print(bench(softplus, x)) |
|
|
|
|
|
elif args.benchmark == "celu": |
|
|
print(bench(celu, x)) |
|
|
|
|
|
elif args.benchmark == "log_sigmoid": |
|
|
print(bench(log_sigmoid, x)) |
|
|
|
|
|
elif args.benchmark == "prelu": |
|
|
print(bench(prelu, x)) |
|
|
elif args.benchmark == "mish": |
|
|
print(bench(mish, x)) |
|
|
elif args.benchmark == "scalar_mul": |
|
|
print(bench(scalar_mult, x)) |
|
|
|
|
|
elif args.benchmark == "cross_entropy": |
|
|
if len(size) != 2: |
|
|
raise ValueError("Error: [cross_entropy] benchmark requires a 2 dim size") |
|
|
|
|
|
targets = torch.zeros(len(x), dtype=torch.long).to(x.device) |
|
|
print(bench(cross_entropy, targets, x)) |
|
|
|
|
|
elif args.benchmark == "logsumexp": |
|
|
print(bench(logsumexp, axis, x)) |
|
|
|
|
|
elif args.benchmark == "rope": |
|
|
print(bench(rope, x)) |
|
|
|
|
|
elif args.benchmark == "concatenate": |
|
|
print(bench(concatenate, axis, *xs)) |
|
|
|
|
|
elif args.benchmark == "cumsum": |
|
|
print(bench(cumsum, axis, *xs)) |
|
|
|
|
|
elif args.benchmark == "conv1d": |
|
|
print(bench(conv1d, *xs)) |
|
|
|
|
|
elif args.benchmark == "conv2d": |
|
|
print(bench(conv2d, *xs)) |
|
|
|
|
|
elif args.benchmark == "sort": |
|
|
print(bench(sort, axis, x)) |
|
|
|
|
|
elif args.benchmark == "topk": |
|
|
print(bench(topk, axis, x)) |
|
|
|
|
|
elif args.benchmark == "step": |
|
|
print(bench(step_function, x)) |
|
|
|
|
|
elif args.benchmark == "selu": |
|
|
print(bench(selu, x)) |
|
|
|
|
|
elif args.benchmark == "sum_and_add": |
|
|
print(bench(sum_and_add, axis, *xs)) |
|
|
|
|
|
else: |
|
|
raise ValueError(f"Unknown benchmark `{args.benchmark}`.") |
|
|
|