Masaaki Kawata commited on
Commit
fb1a5ba
Β·
1 Parent(s): fec4ab4

Update optimization.py

Browse files
Files changed (1) hide show
  1. optimization.py +7 -8
optimization.py CHANGED
@@ -8,9 +8,8 @@ from typing import ParamSpec
8
  import spaces
9
  import torch
10
  from torch.utils._pytree import tree_map_only
11
-
12
- from optimization_utils import capture_component_call
13
- from optimization_utils import aoti_compile
14
 
15
 
16
  P = ParamSpec('P')
@@ -38,7 +37,7 @@ def optimize_pipeline_(pipeline: Callable[P, Any], *args: P.args, **kwargs: P.kw
38
  @spaces.GPU(duration=1500)
39
  def compile_transformer():
40
 
41
- with capture_component_call(pipeline, 'transformer') as call:
42
  pipeline(*args, **kwargs)
43
 
44
  dynamic_shapes = tree_map_only((torch.Tensor, bool), lambda t: None, call.kwargs)
@@ -46,6 +45,8 @@ def optimize_pipeline_(pipeline: Callable[P, Any], *args: P.args, **kwargs: P.kw
46
 
47
  pipeline.transformer.fuse_qkv_projections()
48
 
 
 
49
  exported = torch.export.export(
50
  mod=pipeline.transformer,
51
  args=call.args,
@@ -53,8 +54,6 @@ def optimize_pipeline_(pipeline: Callable[P, Any], *args: P.args, **kwargs: P.kw
53
  dynamic_shapes=dynamic_shapes,
54
  )
55
 
56
- return aoti_compile(exported, INDUCTOR_CONFIGS)
57
 
58
- transformer_config = pipeline.transformer.config
59
- pipeline.transformer = compile_transformer()
60
- pipeline.transformer.config = transformer_config # pyright: ignore[reportAttributeAccessIssue]
 
8
  import spaces
9
  import torch
10
  from torch.utils._pytree import tree_map_only
11
+ from torchao.quantization import quantize_
12
+ from torchao.quantization import Float8DynamicActivationFloat8WeightConfig
 
13
 
14
 
15
  P = ParamSpec('P')
 
37
  @spaces.GPU(duration=1500)
38
  def compile_transformer():
39
 
40
+ with spaces.aoti_capture(pipeline.transformer) as call:
41
  pipeline(*args, **kwargs)
42
 
43
  dynamic_shapes = tree_map_only((torch.Tensor, bool), lambda t: None, call.kwargs)
 
45
 
46
  pipeline.transformer.fuse_qkv_projections()
47
 
48
+ quantize_(pipeline.transformer, Float8DynamicActivationFloat8WeightConfig())
49
+
50
  exported = torch.export.export(
51
  mod=pipeline.transformer,
52
  args=call.args,
 
54
  dynamic_shapes=dynamic_shapes,
55
  )
56
 
57
+ return spaces.aoti_compile(exported, INDUCTOR_CONFIGS)
58
 
59
+ spaces.aoti_apply(compile_transformer(), pipeline.transformer)