Fabrice-TIERCELIN commited on
Commit
7482601
·
verified ·
1 Parent(s): d8ae88b

COMPILED_TRANSFORMER_1 = compiled_transformer_1

Browse files
Files changed (1) hide show
  1. optimization.py +10 -0
optimization.py CHANGED
@@ -23,6 +23,10 @@ from optimization_utils import ZeroGPUCompiledModel, ZeroGPUWeights
23
 
24
  P = ParamSpec('P')
25
 
 
 
 
 
26
  LATENT_FRAMES_DIM = torch.export.Dim('num_latent_frames', min=8, max=81)
27
  LATENT_PATCHED_HEIGHT_DIM = torch.export.Dim('latent_patched_height', min=30, max=52)
28
  LATENT_PATCHED_WIDTH_DIM = torch.export.Dim('latent_patched_width', min=30, max=52)
@@ -99,6 +103,8 @@ def load_compiled_transformers_from_hub(
99
 
100
 
101
  def optimize_pipeline_(pipeline: Callable[P, Any], *args: P.args, **kwargs: P.kwargs):
 
 
102
  @spaces.GPU(duration=1500)
103
  def compile_transformer():
104
  pipeline.load_lora_weights(
@@ -156,6 +162,10 @@ def optimize_pipeline_(pipeline: Callable[P, Any], *args: P.args, **kwargs: P.kw
156
  else:
157
  compiled_transformer_1, compiled_transformer_2 = compile_transformer()
158
 
 
 
 
 
159
  pipeline.transformer.forward = compiled_transformer_1
160
  drain_module_parameters(pipeline.transformer)
161
 
 
23
 
24
  P = ParamSpec('P')
25
 
26
+ # Expose compiled models so app.py can offer them for download
27
+ COMPILED_TRANSFORMER_1 = None
28
+ COMPILED_TRANSFORMER_2 = None
29
+
30
  LATENT_FRAMES_DIM = torch.export.Dim('num_latent_frames', min=8, max=81)
31
  LATENT_PATCHED_HEIGHT_DIM = torch.export.Dim('latent_patched_height', min=30, max=52)
32
  LATENT_PATCHED_WIDTH_DIM = torch.export.Dim('latent_patched_width', min=30, max=52)
 
103
 
104
 
105
  def optimize_pipeline_(pipeline: Callable[P, Any], *args: P.args, **kwargs: P.kwargs):
106
+ global COMPILED_TRANSFORMER_1, COMPILED_TRANSFORMER_2
107
+
108
  @spaces.GPU(duration=1500)
109
  def compile_transformer():
110
  pipeline.load_lora_weights(
 
162
  else:
163
  compiled_transformer_1, compiled_transformer_2 = compile_transformer()
164
 
165
+ # expose for downloads
166
+ COMPILED_TRANSFORMER_1 = compiled_transformer_1
167
+ COMPILED_TRANSFORMER_2 = compiled_transformer_2
168
+
169
  pipeline.transformer.forward = compiled_transformer_1
170
  drain_module_parameters(pipeline.transformer)
171