cbensimon HF Staff commited on
Commit
356294b
·
verified ·
1 Parent(s): 0c785c6
Files changed (1) hide show
  1. optimization.py +20 -4
optimization.py CHANGED
@@ -1,6 +1,7 @@
1
  """
2
  """
3
 
 
4
  from typing import Any
5
  from typing import Callable
6
  from typing import ParamSpec
@@ -38,17 +39,25 @@ INDUCTOR_CONFIGS = {
38
 
39
  def optimize_pipeline_(pipeline: Callable[P, Any], *args: P.args, **kwargs: P.kwargs):
40
 
 
 
41
  @spaces.GPU(duration=1500)
42
  def compile_transformer():
43
 
 
 
44
  with capture_component_call(pipeline, 'transformer') as call:
45
  pipeline(*args, **kwargs)
46
 
 
 
47
  dynamic_shapes = tree_map_only((torch.Tensor, bool), lambda t: None, call.kwargs)
48
  dynamic_shapes |= TRANSFORMER_DYNAMIC_SHAPES
49
 
50
  quantize_(pipeline.transformer, Float8DynamicActivationFloat8WeightConfig())
51
 
 
 
52
  hidden_states: torch.Tensor = call.kwargs['hidden_states']
53
  hidden_states_transposed = hidden_states.transpose(-1, -2).contiguous()
54
  if hidden_states.shape[-1] > hidden_states.shape[-2]:
@@ -65,6 +74,8 @@ def optimize_pipeline_(pipeline: Callable[P, Any], *args: P.args, **kwargs: P.kw
65
  dynamic_shapes=dynamic_shapes,
66
  )
67
 
 
 
68
  exported_portrait = torch.export.export(
69
  mod=pipeline.transformer,
70
  args=call.args,
@@ -72,10 +83,15 @@ def optimize_pipeline_(pipeline: Callable[P, Any], *args: P.args, **kwargs: P.kw
72
  dynamic_shapes=dynamic_shapes,
73
  )
74
 
75
- return (
76
- aoti_compile(exported_landscape, INDUCTOR_CONFIGS),
77
- aoti_compile(exported_portrait, INDUCTOR_CONFIGS),
78
- )
 
 
 
 
 
79
 
80
  compiled_landscape, compiled_portrait = compile_transformer()
81
 
 
1
  """
2
  """
3
 
4
+ from datetime import datetime
5
  from typing import Any
6
  from typing import Callable
7
  from typing import ParamSpec
 
39
 
40
  def optimize_pipeline_(pipeline: Callable[P, Any], *args: P.args, **kwargs: P.kwargs):
41
 
42
+ t0 = datetime.now()
43
+
44
  @spaces.GPU(duration=1500)
45
  def compile_transformer():
46
 
47
+ print('compile_transformer', -(t0 - (t0 := datetime.now())))
48
+
49
  with capture_component_call(pipeline, 'transformer') as call:
50
  pipeline(*args, **kwargs)
51
 
52
+ print('capture_component_call', -(t0 - (t0 := datetime.now())))
53
+
54
  dynamic_shapes = tree_map_only((torch.Tensor, bool), lambda t: None, call.kwargs)
55
  dynamic_shapes |= TRANSFORMER_DYNAMIC_SHAPES
56
 
57
  quantize_(pipeline.transformer, Float8DynamicActivationFloat8WeightConfig())
58
 
59
+ print('quantize_', -(t0 - (t0 := datetime.now())))
60
+
61
  hidden_states: torch.Tensor = call.kwargs['hidden_states']
62
  hidden_states_transposed = hidden_states.transpose(-1, -2).contiguous()
63
  if hidden_states.shape[-1] > hidden_states.shape[-2]:
 
74
  dynamic_shapes=dynamic_shapes,
75
  )
76
 
77
+ print('exported_landscape', -(t0 - (t0 := datetime.now())))
78
+
79
  exported_portrait = torch.export.export(
80
  mod=pipeline.transformer,
81
  args=call.args,
 
83
  dynamic_shapes=dynamic_shapes,
84
  )
85
 
86
+ print('exported_portrait', -(t0 - (t0 := datetime.now())))
87
+
88
+ compiled_landscape = aoti_compile(exported_landscape, INDUCTOR_CONFIGS)
89
+ print('compiled_landscape', -(t0 - (t0 := datetime.now())))
90
+
91
+ compiled_portrait = aoti_compile(exported_portrait, INDUCTOR_CONFIGS)
92
+ print('compiled_portrait', -(t0 - (t0 := datetime.now())))
93
+
94
+ return compiled_landscape, compiled_portrait
95
 
96
  compiled_landscape, compiled_portrait = compile_transformer()
97