multimodalart HF Staff commited on
Commit
3174181
·
verified ·
1 Parent(s): 8e36388

Update optimization.py

Browse files
Files changed (1) hide show
  1. optimization.py +6 -7
optimization.py CHANGED
@@ -14,8 +14,8 @@ from torch.utils._pytree import tree_map
14
  P = ParamSpec('P')
15
 
16
 
17
- TRANSFORMER_IMAGE_SEQ_LENGTH_DIM = torch.export.Dim('image_seq_length')
18
- TRANSFORMER_TEXT_SEQ_LENGTH_DIM = torch.export.Dim('text_seq_length')
19
 
20
  TRANSFORMER_DYNAMIC_SHAPES = {
21
  'hidden_states': {
@@ -27,11 +27,10 @@ TRANSFORMER_DYNAMIC_SHAPES = {
27
  'encoder_hidden_states_mask': {
28
  1: TRANSFORMER_TEXT_SEQ_LENGTH_DIM,
29
  },
30
- #'image_rotary_emb': ({
31
- # 0: TRANSFORMER_IMAGE_SEQ_LENGTH_DIM,
32
- #}, {
33
- # 0: TRANSFORMER_TEXT_SEQ_LENGTH_DIM,
34
- #}),
35
  }
36
 
37
 
 
14
  P = ParamSpec('P')
15
 
16
 
17
+ TRANSFORMER_IMAGE_SEQ_LENGTH_DIM = torch.export.Dim('image_seq_length', min=1024, max=16384)
18
+ TRANSFORMER_TEXT_SEQ_LENGTH_DIM = torch.export.Dim('text_seq_length', min=64, max=1024)
19
 
20
  TRANSFORMER_DYNAMIC_SHAPES = {
21
  'hidden_states': {
 
27
  'encoder_hidden_states_mask': {
28
  1: TRANSFORMER_TEXT_SEQ_LENGTH_DIM,
29
  },
30
+ 'image_rotary_emb': (
31
+ {0: TRANSFORMER_IMAGE_SEQ_LENGTH_DIM}, # vid_freqs: [img_seq_len, rope_dim]
32
+ {0: TRANSFORMER_TEXT_SEQ_LENGTH_DIM}, # txt_freqs: [txt_seq_len, rope_dim]
33
+ ),
 
34
  }
35
 
36