|
|
import torch |
|
|
import torch.nn as nn |
|
|
import torch.nn.functional as F |
|
|
from torch.cuda.amp import autocast |
|
|
from torch.optim import Adam |
|
|
import cupy as cp |
|
|
import cudf |
|
|
import flash_attn |
|
|
import onnx |
|
|
import onnxruntime as ort |
|
|
import tensorrt as trt |
|
|
from nemo.collections.nlp.models import GPTModel |
|
|
from nemo.collections.tts.models import FastPitchModel |
|
|
from nemo.collections.asr.models import EncDecCTCModel |
|
|
from torch2trt import torch2trt |
|
|
from transformers import AutoModel, AutoTokenizer |
|
|
import apex |
|
|
from apex import amp |
|
|
from apex.optimizers import FusedAdam |
|
|
|
|
|
|
|
|
import fused_ops |
|
|
|
|
|
class SparseLinear(nn.Module): |
|
|
""" |
|
|
Sparse Linear Layer with Tensor Core Optimizations and Dynamic Pruning. |
|
|
Integrates fused GEMM + ReLU CUDA kernel for GPU efficiency. |
|
|
""" |
|
|
def __init__(self, in_features, out_features, sparsity=0.5, use_fp16=True, dynamic_pruning=False): |
|
|
super(SparseLinear, self).__init__() |
|
|
self.in_features = in_features |
|
|
self.out_features = out_features |
|
|
self.sparsity = sparsity |
|
|
self.use_fp16 = use_fp16 |
|
|
self.dynamic_pruning = dynamic_pruning |
|
|
|
|
|
|
|
|
self.weight = nn.Parameter( |
|
|
torch.randn(out_features, in_features, dtype=torch.float16 if use_fp16 else torch.float32) |
|
|
) |
|
|
self.bias = nn.Parameter( |
|
|
torch.zeros(out_features, dtype=torch.float16 if use_fp16 else torch.float32) |
|
|
) |
|
|
|
|
|
|
|
|
self.register_buffer("mask", self.generate_mask()) |
|
|
|
|
|
def generate_mask(self): |
|
|
""" |
|
|
Generates a binary mask based on weight magnitude for structured sparsity. |
|
|
""" |
|
|
if self.dynamic_pruning: |
|
|
return torch.ones_like(self.weight) |
|
|
weights_abs = self.weight.abs() |
|
|
threshold = torch.quantile(weights_abs.flatten(), self.sparsity) |
|
|
return (weights_abs > threshold).to(self.weight.dtype) |
|
|
|
|
|
def update_mask(self): |
|
|
"""Update mask dynamically based on current weight magnitudes.""" |
|
|
if self.dynamic_pruning: |
|
|
weights_abs = self.weight.abs() |
|
|
threshold = torch.quantile(weights_abs.flatten(), self.sparsity) |
|
|
self.mask.data = (weights_abs > threshold).to(self.weight.dtype) |
|
|
|
|
|
def forward(self, x): |
|
|
if self.dynamic_pruning: |
|
|
self.update_mask() |
|
|
|
|
|
if self.use_fp16 and x.is_cuda(): |
|
|
|
|
|
return fused_ops.fused_sparse_gemm_relu(x, self.weight, self.mask, self.bias) |
|
|
else: |
|
|
|
|
|
pruned_weight = self.weight.float() * self.mask.float() |
|
|
return F.relu(F.linear(x.float(), pruned_weight, self.bias.float())) |
|
|
|
|
|
|
|
|
class SparseConv2d(nn.Module): |
|
|
""" |
|
|
Sparse 2D Convolution with structured sparsity and block sparsity support. |
|
|
Reduces computation by pruning less important weights. |
|
|
""" |
|
|
def __init__(self, in_channels, out_channels, kernel_size, stride=1, padding=0, |
|
|
sparsity=0.5, use_fp16=True, block_size=None, dynamic_pruning=False): |
|
|
super(SparseConv2d, self).__init__() |
|
|
self.use_fp16 = use_fp16 |
|
|
self.sparsity = sparsity |
|
|
self.dynamic_pruning = dynamic_pruning |
|
|
self.block_size = block_size |
|
|
|
|
|
self.conv = nn.Conv2d( |
|
|
in_channels, |
|
|
out_channels, |
|
|
kernel_size, |
|
|
stride=stride, |
|
|
padding=padding, |
|
|
dtype=torch.float16 if use_fp16 else torch.float32, |
|
|
) |
|
|
self.register_buffer("mask", self.generate_mask()) |
|
|
|
|
|
def generate_mask(self): |
|
|
weights = self.conv.weight |
|
|
if self.dynamic_pruning: |
|
|
return torch.ones_like(weights) |
|
|
weights_abs = weights.abs() |
|
|
if self.block_size: |
|
|
kh, kw = self.block_size |
|
|
weights_reshaped = weights_abs.view(weights_abs.size(0), weights_abs.size(1), |
|
|
weights_abs.size(2) // kh, kh, |
|
|
weights_abs.size(3) // kw, kw) |
|
|
block_magnitudes = weights_reshaped.norm(p=2, dim=(3, 4)) |
|
|
threshold = torch.quantile(block_magnitudes.flatten(), self.sparsity) |
|
|
block_mask = (block_magnitudes > threshold).float() |
|
|
mask = block_mask.unsqueeze(-1).unsqueeze(-1).expand_as(weights_reshaped).reshape_as(weights) |
|
|
else: |
|
|
threshold = torch.quantile(weights_abs.flatten(), self.sparsity) |
|
|
mask = (weights_abs > threshold).float() |
|
|
return mask |
|
|
|
|
|
def update_mask(self): |
|
|
if self.dynamic_pruning: |
|
|
self.mask.data = self.generate_mask() |
|
|
|
|
|
def forward(self, x): |
|
|
if self.dynamic_pruning: |
|
|
self.update_mask() |
|
|
|
|
|
if self.use_fp16: |
|
|
with autocast(): |
|
|
pruned_weight = self.conv.weight * self.mask |
|
|
return F.conv2d(x, pruned_weight, self.conv.bias, self.conv.stride, self.conv.padding) |
|
|
else: |
|
|
pruned_weight = self.conv.weight.float() * self.mask.float() |
|
|
return F.conv2d(x.float(), pruned_weight, self.conv.bias.float(), |
|
|
self.conv.stride, self.conv.padding) |
|
|
|
|
|
|
|
|
class SparseMLP(nn.Module): |
|
|
""" |
|
|
Sparse MLP with Tensor Core Acceleration and optional dynamic pruning. |
|
|
Uses sparse linear layers with fused ops for efficiency. |
|
|
""" |
|
|
def __init__(self, input_dim, hidden_dim, output_dim, sparsity=0.5, |
|
|
use_fp16=True, dynamic_pruning=False): |
|
|
super(SparseMLP, self).__init__() |
|
|
self.fc1 = SparseLinear(input_dim, hidden_dim, sparsity, use_fp16, dynamic_pruning) |
|
|
self.fc2 = SparseLinear(hidden_dim, output_dim, sparsity, use_fp16, dynamic_pruning) |
|
|
self.use_fp16 = use_fp16 |
|
|
|
|
|
def forward(self, x): |
|
|
if self.use_fp16: |
|
|
with autocast(): |
|
|
x = self.fc1(x) |
|
|
x = self.fc2(x) |
|
|
return x |
|
|
else: |
|
|
x = self.fc1(x) |
|
|
return self.fc2(x) |
|
|
|
|
|
|
|
|
def train_sparse_mlp(): |
|
|
model = SparseMLP(784, 256, 10, sparsity=0.5, use_fp16=True).cuda() |
|
|
optimizer = FusedAdam(model.parameters(), lr=0.001) |
|
|
|
|
|
|
|
|
model, optimizer = amp.initialize(model, optimizer, opt_level="O1") |
|
|
|
|
|
|
|
|
inputs = torch.randn(32, 784).cuda() |
|
|
targets = torch.randint(0, 10, (32,)).cuda() |
|
|
|
|
|
|
|
|
for _ in range(100): |
|
|
optimizer.zero_grad() |
|
|
outputs = model(inputs) |
|
|
loss = F.cross_entropy(outputs, targets) |
|
|
|
|
|
with amp.scale_loss(loss, optimizer) as scaled_loss: |
|
|
scaled_loss.backward() |
|
|
optimizer.step() |
|
|
|
|
|
|
|
|
torch.onnx.export(model, inputs, "sparse_mlp.onnx", opset_version=12) |
|
|
|
|
|
|
|
|
model_trt = torch2trt(model, [inputs], fp16_mode=True) |
|
|
return model_trt |
|
|
|
|
|
if __name__ == "__main__": |
|
|
trt_model = train_sparse_mlp() |