prithivMLmods commited on
Commit
7e4d38b
·
verified ·
1 Parent(s): 6b8e801

Delete optimization.py

Browse files
Files changed (1) hide show
  1. optimization.py +0 -70
optimization.py DELETED
@@ -1,70 +0,0 @@
1
- """
2
- """
3
-
4
- from typing import Any
5
- from typing import Callable
6
- from typing import ParamSpec
7
- from torchao.quantization import quantize_
8
- from torchao.quantization import Float8DynamicActivationFloat8WeightConfig
9
- import spaces
10
- import torch
11
- from torch.utils._pytree import tree_map
12
-
13
-
14
- P = ParamSpec('P')
15
-
16
-
17
- TRANSFORMER_IMAGE_SEQ_LENGTH_DIM = torch.export.Dim('image_seq_length')
18
- TRANSFORMER_TEXT_SEQ_LENGTH_DIM = torch.export.Dim('text_seq_length')
19
-
20
- TRANSFORMER_DYNAMIC_SHAPES = {
21
- 'hidden_states': {
22
- 1: TRANSFORMER_IMAGE_SEQ_LENGTH_DIM,
23
- },
24
- 'encoder_hidden_states': {
25
- 1: TRANSFORMER_TEXT_SEQ_LENGTH_DIM,
26
- },
27
- 'encoder_hidden_states_mask': {
28
- 1: TRANSFORMER_TEXT_SEQ_LENGTH_DIM,
29
- },
30
- 'image_rotary_emb': ({
31
- 0: TRANSFORMER_IMAGE_SEQ_LENGTH_DIM,
32
- }, {
33
- 0: TRANSFORMER_TEXT_SEQ_LENGTH_DIM,
34
- }),
35
- }
36
-
37
-
38
- INDUCTOR_CONFIGS = {
39
- 'conv_1x1_as_mm': True,
40
- 'epilogue_fusion': False,
41
- 'coordinate_descent_tuning': True,
42
- 'coordinate_descent_check_all_directions': True,
43
- 'max_autotune': True,
44
- 'triton.cudagraphs': True,
45
- }
46
-
47
-
48
- def optimize_pipeline_(pipeline: Callable[P, Any], *args: P.args, **kwargs: P.kwargs):
49
-
50
- @spaces.GPU(duration=1500)
51
- def compile_transformer():
52
-
53
- with spaces.aoti_capture(pipeline.transformer) as call:
54
- pipeline(*args, **kwargs)
55
-
56
- dynamic_shapes = tree_map(lambda t: None, call.kwargs)
57
- dynamic_shapes |= TRANSFORMER_DYNAMIC_SHAPES
58
-
59
- # quantize_(pipeline.transformer, Float8DynamicActivationFloat8WeightConfig())
60
-
61
- exported = torch.export.export(
62
- mod=pipeline.transformer,
63
- args=call.args,
64
- kwargs=call.kwargs,
65
- dynamic_shapes=dynamic_shapes,
66
- )
67
-
68
- return spaces.aoti_compile(exported, INDUCTOR_CONFIGS)
69
-
70
- spaces.aoti_apply(compile_transformer(), pipeline.transformer)