IFMedTechdemo commited on
Commit
8ebd4db
·
verified ·
1 Parent(s): c2c0be3

Create optimization.py

Browse files
Files changed (1) hide show
  1. optimization.py +74 -0
optimization.py ADDED
@@ -0,0 +1,74 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Optimization module for Qwen-Image-Edit using TorchAO quantization and AoTI compilation.
3
+ """
4
+
5
+ from typing import Any
6
+ from typing import Callable
7
+ from typing import ParamSpec
8
+ from torchao.quantization import quantize_
9
+ from torchao.quantization import Float8DynamicActivationFloat8WeightConfig
10
+ import spaces
11
+ import torch
12
+ from torch.utils._pytree import tree_map
13
+
14
+
15
+ P = ParamSpec('P')
16
+
17
+
18
+ TRANSFORMER_IMAGE_SEQ_LENGTH_DIM = torch.export.Dim('image_seq_length')
19
+ TRANSFORMER_TEXT_SEQ_LENGTH_DIM = torch.export.Dim('text_seq_length')
20
+
21
+ TRANSFORMER_DYNAMIC_SHAPES = {
22
+ 'hidden_states': {
23
+ 1: TRANSFORMER_IMAGE_SEQ_LENGTH_DIM,
24
+ },
25
+ 'encoder_hidden_states': {
26
+ 1: TRANSFORMER_TEXT_SEQ_LENGTH_DIM,
27
+ },
28
+ 'encoder_hidden_states_mask': {
29
+ 1: TRANSFORMER_TEXT_SEQ_LENGTH_DIM,
30
+ },
31
+ 'image_rotary_emb': ({
32
+ 0: TRANSFORMER_IMAGE_SEQ_LENGTH_DIM,
33
+ }, {
34
+ 0: TRANSFORMER_TEXT_SEQ_LENGTH_DIM,
35
+ }),
36
+ }
37
+
38
+
39
+ INDUCTOR_CONFIGS = {
40
+ 'conv_1x1_as_mm': True,
41
+ 'epilogue_fusion': False,
42
+ 'coordinate_descent_tuning': True,
43
+ 'coordinate_descent_check_all_directions': True,
44
+ 'max_autotune': True,
45
+ 'triton.cudagraphs': True,
46
+ }
47
+
48
+
49
+ def optimize_pipeline_(pipeline: Callable[P, Any], *args: P.args, **kwargs: P.kwargs):
50
+ """
51
+ Optimizes the Qwen-Image-Edit pipeline using AoT compilation and quantization.
52
+ This function pre-compiles the transformer for faster inference.
53
+ """
54
+ @spaces.GPU(duration=1500)
55
+ def compile_transformer():
56
+ with spaces.aoti_capture(pipeline.transformer) as call:
57
+ pipeline(*args, **kwargs)
58
+
59
+ dynamic_shapes = tree_map(lambda t: None, call.kwargs)
60
+ dynamic_shapes |= TRANSFORMER_DYNAMIC_SHAPES
61
+
62
+ # Optional: Uncomment to enable Float8 quantization
63
+ # quantize_(pipeline.transformer, Float8DynamicActivationFloat8WeightConfig())
64
+
65
+ exported = torch.export.export(
66
+ mod=pipeline.transformer,
67
+ args=call.args,
68
+ kwargs=call.kwargs,
69
+ dynamic_shapes=dynamic_shapes,
70
+ )
71
+
72
+ return spaces.aoti_compile(exported, INDUCTOR_CONFIGS)
73
+
74
+ spaces.aoti_apply(compile_transformer(), pipeline.transformer)