Spaces:
Runtime error
Runtime error
COMPILED_TRANSFORMER_1 = compiled_transformer_1
Browse files- optimization.py +10 -0
optimization.py
CHANGED
|
@@ -23,6 +23,10 @@ from optimization_utils import ZeroGPUCompiledModel, ZeroGPUWeights
|
|
| 23 |
|
| 24 |
P = ParamSpec('P')
|
| 25 |
|
|
|
|
|
|
|
|
|
|
|
|
|
| 26 |
LATENT_FRAMES_DIM = torch.export.Dim('num_latent_frames', min=8, max=81)
|
| 27 |
LATENT_PATCHED_HEIGHT_DIM = torch.export.Dim('latent_patched_height', min=30, max=52)
|
| 28 |
LATENT_PATCHED_WIDTH_DIM = torch.export.Dim('latent_patched_width', min=30, max=52)
|
|
@@ -99,6 +103,8 @@ def load_compiled_transformers_from_hub(
|
|
| 99 |
|
| 100 |
|
| 101 |
def optimize_pipeline_(pipeline: Callable[P, Any], *args: P.args, **kwargs: P.kwargs):
|
|
|
|
|
|
|
| 102 |
@spaces.GPU(duration=1500)
|
| 103 |
def compile_transformer():
|
| 104 |
pipeline.load_lora_weights(
|
|
@@ -156,6 +162,10 @@ def optimize_pipeline_(pipeline: Callable[P, Any], *args: P.args, **kwargs: P.kw
|
|
| 156 |
else:
|
| 157 |
compiled_transformer_1, compiled_transformer_2 = compile_transformer()
|
| 158 |
|
|
|
|
|
|
|
|
|
|
|
|
|
| 159 |
pipeline.transformer.forward = compiled_transformer_1
|
| 160 |
drain_module_parameters(pipeline.transformer)
|
| 161 |
|
|
|
|
| 23 |
|
| 24 |
P = ParamSpec('P')
|
| 25 |
|
| 26 |
+
# Expose compiled models so app.py can offer them for download
|
| 27 |
+
COMPILED_TRANSFORMER_1 = None
|
| 28 |
+
COMPILED_TRANSFORMER_2 = None
|
| 29 |
+
|
| 30 |
LATENT_FRAMES_DIM = torch.export.Dim('num_latent_frames', min=8, max=81)
|
| 31 |
LATENT_PATCHED_HEIGHT_DIM = torch.export.Dim('latent_patched_height', min=30, max=52)
|
| 32 |
LATENT_PATCHED_WIDTH_DIM = torch.export.Dim('latent_patched_width', min=30, max=52)
|
|
|
|
| 103 |
|
| 104 |
|
| 105 |
def optimize_pipeline_(pipeline: Callable[P, Any], *args: P.args, **kwargs: P.kwargs):
|
| 106 |
+
global COMPILED_TRANSFORMER_1, COMPILED_TRANSFORMER_2
|
| 107 |
+
|
| 108 |
@spaces.GPU(duration=1500)
|
| 109 |
def compile_transformer():
|
| 110 |
pipeline.load_lora_weights(
|
|
|
|
| 162 |
else:
|
| 163 |
compiled_transformer_1, compiled_transformer_2 = compile_transformer()
|
| 164 |
|
| 165 |
+
# expose for downloads
|
| 166 |
+
COMPILED_TRANSFORMER_1 = compiled_transformer_1
|
| 167 |
+
COMPILED_TRANSFORMER_2 = compiled_transformer_2
|
| 168 |
+
|
| 169 |
pipeline.transformer.forward = compiled_transformer_1
|
| 170 |
drain_module_parameters(pipeline.transformer)
|
| 171 |
|