Prompt48 commited on
Commit
4c4113d
·
verified ·
1 Parent(s): 904b4f7

Upload edit\Qwen3-TTS-test\.venv\Lib\site-packages\accelerate\hooks.py with huggingface_hub

Browse files
edit//Qwen3-TTS-test//.venv//Lib//site-packages//accelerate//hooks.py ADDED
@@ -0,0 +1,776 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2022 The HuggingFace Team. All rights reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ import functools
16
+ from collections.abc import Mapping
17
+ from typing import Optional, Union
18
+
19
+ import torch
20
+ import torch.nn as nn
21
+
22
+ from .state import PartialState
23
+ from .utils import (
24
+ PrefixedDataset,
25
+ find_device,
26
+ named_module_tensors,
27
+ send_to_device,
28
+ set_module_tensor_to_device,
29
+ )
30
+ from .utils.imports import (
31
+ is_mlu_available,
32
+ is_musa_available,
33
+ is_npu_available,
34
+ )
35
+ from .utils.memory import clear_device_cache
36
+ from .utils.modeling import get_non_persistent_buffers
37
+ from .utils.other import recursive_getattr
38
+
39
+
40
+ _accelerate_added_attributes = ["to", "cuda", "npu", "xpu", "mlu", "sdaa", "musa"]
41
+
42
+
43
+ class ModelHook:
44
+ """
45
+ A hook that contains callbacks to be executed just before and after the forward method of a model. The difference
46
+ with PyTorch existing hooks is that they get passed along the kwargs.
47
+
48
+ Class attribute:
49
+ - **no_grad** (`bool`, *optional*, defaults to `False`) -- Whether or not to execute the actual forward pass under
50
+ the `torch.no_grad()` context manager.
51
+ """
52
+
53
+ no_grad = False
54
+
55
+ def init_hook(self, module):
56
+ """
57
+ To be executed when the hook is attached to the module.
58
+
59
+ Args:
60
+ module (`torch.nn.Module`): The module attached to this hook.
61
+ """
62
+ return module
63
+
64
+ def pre_forward(self, module, *args, **kwargs):
65
+ """
66
+ To be executed just before the forward method of the model.
67
+
68
+ Args:
69
+ module (`torch.nn.Module`): The module whose forward pass will be executed just after this event.
70
+ args (`Tuple[Any]`): The positional arguments passed to the module.
71
+ kwargs (`Dict[Str, Any]`): The keyword arguments passed to the module.
72
+
73
+ Returns:
74
+ `Tuple[Tuple[Any], Dict[Str, Any]]`: A tuple with the treated `args` and `kwargs`.
75
+ """
76
+ return args, kwargs
77
+
78
+ def post_forward(self, module, output):
79
+ """
80
+ To be executed just after the forward method of the model.
81
+
82
+ Args:
83
+ module (`torch.nn.Module`): The module whose forward pass been executed just before this event.
84
+ output (`Any`): The output of the module.
85
+
86
+ Returns:
87
+ `Any`: The processed `output`.
88
+ """
89
+ return output
90
+
91
+ def detach_hook(self, module):
92
+ """
93
+ To be executed when the hook is detached from a module.
94
+
95
+ Args:
96
+ module (`torch.nn.Module`): The module detached from this hook.
97
+ """
98
+ return module
99
+
100
+
101
+ class SequentialHook(ModelHook):
102
+ """
103
+ A hook that can contain several hooks and iterates through them at each event.
104
+ """
105
+
106
+ def __init__(self, *hooks):
107
+ self.hooks = hooks
108
+
109
+ def init_hook(self, module):
110
+ for hook in self.hooks:
111
+ module = hook.init_hook(module)
112
+ return module
113
+
114
+ def pre_forward(self, module, *args, **kwargs):
115
+ for hook in self.hooks:
116
+ args, kwargs = hook.pre_forward(module, *args, **kwargs)
117
+ return args, kwargs
118
+
119
+ def post_forward(self, module, output):
120
+ for hook in self.hooks:
121
+ output = hook.post_forward(module, output)
122
+ return output
123
+
124
+ def detach_hook(self, module):
125
+ for hook in self.hooks:
126
+ module = hook.detach_hook(module)
127
+ return module
128
+
129
+
130
+ def add_hook_to_module(module: nn.Module, hook: ModelHook, append: bool = False):
131
+ """
132
+ Adds a hook to a given module. This will rewrite the `forward` method of the module to include the hook, to remove
133
+ this behavior and restore the original `forward` method, use `remove_hook_from_module`.
134
+
135
+ <Tip warning={true}>
136
+
137
+ If the module already contains a hook, this will replace it with the new hook passed by default. To chain two hooks
138
+ together, pass `append=True`, so it chains the current and new hook into an instance of the `SequentialHook` class.
139
+
140
+ </Tip>
141
+
142
+ Args:
143
+ module (`torch.nn.Module`):
144
+ The module to attach a hook to.
145
+ hook (`ModelHook`):
146
+ The hook to attach.
147
+ append (`bool`, *optional*, defaults to `False`):
148
+ Whether the hook should be chained with an existing one (if module already contains a hook) or not.
149
+
150
+ Returns:
151
+ `torch.nn.Module`: The same module, with the hook attached (the module is modified in place, so the result can
152
+ be discarded).
153
+ """
154
+ if append and (getattr(module, "_hf_hook", None) is not None):
155
+ old_hook = module._hf_hook
156
+ remove_hook_from_module(module)
157
+ hook = SequentialHook(old_hook, hook)
158
+
159
+ if hasattr(module, "_hf_hook") and hasattr(module, "_old_forward"):
160
+ # If we already put some hook on this module, we replace it with the new one.
161
+ old_forward = module._old_forward
162
+ else:
163
+ old_forward = module.forward
164
+ module._old_forward = old_forward
165
+
166
+ module = hook.init_hook(module)
167
+ module._hf_hook = hook
168
+
169
+ def new_forward(module, *args, **kwargs):
170
+ args, kwargs = module._hf_hook.pre_forward(module, *args, **kwargs)
171
+ if module._hf_hook.no_grad:
172
+ with torch.no_grad():
173
+ output = module._old_forward(*args, **kwargs)
174
+ else:
175
+ output = module._old_forward(*args, **kwargs)
176
+ return module._hf_hook.post_forward(module, output)
177
+
178
+ # Overriding a GraphModuleImpl forward freezes the forward call and later modifications on the graph will fail.
179
+ # Reference: https://pytorch.slack.com/archives/C3PDTEV8E/p1705929610405409
180
+ if "GraphModuleImpl" in str(type(module)):
181
+ module.__class__.forward = functools.update_wrapper(functools.partial(new_forward, module), old_forward)
182
+ else:
183
+ module.forward = functools.update_wrapper(functools.partial(new_forward, module), old_forward)
184
+
185
+ return module
186
+
187
+
188
+ def remove_hook_from_module(module: nn.Module, recurse=False):
189
+ """
190
+ Removes any hook attached to a module via `add_hook_to_module`.
191
+
192
+ Args:
193
+ module (`torch.nn.Module`): The module to attach a hook to.
194
+ recurse (`bool`, **optional**): Whether to remove the hooks recursively
195
+
196
+ Returns:
197
+ `torch.nn.Module`: The same module, with the hook detached (the module is modified in place, so the result can
198
+ be discarded).
199
+ """
200
+
201
+ if hasattr(module, "_hf_hook"):
202
+ module._hf_hook.detach_hook(module)
203
+ delattr(module, "_hf_hook")
204
+
205
+ if hasattr(module, "_old_forward"):
206
+ # Overriding a GraphModuleImpl forward freezes the forward call and later modifications on the graph will fail.
207
+ # Reference: https://pytorch.slack.com/archives/C3PDTEV8E/p1705929610405409
208
+ if "GraphModuleImpl" in str(type(module)):
209
+ module.__class__.forward = module._old_forward
210
+ else:
211
+ module.forward = module._old_forward
212
+ delattr(module, "_old_forward")
213
+
214
+ # Remove accelerate added warning hooks from dispatch_model
215
+ for attr in _accelerate_added_attributes:
216
+ module.__dict__.pop(attr, None)
217
+
218
+ if recurse:
219
+ for child in module.children():
220
+ remove_hook_from_module(child, recurse)
221
+
222
+ return module
223
+
224
+
225
+ class AlignDevicesHook(ModelHook):
226
+ """
227
+ A generic `ModelHook` that ensures inputs and model weights are on the same device for the forward pass of the
228
+ associated module, potentially offloading the weights after the forward pass.
229
+
230
+ Args:
231
+ execution_device (`torch.device`, *optional*):
232
+ The device on which inputs and model weights should be placed before the forward pass.
233
+ offload (`bool`, *optional*, defaults to `False`):
234
+ Whether or not the weights should be offloaded after the forward pass.
235
+ io_same_device (`bool`, *optional*, defaults to `False`):
236
+ Whether or not the output should be placed on the same device as the input was.
237
+ weights_map (`Mapping[str, torch.Tensor]`, *optional*):
238
+ When the model weights are offloaded, a (potentially lazy) map from param names to the tensor values.
239
+ offload_buffers (`bool`, *optional*, defaults to `False`):
240
+ Whether or not to include the associated module's buffers when offloading.
241
+ place_submodules (`bool`, *optional*, defaults to `False`):
242
+ Whether to place the submodules on `execution_device` during the `init_hook` event.
243
+ """
244
+
245
+ def __init__(
246
+ self,
247
+ execution_device: Optional[Union[int, str, torch.device]] = None,
248
+ offload: bool = False,
249
+ io_same_device: bool = False,
250
+ weights_map: Optional[Mapping] = None,
251
+ offload_buffers: bool = False,
252
+ place_submodules: bool = False,
253
+ skip_keys: Optional[Union[str, list[str]]] = None,
254
+ tied_params_map: Optional[dict[int, dict[torch.device, torch.Tensor]]] = None,
255
+ ):
256
+ self.execution_device = execution_device
257
+ self.offload = offload
258
+ self.io_same_device = io_same_device
259
+ self.weights_map = weights_map
260
+ self.offload_buffers = offload_buffers
261
+ self.place_submodules = place_submodules
262
+ self.skip_keys = skip_keys
263
+
264
+ # Will contain the input device when `io_same_device=True`.
265
+ self.input_device = None
266
+ self.param_original_devices = {}
267
+ self.buffer_original_devices = {}
268
+ self.tied_params_names = set()
269
+
270
+ # The hook pre_forward/post_forward need to have knowledge of this dictionary, as with offloading we want to avoid duplicating memory
271
+ # for tied weights already loaded on the target execution device.
272
+ self.tied_params_map = tied_params_map
273
+
274
+ def __repr__(self):
275
+ return (
276
+ f"AlignDevicesHook(execution_device={self.execution_device}, offload={self.offload}, "
277
+ f"io_same_device={self.io_same_device}, offload_buffers={self.offload_buffers}, "
278
+ f"place_submodules={self.place_submodules}, skip_keys={repr(self.skip_keys)})"
279
+ )
280
+
281
+ def init_hook(self, module):
282
+ # In case the AlignDevicesHook is on meta device, ignore tied weights as data_ptr() is then always zero.
283
+ if self.execution_device == "meta" or self.execution_device == torch.device("meta"):
284
+ self.tied_params_map = None
285
+
286
+ if not self.offload and self.execution_device is not None:
287
+ for name, _ in named_module_tensors(module, recurse=self.place_submodules):
288
+ set_module_tensor_to_device(module, name, self.execution_device, tied_params_map=self.tied_params_map)
289
+ elif self.offload:
290
+ self.original_devices = {
291
+ name: param.device for name, param in named_module_tensors(module, recurse=self.place_submodules)
292
+ }
293
+ if self.weights_map is None:
294
+ self.weights_map = {
295
+ name: param.to("cpu")
296
+ for name, param in named_module_tensors(
297
+ module, include_buffers=self.offload_buffers, recurse=self.place_submodules
298
+ )
299
+ }
300
+ for name, _ in named_module_tensors(
301
+ module, include_buffers=self.offload_buffers, recurse=self.place_submodules, remove_non_persistent=True
302
+ ):
303
+ # When using disk offloading, we can not rely on `weights_map[name].data_ptr()` as the reference pointer,
304
+ # as we have no guarantee that safetensors' `file.get_tensor()` will always give the same pointer.
305
+ # As we have no reliable way to track the shared data pointer of tied weights in this case, we use tied_params_names: List[str]
306
+ # to add on the fly pointers to `tied_params_map` in the pre_forward call.
307
+ if (
308
+ self.tied_params_map is not None
309
+ and recursive_getattr(module, name).data_ptr() in self.tied_params_map
310
+ ):
311
+ self.tied_params_names.add(name)
312
+
313
+ set_module_tensor_to_device(module, name, "meta")
314
+
315
+ if not self.offload_buffers and self.execution_device is not None:
316
+ for name, _ in module.named_buffers(recurse=self.place_submodules):
317
+ set_module_tensor_to_device(
318
+ module, name, self.execution_device, tied_params_map=self.tied_params_map
319
+ )
320
+ elif self.offload_buffers and self.execution_device is not None:
321
+ for name in get_non_persistent_buffers(module, recurse=self.place_submodules):
322
+ set_module_tensor_to_device(
323
+ module, name, self.execution_device, tied_params_map=self.tied_params_map
324
+ )
325
+
326
+ return module
327
+
328
+ def pre_forward(self, module, *args, **kwargs):
329
+ if self.io_same_device:
330
+ self.input_device = find_device([args, kwargs])
331
+ if self.offload:
332
+ self.tied_pointers_to_remove = set()
333
+
334
+ for name, _ in named_module_tensors(
335
+ module,
336
+ include_buffers=self.offload_buffers,
337
+ recurse=self.place_submodules,
338
+ remove_non_persistent=True,
339
+ ):
340
+ fp16_statistics = None
341
+ value = self.weights_map[name]
342
+ if "weight" in name and name.replace("weight", "SCB") in self.weights_map.keys():
343
+ if value.dtype == torch.int8:
344
+ fp16_statistics = self.weights_map[name.replace("weight", "SCB")]
345
+
346
+ # In case we are using offloading with tied weights, we need to keep track of the offloaded weights
347
+ # that are loaded on device at this point, as we will need to remove them as well from the dictionary
348
+ # self.tied_params_map in order to allow to free memory.
349
+ if name in self.tied_params_names and value.data_ptr() not in self.tied_params_map:
350
+ self.tied_params_map[value.data_ptr()] = {}
351
+
352
+ if (
353
+ value is not None
354
+ and self.tied_params_map is not None
355
+ and value.data_ptr() in self.tied_params_map
356
+ and self.execution_device not in self.tied_params_map[value.data_ptr()]
357
+ ):
358
+ self.tied_pointers_to_remove.add((value.data_ptr(), self.execution_device))
359
+
360
+ set_module_tensor_to_device(
361
+ module,
362
+ name,
363
+ self.execution_device,
364
+ value=value,
365
+ fp16_statistics=fp16_statistics,
366
+ tied_params_map=self.tied_params_map,
367
+ )
368
+
369
+ return send_to_device(args, self.execution_device), send_to_device(
370
+ kwargs, self.execution_device, skip_keys=self.skip_keys
371
+ )
372
+
373
+ def post_forward(self, module, output):
374
+ if self.offload:
375
+ for name, _ in named_module_tensors(
376
+ module,
377
+ include_buffers=self.offload_buffers,
378
+ recurse=self.place_submodules,
379
+ remove_non_persistent=True,
380
+ ):
381
+ set_module_tensor_to_device(module, name, "meta")
382
+ if type(module).__name__ == "Linear8bitLt":
383
+ module.state.SCB = None
384
+ module.state.CxB = None
385
+
386
+ # We may have loaded tied weights into self.tied_params_map (avoiding to load them several times in e.g. submodules): remove them from
387
+ # this dictionary to allow the garbage collector to do its job.
388
+ for value_pointer, device in self.tied_pointers_to_remove:
389
+ if isinstance(device, int):
390
+ if is_npu_available():
391
+ device = f"npu:{device}"
392
+ elif is_mlu_available():
393
+ device = f"mlu:{device}"
394
+ elif is_musa_available():
395
+ device = f"musa:{device}"
396
+ if device in self.tied_params_map[value_pointer]:
397
+ del self.tied_params_map[value_pointer][device]
398
+ self.tied_pointers_to_remove = set()
399
+ if self.io_same_device and self.input_device is not None:
400
+ output = send_to_device(output, self.input_device, skip_keys=self.skip_keys)
401
+
402
+ return output
403
+
404
+ def detach_hook(self, module):
405
+ if self.offload:
406
+ for name, device in self.original_devices.items():
407
+ if device != torch.device("meta"):
408
+ set_module_tensor_to_device(module, name, device, value=self.weights_map.get(name, None))
409
+ return module
410
+
411
+
412
+ def attach_execution_device_hook(
413
+ module: torch.nn.Module,
414
+ execution_device: Union[int, str, torch.device],
415
+ skip_keys: Optional[Union[str, list[str]]] = None,
416
+ preload_module_classes: Optional[list[str]] = None,
417
+ tied_params_map: Optional[dict[int, dict[torch.device, torch.Tensor]]] = None,
418
+ ):
419
+ """
420
+ Recursively attaches `AlignDevicesHook` to all submodules of a given model to make sure they have the right
421
+ execution device
422
+
423
+ Args:
424
+ module (`torch.nn.Module`):
425
+ The module where we want to attach the hooks.
426
+ execution_device (`int`, `str` or `torch.device`):
427
+ The device on which inputs and model weights should be placed before the forward pass.
428
+ skip_keys (`str` or `List[str]`, *optional*):
429
+ A list of keys to ignore when moving inputs or outputs between devices.
430
+ preload_module_classes (`List[str]`, *optional*):
431
+ A list of classes whose instances should load all their weights (even in the submodules) at the beginning
432
+ of the forward. This should only be used for classes that have submodules which are registered but not
433
+ called directly during the forward, for instance if a `dense` linear layer is registered, but at forward,
434
+ `dense.weight` and `dense.bias` are used in some operations instead of calling `dense` directly.
435
+ tied_params_map (Optional[Dict[int, Dict[torch.device, torch.Tensor]]], *optional*, defaults to `None`):
436
+ A map of data pointers to dictionaries of devices to already dispatched tied weights. For a given execution
437
+ device, this parameter is useful to reuse the first available pointer of a shared weight for all others,
438
+ instead of duplicating memory.
439
+ """
440
+ if not hasattr(module, "_hf_hook") and len(module.state_dict()) > 0:
441
+ add_hook_to_module(
442
+ module,
443
+ AlignDevicesHook(execution_device, skip_keys=skip_keys, tied_params_map=tied_params_map),
444
+ )
445
+
446
+ # Break the recursion if we get to a preload module.
447
+ if preload_module_classes is not None and module.__class__.__name__ in preload_module_classes:
448
+ return
449
+
450
+ for child in module.children():
451
+ attach_execution_device_hook(
452
+ child,
453
+ execution_device,
454
+ skip_keys=skip_keys,
455
+ preload_module_classes=preload_module_classes,
456
+ tied_params_map=tied_params_map,
457
+ )
458
+
459
+
460
+ def attach_align_device_hook(
461
+ module: torch.nn.Module,
462
+ execution_device: Optional[torch.device] = None,
463
+ offload: bool = False,
464
+ weights_map: Optional[Mapping] = None,
465
+ offload_buffers: bool = False,
466
+ module_name: str = "",
467
+ skip_keys: Optional[Union[str, list[str]]] = None,
468
+ preload_module_classes: Optional[list[str]] = None,
469
+ tied_params_map: Optional[dict[int, dict[torch.device, torch.Tensor]]] = None,
470
+ ):
471
+ """
472
+ Recursively attaches `AlignDevicesHook` to all submodules of a given model that have direct parameters and/or
473
+ buffers.
474
+
475
+ Args:
476
+ module (`torch.nn.Module`):
477
+ The module where we want to attach the hooks.
478
+ execution_device (`torch.device`, *optional*):
479
+ The device on which inputs and model weights should be placed before the forward pass.
480
+ offload (`bool`, *optional*, defaults to `False`):
481
+ Whether or not the weights should be offloaded after the forward pass.
482
+ weights_map (`Mapping[str, torch.Tensor]`, *optional*):
483
+ When the model weights are offloaded, a (potentially lazy) map from param names to the tensor values.
484
+ offload_buffers (`bool`, *optional*, defaults to `False`):
485
+ Whether or not to include the associated module's buffers when offloading.
486
+ module_name (`str`, *optional*, defaults to `""`):
487
+ The name of the module.
488
+ skip_keys (`str` or `List[str]`, *optional*):
489
+ A list of keys to ignore when moving inputs or outputs between devices.
490
+ preload_module_classes (`List[str]`, *optional*):
491
+ A list of classes whose instances should load all their weights (even in the submodules) at the beginning
492
+ of the forward. This should only be used for classes that have submodules which are registered but not
493
+ called directly during the forward, for instance if a `dense` linear layer is registered, but at forward,
494
+ `dense.weight` and `dense.bias` are used in some operations instead of calling `dense` directly.
495
+ tied_params_map (Optional[Dict[int, Dict[torch.device, torch.Tensor]]], *optional*, defaults to `None`):
496
+ A map of data pointers to dictionaries of devices to already dispatched tied weights. For a given execution
497
+ device, this parameter is useful to reuse the first available pointer of a shared weight for all others,
498
+ instead of duplicating memory.
499
+ """
500
+ # Attach the hook on this module if it has any direct tensor.
501
+ directs = named_module_tensors(module)
502
+ full_offload = (
503
+ offload and preload_module_classes is not None and module.__class__.__name__ in preload_module_classes
504
+ )
505
+
506
+ if len(list(directs)) > 0 or full_offload:
507
+ if weights_map is not None:
508
+ prefix = f"{module_name}." if len(module_name) > 0 else ""
509
+ prefixed_weights_map = PrefixedDataset(weights_map, prefix)
510
+ else:
511
+ prefixed_weights_map = None
512
+ hook = AlignDevicesHook(
513
+ execution_device=execution_device,
514
+ offload=offload,
515
+ weights_map=prefixed_weights_map,
516
+ offload_buffers=offload_buffers,
517
+ place_submodules=full_offload,
518
+ skip_keys=skip_keys,
519
+ tied_params_map=tied_params_map,
520
+ )
521
+ add_hook_to_module(module, hook, append=True)
522
+
523
+ # We stop the recursion in case we hit the full offload.
524
+ if full_offload:
525
+ return
526
+
527
+ # Recurse on all children of the module.
528
+ for child_name, child in module.named_children():
529
+ child_name = f"{module_name}.{child_name}" if len(module_name) > 0 else child_name
530
+ attach_align_device_hook(
531
+ child,
532
+ execution_device=execution_device,
533
+ offload=offload,
534
+ weights_map=weights_map,
535
+ offload_buffers=offload_buffers,
536
+ module_name=child_name,
537
+ preload_module_classes=preload_module_classes,
538
+ skip_keys=skip_keys,
539
+ tied_params_map=tied_params_map,
540
+ )
541
+
542
+
543
+ def remove_hook_from_submodules(module: nn.Module):
544
+ """
545
+ Recursively removes all hooks attached on the submodules of a given model.
546
+
547
+ Args:
548
+ module (`torch.nn.Module`): The module on which to remove all hooks.
549
+ """
550
+ remove_hook_from_module(module)
551
+ for child in module.children():
552
+ remove_hook_from_submodules(child)
553
+
554
+
555
+ def attach_align_device_hook_on_blocks(
556
+ module: nn.Module,
557
+ execution_device: Optional[Union[torch.device, dict[str, torch.device]]] = None,
558
+ offload: Union[bool, dict[str, bool]] = False,
559
+ weights_map: Optional[Mapping] = None,
560
+ offload_buffers: bool = False,
561
+ module_name: str = "",
562
+ skip_keys: Optional[Union[str, list[str]]] = None,
563
+ preload_module_classes: Optional[list[str]] = None,
564
+ tied_params_map: Optional[dict[int, dict[torch.device, torch.Tensor]]] = None,
565
+ ):
566
+ """
567
+ Attaches `AlignDevicesHook` to all blocks of a given model as needed.
568
+
569
+ Args:
570
+ module (`torch.nn.Module`):
571
+ The module where we want to attach the hooks.
572
+ execution_device (`torch.device` or `Dict[str, torch.device]`, *optional*):
573
+ The device on which inputs and model weights should be placed before the forward pass. It can be one device
574
+ for the whole module, or a dictionary mapping module name to device.
575
+ offload (`bool`, *optional*, defaults to `False`):
576
+ Whether or not the weights should be offloaded after the forward pass. It can be one boolean for the whole
577
+ module, or a dictionary mapping module name to boolean.
578
+ weights_map (`Mapping[str, torch.Tensor]`, *optional*):
579
+ When the model weights are offloaded, a (potentially lazy) map from param names to the tensor values.
580
+ offload_buffers (`bool`, *optional*, defaults to `False`):
581
+ Whether or not to include the associated module's buffers when offloading.
582
+ module_name (`str`, *optional*, defaults to `""`):
583
+ The name of the module.
584
+ skip_keys (`str` or `List[str]`, *optional*):
585
+ A list of keys to ignore when moving inputs or outputs between devices.
586
+ preload_module_classes (`List[str]`, *optional*):
587
+ A list of classes whose instances should load all their weights (even in the submodules) at the beginning
588
+ of the forward. This should only be used for classes that have submodules which are registered but not
589
+ called directly during the forward, for instance if a `dense` linear layer is registered, but at forward,
590
+ `dense.weight` and `dense.bias` are used in some operations instead of calling `dense` directly.
591
+ tied_params_map (Optional[Dict[int, Dict[torch.device, torch.Tensor]]], *optional*, defaults to `None`):
592
+ A map of data pointers to dictionaries of devices to already dispatched tied weights. For a given execution
593
+ device, this parameter is useful to reuse the first available pointer of a shared weight for all others,
594
+ instead of duplicating memory.
595
+ """
596
+ # If one device and one offload, we've got one hook.
597
+ if not isinstance(execution_device, Mapping) and not isinstance(offload, dict):
598
+ if not offload:
599
+ hook = AlignDevicesHook(
600
+ execution_device=execution_device,
601
+ io_same_device=True,
602
+ skip_keys=skip_keys,
603
+ place_submodules=True,
604
+ tied_params_map=tied_params_map,
605
+ )
606
+ add_hook_to_module(module, hook)
607
+ else:
608
+ attach_align_device_hook(
609
+ module,
610
+ execution_device=execution_device,
611
+ offload=True,
612
+ weights_map=weights_map,
613
+ offload_buffers=offload_buffers,
614
+ module_name=module_name,
615
+ skip_keys=skip_keys,
616
+ tied_params_map=tied_params_map,
617
+ )
618
+ return
619
+
620
+ if not isinstance(execution_device, Mapping):
621
+ execution_device = {key: execution_device for key in offload.keys()}
622
+ if not isinstance(offload, Mapping):
623
+ offload = {key: offload for key in execution_device.keys()}
624
+
625
+ if module_name in execution_device and module_name in offload and not offload[module_name]:
626
+ hook = AlignDevicesHook(
627
+ execution_device=execution_device[module_name],
628
+ offload_buffers=offload_buffers,
629
+ io_same_device=(module_name == ""),
630
+ place_submodules=True,
631
+ skip_keys=skip_keys,
632
+ tied_params_map=tied_params_map,
633
+ )
634
+ add_hook_to_module(module, hook)
635
+ attach_execution_device_hook(
636
+ module, execution_device[module_name], skip_keys=skip_keys, tied_params_map=tied_params_map
637
+ )
638
+ elif module_name in execution_device and module_name in offload:
639
+ attach_align_device_hook(
640
+ module,
641
+ execution_device=execution_device[module_name],
642
+ offload=True,
643
+ weights_map=weights_map,
644
+ offload_buffers=offload_buffers,
645
+ module_name=module_name,
646
+ skip_keys=skip_keys,
647
+ preload_module_classes=preload_module_classes,
648
+ tied_params_map=tied_params_map,
649
+ )
650
+ if not hasattr(module, "_hf_hook"):
651
+ hook = AlignDevicesHook(
652
+ execution_device=execution_device[module_name],
653
+ io_same_device=(module_name == ""),
654
+ skip_keys=skip_keys,
655
+ tied_params_map=tied_params_map,
656
+ )
657
+ add_hook_to_module(module, hook)
658
+ attach_execution_device_hook(
659
+ module,
660
+ execution_device[module_name],
661
+ preload_module_classes=preload_module_classes,
662
+ skip_keys=skip_keys,
663
+ tied_params_map=tied_params_map,
664
+ )
665
+ elif module_name == "":
666
+ hook = AlignDevicesHook(
667
+ execution_device=execution_device.get(""),
668
+ io_same_device=True,
669
+ skip_keys=skip_keys,
670
+ tied_params_map=tied_params_map,
671
+ )
672
+ add_hook_to_module(module, hook)
673
+
674
+ for child_name, child in module.named_children():
675
+ child_name = f"{module_name}.{child_name}" if len(module_name) > 0 else child_name
676
+ attach_align_device_hook_on_blocks(
677
+ child,
678
+ execution_device=execution_device,
679
+ offload=offload,
680
+ weights_map=weights_map,
681
+ offload_buffers=offload_buffers,
682
+ module_name=child_name,
683
+ preload_module_classes=preload_module_classes,
684
+ skip_keys=skip_keys,
685
+ tied_params_map=tied_params_map,
686
+ )
687
+
688
+
689
+ class CpuOffload(ModelHook):
690
+ """
691
+ Offloads a model on the CPU until its forward pass is called. The model will not be offloaded back to the CPU after
692
+ the forward, the user needs to call the `init_hook` method again for this.
693
+
694
+ Args:
695
+ execution_device(`str`, `int` or `torch.device`, *optional*):
696
+ The device on which the model should be executed. Will default to the MPS device if it's available, then
697
+ GPU 0 if there is a GPU, and finally to the CPU.
698
+ prev_module_hook (`UserCpuOffloadHook`, *optional*):
699
+ The hook sent back by [`cpu_offload_with_hook`] for a previous model in the pipeline you are running. If
700
+ passed, its offload method will be called just before the forward of the model to which this hook is
701
+ attached.
702
+ """
703
+
704
+ def __init__(
705
+ self,
706
+ execution_device: Optional[Union[str, int, torch.device]] = None,
707
+ prev_module_hook: Optional["UserCpuOffloadHook"] = None,
708
+ ):
709
+ self.prev_module_hook = prev_module_hook
710
+
711
+ self.execution_device = execution_device if execution_device is not None else PartialState().default_device
712
+
713
+ def init_hook(self, module):
714
+ return module.to("cpu")
715
+
716
+ def pre_forward(self, module, *args, **kwargs):
717
+ if self.prev_module_hook is not None and isinstance(self.prev_module_hook, UserCpuOffloadHook):
718
+ prev_module = self.prev_module_hook.model
719
+ prev_device = next(prev_module.parameters()).device
720
+
721
+ # Only offload the previous module if it is not already on CPU.
722
+ if prev_device != torch.device("cpu"):
723
+ self.prev_module_hook.offload()
724
+ clear_device_cache()
725
+
726
+ # If the current device is already the self.execution_device, we can skip the transfer.
727
+ current_device = next(module.parameters()).device
728
+ if current_device == self.execution_device:
729
+ return args, kwargs
730
+
731
+ module.to(self.execution_device)
732
+ return send_to_device(args, self.execution_device), send_to_device(kwargs, self.execution_device)
733
+
734
+
735
+ class UserCpuOffloadHook:
736
+ """
737
+ A simple hook grouping a model and a `ModelHook`, which provides easy APIs for to call the init method of the hook
738
+ or remove it entirely.
739
+ """
740
+
741
+ def __init__(self, model, hook):
742
+ self.model = model
743
+ self.hook = hook
744
+
745
+ def offload(self):
746
+ self.hook.init_hook(self.model)
747
+
748
+ def remove(self):
749
+ remove_hook_from_module(self.model)
750
+
751
+
752
+ class LayerwiseCastingHook(ModelHook):
753
+ r"""
754
+ A hook that casts the weights of a module to a high precision dtype for computation, and to a low precision dtype
755
+ for storage. This process may lead to quality loss in the output, but can significantly reduce the memory
756
+ footprint.
757
+ """
758
+
759
+ _is_stateful = False
760
+
761
+ def __init__(self, storage_dtype: torch.dtype, compute_dtype: torch.dtype, non_blocking: bool) -> None:
762
+ self.storage_dtype = storage_dtype
763
+ self.compute_dtype = compute_dtype
764
+ self.non_blocking = non_blocking
765
+
766
+ def init_hook(self, module: torch.nn.Module):
767
+ module.to(dtype=self.storage_dtype, non_blocking=self.non_blocking)
768
+ return module
769
+
770
+ def pre_forward(self, module: torch.nn.Module, *args, **kwargs):
771
+ module.to(dtype=self.compute_dtype, non_blocking=self.non_blocking)
772
+ return args, kwargs
773
+
774
+ def post_forward(self, module: torch.nn.Module, output):
775
+ module.to(dtype=self.storage_dtype, non_blocking=self.non_blocking)
776
+ return output