jiuface commited on
Commit
d1c7d07
·
verified ·
1 Parent(s): 95e4e92

Create optimization_utils.py

Browse files
Files changed (1) hide show
  1. optimization_utils.py +105 -0
optimization_utils.py ADDED
@@ -0,0 +1,105 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ mport contextlib
2
+ from contextvars import ContextVar
3
+ from io import BytesIO
4
+ from typing import Any
5
+ from typing import cast
6
+ from unittest.mock import patch
7
+
8
+ import torch
9
+ from torch._inductor.package.package import package_aoti
10
+ from torch.export.pt2_archive._package import AOTICompiledModel
11
+ from torch.export.pt2_archive._package_weights import Weights
12
+
13
+
14
+ INDUCTOR_CONFIGS_OVERRIDES = {
15
+ 'aot_inductor.package_constants_in_so': False,
16
+ 'aot_inductor.package_constants_on_disk': True,
17
+ 'aot_inductor.package': True,
18
+ }
19
+
20
+
21
+ class ZeroGPUWeights:
22
+ def __init__(self, constants_map: dict[str, torch.Tensor], to_cuda: bool = False):
23
+ if to_cuda:
24
+ self.constants_map = {name: tensor.to('cuda') for name, tensor in constants_map.items()}
25
+ else:
26
+ self.constants_map = constants_map
27
+ def __reduce__(self):
28
+ constants_map: dict[str, torch.Tensor] = {}
29
+ for name, tensor in self.constants_map.items():
30
+ tensor_ = torch.empty_like(tensor, device='cpu').pin_memory()
31
+ constants_map[name] = tensor_.copy_(tensor).detach().share_memory_()
32
+ return ZeroGPUWeights, (constants_map, True)
33
+
34
+
35
+ class ZeroGPUCompiledModel:
36
+ def __init__(self, archive_file: torch.types.FileLike, weights: ZeroGPUWeights):
37
+ self.archive_file = archive_file
38
+ self.weights = weights
39
+ self.compiled_model: ContextVar[AOTICompiledModel | None] = ContextVar('compiled_model', default=None)
40
+ def __call__(self, *args, **kwargs):
41
+ if (compiled_model := self.compiled_model.get()) is None:
42
+ compiled_model = cast(AOTICompiledModel, torch._inductor.aoti_load_package(self.archive_file))
43
+ compiled_model.load_constants(self.weights.constants_map, check_full_update=True, user_managed=True)
44
+ self.compiled_model.set(compiled_model)
45
+ return compiled_model(*args, **kwargs)
46
+ def __reduce__(self):
47
+ return ZeroGPUCompiledModel, (self.archive_file, self.weights)
48
+
49
+
50
+ def aoti_compile(
51
+ exported_program: torch.export.ExportedProgram,
52
+ inductor_configs: dict[str, Any] | None = None,
53
+ ):
54
+ inductor_configs = (inductor_configs or {}) | INDUCTOR_CONFIGS_OVERRIDES
55
+ gm = cast(torch.fx.GraphModule, exported_program.module())
56
+ assert exported_program.example_inputs is not None
57
+ args, kwargs = exported_program.example_inputs
58
+ artifacts = torch._inductor.aot_compile(gm, args, kwargs, options=inductor_configs)
59
+ archive_file = BytesIO()
60
+ files: list[str | Weights] = [file for file in artifacts if isinstance(file, str)]
61
+ package_aoti(archive_file, files)
62
+ weights, = (artifact for artifact in artifacts if isinstance(artifact, Weights))
63
+ zerogpu_weights = ZeroGPUWeights({name: weights.get_weight(name)[0] for name in weights})
64
+ return ZeroGPUCompiledModel(archive_file, zerogpu_weights)
65
+
66
+
67
+ @contextlib.contextmanager
68
+ def capture_component_call(
69
+ pipeline: Any,
70
+ component_name: str,
71
+ component_method='forward',
72
+ ):
73
+
74
+ class CapturedCallException(Exception):
75
+ def __init__(self, *args, **kwargs):
76
+ super().__init__()
77
+ self.args = args
78
+ self.kwargs = kwargs
79
+
80
+ class CapturedCall:
81
+ def __init__(self):
82
+ self.args: tuple[Any, ...] = ()
83
+ self.kwargs: dict[str, Any] = {}
84
+
85
+ component = getattr(pipeline, component_name)
86
+ captured_call = CapturedCall()
87
+
88
+ def capture_call(*args, **kwargs):
89
+ raise CapturedCallException(*args, **kwargs)
90
+
91
+ with patch.object(component, component_method, new=capture_call):
92
+ try:
93
+ yield captured_call
94
+ except CapturedCallException as e:
95
+ captured_call.args = e.args
96
+ captured_call.kwargs = e.kwargs
97
+
98
+
99
+ def drain_module_parameters(module: torch.nn.Module):
100
+ state_dict_meta = {name: {'device': tensor.device, 'dtype': tensor.dtype} for name, tensor in module.state_dict().items()}
101
+ state_dict = {name: torch.nn.Parameter(torch.empty_like(tensor, device='cpu')) for name, tensor in module.state_dict().items()}
102
+ module.load_state_dict(state_dict, assign=True)
103
+ for name, param in state_dict.items():
104
+ meta = state_dict_meta[name]
105
+ param.data = torch.Tensor([]).to(**meta)