Fabrice-TIERCELIN commited on
Commit
627fa06
·
verified ·
1 Parent(s): 5ad4bf0

def to_serializable_dict(self) -> dict[str, Any]:

Browse files
Files changed (1) hide show
  1. optimization_utils.py +131 -107
optimization_utils.py CHANGED
@@ -1,107 +1,131 @@
1
- """
2
- """
3
- import contextlib
4
- from contextvars import ContextVar
5
- from io import BytesIO
6
- from typing import Any
7
- from typing import cast
8
- from unittest.mock import patch
9
-
10
- import torch
11
- from torch._inductor.package.package import package_aoti
12
- from torch.export.pt2_archive._package import AOTICompiledModel
13
- from torch.export.pt2_archive._package_weights import Weights
14
-
15
-
16
- INDUCTOR_CONFIGS_OVERRIDES = {
17
- 'aot_inductor.package_constants_in_so': False,
18
- 'aot_inductor.package_constants_on_disk': True,
19
- 'aot_inductor.package': True,
20
- }
21
-
22
-
23
- class ZeroGPUWeights:
24
- def __init__(self, constants_map: dict[str, torch.Tensor], to_cuda: bool = False):
25
- if to_cuda:
26
- self.constants_map = {name: tensor.to('cuda') for name, tensor in constants_map.items()}
27
- else:
28
- self.constants_map = constants_map
29
- def __reduce__(self):
30
- constants_map: dict[str, torch.Tensor] = {}
31
- for name, tensor in self.constants_map.items():
32
- tensor_ = torch.empty_like(tensor, device='cpu').pin_memory()
33
- constants_map[name] = tensor_.copy_(tensor).detach().share_memory_()
34
- return ZeroGPUWeights, (constants_map, True)
35
-
36
-
37
- class ZeroGPUCompiledModel:
38
- def __init__(self, archive_file: torch.types.FileLike, weights: ZeroGPUWeights):
39
- self.archive_file = archive_file
40
- self.weights = weights
41
- self.compiled_model: ContextVar[AOTICompiledModel | None] = ContextVar('compiled_model', default=None)
42
- def __call__(self, *args, **kwargs):
43
- if (compiled_model := self.compiled_model.get()) is None:
44
- compiled_model = cast(AOTICompiledModel, torch._inductor.aoti_load_package(self.archive_file))
45
- compiled_model.load_constants(self.weights.constants_map, check_full_update=True, user_managed=True)
46
- self.compiled_model.set(compiled_model)
47
- return compiled_model(*args, **kwargs)
48
- def __reduce__(self):
49
- return ZeroGPUCompiledModel, (self.archive_file, self.weights)
50
-
51
-
52
- def aoti_compile(
53
- exported_program: torch.export.ExportedProgram,
54
- inductor_configs: dict[str, Any] | None = None,
55
- ):
56
- inductor_configs = (inductor_configs or {}) | INDUCTOR_CONFIGS_OVERRIDES
57
- gm = cast(torch.fx.GraphModule, exported_program.module())
58
- assert exported_program.example_inputs is not None
59
- args, kwargs = exported_program.example_inputs
60
- artifacts = torch._inductor.aot_compile(gm, args, kwargs, options=inductor_configs)
61
- archive_file = BytesIO()
62
- files: list[str | Weights] = [file for file in artifacts if isinstance(file, str)]
63
- package_aoti(archive_file, files)
64
- weights, = (artifact for artifact in artifacts if isinstance(artifact, Weights))
65
- zerogpu_weights = ZeroGPUWeights({name: weights.get_weight(name)[0] for name in weights})
66
- return ZeroGPUCompiledModel(archive_file, zerogpu_weights)
67
-
68
-
69
- @contextlib.contextmanager
70
- def capture_component_call(
71
- pipeline: Any,
72
- component_name: str,
73
- component_method='forward',
74
- ):
75
-
76
- class CapturedCallException(Exception):
77
- def __init__(self, *args, **kwargs):
78
- super().__init__()
79
- self.args = args
80
- self.kwargs = kwargs
81
-
82
- class CapturedCall:
83
- def __init__(self):
84
- self.args: tuple[Any, ...] = ()
85
- self.kwargs: dict[str, Any] = {}
86
-
87
- component = getattr(pipeline, component_name)
88
- captured_call = CapturedCall()
89
-
90
- def capture_call(*args, **kwargs):
91
- raise CapturedCallException(*args, **kwargs)
92
-
93
- with patch.object(component, component_method, new=capture_call):
94
- try:
95
- yield captured_call
96
- except CapturedCallException as e:
97
- captured_call.args = e.args
98
- captured_call.kwargs = e.kwargs
99
-
100
-
101
- def drain_module_parameters(module: torch.nn.Module):
102
- state_dict_meta = {name: {'device': tensor.device, 'dtype': tensor.dtype} for name, tensor in module.state_dict().items()}
103
- state_dict = {name: torch.nn.Parameter(torch.empty_like(tensor, device='cpu')) for name, tensor in module.state_dict().items()}
104
- module.load_state_dict(state_dict, assign=True)
105
- for name, param in state_dict.items():
106
- meta = state_dict_meta[name]
107
- param.data = torch.Tensor([]).to(**meta)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ """
3
+ import contextlib
4
+ from contextvars import ContextVar
5
+ from io import BytesIO
6
+ from typing import Any
7
+ from typing import cast
8
+ from unittest.mock import patch
9
+
10
+ import torch
11
+ from torch._inductor.package.package import package_aoti
12
+ from torch.export.pt2_archive._package import AOTICompiledModel
13
+ from torch.export.pt2_archive._package_weights import Weights
14
+
15
+
16
+ INDUCTOR_CONFIGS_OVERRIDES = {
17
+ 'aot_inductor.package_constants_in_so': False,
18
+ 'aot_inductor.package_constants_on_disk': True,
19
+ 'aot_inductor.package': True,
20
+ }
21
+
22
+
23
+ class ZeroGPUWeights:
24
+ def __init__(self, constants_map: dict[str, torch.Tensor], to_cuda: bool = False):
25
+ if to_cuda:
26
+ self.constants_map = {name: tensor.to('cuda') for name, tensor in constants_map.items()}
27
+ else:
28
+ self.constants_map = constants_map
29
+ def __reduce__(self):
30
+ constants_map: dict[str, torch.Tensor] = {}
31
+ for name, tensor in self.constants_map.items():
32
+ tensor_ = torch.empty_like(tensor, device='cpu').pin_memory()
33
+ constants_map[name] = tensor_.copy_(tensor).detach().share_memory_()
34
+ return ZeroGPUWeights, (constants_map, True)
35
+
36
+
37
+ class ZeroGPUCompiledModel:
38
+ def __init__(self, archive_file: torch.types.FileLike, weights: ZeroGPUWeights):
39
+ self.archive_file = archive_file
40
+ self.weights = weights
41
+ self.compiled_model: ContextVar[AOTICompiledModel | None] = ContextVar('compiled_model', default=None)
42
+ def __call__(self, *args, **kwargs):
43
+ if (compiled_model := self.compiled_model.get()) is None:
44
+ compiled_model = cast(AOTICompiledModel, torch._inductor.aoti_load_package(self.archive_file))
45
+ compiled_model.load_constants(self.weights.constants_map, check_full_update=True, user_managed=True)
46
+ self.compiled_model.set(compiled_model)
47
+ return compiled_model(*args, **kwargs)
48
+ def __reduce__(self):
49
+ return ZeroGPUCompiledModel, (self.archive_file, self.weights)
50
+
51
+ def to_serializable_dict(self) -> dict[str, Any]:
52
+ """
53
+ Return a stable representation that can be stored to disk and later re-loaded
54
+ with torch.load, without depending on Gradio runtime state.
55
+ """
56
+ # BytesIO is file-like; extract raw bytes
57
+ if hasattr(self.archive_file, "getvalue"):
58
+ archive_bytes = self.archive_file.getvalue()
59
+ else:
60
+ # fallback best-effort
61
+ pos = self.archive_file.tell()
62
+ self.archive_file.seek(0)
63
+ archive_bytes = self.archive_file.read()
64
+ self.archive_file.seek(pos)
65
+
66
+ # store constants on CPU in a safe format
67
+ constants_cpu = {k: v.detach().to("cpu") for k, v in self.weights.constants_map.items()}
68
+
69
+ return {
70
+ "format": "zerogpu_aoti_v1",
71
+ "archive_bytes": archive_bytes,
72
+ "constants_map": constants_cpu,
73
+ }
74
+
75
+
76
+ def aoti_compile(
77
+ exported_program: torch.export.ExportedProgram,
78
+ inductor_configs: dict[str, Any] | None = None,
79
+ ):
80
+ inductor_configs = (inductor_configs or {}) | INDUCTOR_CONFIGS_OVERRIDES
81
+ gm = cast(torch.fx.GraphModule, exported_program.module())
82
+ assert exported_program.example_inputs is not None
83
+ args, kwargs = exported_program.example_inputs
84
+ artifacts = torch._inductor.aot_compile(gm, args, kwargs, options=inductor_configs)
85
+ archive_file = BytesIO()
86
+ files: list[str | Weights] = [file for file in artifacts if isinstance(file, str)]
87
+ package_aoti(archive_file, files)
88
+ weights, = (artifact for artifact in artifacts if isinstance(artifact, Weights))
89
+ zerogpu_weights = ZeroGPUWeights({name: weights.get_weight(name)[0] for name in weights})
90
+ return ZeroGPUCompiledModel(archive_file, zerogpu_weights)
91
+
92
+
93
+ @contextlib.contextmanager
94
+ def capture_component_call(
95
+ pipeline: Any,
96
+ component_name: str,
97
+ component_method='forward',
98
+ ):
99
+
100
+ class CapturedCallException(Exception):
101
+ def __init__(self, *args, **kwargs):
102
+ super().__init__()
103
+ self.args = args
104
+ self.kwargs = kwargs
105
+
106
+ class CapturedCall:
107
+ def __init__(self):
108
+ self.args: tuple[Any, ...] = ()
109
+ self.kwargs: dict[str, Any] = {}
110
+
111
+ component = getattr(pipeline, component_name)
112
+ captured_call = CapturedCall()
113
+
114
+ def capture_call(*args, **kwargs):
115
+ raise CapturedCallException(*args, **kwargs)
116
+
117
+ with patch.object(component, component_method, new=capture_call):
118
+ try:
119
+ yield captured_call
120
+ except CapturedCallException as e:
121
+ captured_call.args = e.args
122
+ captured_call.kwargs = e.kwargs
123
+
124
+
125
+ def drain_module_parameters(module: torch.nn.Module):
126
+ state_dict_meta = {name: {'device': tensor.device, 'dtype': tensor.dtype} for name, tensor in module.state_dict().items()}
127
+ state_dict = {name: torch.nn.Parameter(torch.empty_like(tensor, device='cpu')) for name, tensor in module.state_dict().items()}
128
+ module.load_state_dict(state_dict, assign=True)
129
+ for name, param in state_dict.items():
130
+ meta = state_dict_meta[name]
131
+ param.data = torch.Tensor([]).to(**meta)