prithivMLmods commited on
Commit
c3dd14a
Β·
verified Β·
1 Parent(s): 7986899

update optimization

Browse files
Files changed (1) hide show
  1. optimization.py +71 -0
optimization.py ADDED
@@ -0,0 +1,71 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import Any
2
+ from typing import Callable
3
+ from typing import ParamSpec
4
+ import spaces
5
+ import torch
6
+ from torch.utils._pytree import tree_map
7
+ from spaces.zero.torch.aoti import ZeroGPUCompiledModel, ZeroGPUWeights
8
+
9
+ P = ParamSpec('P')
10
+
11
+
12
+ TRANSFORMER_IMAGE_SEQ_LENGTH_DIM = torch.export.Dim('image_seq_length')
13
+ TRANSFORMER_TEXT_SEQ_LENGTH_DIM = torch.export.Dim('text_seq_length')
14
+
15
+ TRANSFORMER_DYNAMIC_SHAPES = {
16
+ 'hidden_states': {
17
+ 1: TRANSFORMER_IMAGE_SEQ_LENGTH_DIM,
18
+ },
19
+ 'encoder_hidden_states': {
20
+ 1: TRANSFORMER_TEXT_SEQ_LENGTH_DIM,
21
+ },
22
+ 'encoder_hidden_states_mask': {
23
+ 1: TRANSFORMER_TEXT_SEQ_LENGTH_DIM,
24
+ },
25
+ 'image_rotary_emb': ({
26
+ 0: TRANSFORMER_IMAGE_SEQ_LENGTH_DIM,
27
+ }, {
28
+ 0: TRANSFORMER_TEXT_SEQ_LENGTH_DIM,
29
+ }),
30
+ }
31
+
32
+
33
+ INDUCTOR_CONFIGS = {
34
+ 'conv_1x1_as_mm': True,
35
+ 'epilogue_fusion': False,
36
+ 'coordinate_descent_tuning': True,
37
+ 'coordinate_descent_check_all_directions': True,
38
+ 'max_autotune': True,
39
+ 'triton.cudagraphs': True,
40
+ }
41
+
42
+
43
+ def optimize_pipeline_(pipeline: Callable[P, Any], *args: P.args, **kwargs: P.kwargs):
44
+
45
+ @spaces.GPU(duration=1500)
46
+ def compile_transformer():
47
+
48
+ # Only capture what the first `transformer_block` sees.
49
+ with spaces.aoti_capture(pipeline.transformer.transformer_blocks[0]) as call:
50
+ pipeline(*args, **kwargs)
51
+
52
+ dynamic_shapes = tree_map(lambda t: None, call.kwargs)
53
+ dynamic_shapes |= TRANSFORMER_DYNAMIC_SHAPES
54
+
55
+ # Optionally quantize it.
56
+ # quantize_(pipeline.transformer, Float8DynamicActivationFloat8WeightConfig())
57
+
58
+ # Only export the first transformer block.
59
+ exported = torch.export.export(
60
+ mod=pipeline.transformer.transformer_blocks[0],
61
+ args=call.args,
62
+ kwargs=call.kwargs,
63
+ dynamic_shapes=dynamic_shapes,
64
+ )
65
+ return spaces.aoti_compile(exported, INDUCTOR_CONFIGS)
66
+
67
+ compiled = compile_transformer()
68
+ for block in pipeline.transformer.transformer_blocks:
69
+ weights = ZeroGPUWeights(block.state_dict())
70
+ compiled_block = ZeroGPUCompiledModel(compiled.archive_file, weights)
71
+ block.forward = compiled_block