Susav's picture
Upload folder using huggingface_hub
b3a3b15 verified
import torch
import torch.nn as nn
try:
import AmpSelGemm
except ImportError:
AmpSelGemm = None
from HybridTensor.triton.gather_gemm_col import gather_matmul_col
from HybridTensor.triton.gather_gemm_row import gather_matmul_row
from HybridTensor.utils.utils import arg_parser, sparse_index, create_results_directory
from HybridTensor.utils.profiling import benchmark_mlp_fwd, generate_index_sizes, save_results_to_csv, plot_results
import pandas as pd
import matplotlib.pyplot as plt
from tqdm import tqdm # For progress bars
# implement standard MLP block with ReLU activation
class StandardMLPBlock(nn.Module):
def __init__(self, in_features, hidden_features=None, bias=False, device='cuda'):
super(StandardMLPBlock, self).__init__()
hidden_features = hidden_features or in_features*4
# this is stored in correct order; don't need to transpose for the CUTLASS kernel
self.fc1 = nn.Linear(in_features, hidden_features, bias=bias, device=device, dtype=torch.float16)
# this is stored in row major; need to transpose for the CUTLASS kernel
self.fc2 = nn.Linear(hidden_features, in_features, bias=bias, device=device, dtype=torch.float16)
self.relu = nn.ReLU()
def forward(self, x):
# fc1 : d x (4d)
out = self.fc1(x)
# B x (4d)
out = self.relu(out)
#fc2: (4d) x d
out = self.fc2(out)
return out
def sparsify(self, zero_index):
self.fc1.weight.data[zero_index, :] = 0.0
self.fc2.weight.data[:, zero_index] = 0.0
class SelectiveMLP(nn.Module):
def __init__(self, in_features, hidden_features=None, bias=False, requires_grad = False, device='cuda', dtype=torch.float16, activation='relu'):
super(SelectiveMLP, self).__init__()
if hidden_features is None:
hidden_features = in_features*4
factory_kwargs = {'device': torch.device(device), 'dtype': dtype}
self.fc1_w = torch.empty((hidden_features, in_features), requires_grad=requires_grad, **factory_kwargs)
self.fc2_w = torch.empty((hidden_features, in_features), requires_grad=requires_grad, **factory_kwargs)
self.act = nn.ReLU()
def forward(self, x, index_vec):
index_size = index_vec.size(0)
#AmpSelGemm.run(A=A, B=B_col_major, index_vec= index_vec, M= M, N=index_size, K=K, index_size=index_size)
out = AmpSelGemm.run_col(A=x, B=self.fc1_w, index_vec= index_vec, M = x.size(0), N = index_size, K = self.fc1_w.size(1), index_size=index_size)
out = self.act(out) # need to fuse this with fc1 in the next iteration
out = AmpSelGemm.run_row1(A=out, B=self.fc2_w, index_vec= index_vec, M = x.size(0), N = self.fc2_w.size(1), K = index_size, index_size=index_size)
return out
def load_from_MLP(self, mlp):
self.fc1_w = mlp.fc1.weight
self.fc2_w = mlp.fc2.weight.t().contiguous()
return self
class SelectiveMLPTriton(SelectiveMLP):
def __init__(self, in_features, hidden_features=None, bias=False, requires_grad = False, device='cuda', dtype=torch.float16, activation='relu'):
super(SelectiveMLPTriton, self).__init__(in_features, hidden_features, bias, requires_grad, device, dtype, activation)
self.activation = activation
def forward(self, x, index_vec):
out = gather_matmul_col(x, self.fc1_w, index_vec, activations=self.activation)
out = gather_matmul_row(out, self.fc2_w, index_vec)
return out
def SelectiveMLPFunc(x, fc1_w, fc2_w, index_vec, bias1 = None, bias2 = None, activation='relu'):
out = gather_matmul_col(x, fc1_w, index_vec, bias = bias1, activations=activation)
out = gather_matmul_row(out, fc2_w, index_vec, bias = bias2)
return out
def profile_mlps(args, index_size):
# Create standard MLP block
standardmlp = StandardMLPBlock(args.in_features, args.hidden_features)
if AmpSelGemm is not None:
selectiveMLP = SelectiveMLP(args.in_features, args.hidden_features).load_from_MLP(standardmlp)
selectiveMLPTriton = SelectiveMLPTriton(args.in_features, args.hidden_features).load_from_MLP(standardmlp)
# Test input
x = torch.randn(args.batch_size, args.in_features, dtype=torch.float16, device='cuda')
index_vec, zero_index = sparse_index(index_size, args.hidden_features or args.in_features*4)
standardmlp.sparsify(zero_index)
# Measure execution time
std_out, standardmlp_time = benchmark_mlp_fwd(x, standardmlp, index_vec=None, iterations=args.iterations, print_result=False)
if AmpSelGemm is not None:
cutlass_out, selectiveMLP_time = benchmark_mlp_fwd(x, selectiveMLP, index_vec=index_vec, iterations=args.iterations, print_result=False)
triton_out, selectiveMLPTriton_time = benchmark_mlp_fwd(x, selectiveMLPTriton, index_vec=index_vec, iterations=args.iterations, print_result=False)
# Optionally check results
if args.check_results:
print('Standard MLP output:', std_out[0])
if AmpSelGemm is not None:
print('Selective MLP Cutlass output:', cutlass_out)
print('Selective MLP Triton output:', triton_out)
# Calculate speedups
triton_speedup = standardmlp_time / selectiveMLPTriton_time if selectiveMLPTriton_time > 0 else float('inf')
if AmpSelGemm is not None:
cutlass_speedup = standardmlp_time / selectiveMLP_time if selectiveMLP_time > 0 else float('inf')
return {
'index_size': index_size,
'standard_time': standardmlp_time,
'selective_cutlass_time': selectiveMLP_time,
'selective_triton_time': selectiveMLPTriton_time,
'cutlass_speedup': cutlass_speedup,
'triton_speedup': triton_speedup
}
def run_profiling_over_index_sizes(args, index_sizes):
results = []
for size in tqdm(index_sizes, desc="Profiling MLPs"):
result = profile_mlps(args, size)
results.append(result)
return pd.DataFrame(results)
'''
if __name__ == '__main__':
args = arg_parser()
index_sizes = generate_index_sizes(args.hidden_features)
# create standard MLP block
standardmlp = StandardMLPBlock(args.in_features, args.hidden_features)
selectiveMLP = SelectiveMLP(args.in_features, args.hidden_features).load_from_MLP(standardmlp)
selectiveMLPTriton = SelectiveMLPTriton(args.in_features, args.hidden_features).load_from_MLP(standardmlp)
# test input
x = torch.randn(args.batch_size, args.in_features, dtype=torch.float16, device='cuda')
index_vec, zero_index = sparse_index(args.index_size, args.hidden_features or args.in_features*4)
standardmlp.sparsify(zero_index)
# measure execution time
std_out, standardmlp_time = benchmark_mlp_fwd(x, standardmlp, index_vec= None, iterations=args.iterations, print_result=True)
cutlass_out, selectiveMLP_time = benchmark_mlp_fwd(x, selectiveMLP, index_vec= index_vec, iterations=args.iterations, print_result=True)
triton_out, selectiveMLPTriton_time = benchmark_mlp_fwd(x, selectiveMLPTriton, index_vec= index_vec, iterations=args.iterations, print_result=True)
if args.check_results:
print('Standard MLP output:', std_out[0])
print('Selective MLP Cutlass output:', cutlass_out)
print('Selective MLP Triton output:', triton_out)
triton_speedup = standardmlp_time/selectiveMLPTriton_time
cutlass_speedup = standardmlp_time/selectiveMLP_time
print(f"Speedup of Cutlass implementation over standard MLP: {cutlass_speedup}")
print(f"Speedup of Triton implementation over standard MLP: {triton_speedup}")
'''
if __name__ == '__main__':
args = arg_parser()
print(f"Profiling MLPs")
# Results directory
results_dir = create_results_directory(args.results_dir)
# Define the range of index sizes you want to profile
index_sizes = generate_index_sizes(args.hidden_features)
# Run profiling over different index sizes
profiling_results = run_profiling_over_index_sizes(args, index_sizes)
# Save the results to a CSV file
save_results_to_csv(profiling_results, filename_prefix='mlp_profiling_results', results_dir=results_dir)
# Plot the results
plot_results(profiling_results, output_prefix='mlp_profiling', results_dir=results_dir)
# Optionally, print the DataFrame
if args.check_results:
print(profiling_results)