|
|
import torch |
|
|
import torch.nn as nn |
|
|
try: |
|
|
import AmpSelGemm |
|
|
except ImportError: |
|
|
AmpSelGemm = None |
|
|
|
|
|
from HybridTensor.triton.gather_gemm_col import gather_matmul_col |
|
|
from HybridTensor.triton.gather_gemm_row import gather_matmul_row |
|
|
from HybridTensor.utils.utils import arg_parser, sparse_index, create_results_directory |
|
|
from HybridTensor.utils.profiling import benchmark_mlp_fwd, generate_index_sizes, save_results_to_csv, plot_results |
|
|
|
|
|
import pandas as pd |
|
|
import matplotlib.pyplot as plt |
|
|
|
|
|
from tqdm import tqdm |
|
|
|
|
|
|
|
|
class StandardMLPBlock(nn.Module): |
|
|
def __init__(self, in_features, hidden_features=None, bias=False, device='cuda'): |
|
|
super(StandardMLPBlock, self).__init__() |
|
|
hidden_features = hidden_features or in_features*4 |
|
|
|
|
|
|
|
|
self.fc1 = nn.Linear(in_features, hidden_features, bias=bias, device=device, dtype=torch.float16) |
|
|
|
|
|
|
|
|
self.fc2 = nn.Linear(hidden_features, in_features, bias=bias, device=device, dtype=torch.float16) |
|
|
self.relu = nn.ReLU() |
|
|
|
|
|
def forward(self, x): |
|
|
|
|
|
out = self.fc1(x) |
|
|
|
|
|
out = self.relu(out) |
|
|
|
|
|
|
|
|
out = self.fc2(out) |
|
|
return out |
|
|
|
|
|
def sparsify(self, zero_index): |
|
|
self.fc1.weight.data[zero_index, :] = 0.0 |
|
|
self.fc2.weight.data[:, zero_index] = 0.0 |
|
|
|
|
|
class SelectiveMLP(nn.Module): |
|
|
def __init__(self, in_features, hidden_features=None, bias=False, requires_grad = False, device='cuda', dtype=torch.float16, activation='relu'): |
|
|
super(SelectiveMLP, self).__init__() |
|
|
if hidden_features is None: |
|
|
hidden_features = in_features*4 |
|
|
|
|
|
factory_kwargs = {'device': torch.device(device), 'dtype': dtype} |
|
|
self.fc1_w = torch.empty((hidden_features, in_features), requires_grad=requires_grad, **factory_kwargs) |
|
|
self.fc2_w = torch.empty((hidden_features, in_features), requires_grad=requires_grad, **factory_kwargs) |
|
|
self.act = nn.ReLU() |
|
|
|
|
|
def forward(self, x, index_vec): |
|
|
index_size = index_vec.size(0) |
|
|
|
|
|
out = AmpSelGemm.run_col(A=x, B=self.fc1_w, index_vec= index_vec, M = x.size(0), N = index_size, K = self.fc1_w.size(1), index_size=index_size) |
|
|
out = self.act(out) |
|
|
out = AmpSelGemm.run_row1(A=out, B=self.fc2_w, index_vec= index_vec, M = x.size(0), N = self.fc2_w.size(1), K = index_size, index_size=index_size) |
|
|
return out |
|
|
|
|
|
def load_from_MLP(self, mlp): |
|
|
self.fc1_w = mlp.fc1.weight |
|
|
self.fc2_w = mlp.fc2.weight.t().contiguous() |
|
|
return self |
|
|
|
|
|
class SelectiveMLPTriton(SelectiveMLP): |
|
|
def __init__(self, in_features, hidden_features=None, bias=False, requires_grad = False, device='cuda', dtype=torch.float16, activation='relu'): |
|
|
super(SelectiveMLPTriton, self).__init__(in_features, hidden_features, bias, requires_grad, device, dtype, activation) |
|
|
self.activation = activation |
|
|
|
|
|
def forward(self, x, index_vec): |
|
|
out = gather_matmul_col(x, self.fc1_w, index_vec, activations=self.activation) |
|
|
out = gather_matmul_row(out, self.fc2_w, index_vec) |
|
|
return out |
|
|
|
|
|
def SelectiveMLPFunc(x, fc1_w, fc2_w, index_vec, bias1 = None, bias2 = None, activation='relu'): |
|
|
out = gather_matmul_col(x, fc1_w, index_vec, bias = bias1, activations=activation) |
|
|
out = gather_matmul_row(out, fc2_w, index_vec, bias = bias2) |
|
|
return out |
|
|
|
|
|
def profile_mlps(args, index_size): |
|
|
|
|
|
standardmlp = StandardMLPBlock(args.in_features, args.hidden_features) |
|
|
if AmpSelGemm is not None: |
|
|
selectiveMLP = SelectiveMLP(args.in_features, args.hidden_features).load_from_MLP(standardmlp) |
|
|
selectiveMLPTriton = SelectiveMLPTriton(args.in_features, args.hidden_features).load_from_MLP(standardmlp) |
|
|
|
|
|
|
|
|
x = torch.randn(args.batch_size, args.in_features, dtype=torch.float16, device='cuda') |
|
|
index_vec, zero_index = sparse_index(index_size, args.hidden_features or args.in_features*4) |
|
|
standardmlp.sparsify(zero_index) |
|
|
|
|
|
|
|
|
std_out, standardmlp_time = benchmark_mlp_fwd(x, standardmlp, index_vec=None, iterations=args.iterations, print_result=False) |
|
|
if AmpSelGemm is not None: |
|
|
cutlass_out, selectiveMLP_time = benchmark_mlp_fwd(x, selectiveMLP, index_vec=index_vec, iterations=args.iterations, print_result=False) |
|
|
triton_out, selectiveMLPTriton_time = benchmark_mlp_fwd(x, selectiveMLPTriton, index_vec=index_vec, iterations=args.iterations, print_result=False) |
|
|
|
|
|
|
|
|
if args.check_results: |
|
|
print('Standard MLP output:', std_out[0]) |
|
|
if AmpSelGemm is not None: |
|
|
print('Selective MLP Cutlass output:', cutlass_out) |
|
|
print('Selective MLP Triton output:', triton_out) |
|
|
|
|
|
|
|
|
triton_speedup = standardmlp_time / selectiveMLPTriton_time if selectiveMLPTriton_time > 0 else float('inf') |
|
|
if AmpSelGemm is not None: |
|
|
cutlass_speedup = standardmlp_time / selectiveMLP_time if selectiveMLP_time > 0 else float('inf') |
|
|
|
|
|
return { |
|
|
'index_size': index_size, |
|
|
'standard_time': standardmlp_time, |
|
|
'selective_cutlass_time': selectiveMLP_time, |
|
|
'selective_triton_time': selectiveMLPTriton_time, |
|
|
'cutlass_speedup': cutlass_speedup, |
|
|
'triton_speedup': triton_speedup |
|
|
} |
|
|
|
|
|
def run_profiling_over_index_sizes(args, index_sizes): |
|
|
results = [] |
|
|
for size in tqdm(index_sizes, desc="Profiling MLPs"): |
|
|
result = profile_mlps(args, size) |
|
|
results.append(result) |
|
|
return pd.DataFrame(results) |
|
|
|
|
|
|
|
|
''' |
|
|
if __name__ == '__main__': |
|
|
args = arg_parser() |
|
|
|
|
|
index_sizes = generate_index_sizes(args.hidden_features) |
|
|
|
|
|
# create standard MLP block |
|
|
standardmlp = StandardMLPBlock(args.in_features, args.hidden_features) |
|
|
selectiveMLP = SelectiveMLP(args.in_features, args.hidden_features).load_from_MLP(standardmlp) |
|
|
selectiveMLPTriton = SelectiveMLPTriton(args.in_features, args.hidden_features).load_from_MLP(standardmlp) |
|
|
|
|
|
|
|
|
# test input |
|
|
x = torch.randn(args.batch_size, args.in_features, dtype=torch.float16, device='cuda') |
|
|
index_vec, zero_index = sparse_index(args.index_size, args.hidden_features or args.in_features*4) |
|
|
standardmlp.sparsify(zero_index) |
|
|
|
|
|
# measure execution time |
|
|
std_out, standardmlp_time = benchmark_mlp_fwd(x, standardmlp, index_vec= None, iterations=args.iterations, print_result=True) |
|
|
cutlass_out, selectiveMLP_time = benchmark_mlp_fwd(x, selectiveMLP, index_vec= index_vec, iterations=args.iterations, print_result=True) |
|
|
triton_out, selectiveMLPTriton_time = benchmark_mlp_fwd(x, selectiveMLPTriton, index_vec= index_vec, iterations=args.iterations, print_result=True) |
|
|
|
|
|
if args.check_results: |
|
|
print('Standard MLP output:', std_out[0]) |
|
|
print('Selective MLP Cutlass output:', cutlass_out) |
|
|
print('Selective MLP Triton output:', triton_out) |
|
|
|
|
|
triton_speedup = standardmlp_time/selectiveMLPTriton_time |
|
|
cutlass_speedup = standardmlp_time/selectiveMLP_time |
|
|
|
|
|
print(f"Speedup of Cutlass implementation over standard MLP: {cutlass_speedup}") |
|
|
print(f"Speedup of Triton implementation over standard MLP: {triton_speedup}") |
|
|
|
|
|
''' |
|
|
|
|
|
if __name__ == '__main__': |
|
|
args = arg_parser() |
|
|
|
|
|
print(f"Profiling MLPs") |
|
|
|
|
|
results_dir = create_results_directory(args.results_dir) |
|
|
|
|
|
|
|
|
index_sizes = generate_index_sizes(args.hidden_features) |
|
|
|
|
|
|
|
|
profiling_results = run_profiling_over_index_sizes(args, index_sizes) |
|
|
|
|
|
|
|
|
save_results_to_csv(profiling_results, filename_prefix='mlp_profiling_results', results_dir=results_dir) |
|
|
|
|
|
|
|
|
plot_results(profiling_results, output_prefix='mlp_profiling', results_dir=results_dir) |
|
|
|
|
|
|
|
|
if args.check_results: |
|
|
print(profiling_results) |