Fabrice-TIERCELIN commited on
Commit
30d1371
·
verified ·
1 Parent(s): 34937e3

loader Hub + delete duplicate

Browse files
Files changed (1) hide show
  1. optimization.py +158 -172
optimization.py CHANGED
@@ -1,173 +1,159 @@
1
- from typing import Any
2
- from typing import Callable
3
- from typing import ParamSpec
4
-
5
- import os
6
- import spaces
7
- import torch
8
- from torch.utils._pytree import tree_map_only
9
- from torchao.quantization import quantize_
10
- from torchao.quantization import Float8DynamicActivationFloat8WeightConfig
11
- from torchao.quantization import Int8WeightOnlyConfig
12
- from huggingface_hub import hf_hub_download
13
-
14
- from io import BytesIO
15
-
16
- from optimization_utils import capture_component_call
17
- from optimization_utils import aoti_compile
18
- from optimization_utils import drain_module_parameters
19
-
20
- # NEW: import classes to rebuild compiled objects
21
- from optimization_utils import ZeroGPUCompiledModel, ZeroGPUWeights
22
-
23
-
24
- P = ParamSpec('P')
25
-
26
- # Expose compiled models so app.py can offer them for download
27
- COMPILED_TRANSFORMER_1 = None
28
- COMPILED_TRANSFORMER_2 = None
29
-
30
- LATENT_FRAMES_DIM = torch.export.Dim('num_latent_frames', min=8, max=81)
31
- LATENT_PATCHED_HEIGHT_DIM = torch.export.Dim('latent_patched_height', min=30, max=52)
32
- LATENT_PATCHED_WIDTH_DIM = torch.export.Dim('latent_patched_width', min=30, max=52)
33
-
34
- TRANSFORMER_DYNAMIC_SHAPES = {
35
- 'hidden_states': {
36
- 2: LATENT_FRAMES_DIM,
37
- 3: 2 * LATENT_PATCHED_HEIGHT_DIM,
38
- 4: 2 * LATENT_PATCHED_WIDTH_DIM,
39
- },
40
- }
41
-
42
- INDUCTOR_CONFIGS = {
43
- 'conv_1x1_as_mm': True,
44
- 'epilogue_fusion': False,
45
- 'coordinate_descent_tuning': True,
46
- 'coordinate_descent_check_all_directions': True,
47
- 'max_autotune': True,
48
- 'triton.cudagraphs': True,
49
- }
50
-
51
-
52
- def _deserialize_zerogpu_aoti(payload: dict[str, Any]) -> ZeroGPUCompiledModel:
53
- """
54
- Rebuild a ZeroGPUCompiledModel from a stable serialized dict produced by
55
- ZeroGPUCompiledModel.to_serializable_dict().
56
- """
57
- if not isinstance(payload, dict):
58
- raise ValueError(f"Expected dict payload, got: {type(payload)}")
59
-
60
- fmt = payload.get("format")
61
- if fmt != "zerogpu_aoti_v1":
62
- raise ValueError(f"Unsupported payload format: {fmt!r}")
63
-
64
- archive_bytes = payload.get("archive_bytes")
65
- constants_map = payload.get("constants_map")
66
-
67
- if not isinstance(archive_bytes, (bytes, bytearray)):
68
- raise ValueError("payload['archive_bytes'] must be bytes")
69
- if not isinstance(constants_map, dict):
70
- raise ValueError("payload['constants_map'] must be a dict of tensors")
71
-
72
- # Recreate in-memory archive file (what aoti_load_package expects)
73
- archive_file = BytesIO(archive_bytes)
74
-
75
- # Ensure constants are CPU tensors (ZeroGPUWeights will pin/copy for runtime)
76
- constants_map = {k: v.detach().to("cpu") for k, v in constants_map.items()}
77
-
78
- weights = ZeroGPUWeights(constants_map, to_cuda=False)
79
- return ZeroGPUCompiledModel(archive_file, weights)
80
-
81
-
82
- def load_compiled_transformers_from_hub(
83
- repo_id: str,
84
- filename_1: str = "compiled_transformer_1.pt",
85
- filename_2: str = "compiled_transformer_2.pt",
86
- ):
87
- """
88
- Charge les artefacts précompilés depuis le Hub.
89
-
90
- IMPORTANT: les fichiers .pt doivent contenir le dict sérialisé produit par
91
- ZeroGPUCompiledModel.to_serializable_dict() (format "zerogpu_aoti_v1").
92
- """
93
- path_1 = hf_hub_download(repo_id=repo_id, filename=filename_1)
94
- path_2 = hf_hub_download(repo_id=repo_id, filename=filename_2)
95
-
96
- payload_1 = torch.load(path_1, map_location="cpu", weights_only=False)
97
- payload_2 = torch.load(path_2, map_location="cpu", weights_only=False)
98
-
99
- compiled_1 = _deserialize_zerogpu_aoti(payload_1)
100
- compiled_2 = _deserialize_zerogpu_aoti(payload_2)
101
-
102
- return compiled_1, compiled_2
103
-
104
-
105
- def optimize_pipeline_(pipeline: Callable[P, Any], *args: P.args, **kwargs: P.kwargs):
106
- global COMPILED_TRANSFORMER_1, COMPILED_TRANSFORMER_2
107
-
108
- @spaces.GPU(duration=1500)
109
- def compile_transformer():
110
- pipeline.load_lora_weights(
111
- "Kijai/WanVideo_comfy",
112
- weight_name="Lightx2v/lightx2v_I2V_14B_480p_cfg_step_distill_rank128_bf16.safetensors",
113
- adapter_name="lightx2v",
114
- )
115
- kwargs_lora = {"load_into_transformer_2": True}
116
- pipeline.load_lora_weights(
117
- "Kijai/WanVideo_comfy",
118
- weight_name="Lightx2v/lightx2v_I2V_14B_480p_cfg_step_distill_rank128_bf16.safetensors",
119
- adapter_name="lightx2v_2",
120
- **kwargs_lora,
121
- )
122
- pipeline.set_adapters(["lightx2v", "lightx2v_2"], adapter_weights=[1.0, 1.0])
123
- pipeline.fuse_lora(adapter_names=["lightx2v"], lora_scale=3.0, components=["transformer"])
124
- pipeline.fuse_lora(adapter_names=["lightx2v_2"], lora_scale=1.0, components=["transformer_2"])
125
- pipeline.unload_lora_weights()
126
-
127
- with capture_component_call(pipeline, "transformer") as call:
128
- pipeline(*args, **kwargs)
129
-
130
- dynamic_shapes = tree_map_only((torch.Tensor, bool), lambda t: None, call.kwargs)
131
- dynamic_shapes |= TRANSFORMER_DYNAMIC_SHAPES
132
-
133
- quantize_(pipeline.transformer, Float8DynamicActivationFloat8WeightConfig())
134
- quantize_(pipeline.transformer_2, Float8DynamicActivationFloat8WeightConfig())
135
-
136
- exported_1 = torch.export.export(
137
- mod=pipeline.transformer,
138
- args=call.args,
139
- kwargs=call.kwargs,
140
- dynamic_shapes=dynamic_shapes,
141
- )
142
- exported_2 = torch.export.export(
143
- mod=pipeline.transformer_2,
144
- args=call.args,
145
- kwargs=call.kwargs,
146
- dynamic_shapes=dynamic_shapes,
147
- )
148
-
149
- compiled_1 = aoti_compile(exported_1, INDUCTOR_CONFIGS)
150
- compiled_2 = aoti_compile(exported_2, INDUCTOR_CONFIGS)
151
- return compiled_1, compiled_2
152
-
153
- quantize_(pipeline.text_encoder, Int8WeightOnlyConfig())
154
-
155
- use_precompiled = False
156
- precompiled_repo = "Fabrice-TIERCELIN/Wan_2.2_compiled"
157
-
158
- if use_precompiled:
159
- compiled_transformer_1, compiled_transformer_2 = load_compiled_transformers_from_hub(
160
- repo_id=precompiled_repo
161
- )
162
- else:
163
- compiled_transformer_1, compiled_transformer_2 = compile_transformer()
164
-
165
- # expose for downloads
166
- COMPILED_TRANSFORMER_1 = compiled_transformer_1
167
- COMPILED_TRANSFORMER_2 = compiled_transformer_2
168
-
169
- pipeline.transformer.forward = compiled_transformer_1
170
- drain_module_parameters(pipeline.transformer)
171
-
172
- pipeline.transformer_2.forward = compiled_transformer_2
173
  drain_module_parameters(pipeline.transformer_2)
 
1
+ """
2
+ """
3
+
4
+ from typing import Any
5
+ from typing import Callable
6
+ from typing import ParamSpec
7
+
8
+ import os
9
+ import spaces
10
+ import torch
11
+ from torch.utils._pytree import tree_map_only
12
+ from torchao.quantization import quantize_
13
+ from torchao.quantization import Float8DynamicActivationFloat8WeightConfig
14
+ from torchao.quantization import Int8WeightOnlyConfig
15
+ from huggingface_hub import hf_hub_download
16
+
17
+ from optimization_utils import capture_component_call
18
+ from optimization_utils import aoti_compile
19
+ from optimization_utils import drain_module_parameters
20
+ from optimization_utils import zerogpu_compiled_from_serializable_dict
21
+ from optimization_utils import ZeroGPUCompiledModel
22
+
23
+
24
+ P = ParamSpec('P')
25
+
26
+ LATENT_FRAMES_DIM = torch.export.Dim('num_latent_frames', min=8, max=81)
27
+ LATENT_PATCHED_HEIGHT_DIM = torch.export.Dim('latent_patched_height', min=30, max=52)
28
+ LATENT_PATCHED_WIDTH_DIM = torch.export.Dim('latent_patched_width', min=30, max=52)
29
+
30
+ TRANSFORMER_DYNAMIC_SHAPES = {
31
+ 'hidden_states': {
32
+ 2: LATENT_FRAMES_DIM,
33
+ 3: 2 * LATENT_PATCHED_HEIGHT_DIM,
34
+ 4: 2 * LATENT_PATCHED_WIDTH_DIM,
35
+ },
36
+ }
37
+
38
+ INDUCTOR_CONFIGS = {
39
+ 'conv_1x1_as_mm': True,
40
+ 'epilogue_fusion': False,
41
+ 'coordinate_descent_tuning': True,
42
+ 'coordinate_descent_check_all_directions': True,
43
+ 'max_autotune': True,
44
+ 'triton.cudagraphs': True,
45
+ }
46
+
47
+
48
+ def _strtobool(v: str | None, default: bool = True) -> bool:
49
+ if v is None:
50
+ return default
51
+ return v.strip().lower() in ("1", "true", "yes", "y", "on")
52
+
53
+
54
+ def _load_compiled_pt(path: str):
55
+ """
56
+ Load either:
57
+ - a serialized dict produced by to_serializable_dict() (format zerogpu_aoti_v1), or
58
+ - an old-style pickled ZeroGPUCompiledModel.
59
+ """
60
+ obj = torch.load(path, map_location="cpu", weights_only=False)
61
+
62
+ # New format: dict payload
63
+ if isinstance(obj, dict) and obj.get("format") == "zerogpu_aoti_v1":
64
+ return zerogpu_compiled_from_serializable_dict(obj)
65
+
66
+ # Old format: direct object
67
+ if isinstance(obj, ZeroGPUCompiledModel):
68
+ return obj
69
+
70
+ raise ValueError(
71
+ f"Unsupported compiled transformer file format at {path}. "
72
+ f"Got type={type(obj)} keys={list(obj.keys()) if isinstance(obj, dict) else None}"
73
+ )
74
+
75
+
76
+ def load_compiled_transformers_from_hub(
77
+ repo_id: str,
78
+ filename_1: str = "compiled_transformer_1.pt",
79
+ filename_2: str = "compiled_transformer_2.pt",
80
+ ):
81
+ """
82
+ Charge les artefacts précompilés depuis le Hub.
83
+
84
+ IMPORTANT:
85
+ Les fichiers attendus sont ceux que tu exportes via to_serializable_dict()
86
+ (format 'zerogpu_aoti_v1') OU un pickle direct de ZeroGPUCompiledModel.
87
+ """
88
+ path_1 = hf_hub_download(repo_id=repo_id, filename=filename_1)
89
+ path_2 = hf_hub_download(repo_id=repo_id, filename=filename_2)
90
+
91
+ compiled_1 = _load_compiled_pt(path_1)
92
+ compiled_2 = _load_compiled_pt(path_2)
93
+ return compiled_1, compiled_2
94
+
95
+
96
+ def optimize_pipeline_(pipeline: Callable[P, Any], *args: P.args, **kwargs: P.kwargs):
97
+ @spaces.GPU(duration=1500)
98
+ def compile_transformer():
99
+ pipeline.load_lora_weights(
100
+ "Kijai/WanVideo_comfy",
101
+ weight_name="Lightx2v/lightx2v_I2V_14B_480p_cfg_step_distill_rank128_bf16.safetensors",
102
+ adapter_name="lightx2v",
103
+ )
104
+ kwargs_lora = {"load_into_transformer_2": True}
105
+ pipeline.load_lora_weights(
106
+ "Kijai/WanVideo_comfy",
107
+ weight_name="Lightx2v/lightx2v_I2V_14B_480p_cfg_step_distill_rank128_bf16.safetensors",
108
+ adapter_name="lightx2v_2",
109
+ **kwargs_lora,
110
+ )
111
+ pipeline.set_adapters(["lightx2v", "lightx2v_2"], adapter_weights=[1.0, 1.0])
112
+ pipeline.fuse_lora(adapter_names=["lightx2v"], lora_scale=3.0, components=["transformer"])
113
+ pipeline.fuse_lora(adapter_names=["lightx2v_2"], lora_scale=1.0, components=["transformer_2"])
114
+ pipeline.unload_lora_weights()
115
+
116
+ with capture_component_call(pipeline, "transformer") as call:
117
+ pipeline(*args, **kwargs)
118
+
119
+ dynamic_shapes = tree_map_only((torch.Tensor, bool), lambda t: None, call.kwargs)
120
+ dynamic_shapes |= TRANSFORMER_DYNAMIC_SHAPES
121
+
122
+ quantize_(pipeline.transformer, Float8DynamicActivationFloat8WeightConfig())
123
+ quantize_(pipeline.transformer_2, Float8DynamicActivationFloat8WeightConfig())
124
+
125
+ exported_1 = torch.export.export(
126
+ mod=pipeline.transformer,
127
+ args=call.args,
128
+ kwargs=call.kwargs,
129
+ dynamic_shapes=dynamic_shapes,
130
+ )
131
+ exported_2 = torch.export.export(
132
+ mod=pipeline.transformer_2,
133
+ args=call.args,
134
+ kwargs=call.kwargs,
135
+ dynamic_shapes=dynamic_shapes,
136
+ )
137
+
138
+ compiled_1 = aoti_compile(exported_1, INDUCTOR_CONFIGS)
139
+ compiled_2 = aoti_compile(exported_2, INDUCTOR_CONFIGS)
140
+ return compiled_1, compiled_2
141
+
142
+ # Text encoder quant (inchangé)
143
+ quantize_(pipeline.text_encoder, Int8WeightOnlyConfig())
144
+
145
+ use_precompiled = True
146
+ precompiled_repo = os.getenv("WAN_PRECOMPILED_REPO", "Fabrice-TIERCELIN/Wan_2.2_compiled")
147
+
148
+ if use_precompiled:
149
+ compiled_transformer_1, compiled_transformer_2 = load_compiled_transformers_from_hub(
150
+ repo_id=precompiled_repo
151
+ )
152
+ else:
153
+ compiled_transformer_1, compiled_transformer_2 = compile_transformer()
154
+
155
+ pipeline.transformer.forward = compiled_transformer_1
156
+ drain_module_parameters(pipeline.transformer)
157
+
158
+ pipeline.transformer_2.forward = compiled_transformer_2
 
 
 
 
 
 
 
 
 
 
 
 
 
 
159
  drain_module_parameters(pipeline.transformer_2)