Masaaki Kawata commited on
Commit
88f11ab
Β·
1 Parent(s): 61715a7

Update optimization.py

Browse files
Files changed (1) hide show
  1. optimization.py +11 -22
optimization.py CHANGED
@@ -4,37 +4,24 @@
4
  from typing import Any
5
  from typing import Callable
6
  from typing import ParamSpec
7
- from torchao.quantization import quantize_
8
- from torchao.quantization import Float8DynamicActivationFloat8WeightConfig
9
  import spaces
10
  import torch
11
- from torch.utils._pytree import tree_map
 
 
12
 
13
 
14
  P = ParamSpec('P')
15
 
16
 
17
- TRANSFORMER_IMAGE_SEQ_LENGTH_DIM = torch.export.Dim('image_seq_length')
18
- TRANSFORMER_TEXT_SEQ_LENGTH_DIM = torch.export.Dim('text_seq_length')
19
 
20
  TRANSFORMER_DYNAMIC_SHAPES = {
21
- 'hidden_states': {
22
- 1: TRANSFORMER_IMAGE_SEQ_LENGTH_DIM,
23
- },
24
- 'encoder_hidden_states': {
25
- 1: TRANSFORMER_TEXT_SEQ_LENGTH_DIM,
26
- },
27
- 'encoder_hidden_states_mask': {
28
- 1: TRANSFORMER_TEXT_SEQ_LENGTH_DIM,
29
- },
30
- 'image_rotary_emb': ({
31
- 0: TRANSFORMER_IMAGE_SEQ_LENGTH_DIM,
32
- }, {
33
- 0: TRANSFORMER_TEXT_SEQ_LENGTH_DIM,
34
- }),
35
  }
36
 
37
-
38
  INDUCTOR_CONFIGS = {
39
  'conv_1x1_as_mm': True,
40
  'epilogue_fusion': False,
@@ -53,10 +40,12 @@ def optimize_pipeline_(pipeline: Callable[P, Any], *args: P.args, **kwargs: P.kw
53
  with spaces.aoti_capture(pipeline.transformer) as call:
54
  pipeline(*args, **kwargs)
55
 
56
- dynamic_shapes = tree_map(lambda t: None, call.kwargs)
57
  dynamic_shapes |= TRANSFORMER_DYNAMIC_SHAPES
58
 
59
- # quantize_(pipeline.transformer, Float8DynamicActivationFloat8WeightConfig())
 
 
60
 
61
  exported = torch.export.export(
62
  mod=pipeline.transformer,
 
4
  from typing import Any
5
  from typing import Callable
6
  from typing import ParamSpec
7
+
 
8
  import spaces
9
  import torch
10
+ from torch.utils._pytree import tree_map_only
11
+ from torchao.quantization import quantize_
12
+ from torchao.quantization import Float8DynamicActivationFloat8WeightConfig
13
 
14
 
15
  P = ParamSpec('P')
16
 
17
 
18
+ TRANSFORMER_HIDDEN_DIM = torch.export.Dim('hidden', min=4096, max=8212)
 
19
 
20
  TRANSFORMER_DYNAMIC_SHAPES = {
21
+ 'hidden_states': {1: TRANSFORMER_HIDDEN_DIM},
22
+ 'img_ids': {0: TRANSFORMER_HIDDEN_DIM},
 
 
 
 
 
 
 
 
 
 
 
 
23
  }
24
 
 
25
  INDUCTOR_CONFIGS = {
26
  'conv_1x1_as_mm': True,
27
  'epilogue_fusion': False,
 
40
  with spaces.aoti_capture(pipeline.transformer) as call:
41
  pipeline(*args, **kwargs)
42
 
43
+ dynamic_shapes = tree_map_only((torch.Tensor, bool), lambda t: None, call.kwargs)
44
  dynamic_shapes |= TRANSFORMER_DYNAMIC_SHAPES
45
 
46
+ pipeline.transformer.fuse_qkv_projections()
47
+
48
+ quantize_(pipeline.transformer, Float8DynamicActivationFloat8WeightConfig())
49
 
50
  exported = torch.export.export(
51
  mod=pipeline.transformer,