Fabrice-TIERCELIN commited on
Commit
34937e3
·
verified ·
1 Parent(s): e49ba69
Files changed (1) hide show
  1. optimization_utils.py +138 -131
optimization_utils.py CHANGED
@@ -1,131 +1,138 @@
1
- """
2
- """
3
- import contextlib
4
- from contextvars import ContextVar
5
- from io import BytesIO
6
- from typing import Any
7
- from typing import cast
8
- from unittest.mock import patch
9
-
10
- import torch
11
- from torch._inductor.package.package import package_aoti
12
- from torch.export.pt2_archive._package import AOTICompiledModel
13
- from torch.export.pt2_archive._package_weights import Weights
14
-
15
-
16
- INDUCTOR_CONFIGS_OVERRIDES = {
17
- 'aot_inductor.package_constants_in_so': False,
18
- 'aot_inductor.package_constants_on_disk': True,
19
- 'aot_inductor.package': True,
20
- }
21
-
22
-
23
- class ZeroGPUWeights:
24
- def __init__(self, constants_map: dict[str, torch.Tensor], to_cuda: bool = False):
25
- if to_cuda:
26
- self.constants_map = {name: tensor.to('cuda') for name, tensor in constants_map.items()}
27
- else:
28
- self.constants_map = constants_map
29
- def __reduce__(self):
30
- constants_map: dict[str, torch.Tensor] = {}
31
- for name, tensor in self.constants_map.items():
32
- tensor_ = torch.empty_like(tensor, device='cpu').pin_memory()
33
- constants_map[name] = tensor_.copy_(tensor).detach().share_memory_()
34
- return ZeroGPUWeights, (constants_map, True)
35
-
36
-
37
- class ZeroGPUCompiledModel:
38
- def __init__(self, archive_file: torch.types.FileLike, weights: ZeroGPUWeights):
39
- self.archive_file = archive_file
40
- self.weights = weights
41
- self.compiled_model: ContextVar[AOTICompiledModel | None] = ContextVar('compiled_model', default=None)
42
- def __call__(self, *args, **kwargs):
43
- if (compiled_model := self.compiled_model.get()) is None:
44
- compiled_model = cast(AOTICompiledModel, torch._inductor.aoti_load_package(self.archive_file))
45
- compiled_model.load_constants(self.weights.constants_map, check_full_update=True, user_managed=True)
46
- self.compiled_model.set(compiled_model)
47
- return compiled_model(*args, **kwargs)
48
- def __reduce__(self):
49
- return ZeroGPUCompiledModel, (self.archive_file, self.weights)
50
-
51
- def to_serializable_dict(self) -> dict[str, Any]:
52
- """
53
- Return a stable representation that can be stored to disk and later re-loaded
54
- with torch.load, without depending on Gradio runtime state.
55
- """
56
- # BytesIO is file-like; extract raw bytes
57
- if hasattr(self.archive_file, "getvalue"):
58
- archive_bytes = self.archive_file.getvalue()
59
- else:
60
- # fallback best-effort
61
- pos = self.archive_file.tell()
62
- self.archive_file.seek(0)
63
- archive_bytes = self.archive_file.read()
64
- self.archive_file.seek(pos)
65
-
66
- # store constants on CPU in a safe format
67
- constants_cpu = {k: v.detach().to("cpu") for k, v in self.weights.constants_map.items()}
68
-
69
- return {
70
- "format": "zerogpu_aoti_v1",
71
- "archive_bytes": archive_bytes,
72
- "constants_map": constants_cpu,
73
- }
74
-
75
-
76
- def aoti_compile(
77
- exported_program: torch.export.ExportedProgram,
78
- inductor_configs: dict[str, Any] | None = None,
79
- ):
80
- inductor_configs = (inductor_configs or {}) | INDUCTOR_CONFIGS_OVERRIDES
81
- gm = cast(torch.fx.GraphModule, exported_program.module())
82
- assert exported_program.example_inputs is not None
83
- args, kwargs = exported_program.example_inputs
84
- artifacts = torch._inductor.aot_compile(gm, args, kwargs, options=inductor_configs)
85
- archive_file = BytesIO()
86
- files: list[str | Weights] = [file for file in artifacts if isinstance(file, str)]
87
- package_aoti(archive_file, files)
88
- weights, = (artifact for artifact in artifacts if isinstance(artifact, Weights))
89
- zerogpu_weights = ZeroGPUWeights({name: weights.get_weight(name)[0] for name in weights})
90
- return ZeroGPUCompiledModel(archive_file, zerogpu_weights)
91
-
92
-
93
- @contextlib.contextmanager
94
- def capture_component_call(
95
- pipeline: Any,
96
- component_name: str,
97
- component_method='forward',
98
- ):
99
-
100
- class CapturedCallException(Exception):
101
- def __init__(self, *args, **kwargs):
102
- super().__init__()
103
- self.args = args
104
- self.kwargs = kwargs
105
-
106
- class CapturedCall:
107
- def __init__(self):
108
- self.args: tuple[Any, ...] = ()
109
- self.kwargs: dict[str, Any] = {}
110
-
111
- component = getattr(pipeline, component_name)
112
- captured_call = CapturedCall()
113
-
114
- def capture_call(*args, **kwargs):
115
- raise CapturedCallException(*args, **kwargs)
116
-
117
- with patch.object(component, component_method, new=capture_call):
118
- try:
119
- yield captured_call
120
- except CapturedCallException as e:
121
- captured_call.args = e.args
122
- captured_call.kwargs = e.kwargs
123
-
124
-
125
- def drain_module_parameters(module: torch.nn.Module):
126
- state_dict_meta = {name: {'device': tensor.device, 'dtype': tensor.dtype} for name, tensor in module.state_dict().items()}
127
- state_dict = {name: torch.nn.Parameter(torch.empty_like(tensor, device='cpu')) for name, tensor in module.state_dict().items()}
128
- module.load_state_dict(state_dict, assign=True)
129
- for name, param in state_dict.items():
130
- meta = state_dict_meta[name]
131
- param.data = torch.Tensor([]).to(**meta)
 
 
 
 
 
 
 
 
1
+ """
2
+ """
3
+ import contextlib
4
+ from contextvars import ContextVar
5
+ from io import BytesIO
6
+ from typing import Any
7
+ from typing import cast
8
+ from unittest.mock import patch
9
+
10
+ import torch
11
+ from torch._inductor.package.package import package_aoti
12
+ from torch.export.pt2_archive._package import AOTICompiledModel
13
+ from torch.export.pt2_archive._package_weights import Weights
14
+
15
+
16
+ INDUCTOR_CONFIGS_OVERRIDES = {
17
+ 'aot_inductor.package_constants_in_so': False,
18
+ 'aot_inductor.package_constants_on_disk': True,
19
+ 'aot_inductor.package': True,
20
+ }
21
+
22
+
23
+ class ZeroGPUWeights:
24
+ def __init__(self, constants_map: dict[str, torch.Tensor], to_cuda: bool = False):
25
+ if to_cuda:
26
+ self.constants_map = {name: tensor.to('cuda') for name, tensor in constants_map.items()}
27
+ else:
28
+ self.constants_map = constants_map
29
+ def __reduce__(self):
30
+ constants_map: dict[str, torch.Tensor] = {}
31
+ for name, tensor in self.constants_map.items():
32
+ tensor_ = torch.empty_like(tensor, device='cpu').pin_memory()
33
+ constants_map[name] = tensor_.copy_(tensor).detach().share_memory_()
34
+ return ZeroGPUWeights, (constants_map, True)
35
+
36
+
37
+ class ZeroGPUCompiledModel:
38
+ def __init__(self, archive_file: torch.types.FileLike, weights: ZeroGPUWeights):
39
+ self.archive_file = archive_file
40
+ self.weights = weights
41
+ self.compiled_model: ContextVar[AOTICompiledModel | None] = ContextVar('compiled_model', default=None)
42
+ def __call__(self, *args, **kwargs):
43
+ if (compiled_model := self.compiled_model.get()) is None:
44
+ compiled_model = cast(AOTICompiledModel, torch._inductor.aoti_load_package(self.archive_file))
45
+ compiled_model.load_constants(self.weights.constants_map, check_full_update=True, user_managed=True)
46
+ self.compiled_model.set(compiled_model)
47
+ return compiled_model(*args, **kwargs)
48
+ def __reduce__(self):
49
+ return ZeroGPUCompiledModel, (self.archive_file, self.weights)
50
+
51
+
52
+ def zerogpu_compiled_from_serializable_dict(payload: dict[str, Any]) -> ZeroGPUCompiledModel:
53
+ """
54
+ Rebuild a ZeroGPUCompiledModel from a stable dict representation produced by:
55
+ ZeroGPUCompiledModel.to_serializable_dict()
56
+
57
+ Expected format:
58
+ {
59
+ "format": "zerogpu_aoti_v1",
60
+ "archive_bytes": <bytes>,
61
+ "constants_map": {name: Tensor(cpu), ...}
62
+ }
63
+ """
64
+ fmt = payload.get("format")
65
+ if fmt != "zerogpu_aoti_v1":
66
+ raise ValueError(f"Unsupported compiled payload format: {fmt!r}")
67
+
68
+ archive_bytes = payload["archive_bytes"]
69
+ constants_map = payload["constants_map"]
70
+
71
+ if not isinstance(archive_bytes, (bytes, bytearray)):
72
+ raise TypeError("payload['archive_bytes'] must be bytes/bytearray")
73
+
74
+ if not isinstance(constants_map, dict):
75
+ raise TypeError("payload['constants_map'] must be a dict")
76
+
77
+ # Ensure tensors are CPU and detached (safe)
78
+ constants_cpu = {}
79
+ for k, v in constants_map.items():
80
+ if not isinstance(v, torch.Tensor):
81
+ raise TypeError(f"constants_map[{k!r}] is not a Tensor")
82
+ constants_cpu[k] = v.detach().to("cpu")
83
+
84
+ archive_file = BytesIO(bytes(archive_bytes))
85
+ weights = ZeroGPUWeights(constants_cpu, to_cuda=False)
86
+ return ZeroGPUCompiledModel(archive_file, weights)
87
+
88
+
89
+ def aoti_compile(
90
+ exported_program: torch.export.ExportedProgram,
91
+ inductor_configs: dict[str, Any] | None = None,
92
+ ):
93
+ inductor_configs = (inductor_configs or {}) | INDUCTOR_CONFIGS_OVERRIDES
94
+ gm = cast(torch.fx.GraphModule, exported_program.module())
95
+ assert exported_program.example_inputs is not None
96
+ args, kwargs = exported_program.example_inputs
97
+ artifacts = torch._inductor.aot_compile(gm, args, kwargs, options=inductor_configs)
98
+ archive_file = BytesIO()
99
+ files: list[str | Weights] = [file for file in artifacts if isinstance(file, str)]
100
+ package_aoti(archive_file, files)
101
+ weights, = (artifact for artifact in artifacts if isinstance(artifact, Weights))
102
+ zerogpu_weights = ZeroGPUWeights({name: weights.get_weight(name)[0] for name in weights})
103
+ return ZeroGPUCompiledModel(archive_file, zerogpu_weights)
104
+
105
+
106
+ @contextlib.contextmanager
107
+ def capture_component_call(
108
+ pipeline: Any,
109
+ component_name: str,
110
+ component_method='forward',
111
+ ):
112
+
113
+ class CapturedCallException(Exception):
114
+ def __init__(self, *args, **kwargs):
115
+ super().__init__()
116
+ self.args = args
117
+ self.kwargs = kwargs
118
+
119
+ class CapturedCall:
120
+ def __init__(self):
121
+ self.args: tuple[Any, ...] = ()
122
+ self.kwargs: dict[str, Any] = {}
123
+
124
+ component = getattr(pipeline, component_name)
125
+ captured_call = CapturedCall()
126
+
127
+ def capture_call(*args, **kwargs):
128
+ raise CapturedCallException(*args, **kwargs)
129
+
130
+ with patch.object(component, component_method, new=capture_call):
131
+ try:
132
+ yield captured_call
133
+ except CapturedCallException as e:
134
+ captured_call.args = e.args
135
+ captured_call.kwargs = e.kwargs
136
+
137
+
138
+ def drain_module_parameters(module: torch.n