IFMedTechdemo commited on
Commit
e70b695
·
verified ·
1 Parent(s): 62ee321

Update optimization.py

Browse files
Files changed (1) hide show
  1. optimization.py +4 -8
optimization.py CHANGED
@@ -1,7 +1,6 @@
1
  """
2
  Optimization module for Qwen-Image-Edit using TorchAO quantization and AoTI compilation.
3
  """
4
-
5
  from typing import Any
6
  from typing import Callable
7
  from typing import ParamSpec
@@ -47,21 +46,18 @@ INDUCTOR_CONFIGS = {
47
 
48
 
49
  def optimize_pipeline_(pipeline: Callable[P, Any], *args: P.args, **kwargs: P.kwargs):
50
- """
51
- Optimizes the Qwen-Image-Edit pipeline using AoT compilation and quantization.
52
- This function pre-compiles the transformer for faster inference.
53
- """
54
  @spaces.GPU(duration=1500)
55
  def compile_transformer():
 
56
  with spaces.aoti_capture(pipeline.transformer) as call:
57
  pipeline(*args, **kwargs)
58
 
59
  dynamic_shapes = tree_map(lambda t: None, call.kwargs)
60
  dynamic_shapes |= TRANSFORMER_DYNAMIC_SHAPES
61
 
62
- # Optional: Uncomment to enable Float8 quantization
63
  # quantize_(pipeline.transformer, Float8DynamicActivationFloat8WeightConfig())
64
-
65
  exported = torch.export.export(
66
  mod=pipeline.transformer,
67
  args=call.args,
@@ -71,4 +67,4 @@ def optimize_pipeline_(pipeline: Callable[P, Any], *args: P.args, **kwargs: P.kw
71
 
72
  return spaces.aoti_compile(exported, INDUCTOR_CONFIGS)
73
 
74
- spaces.aoti_apply(compile_transformer(), pipeline.transformer)
 
1
  """
2
  Optimization module for Qwen-Image-Edit using TorchAO quantization and AoTI compilation.
3
  """
 
4
  from typing import Any
5
  from typing import Callable
6
  from typing import ParamSpec
 
46
 
47
 
48
  def optimize_pipeline_(pipeline: Callable[P, Any], *args: P.args, **kwargs: P.kwargs):
49
+
 
 
 
50
  @spaces.GPU(duration=1500)
51
  def compile_transformer():
52
+
53
  with spaces.aoti_capture(pipeline.transformer) as call:
54
  pipeline(*args, **kwargs)
55
 
56
  dynamic_shapes = tree_map(lambda t: None, call.kwargs)
57
  dynamic_shapes |= TRANSFORMER_DYNAMIC_SHAPES
58
 
 
59
  # quantize_(pipeline.transformer, Float8DynamicActivationFloat8WeightConfig())
60
+
61
  exported = torch.export.export(
62
  mod=pipeline.transformer,
63
  args=call.args,
 
67
 
68
  return spaces.aoti_compile(exported, INDUCTOR_CONFIGS)
69
 
70
+ spaces.aoti_apply(compile_transformer(), pipeline.transformer)