GeminiFan207 commited on
Commit
d21b0aa
·
verified ·
1 Parent(s): 99d7145

Create sparse_ops.py

Browse files
Files changed (1) hide show
  1. core/data_architecture/sparse_ops.py +163 -0
core/data_architecture/sparse_ops.py ADDED
@@ -0,0 +1,163 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+ import torch.nn.functional as F
4
+ from torch.cuda.amp import autocast
5
+ from torch.optim import Adam
6
+ import cupy as cp # Optional for CUDA kernels
7
+ import cudf # cuDF for GPU-accelerated DataFrames
8
+ import flash_attn # FlashAttention for GPU-optimized attention
9
+ import onnx
10
+ import onnxruntime as ort
11
+ import tensorrt as trt
12
+ from nemo.collections.nlp.models import GPTModel
13
+ from nemo.collections.tts.models import FastPitchModel
14
+ from nemo.collections.asr.models import EncDecCTCModel
15
+ from torch2trt import torch2trt # Convert PyTorch to TensorRT
16
+ from transformers import AutoModel, AutoTokenizer
17
+ import apex
18
+ from apex import amp
19
+ from apex.optimizers import FusedAdam
20
+
21
+ class SparseLinear(nn.Module):
22
+ """
23
+ Sparse Linear Layer with Tensor Core Optimizations and Dynamic Pruning.
24
+ Prunes weights based on magnitude to improve efficiency on GPU.
25
+ """
26
+ def __init__(self, in_features, out_features, sparsity=0.5, use_fp16=True, dynamic_pruning=False):
27
+ super(SparseLinear, self).__init__()
28
+ self.in_features = in_features
29
+ self.out_features = out_features
30
+ self.sparsity = sparsity
31
+ self.use_fp16 = use_fp16
32
+ self.dynamic_pruning = dynamic_pruning # Toggle dynamic vs static pruning
33
+
34
+ # Initialize dense weight and bias
35
+ self.weight = nn.Parameter(
36
+ torch.randn(out_features, in_features, dtype=torch.float16 if use_fp16 else torch.float32)
37
+ )
38
+ self.bias = nn.Parameter(
39
+ torch.zeros(out_features, dtype=torch.float16 if use_fp16 else torch.float32)
40
+ )
41
+
42
+ # Sparse mask (static unless dynamic_pruning is enabled)
43
+ self.register_buffer("mask", self.generate_mask())
44
+
45
+ def generate_mask(self):
46
+ """
47
+ Generates a binary mask based on weight magnitude for structured sparsity.
48
+ """
49
+ if self.dynamic_pruning:
50
+ # Dynamic pruning will recompute this in forward pass
51
+ return torch.ones_like(self.weight)
52
+ weights_abs = self.weight.abs()
53
+ threshold = torch.quantile(weights_abs.flatten(), self.sparsity)
54
+ return (weights_abs > threshold).float()
55
+
56
+ def update_mask(self):
57
+ """Update mask dynamically based on current weight magnitudes."""
58
+ if self.dynamic_pruning:
59
+ weights_abs = self.weight.abs()
60
+ threshold = torch.quantile(weights_abs.flatten(), self.sparsity)
61
+ self.mask.data = (weights_abs > threshold).float()
62
+
63
+ def forward(self, x):
64
+ if self.dynamic_pruning:
65
+ self.update_mask() # Recompute mask if dynamic pruning is enabled
66
+
67
+ if self.use_fp16:
68
+ with autocast():
69
+ pruned_weight = self.weight * self.mask
70
+ return F.linear(x, pruned_weight, self.bias)
71
+ else:
72
+ pruned_weight = self.weight.float() * self.mask.float()
73
+ return F.linear(x.float(), pruned_weight, self.bias.float())
74
+
75
+
76
+ class SparseConv2d(nn.Module):
77
+ """
78
+ Sparse 2D Convolution with structured sparsity and block sparsity support.
79
+ Reduces computation by pruning less important weights.
80
+ """
81
+ def __init__(self, in_channels, out_channels, kernel_size, stride=1, padding=0,
82
+ sparsity=0.5, use_fp16=True, block_size=None, dynamic_pruning=False):
83
+ super(SparseConv2d, self).__init__()
84
+ self.use_fp16 = use_fp16
85
+ self.sparsity = sparsity
86
+ self.dynamic_pruning = dynamic_pruning
87
+ self.block_size = block_size # Optional block sparsity (e.g., (2, 2))
88
+
89
+ self.conv = nn.Conv2d(
90
+ in_channels,
91
+ out_channels,
92
+ kernel_size,
93
+ stride=stride,
94
+ padding=padding,
95
+ dtype=torch.float16 if use_fp16 else torch.float32,
96
+ )
97
+ self.register_buffer("mask", self.generate_mask())
98
+
99
+ def generate_mask(self):
100
+ """
101
+ Generate a mask based on weight magnitude, optionally with block sparsity.
102
+ """
103
+ weights = self.conv.weight
104
+ if self.dynamic_pruning:
105
+ return torch.ones_like(weights)
106
+
107
+ weights_abs = weights.abs()
108
+ if self.block_size: # Block sparsity
109
+ # Reshape weights into blocks and compute block-wise magnitude
110
+ kh, kw = self.block_size
111
+ weights_reshaped = weights_abs.view(weights_abs.size(0), weights_abs.size(1),
112
+ weights_abs.size(2) // kh, kh,
113
+ weights_abs.size(3) // kw, kw)
114
+ block_magnitudes = weights_reshaped.norm(p=2, dim=(3, 4))
115
+ threshold = torch.quantile(block_magnitudes.flatten(), self.sparsity)
116
+ block_mask = (block_magnitudes > threshold).float()
117
+ # Expand block mask back to full weight shape
118
+ mask = block_mask.unsqueeze(-1).unsqueeze(-1).expand_as(weights_reshaped).reshape_as(weights)
119
+ else:
120
+ threshold = torch.quantile(weights_abs.flatten(), self.sparsity)
121
+ mask = (weights_abs > threshold).float()
122
+ return mask
123
+
124
+ def update_mask(self):
125
+ """Update mask dynamically based on current weight magnitudes."""
126
+ if self.dynamic_pruning:
127
+ self.mask.data = self.generate_mask()
128
+
129
+ def forward(self, x):
130
+ if self.dynamic_pruning:
131
+ self.update_mask()
132
+
133
+ if self.use_fp16:
134
+ with autocast():
135
+ pruned_weight = self.conv.weight * self.mask
136
+ return F.conv2d(x, pruned_weight, self.conv.bias, self.conv.stride, self.conv.padding)
137
+ else:
138
+ pruned_weight = self.conv.weight.float() * self.mask.float()
139
+ return F.conv2d(x.float(), pruned_weight, self.conv.bias.float(),
140
+ self.conv.stride, self.conv.padding)
141
+
142
+
143
+ class SparseMLP(nn.Module):
144
+ """
145
+ Sparse MLP with Tensor Core Acceleration and optional dynamic pruning.
146
+ Uses sparse linear layers to reduce computation.
147
+ """
148
+ def __init__(self, input_dim, hidden_dim, output_dim, sparsity=0.5,
149
+ use_fp16=True, dynamic_pruning=False):
150
+ super(SparseMLP, self).__init__()
151
+ self.fc1 = SparseLinear(input_dim, hidden_dim, sparsity, use_fp16, dynamic_pruning)
152
+ self.fc2 = SparseLinear(hidden_dim, output_dim, sparsity, use_fp16, dynamic_pruning)
153
+ self.use_fp16 = use_fp16
154
+
155
+ def forward(self, x):
156
+ if self.use_fp16:
157
+ with autocast():
158
+ x = F.relu(self.fc1(x))
159
+ x = self.fc2(x)
160
+ return x
161
+ else:
162
+ x = F.relu(self.fc1(x.float()))
163
+ return self.fc2(x.float())