camenduru commited on
Commit
9dc3a34
·
1 Parent(s): c03d1cc

Delete models.py

Browse files
Files changed (1) hide show
  1. models.py +0 -980
models.py DELETED
@@ -1,980 +0,0 @@
1
- #
2
- # SPDX-FileCopyrightText: Copyright (c) 1993-2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
3
- # SPDX-License-Identifier: Apache-2.0
4
- #
5
- # Licensed under the Apache License, Version 2.0 (the "License");
6
- # you may not use this file except in compliance with the License.
7
- # You may obtain a copy of the License at
8
- #
9
- # http://www.apache.org/licenses/LICENSE-2.0
10
- #
11
- # Unless required by applicable law or agreed to in writing, software
12
- # distributed under the License is distributed on an "AS IS" BASIS,
13
- # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14
- # See the License for the specific language governing permissions and
15
- # limitations under the License.
16
- #
17
-
18
- from collections import OrderedDict
19
- from copy import deepcopy
20
- from diffusers.models import AutoencoderKL, UNet2DConditionModel
21
- import numpy as np
22
- from onnx import shape_inference
23
- import onnx_graphsurgeon as gs
24
- from polygraphy.backend.onnx.loader import fold_constants
25
- import torch
26
- from transformers import CLIPTextModel
27
- from cuda import cudart
28
-
29
- class Optimizer():
30
- def __init__(
31
- self,
32
- onnx_graph,
33
- verbose=False
34
- ):
35
- self.graph = gs.import_onnx(onnx_graph)
36
- self.verbose = verbose
37
-
38
- def info(self, prefix=''):
39
- if self.verbose:
40
- print(f"{prefix} .. {len(self.graph.nodes)} nodes, {len(self.graph.tensors().keys())} tensors, {len(self.graph.inputs)} inputs, {len(self.graph.outputs)} outputs")
41
-
42
- def cleanup(self, return_onnx=False):
43
- self.graph.cleanup().toposort()
44
- if return_onnx:
45
- return gs.export_onnx(self.graph)
46
-
47
- def select_outputs(self, keep, names=None):
48
- self.graph.outputs = [self.graph.outputs[o] for o in keep]
49
- if names:
50
- for i, name in enumerate(names):
51
- self.graph.outputs[i].name = name
52
-
53
- def fold_constants(self, return_onnx=False):
54
- onnx_graph = fold_constants(gs.export_onnx(self.graph), allow_onnxruntime_shape_inference=True)
55
- self.graph = gs.import_onnx(onnx_graph)
56
- if return_onnx:
57
- return onnx_graph
58
-
59
- def infer_shapes(self, return_onnx=False):
60
- onnx_graph = gs.export_onnx(self.graph)
61
- if onnx_graph.ByteSize() > 2147483648:
62
- raise TypeError("ERROR: model size exceeds supported 2GB limit")
63
- else:
64
- onnx_graph = shape_inference.infer_shapes(onnx_graph)
65
-
66
- self.graph = gs.import_onnx(onnx_graph)
67
- if return_onnx:
68
- return onnx_graph
69
-
70
- def remove_casts(self):
71
- nRemoveCastNode = 0
72
- for node in self.graph.nodes:
73
- # Remove Cast nodes before qkv gemm
74
- if node.op in ["Add", "Transpose"] and len(node.outputs[0].outputs) == 3 and node.o().op == "Cast" and node.o(1).op == "Cast" and node.o(2).op == "Cast":
75
- for i in range(len(node.outputs[0].outputs)):
76
- matMulNode = node.o(i, 0).o()
77
- matMulNode.inputs[0] = node.outputs[0]
78
- nRemoveCastNode += 1
79
-
80
- # Remove double cast nodes after Softmax Node
81
- if node.op == "Softmax" and node.o().op == "Cast" and node.o().o().op == "Cast":
82
- node.o().o().o().inputs[0] = node.outputs[0]
83
- nRemoveCastNode += 1
84
-
85
- self.cleanup()
86
- return nRemoveCastNode
87
-
88
- def remove_parallel_swish(self):
89
- mRemoveSwishNode = 0
90
- for node in self.graph.nodes:
91
- if node.op == "Gemm" and len(node.outputs[0].outputs) > 6:
92
- swishOutputTensor = None
93
- for nextNode in node.outputs[0].outputs:
94
- if nextNode.op == "Mul":
95
- if swishOutputTensor is None:
96
- swishOutputTensor = nextNode.outputs[0]
97
- else:
98
- nextGemmNode = nextNode.o(0)
99
- assert nextGemmNode.op == "Gemm", "Unexpected node type for nextGemmNode {}".format(nextGemmNode.name)
100
- nextGemmNode.inputs = [swishOutputTensor, nextGemmNode.inputs[1], nextGemmNode.inputs[2]]
101
- nextNode.outputs.clear()
102
- mRemoveSwishNode += 1
103
-
104
- self.cleanup()
105
- return mRemoveSwishNode
106
-
107
- def resize_fix(self):
108
- '''
109
- This function loops through the graph looking for Resize nodes that uses scales for resize (has 3 inputs).
110
- It substitutes found Resize with Resize that takes the size of the output tensor instead of scales.
111
- It adds Shape->Slice->Concat
112
- Shape->Slice----^ subgraph to the graph to extract the shape of the output tensor.
113
- This fix is required for the dynamic shape support.
114
- '''
115
- mResizeNodes = 0
116
- for node in self.graph.nodes:
117
- if node.op == "Resize" and len(node.inputs) == 3:
118
- name = node.name + "/"
119
-
120
- add_node = node.o().o().i(1)
121
- div_node = node.i()
122
-
123
- shape_hw_out = gs.Variable(name=name + "shape_hw_out", dtype=np.int64, shape=[4])
124
- shape_hw = gs.Node(op="Shape", name=name+"shape_hw", inputs=[add_node.outputs[0]], outputs=[shape_hw_out])
125
-
126
- const_zero = gs.Constant(name=name + "const_zero", values=np.array([0], dtype=np.int64))
127
- const_two = gs.Constant(name=name + "const_two", values=np.array([2], dtype=np.int64))
128
- const_four = gs.Constant(name=name + "const_four", values=np.array([4], dtype=np.int64))
129
-
130
- slice_hw_out = gs.Variable(name=name + "slice_hw_out", dtype=np.int64, shape=[2])
131
- slice_hw = gs.Node(op="Slice", name=name+"slice_hw", inputs=[shape_hw_out, const_two, const_four, const_zero], outputs=[slice_hw_out])
132
-
133
- shape_bc_out = gs.Variable(name=name + "shape_bc_out", dtype=np.int64, shape=[2])
134
- shape_bc = gs.Node(op="Shape", name=name+"shape_bc", inputs=[div_node.outputs[0]], outputs=[shape_bc_out])
135
-
136
- slice_bc_out = gs.Variable(name=name + "slice_bc_out", dtype=np.int64, shape=[2])
137
- slice_bc = gs.Node(op="Slice", name=name+"slice_bc", inputs=[shape_bc_out, const_zero, const_two, const_zero], outputs=[slice_bc_out])
138
-
139
- concat_bchw_out = gs.Variable(name=name + "concat_bchw_out", dtype=np.int64, shape=[4])
140
- concat_bchw = gs.Node(op="Concat", name=name+"concat_bchw", attrs={"axis": 0}, inputs=[slice_bc_out, slice_hw_out], outputs=[concat_bchw_out])
141
-
142
- none_var = gs.Variable.empty()
143
-
144
- resize_bchw = gs.Node(op="Resize", name=name+"resize_bchw", attrs=node.attrs, inputs=[node.inputs[0], none_var, none_var, concat_bchw_out], outputs=[node.outputs[0]])
145
-
146
- self.graph.nodes.extend([shape_hw, slice_hw, shape_bc, slice_bc, concat_bchw, resize_bchw])
147
-
148
- node.inputs = []
149
- node.outputs = []
150
-
151
- mResizeNodes += 1
152
-
153
- self.cleanup()
154
- return mResizeNodes
155
-
156
-
157
- def adjustAddNode(self):
158
- nAdjustAddNode = 0
159
- for node in self.graph.nodes:
160
- # Change the bias const to the second input to allow Gemm+BiasAdd fusion in TRT.
161
- if node.op in ["Add"] and isinstance(node.inputs[0], gs.ir.tensor.Constant):
162
- tensor = node.inputs[1]
163
- bias = node.inputs[0]
164
- node.inputs = [tensor, bias]
165
- nAdjustAddNode += 1
166
-
167
- self.cleanup()
168
- return nAdjustAddNode
169
-
170
- def decompose_instancenorms(self):
171
- nRemoveInstanceNorm = 0
172
- for node in self.graph.nodes:
173
- if node.op == "InstanceNormalization":
174
- name = node.name + "/"
175
- input_tensor = node.inputs[0]
176
- output_tensor = node.outputs[0]
177
- mean_out = gs.Variable(name=name + "mean_out")
178
- mean_node = gs.Node(op="ReduceMean", name=name + "mean_node", attrs={"axes": [-1]}, inputs=[input_tensor], outputs=[mean_out])
179
- sub_out = gs.Variable(name=name + "sub_out")
180
- sub_node = gs.Node(op="Sub", name=name + "sub_node", attrs={}, inputs=[input_tensor, mean_out], outputs=[sub_out])
181
- pow_out = gs.Variable(name=name + "pow_out")
182
- pow_const = gs.Constant(name=name + "pow_const", values=np.array([2.0], dtype=np.float32))
183
- pow_node = gs.Node(op="Pow", name=name + "pow_node", attrs={}, inputs=[sub_out, pow_const], outputs=[pow_out])
184
- mean2_out = gs.Variable(name=name + "mean2_out")
185
- mean2_node = gs.Node(op="ReduceMean", name=name + "mean2_node", attrs={"axes": [-1]}, inputs=[pow_out], outputs=[mean2_out])
186
- epsilon_out = gs.Variable(name=name + "epsilon_out")
187
- epsilon_const = gs.Constant(name=name + "epsilon_const", values=np.array([node.attrs["epsilon"]], dtype=np.float32))
188
- epsilon_node = gs.Node(op="Add", name=name + "epsilon_node", attrs={}, inputs=[mean2_out, epsilon_const], outputs=[epsilon_out])
189
- sqrt_out = gs.Variable(name=name + "sqrt_out")
190
- sqrt_node = gs.Node(op="Sqrt", name=name + "sqrt_node", attrs={}, inputs=[epsilon_out], outputs=[sqrt_out])
191
- div_out = gs.Variable(name=name + "div_out")
192
- div_node = gs.Node(op="Div", name=name + "div_node", attrs={}, inputs=[sub_out, sqrt_out], outputs=[div_out])
193
- constantScale = gs.Constant("InstanceNormScaleV-" + str(nRemoveInstanceNorm), np.ascontiguousarray(node.inputs[1].inputs[0].attrs["value"].values.reshape(1, 32, 1)))
194
- constantBias = gs.Constant("InstanceBiasV-" + str(nRemoveInstanceNorm), np.ascontiguousarray(node.inputs[2].inputs[0].attrs["value"].values.reshape(1, 32, 1)))
195
- mul_out = gs.Variable(name=name + "mul_out")
196
- mul_node = gs.Node(op="Mul", name=name + "mul_node", attrs={}, inputs=[div_out, constantScale], outputs=[mul_out])
197
- add_node = gs.Node(op="Add", name=name + "add_node", attrs={}, inputs=[mul_out, constantBias], outputs=[output_tensor])
198
- self.graph.nodes.extend([mean_node, sub_node, pow_node, mean2_node, epsilon_node, sqrt_node, div_node, mul_node, add_node])
199
- node.inputs = []
200
- node.outputs = []
201
- nRemoveInstanceNorm += 1
202
-
203
- self.cleanup()
204
- return nRemoveInstanceNorm
205
-
206
- def insert_groupnorm_plugin(self):
207
- nGroupNormPlugin = 0
208
- for node in self.graph.nodes:
209
- if node.op == "Reshape" and node.outputs != [] and \
210
- node.o().op == "ReduceMean" and node.o(1).op == "Sub" and node.o().o() == node.o(1) and \
211
- node.o().o().o().o().o().o().o().o().o().o().o().op == "Mul" and \
212
- node.o().o().o().o().o().o().o().o().o().o().o().o().op == "Add" and \
213
- len(node.o().o().o().o().o().o().o().o().inputs[1].values.shape) == 3:
214
- # "node.outputs != []" is added for VAE
215
-
216
- inputTensor = node.i().inputs[0]
217
-
218
- gammaNode = node.o().o().o().o().o().o().o().o().o().o().o()
219
- index = [type(i) == gs.ir.tensor.Constant for i in gammaNode.inputs].index(True)
220
- gamma = np.array(deepcopy(gammaNode.inputs[index].values.tolist()), dtype=np.float32)
221
- constantGamma = gs.Constant("groupNormGamma-" + str(nGroupNormPlugin), np.ascontiguousarray(gamma.reshape(-1))) # MUST use np.ascontiguousarray, or TRT will regard the shape of this Constant as (0) !!!
222
-
223
- betaNode = gammaNode.o()
224
- index = [type(i) == gs.ir.tensor.Constant for i in betaNode.inputs].index(True)
225
- beta = np.array(deepcopy(betaNode.inputs[index].values.tolist()), dtype=np.float32)
226
- constantBeta = gs.Constant("groupNormBeta-" + str(nGroupNormPlugin), np.ascontiguousarray(beta.reshape(-1)))
227
-
228
- epsilon = node.o().o().o().o().o().inputs[1].values.tolist()[0]
229
-
230
- if betaNode.o().op == "Sigmoid": # need Swish
231
- bSwish = True
232
- lastNode = betaNode.o().o() # Mul node of Swish
233
- else:
234
- bSwish = False
235
- lastNode = betaNode # Cast node after Group Norm
236
-
237
- if lastNode.o().op == "Cast":
238
- lastNode = lastNode.o()
239
- inputList = [inputTensor, constantGamma, constantBeta]
240
- groupNormV = gs.Variable("GroupNormV-" + str(nGroupNormPlugin), np.dtype(np.float16), inputTensor.shape)
241
- groupNormN = gs.Node("GroupNorm", "GroupNormN-" + str(nGroupNormPlugin), inputs=inputList, outputs=[groupNormV], attrs=OrderedDict([('epsilon', epsilon), ('bSwish', int(bSwish))]))
242
- self.graph.nodes.append(groupNormN)
243
-
244
- for subNode in self.graph.nodes:
245
- if lastNode.outputs[0] in subNode.inputs:
246
- index = subNode.inputs.index(lastNode.outputs[0])
247
- subNode.inputs[index] = groupNormV
248
- node.i().inputs = []
249
- lastNode.outputs = []
250
- nGroupNormPlugin += 1
251
-
252
- self.cleanup()
253
- return nGroupNormPlugin
254
-
255
- def insert_layernorm_plugin(self):
256
- nLayerNormPlugin = 0
257
- for node in self.graph.nodes:
258
- if node.op == 'ReduceMean' and \
259
- node.o().op == 'Sub' and node.o().inputs[0] == node.inputs[0] and \
260
- node.o().o(0).op =='Pow' and node.o().o(1).op =='Div' and \
261
- node.o().o(0).o().op == 'ReduceMean' and \
262
- node.o().o(0).o().o().op == 'Add' and \
263
- node.o().o(0).o().o().o().op == 'Sqrt' and \
264
- node.o().o(0).o().o().o().o().op == 'Div' and node.o().o(0).o().o().o().o() == node.o().o(1) and \
265
- node.o().o(0).o().o().o().o().o().op == 'Mul' and \
266
- node.o().o(0).o().o().o().o().o().o().op == 'Add' and \
267
- len(node.o().o(0).o().o().o().o().o().inputs[1].values.shape) == 1:
268
-
269
- if node.i().op == "Add":
270
- inputTensor = node.inputs[0] # CLIP
271
- else:
272
- inputTensor = node.i().inputs[0] # UNet and VAE
273
-
274
- gammaNode = node.o().o().o().o().o().o().o()
275
- index = [type(i) == gs.ir.tensor.Constant for i in gammaNode.inputs].index(True)
276
- gamma = np.array(deepcopy(gammaNode.inputs[index].values.tolist()), dtype=np.float32)
277
- constantGamma = gs.Constant("LayerNormGamma-" + str(nLayerNormPlugin), np.ascontiguousarray(gamma.reshape(-1))) # MUST use np.ascontiguousarray, or TRT will regard the shape of this Constant as (0) !!!
278
-
279
- betaNode = gammaNode.o()
280
- index = [type(i) == gs.ir.tensor.Constant for i in betaNode.inputs].index(True)
281
- beta = np.array(deepcopy(betaNode.inputs[index].values.tolist()), dtype=np.float32)
282
- constantBeta = gs.Constant("LayerNormBeta-" + str(nLayerNormPlugin), np.ascontiguousarray(beta.reshape(-1)))
283
-
284
- inputList = [inputTensor, constantGamma, constantBeta]
285
- layerNormV = gs.Variable("LayerNormV-" + str(nLayerNormPlugin), np.dtype(np.float32), inputTensor.shape)
286
- layerNormN = gs.Node("LayerNorm", "LayerNormN-" + str(nLayerNormPlugin), inputs=inputList, attrs=OrderedDict([('epsilon', 1.e-5)]), outputs=[layerNormV])
287
- self.graph.nodes.append(layerNormN)
288
- nLayerNormPlugin += 1
289
-
290
- if betaNode.outputs[0] in self.graph.outputs:
291
- index = self.graph.outputs.index(betaNode.outputs[0])
292
- self.graph.outputs[index] = layerNormV
293
- else:
294
- if betaNode.o().op == "Cast":
295
- lastNode = betaNode.o()
296
- else:
297
- lastNode = betaNode
298
- for subNode in self.graph.nodes:
299
- if lastNode.outputs[0] in subNode.inputs:
300
- index = subNode.inputs.index(lastNode.outputs[0])
301
- subNode.inputs[index] = layerNormV
302
- lastNode.outputs = []
303
-
304
- self.cleanup()
305
- return nLayerNormPlugin
306
-
307
- def insert_splitgelu_plugin(self):
308
- nSplitGeLUPlugin = 0
309
- for node in self.graph.nodes:
310
- if node.op == "Erf":
311
- inputTensor = node.i().i().i().outputs[0]
312
- lastNode = node.o().o().o().o()
313
- outputShape = inputTensor.shape
314
- outputShape[2] = outputShape[2] // 2
315
-
316
- splitGeLUV = gs.Variable("splitGeLUV-" + str(nSplitGeLUPlugin), np.dtype(np.float32), outputShape)
317
- splitGeLUN = gs.Node("SplitGeLU", "splitGeLUN-" + str(nSplitGeLUPlugin), inputs=[inputTensor], outputs=[splitGeLUV])
318
- self.graph.nodes.append(splitGeLUN)
319
-
320
- for subNode in self.graph.nodes:
321
- if lastNode.outputs[0] in subNode.inputs:
322
- index = subNode.inputs.index(lastNode.outputs[0])
323
- subNode.inputs[index] = splitGeLUV
324
- lastNode.outputs = []
325
- nSplitGeLUPlugin += 1
326
-
327
- self.cleanup()
328
- return nSplitGeLUPlugin
329
-
330
- def insert_seq2spatial_plugin(self):
331
- nSeqLen2SpatialPlugin = 0
332
- for node in self.graph.nodes:
333
- if node.op == "Transpose" and node.o().op == "Conv":
334
- transposeNode = node
335
- reshapeNode = node.i()
336
- assert reshapeNode.op == "Reshape", "Unexpected node type for reshapeNode {}".format(reshapeNode.name)
337
- residualNode = reshapeNode.i(0)
338
- assert residualNode.op == "Add", "Unexpected node type for residualNode {}".format(residualNode.name)
339
- biasNode = residualNode.i(0)
340
- assert biasNode.op == "Add", "Unexpected node type for biasNode {}".format(biasNode.name)
341
- biasIndex = [type(i) == gs.ir.tensor.Constant for i in biasNode.inputs].index(True)
342
- bias = np.array(deepcopy(biasNode.inputs[biasIndex].values.tolist()), dtype=np.float32)
343
- biasInput = gs.Constant("AddAddSeqLen2SpatialBias-" + str(nSeqLen2SpatialPlugin), np.ascontiguousarray(bias.reshape(-1)))
344
- inputIndex = 1 - biasIndex
345
- inputTensor = biasNode.inputs[inputIndex]
346
- residualInput = residualNode.inputs[1]
347
- outputTensor = transposeNode.outputs[0]
348
- outputShapeTensor = transposeNode.i().i().i(1).i(1).i(1).i().inputs[0]
349
- seqLen2SpatialNode = gs.Node("SeqLen2Spatial", "AddAddSeqLen2Spatial-" + str(nSeqLen2SpatialPlugin),
350
- inputs=[inputTensor, biasInput, residualInput, outputShapeTensor], outputs=[outputTensor])
351
- self.graph.nodes.append(seqLen2SpatialNode)
352
- biasNode.inputs.clear()
353
- transposeNode.outputs.clear()
354
- nSeqLen2SpatialPlugin += 1
355
-
356
- self.cleanup()
357
- return nSeqLen2SpatialPlugin
358
-
359
- def fuse_kv(self, node_k, node_v, fused_kv_idx, heads, num_dynamic=0):
360
- # Get weights of K
361
- weights_k = node_k.inputs[1].values
362
- # Get weights of V
363
- weights_v = node_v.inputs[1].values
364
- # Input number of channels to K and V
365
- C = weights_k.shape[0]
366
- # Number of heads
367
- H = heads
368
- # Dimension per head
369
- D = weights_k.shape[1] // H
370
-
371
- # Concat and interleave weights such that the output of fused KV GEMM has [b, s_kv, h, 2, d] shape
372
- weights_kv = np.dstack([weights_k.reshape(C, H, D), weights_v.reshape(C, H, D)]).reshape(C, 2 * H * D)
373
-
374
- # K and V have the same input
375
- input_tensor = node_k.inputs[0]
376
- # K and V must have the same output which we feed into fmha plugin
377
- output_tensor_k = node_k.outputs[0]
378
- # Create tensor
379
- constant_weights_kv = gs.Constant("Weights_KV_{}".format(fused_kv_idx), np.ascontiguousarray(weights_kv))
380
-
381
- # Create fused KV node
382
- fused_kv_node = gs.Node(op="MatMul", name="MatMul_KV_{}".format(fused_kv_idx), inputs=[input_tensor, constant_weights_kv], outputs=[output_tensor_k])
383
- self.graph.nodes.append(fused_kv_node)
384
-
385
- # Connect the output of fused node to the inputs of the nodes after K and V
386
- node_v.o(num_dynamic).inputs[0] = output_tensor_k
387
- node_k.o(num_dynamic).inputs[0] = output_tensor_k
388
- for i in range(0,num_dynamic):
389
- node_v.o().inputs.clear()
390
- node_k.o().inputs.clear()
391
-
392
- # Clear inputs and outputs of K and V to ge these nodes cleared
393
- node_k.outputs.clear()
394
- node_v.outputs.clear()
395
- node_k.inputs.clear()
396
- node_v.inputs.clear()
397
-
398
- self.cleanup()
399
- return fused_kv_node
400
-
401
- def insert_fmhca(self, node_q, node_kv, final_tranpose, mhca_idx, heads, num_dynamic=0):
402
- # Get inputs and outputs for the fMHCA plugin
403
- # We take an output of reshape that follows the Q GEMM
404
- output_q = node_q.o(num_dynamic).o().inputs[0]
405
- output_kv = node_kv.o().inputs[0]
406
- output_final_tranpose = final_tranpose.outputs[0]
407
-
408
- # Clear the inputs of the nodes that follow the Q and KV GEMM
409
- # to delete these subgraphs (it will be substituted by fMHCA plugin)
410
- node_kv.outputs[0].outputs[0].inputs.clear()
411
- node_kv.outputs[0].outputs[0].inputs.clear()
412
- node_q.o(num_dynamic).o().inputs.clear()
413
- for i in range(0,num_dynamic):
414
- node_q.o(i).o().o(1).inputs.clear()
415
-
416
- weights_kv = node_kv.inputs[1].values
417
- dims_per_head = weights_kv.shape[1] // (heads * 2)
418
-
419
- # Reshape dims
420
- shape = gs.Constant("Shape_KV_{}".format(mhca_idx), np.ascontiguousarray(np.array([0, 0, heads, 2, dims_per_head], dtype=np.int64)))
421
-
422
- # Reshape output tensor
423
- output_reshape = gs.Variable("ReshapeKV_{}".format(mhca_idx), np.dtype(np.float16), None)
424
- # Create fMHA plugin
425
- reshape = gs.Node(op="Reshape", name="Reshape_{}".format(mhca_idx), inputs=[output_kv, shape], outputs=[output_reshape])
426
- # Insert node
427
- self.graph.nodes.append(reshape)
428
-
429
- # Create fMHCA plugin
430
- fmhca = gs.Node(op="fMHCA", name="fMHCA_{}".format(mhca_idx), inputs=[output_q, output_reshape], outputs=[output_final_tranpose])
431
- # Insert node
432
- self.graph.nodes.append(fmhca)
433
-
434
- # Connect input of fMHCA to output of Q GEMM
435
- node_q.o(num_dynamic).outputs[0] = output_q
436
-
437
- if num_dynamic > 0:
438
- reshape2_input1_out = gs.Variable("Reshape2_fmhca{}_out".format(mhca_idx), np.dtype(np.int64), None)
439
- reshape2_input1_shape = gs.Node("Shape", "Reshape2_fmhca{}_shape".format(mhca_idx), inputs=[node_q.inputs[0]], outputs=[reshape2_input1_out])
440
- self.graph.nodes.append(reshape2_input1_shape)
441
- final_tranpose.o().inputs[1] = reshape2_input1_out
442
-
443
- # Clear outputs of transpose to get this subgraph cleared
444
- final_tranpose.outputs.clear()
445
-
446
- self.cleanup()
447
-
448
- def fuse_qkv(self, node_q, node_k, node_v, fused_qkv_idx, heads, num_dynamic=0):
449
- # Get weights of Q
450
- weights_q = node_q.inputs[1].values
451
- # Get weights of K
452
- weights_k = node_k.inputs[1].values
453
- # Get weights of V
454
- weights_v = node_v.inputs[1].values
455
-
456
- # Input number of channels to Q, K and V
457
- C = weights_k.shape[0]
458
- # Number of heads
459
- H = heads
460
- # Hidden dimension per head
461
- D = weights_k.shape[1] // H
462
-
463
- # Concat and interleave weights such that the output of fused QKV GEMM has [b, s, h, 3, d] shape
464
- weights_qkv = np.dstack([weights_q.reshape(C, H, D), weights_k.reshape(C, H, D), weights_v.reshape(C, H, D)]).reshape(C, 3 * H * D)
465
-
466
- input_tensor = node_k.inputs[0] # K and V have the same input
467
- # Q, K and V must have the same output which we feed into fmha plugin
468
- output_tensor_k = node_k.outputs[0]
469
- # Concat and interleave weights such that the output of fused QKV GEMM has [b, s, h, 3, d] shape
470
- constant_weights_qkv = gs.Constant("Weights_QKV_{}".format(fused_qkv_idx), np.ascontiguousarray(weights_qkv))
471
-
472
- # Created a fused node
473
- fused_qkv_node = gs.Node(op="MatMul", name="MatMul_QKV_{}".format(fused_qkv_idx), inputs=[input_tensor, constant_weights_qkv], outputs=[output_tensor_k])
474
- self.graph.nodes.append(fused_qkv_node)
475
-
476
- # Connect the output of the fused node to the inputs of the nodes after Q, K and V
477
- node_q.o(num_dynamic).inputs[0] = output_tensor_k
478
- node_k.o(num_dynamic).inputs[0] = output_tensor_k
479
- node_v.o(num_dynamic).inputs[0] = output_tensor_k
480
- for i in range(0,num_dynamic):
481
- node_q.o().inputs.clear()
482
- node_k.o().inputs.clear()
483
- node_v.o().inputs.clear()
484
-
485
- # Clear inputs and outputs of Q, K and V to ge these nodes cleared
486
- node_q.outputs.clear()
487
- node_k.outputs.clear()
488
- node_v.outputs.clear()
489
-
490
- node_q.inputs.clear()
491
- node_k.inputs.clear()
492
- node_v.inputs.clear()
493
-
494
- self.cleanup()
495
- return fused_qkv_node
496
-
497
- def insert_fmha(self, node_qkv, final_tranpose, mha_idx, heads, num_dynamic=0):
498
- # Get inputs and outputs for the fMHA plugin
499
- output_qkv = node_qkv.o().inputs[0]
500
- output_final_tranpose = final_tranpose.outputs[0]
501
-
502
- # Clear the inputs of the nodes that follow the QKV GEMM
503
- # to delete these subgraphs (it will be substituted by fMHA plugin)
504
- node_qkv.outputs[0].outputs[2].inputs.clear()
505
- node_qkv.outputs[0].outputs[1].inputs.clear()
506
- node_qkv.outputs[0].outputs[0].inputs.clear()
507
-
508
- weights_qkv = node_qkv.inputs[1].values
509
- dims_per_head = weights_qkv.shape[1] // (heads * 3)
510
-
511
- # Reshape dims
512
- shape = gs.Constant("Shape_QKV_{}".format(mha_idx), np.ascontiguousarray(np.array([0, 0, heads, 3, dims_per_head], dtype=np.int64)))
513
-
514
- # Reshape output tensor
515
- output_shape = gs.Variable("ReshapeQKV_{}".format(mha_idx), np.dtype(np.float16), None)
516
- # Create fMHA plugin
517
- reshape = gs.Node(op="Reshape", name="Reshape_{}".format(mha_idx), inputs=[output_qkv, shape], outputs=[output_shape])
518
- # Insert node
519
- self.graph.nodes.append(reshape)
520
-
521
- # Create fMHA plugin
522
- fmha = gs.Node(op="fMHA_V2", name="fMHA_{}".format(mha_idx), inputs=[output_shape], outputs=[output_final_tranpose])
523
- # Insert node
524
- self.graph.nodes.append(fmha)
525
-
526
- if num_dynamic > 0:
527
- reshape2_input1_out = gs.Variable("Reshape2_{}_out".format(mha_idx), np.dtype(np.int64), None)
528
- reshape2_input1_shape = gs.Node("Shape", "Reshape2_{}_shape".format(mha_idx), inputs=[node_qkv.inputs[0]], outputs=[reshape2_input1_out])
529
- self.graph.nodes.append(reshape2_input1_shape)
530
- final_tranpose.o().inputs[1] = reshape2_input1_out
531
-
532
- # Clear outputs of transpose to get this subgraph cleared
533
- final_tranpose.outputs.clear()
534
-
535
- self.cleanup()
536
-
537
- def mha_mhca_detected(self, node, mha):
538
- # Go from V GEMM down to the S*V MatMul and all way up to K GEMM
539
- # If we are looking for MHCA inputs of two matmuls (K and V) must be equal.
540
- # If we are looking for MHA inputs (K and V) must be not equal.
541
- if node.op == "MatMul" and len(node.outputs) == 1 and \
542
- ((mha and len(node.inputs[0].inputs) > 0 and node.i().op == "Add") or \
543
- (not mha and len(node.inputs[0].inputs) == 0)):
544
-
545
- if node.o().op == 'Shape':
546
- if node.o(1).op == 'Shape':
547
- num_dynamic_kv = 3 if node.o(2).op == 'Shape' else 2
548
- else:
549
- num_dynamic_kv = 1
550
- # For Cross-Attention, if batch axis is dynamic (in QKV), assume H*W (in Q) is dynamic as well
551
- num_dynamic_q = num_dynamic_kv if mha else num_dynamic_kv + 1
552
- else:
553
- num_dynamic_kv = 0
554
- num_dynamic_q = 0
555
-
556
- o = node.o(num_dynamic_kv)
557
- if o.op == "Reshape" and \
558
- o.o().op == "Transpose" and \
559
- o.o().o().op == "Reshape" and \
560
- o.o().o().o().op == "MatMul" and \
561
- o.o().o().o().i(0).op == "Softmax" and \
562
- o.o().o().o().i(1).op == "Reshape" and \
563
- o.o().o().o().i(0).i().op == "Mul" and \
564
- o.o().o().o().i(0).i().i().op == "MatMul" and \
565
- o.o().o().o().i(0).i().i().i(0).op == "Reshape" and \
566
- o.o().o().o().i(0).i().i().i(1).op == "Transpose" and \
567
- o.o().o().o().i(0).i().i().i(1).i().op == "Reshape" and \
568
- o.o().o().o().i(0).i().i().i(1).i().i().op == "Transpose" and \
569
- o.o().o().o().i(0).i().i().i(1).i().i().i().op == "Reshape" and \
570
- o.o().o().o().i(0).i().i().i(1).i().i().i().i().op == "MatMul" and \
571
- node.name != o.o().o().o().i(0).i().i().i(1).i().i().i().i().name:
572
- # "len(node.outputs) == 1" to make sure we are not in the already fused node
573
- node_q = o.o().o().o().i(0).i().i().i(0).i().i().i()
574
- node_k = o.o().o().o().i(0).i().i().i(1).i().i().i().i()
575
- node_v = node
576
- final_tranpose = o.o().o().o().o(num_dynamic_q).o()
577
- # Sanity check to make sure that the graph looks like expected
578
- if node_q.op == "MatMul" and final_tranpose.op == "Transpose":
579
- return True, num_dynamic_q, num_dynamic_kv, node_q, node_k, node_v, final_tranpose
580
- return False, 0, 0, None, None, None, None
581
-
582
- def fuse_kv_insert_fmhca(self, heads, mhca_index, sm):
583
- nodes = self.graph.nodes
584
- # Iterate over graph and search for MHCA pattern
585
- for idx, _ in enumerate(nodes):
586
- # fMHCA can't be at the 2 last layers of the network. It is a guard from OOB
587
- if idx + 1 > len(nodes) or idx + 2 > len(nodes):
588
- continue
589
-
590
- # Get anchor nodes for fusion and fMHCA plugin insertion if the MHCA is detected
591
- detected, num_dynamic_q, num_dynamic_kv, node_q, node_k, node_v, final_tranpose = \
592
- self.mha_mhca_detected(nodes[idx], mha=False)
593
- if detected:
594
- assert num_dynamic_q == 0 or num_dynamic_q == num_dynamic_kv + 1
595
- # Skip the FMHCA plugin for SM75 except for when the dim per head is 40.
596
- if sm == 75 and node_q.inputs[1].shape[1] // heads == 160:
597
- continue
598
- # Fuse K and V GEMMS
599
- node_kv = self.fuse_kv(node_k, node_v, mhca_index, heads, num_dynamic_kv)
600
- # Insert fMHCA plugin
601
- self.insert_fmhca(node_q, node_kv, final_tranpose, mhca_index, heads, num_dynamic_q)
602
- return True
603
- return False
604
-
605
- def fuse_qkv_insert_fmha(self, heads, mha_index):
606
- nodes = self.graph.nodes
607
- # Iterate over graph and search for MHA pattern
608
- for idx, _ in enumerate(nodes):
609
- # fMHA can't be at the 2 last layers of the network. It is a guard from OOB
610
- if idx + 1 > len(nodes) or idx + 2 > len(nodes):
611
- continue
612
-
613
- # Get anchor nodes for fusion and fMHA plugin insertion if the MHA is detected
614
- detected, num_dynamic_q, num_dynamic_kv, node_q, node_k, node_v, final_tranpose = \
615
- self.mha_mhca_detected(nodes[idx], mha=True)
616
- if detected:
617
- assert num_dynamic_q == num_dynamic_kv
618
- # Fuse Q, K and V GEMMS
619
- node_qkv = self.fuse_qkv(node_q, node_k, node_v, mha_index, heads, num_dynamic_kv)
620
- # Insert fMHA plugin
621
- self.insert_fmha(node_qkv, final_tranpose, mha_index, heads, num_dynamic_kv)
622
- return True
623
- return False
624
-
625
- def insert_fmhca_plugin(self, num_heads, sm):
626
- mhca_index = 0
627
- while self.fuse_kv_insert_fmhca(num_heads, mhca_index, sm):
628
- mhca_index += 1
629
- return mhca_index
630
-
631
- def insert_fmha_plugin(self, num_heads):
632
- mha_index = 0
633
- while self.fuse_qkv_insert_fmha(num_heads, mha_index):
634
- mha_index += 1
635
- return mha_index
636
-
637
- class BaseModel():
638
- def __init__(
639
- self,
640
- hf_token,
641
- text_maxlen=77,
642
- embedding_dim=768,
643
- fp16=False,
644
- device='cuda',
645
- verbose=True,
646
- max_batch_size=16
647
- ):
648
- self.fp16 = fp16
649
- self.device = device
650
- self.verbose = verbose
651
- self.hf_token = hf_token
652
-
653
- # Defaults
654
- self.text_maxlen = text_maxlen
655
- self.embedding_dim = embedding_dim
656
- self.min_batch = 1
657
- self.max_batch = max_batch_size
658
- self.min_latent_shape = 256 // 8 # min image resolution: 256x256
659
- self.max_latent_shape = 1024 // 8 # max image resolution: 1024x1024
660
-
661
- def get_model(self):
662
- pass
663
-
664
- def get_input_names(self):
665
- pass
666
-
667
- def get_output_names(self):
668
- pass
669
-
670
- def get_dynamic_axes(self):
671
- return None
672
-
673
- def get_sample_input(self, batch_size, image_height, image_width):
674
- pass
675
-
676
- def get_input_profile(self, batch_size, image_height, image_width, static_batch, static_shape):
677
- return None
678
-
679
- def get_shape_dict(self, batch_size, image_height, image_width):
680
- return None
681
-
682
- def optimize(self, onnx_graph, minimal_optimization=False):
683
- return onnx_graph
684
-
685
- def check_dims(self, batch_size, image_height, image_width):
686
- assert batch_size >= self.min_batch and batch_size <= self.max_batch
687
- assert image_height % 8 == 0 or image_width % 8 == 0
688
- latent_height = image_height // 8
689
- latent_width = image_width // 8
690
- assert latent_height >= self.min_latent_shape and latent_height <= self.max_latent_shape
691
- assert latent_width >= self.min_latent_shape and latent_width <= self.max_latent_shape
692
- return (latent_height, latent_width)
693
-
694
- def get_minmax_dims(self, batch_size, image_height, image_width, static_batch, static_shape):
695
- min_batch = batch_size if static_batch else self.min_batch
696
- max_batch = batch_size if static_batch else self.max_batch
697
- latent_height = image_height // 8
698
- latent_width = image_width // 8
699
- min_latent_height = latent_height if static_shape else self.min_latent_shape
700
- max_latent_height = latent_height if static_shape else self.max_latent_shape
701
- min_latent_width = latent_width if static_shape else self.min_latent_shape
702
- max_latent_width = latent_width if static_shape else self.max_latent_shape
703
- return (min_batch, max_batch, min_latent_height, max_latent_height, min_latent_width, max_latent_width)
704
-
705
- class CLIP(BaseModel):
706
- def get_model(self):
707
- return CLIPTextModel.from_pretrained("openai/clip-vit-large-patch14").to(self.device)
708
-
709
- def get_input_names(self):
710
- return ['input_ids']
711
-
712
- def get_output_names(self):
713
- return ['text_embeddings', 'pooler_output']
714
-
715
- def get_dynamic_axes(self):
716
- return {
717
- 'input_ids': {0: 'B'},
718
- 'text_embeddings': {0: 'B'}
719
- }
720
-
721
- def get_input_profile(self, batch_size, image_height, image_width, static_batch, static_shape):
722
- self.check_dims(batch_size, image_height, image_width)
723
- min_batch, max_batch, _, _, _, _ = self.get_minmax_dims(batch_size, image_height, image_width, static_batch, static_shape)
724
- return {
725
- 'input_ids': [(min_batch, self.text_maxlen), (batch_size, self.text_maxlen), (max_batch, self.text_maxlen)]
726
- }
727
-
728
- def get_shape_dict(self, batch_size, image_height, image_width):
729
- self.check_dims(batch_size, image_height, image_width)
730
- return {
731
- 'input_ids': (batch_size, self.text_maxlen),
732
- 'text_embeddings': (batch_size, self.text_maxlen, self.embedding_dim)
733
- }
734
-
735
- def get_sample_input(self, batch_size, image_height, image_width):
736
- self.check_dims(batch_size, image_height, image_width)
737
- return torch.zeros(batch_size, self.text_maxlen, dtype=torch.int32, device=self.device)
738
-
739
- def optimize(self, onnx_graph, minimal_optimization=False):
740
- enable_optimization = not minimal_optimization
741
-
742
- # Remove Cast Node to optimize Attention block
743
- bRemoveCastNode = enable_optimization
744
- # Insert LayerNormalization Plugin
745
- bLayerNormPlugin = enable_optimization
746
-
747
- opt = Optimizer(onnx_graph, verbose=self.verbose)
748
- opt.info('CLIP: original')
749
- opt.select_outputs([0]) # delete graph output#1
750
- opt.cleanup()
751
- opt.info('CLIP: remove output[1]')
752
- opt.fold_constants()
753
- opt.info('CLIP: fold constants')
754
- opt.infer_shapes()
755
- opt.info('CLIP: shape inference')
756
-
757
- if bRemoveCastNode:
758
- num_casts_removed = opt.remove_casts()
759
- opt.info('CLIP: removed '+str(num_casts_removed)+' casts')
760
-
761
- if bLayerNormPlugin:
762
- num_layernorm_inserted = opt.insert_layernorm_plugin()
763
- opt.info('CLIP: inserted '+str(num_layernorm_inserted)+' LayerNorm plugins')
764
-
765
- opt.select_outputs([0], names=['text_embeddings']) # rename network output
766
- opt_onnx_graph = opt.cleanup(return_onnx=True)
767
- opt.info('CLIP: final')
768
- return opt_onnx_graph
769
-
770
- class UNet(BaseModel):
771
- def get_model(self):
772
- model_opts = {'revision': 'fp16', 'torch_dtype': torch.float16} if self.fp16 else {}
773
- return UNet2DConditionModel.from_pretrained("CompVis/stable-diffusion-v1-4",
774
- subfolder="unet",
775
- use_auth_token=self.hf_token,
776
- **model_opts).to(self.device)
777
-
778
- def get_input_names(self):
779
- return ['sample', 'timestep', 'encoder_hidden_states']
780
-
781
- def get_output_names(self):
782
- return ['latent']
783
-
784
- def get_dynamic_axes(self):
785
- return {
786
- 'sample': {0: '2B', 2: 'H', 3: 'W'},
787
- 'encoder_hidden_states': {0: '2B'},
788
- 'latent': {0: '2B', 2: 'H', 3: 'W'}
789
- }
790
-
791
- def get_input_profile(self, batch_size, image_height, image_width, static_batch, static_shape):
792
- latent_height, latent_width = self.check_dims(batch_size, image_height, image_width)
793
- min_batch, max_batch, min_latent_height, max_latent_height, min_latent_width, max_latent_width = \
794
- self.get_minmax_dims(batch_size, image_height, image_width, static_batch, static_shape)
795
- return {
796
- 'sample': [(2*min_batch, 4, min_latent_height, min_latent_width), (2*batch_size, 4, latent_height, latent_width), (2*max_batch, 4, max_latent_height, max_latent_width)],
797
- 'encoder_hidden_states': [(2*min_batch, self.text_maxlen, self.embedding_dim), (2*batch_size, self.text_maxlen, self.embedding_dim), (2*max_batch, self.text_maxlen, self.embedding_dim)]
798
- }
799
-
800
- def get_shape_dict(self, batch_size, image_height, image_width):
801
- latent_height, latent_width = self.check_dims(batch_size, image_height, image_width)
802
- return {
803
- 'sample': (2*batch_size, 4, latent_height, latent_width),
804
- 'encoder_hidden_states': (2*batch_size, self.text_maxlen, self.embedding_dim),
805
- 'latent': (2*batch_size, 4, latent_height, latent_width)
806
- }
807
-
808
- def get_sample_input(self, batch_size, image_height, image_width):
809
- latent_height, latent_width = self.check_dims(batch_size, image_height, image_width)
810
- dtype = torch.float16 if self.fp16 else torch.float32
811
- return (
812
- torch.randn(2*batch_size, 4, latent_height, latent_width, dtype=torch.float32, device=self.device),
813
- torch.tensor([1.], dtype=torch.float32, device=self.device),
814
- torch.randn(2*batch_size, self.text_maxlen, self.embedding_dim, dtype=dtype, device=self.device)
815
- )
816
-
817
- def optimize(self, onnx_graph, minimal_optimization=False):
818
- enable_optimization = not minimal_optimization
819
-
820
- # Decompose InstanceNormalization into primitive Ops
821
- bRemoveInstanceNorm = enable_optimization
822
- # Remove Cast Node to optimize Attention block
823
- bRemoveCastNode = enable_optimization
824
- # Remove parallel Swish ops
825
- bRemoveParallelSwish = enable_optimization
826
- # Adjust the bias to be the second input to the Add ops
827
- bAdjustAddNode = enable_optimization
828
- # Change Resize node to take size instead of scale
829
- bResizeFix = enable_optimization
830
-
831
- # Common override for disabling all plugins below
832
- bDisablePlugins = minimal_optimization
833
- # Use multi-head attention Plugin
834
- bMHAPlugin = True
835
- # Use multi-head cross attention Plugin
836
- bMHCAPlugin = True
837
- # Insert GroupNormalization Plugin
838
- bGroupNormPlugin = True
839
- # Insert LayerNormalization Plugin
840
- bLayerNormPlugin = True
841
- # Insert Split+GeLU Plugin
842
- bSplitGeLUPlugin = True
843
- # Replace BiasAdd+ResidualAdd+SeqLen2Spatial with plugin
844
- bSeqLen2SpatialPlugin = True
845
-
846
- opt = Optimizer(onnx_graph, verbose=self.verbose)
847
- opt.info('UNet: original')
848
-
849
- if bRemoveInstanceNorm:
850
- num_instancenorm_replaced = opt.decompose_instancenorms()
851
- opt.info('UNet: replaced '+str(num_instancenorm_replaced)+' InstanceNorms')
852
-
853
- if bRemoveCastNode:
854
- num_casts_removed = opt.remove_casts()
855
- opt.info('UNet: removed '+str(num_casts_removed)+' casts')
856
-
857
- if bRemoveParallelSwish:
858
- num_parallel_swish_removed = opt.remove_parallel_swish()
859
- opt.info('UNet: removed '+str(num_parallel_swish_removed)+' parallel swish ops')
860
-
861
- if bAdjustAddNode:
862
- num_adjust_add = opt.adjustAddNode()
863
- opt.info('UNet: adjusted '+str(num_adjust_add)+' adds')
864
-
865
- if bResizeFix:
866
- num_resize_fix = opt.resize_fix()
867
- opt.info('UNet: fixed '+str(num_resize_fix)+' resizes')
868
-
869
- opt.cleanup()
870
- opt.info('UNet: cleanup')
871
- opt.fold_constants()
872
- opt.info('UNet: fold constants')
873
- opt.infer_shapes()
874
- opt.info('UNet: shape inference')
875
-
876
- num_heads = 8
877
- if bMHAPlugin and not bDisablePlugins:
878
- num_fmha_inserted = opt.insert_fmha_plugin(num_heads)
879
- opt.info('UNet: inserted '+str(num_fmha_inserted)+' fMHA plugins')
880
-
881
- if bMHCAPlugin and not bDisablePlugins:
882
- props = cudart.cudaGetDeviceProperties(0)[1]
883
- sm = props.major * 10 + props.minor
884
- num_fmhca_inserted = opt.insert_fmhca_plugin(num_heads, sm)
885
- opt.info('UNet: inserted '+str(num_fmhca_inserted)+' fMHCA plugins')
886
-
887
- if bGroupNormPlugin and not bDisablePlugins:
888
- num_groupnorm_inserted = opt.insert_groupnorm_plugin()
889
- opt.info('UNet: inserted '+str(num_groupnorm_inserted)+' GroupNorm plugins')
890
-
891
- if bLayerNormPlugin and not bDisablePlugins:
892
- num_layernorm_inserted = opt.insert_layernorm_plugin()
893
- opt.info('UNet: inserted '+str(num_layernorm_inserted)+' LayerNorm plugins')
894
-
895
- if bSplitGeLUPlugin and not bDisablePlugins:
896
- num_splitgelu_inserted = opt.insert_splitgelu_plugin()
897
- opt.info('UNet: inserted '+str(num_splitgelu_inserted)+' SplitGeLU plugins')
898
-
899
- if bSeqLen2SpatialPlugin and not bDisablePlugins:
900
- num_seq2spatial_inserted = opt.insert_seq2spatial_plugin()
901
- opt.info('UNet: inserted '+str(num_seq2spatial_inserted)+' SeqLen2Spatial plugins')
902
-
903
- onnx_opt_graph = opt.cleanup(return_onnx=True)
904
- opt.info('UNet: final')
905
- return onnx_opt_graph
906
-
907
- class VAE(BaseModel):
908
- def get_model(self):
909
- vae = AutoencoderKL.from_pretrained("CompVis/stable-diffusion-v1-4",
910
- subfolder="vae",
911
- use_auth_token=self.hf_token).to(self.device)
912
- vae.forward = vae.decode
913
- return vae
914
-
915
- def get_input_names(self):
916
- return ['latent']
917
-
918
- def get_output_names(self):
919
- return ['images']
920
-
921
- def get_dynamic_axes(self):
922
- return {
923
- 'latent': {0: 'B', 2: 'H', 3: 'W'},
924
- 'images': {0: 'B', 2: '8H', 3: '8W'}
925
- }
926
-
927
- def get_input_profile(self, batch_size, image_height, image_width, static_batch, static_shape):
928
- latent_height, latent_width = self.check_dims(batch_size, image_height, image_width)
929
- min_batch, max_batch, min_latent_height, max_latent_height, min_latent_width, max_latent_width = \
930
- self.get_minmax_dims(batch_size, image_height, image_width, static_batch, static_shape)
931
- return {
932
- 'latent': [(min_batch, 4, min_latent_height, min_latent_width), (batch_size, 4, latent_height, latent_width), (max_batch, 4, max_latent_height, max_latent_width)]
933
- }
934
-
935
- def get_shape_dict(self, batch_size, image_height, image_width):
936
- latent_height, latent_width = self.check_dims(batch_size, image_height, image_width)
937
- return {
938
- 'latent': (batch_size, 4, latent_height, latent_width),
939
- 'images': (batch_size, 3, image_height, image_width)
940
- }
941
-
942
- def get_sample_input(self, batch_size, image_height, image_width):
943
- latent_height, latent_width = self.check_dims(batch_size, image_height, image_width)
944
- return torch.randn(batch_size, 4, latent_height, latent_width, dtype=torch.float32, device=self.device)
945
-
946
- def optimize(self, onnx_graph, minimal_optimization=False):
947
- enable_optimization = not minimal_optimization
948
-
949
- # Decompose InstanceNormalization into primitive Ops
950
- bRemoveInstanceNorm = enable_optimization
951
- # Remove Cast Node to optimize Attention block
952
- bRemoveCastNode = enable_optimization
953
- # Insert GroupNormalization Plugin
954
- bGroupNormPlugin = enable_optimization
955
-
956
- opt = Optimizer(onnx_graph, verbose=self.verbose)
957
- opt.info('VAE: original')
958
-
959
- if bRemoveInstanceNorm:
960
- num_instancenorm_replaced = opt.decompose_instancenorms()
961
- opt.info('VAE: replaced '+str(num_instancenorm_replaced)+' InstanceNorms')
962
-
963
- if bRemoveCastNode:
964
- num_casts_removed = opt.remove_casts()
965
- opt.info('VAE: removed '+str(num_casts_removed)+' casts')
966
-
967
- opt.cleanup()
968
- opt.info('VAE: cleanup')
969
- opt.fold_constants()
970
- opt.info('VAE: fold constants')
971
- opt.infer_shapes()
972
- opt.info('VAE: shape inference')
973
-
974
- if bGroupNormPlugin:
975
- num_groupnorm_inserted = opt.insert_groupnorm_plugin()
976
- opt.info('VAE: inserted '+str(num_groupnorm_inserted)+' GroupNorm plugins')
977
-
978
- onnx_opt_graph = opt.cleanup(return_onnx=True)
979
- opt.info('VAE: final')
980
- return onnx_opt_graph