Prompt48 commited on
Commit
fb750ba
·
verified ·
1 Parent(s): 8813ad8

Upload edit\Qwen3-TTS-test\.venv\Lib\site-packages\accelerate\optimizer.py with huggingface_hub

Browse files
edit//Qwen3-TTS-test//.venv//Lib//site-packages//accelerate//optimizer.py ADDED
@@ -0,0 +1,213 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2021 The HuggingFace Team. All rights reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ import inspect
16
+
17
+ import torch
18
+
19
+ from .state import AcceleratorState, GradientState
20
+ from .utils import DistributedType, honor_type, is_lomo_available, is_torch_xla_available
21
+
22
+
23
+ if is_torch_xla_available():
24
+ import torch_xla.core.xla_model as xm
25
+ import torch_xla.runtime as xr
26
+
27
+
28
+ def move_to_device(state, device):
29
+ if isinstance(state, (list, tuple)):
30
+ return honor_type(state, (move_to_device(t, device) for t in state))
31
+ elif isinstance(state, dict):
32
+ return type(state)({k: move_to_device(v, device) for k, v in state.items()})
33
+ elif isinstance(state, torch.Tensor):
34
+ return state.to(device)
35
+ return state
36
+
37
+
38
+ class AcceleratedOptimizer(torch.optim.Optimizer):
39
+ """
40
+ Internal wrapper around a torch optimizer.
41
+
42
+ Conditionally will perform `step` and `zero_grad` if gradients should be synchronized when performing gradient
43
+ accumulation.
44
+
45
+ Args:
46
+ optimizer (`torch.optim.optimizer.Optimizer`):
47
+ The optimizer to wrap.
48
+ device_placement (`bool`, *optional*, defaults to `True`):
49
+ Whether or not the optimizer should handle device placement. If so, it will place the state dictionary of
50
+ `optimizer` on the right device.
51
+ scaler (`torch.amp.GradScaler` or `torch.cuda.amp.GradScaler`, *optional*):
52
+ The scaler to use in the step function if training with mixed precision.
53
+ """
54
+
55
+ def __init__(self, optimizer, device_placement=True, scaler=None):
56
+ self.optimizer = optimizer
57
+ self.scaler = scaler
58
+ self.accelerator_state = AcceleratorState()
59
+ self.gradient_state = GradientState()
60
+ self.device_placement = device_placement
61
+ self._is_overflow = False
62
+
63
+ if self.scaler is not None:
64
+ self._accelerate_step_called = False
65
+ self._optimizer_original_step_method = self.optimizer.step
66
+ self._optimizer_patched_step_method = patch_optimizer_step(self, self.optimizer.step)
67
+
68
+ # Handle device placement
69
+ if device_placement:
70
+ state_dict = self.optimizer.state_dict()
71
+ if self.accelerator_state.distributed_type == DistributedType.XLA:
72
+ xm.send_cpu_data_to_device(state_dict, self.accelerator_state.device)
73
+ else:
74
+ state_dict = move_to_device(state_dict, self.accelerator_state.device)
75
+ self.optimizer.load_state_dict(state_dict)
76
+
77
+ @property
78
+ def state(self):
79
+ return self.optimizer.state
80
+
81
+ @state.setter
82
+ def state(self, state):
83
+ self.optimizer.state = state
84
+
85
+ @property
86
+ def param_groups(self):
87
+ return self.optimizer.param_groups
88
+
89
+ @param_groups.setter
90
+ def param_groups(self, param_groups):
91
+ self.optimizer.param_groups = param_groups
92
+
93
+ @property
94
+ def defaults(self):
95
+ return self.optimizer.defaults
96
+
97
+ @defaults.setter
98
+ def defaults(self, defaults):
99
+ self.optimizer.defaults = defaults
100
+
101
+ def add_param_group(self, param_group):
102
+ self.optimizer.add_param_group(param_group)
103
+
104
+ def load_state_dict(self, state_dict):
105
+ if self.accelerator_state.distributed_type == DistributedType.XLA and self.device_placement:
106
+ xm.send_cpu_data_to_device(state_dict, self.accelerator_state.device)
107
+ self.optimizer.load_state_dict(state_dict)
108
+
109
+ def state_dict(self):
110
+ return self.optimizer.state_dict()
111
+
112
+ def zero_grad(self, set_to_none=None):
113
+ if self.gradient_state.sync_gradients:
114
+ accept_arg = "set_to_none" in inspect.signature(self.optimizer.zero_grad).parameters
115
+ if accept_arg:
116
+ if set_to_none is None:
117
+ set_to_none = True
118
+ self.optimizer.zero_grad(set_to_none=set_to_none)
119
+ else:
120
+ if set_to_none is not None:
121
+ raise ValueError("`set_to_none` for Optimizer.zero_grad` is not supported by this optimizer.")
122
+ self.optimizer.zero_grad()
123
+
124
+ def train(self):
125
+ """
126
+ Sets the optimizer to "train" mode. Useful for optimizers like `schedule_free`
127
+ """
128
+ if hasattr(self.optimizer, "train") and callable(self.optimizer.train):
129
+ self.optimizer.train()
130
+ elif (
131
+ hasattr(self.optimizer, "optimizer")
132
+ and hasattr(self.optimizer.optimizer, "train")
133
+ and callable(self.optimizer.optimizer.train)
134
+ ):
135
+ # the deepspeed optimizer further wraps the optimizer
136
+ self.optimizer.optimizer.train()
137
+
138
+ def eval(self):
139
+ """
140
+ Sets the optimizer to "eval" mode. Useful for optimizers like `schedule_free`
141
+ """
142
+ if hasattr(self.optimizer, "eval") and callable(self.optimizer.eval):
143
+ self.optimizer.eval()
144
+
145
+ def step(self, closure=None):
146
+ if is_lomo_available():
147
+ from lomo_optim import AdaLomo, Lomo
148
+
149
+ if (
150
+ not self.gradient_state.is_xla_gradients_synced
151
+ and self.accelerator_state.distributed_type == DistributedType.XLA
152
+ ):
153
+ gradients = xm._fetch_gradients(self.optimizer)
154
+ xm.all_reduce("sum", gradients, scale=1.0 / xr.world_size())
155
+ self.gradient_state.is_xla_gradients_synced = True
156
+
157
+ if is_lomo_available():
158
+ # `step` should be a no-op for LOMO optimizers.
159
+ if isinstance(self.optimizer, (Lomo, AdaLomo)):
160
+ return
161
+
162
+ if self.gradient_state.sync_gradients:
163
+ if self.scaler is not None:
164
+ self.optimizer.step = self._optimizer_patched_step_method
165
+
166
+ self.scaler.step(self.optimizer, closure)
167
+ self.scaler.update()
168
+
169
+ if not self._accelerate_step_called:
170
+ # If the optimizer step was skipped, gradient overflow was detected.
171
+ self._is_overflow = True
172
+ else:
173
+ self._is_overflow = False
174
+ # Reset the step method to the original one
175
+ self.optimizer.step = self._optimizer_original_step_method
176
+ # Reset the indicator
177
+ self._accelerate_step_called = False
178
+ else:
179
+ self.optimizer.step(closure)
180
+ if self.accelerator_state.distributed_type == DistributedType.XLA:
181
+ self.gradient_state.is_xla_gradients_synced = False
182
+
183
+ def _switch_parameters(self, parameters_map):
184
+ for param_group in self.optimizer.param_groups:
185
+ param_group["params"] = [parameters_map.get(p, p) for p in param_group["params"]]
186
+
187
+ @property
188
+ def step_was_skipped(self):
189
+ """Whether or not the optimizer step was skipped."""
190
+ return self._is_overflow
191
+
192
+ def __getstate__(self):
193
+ _ignored_keys = [
194
+ "_accelerate_step_called",
195
+ "_optimizer_original_step_method",
196
+ "_optimizer_patched_step_method",
197
+ ]
198
+ return {k: v for k, v in self.__dict__.items() if k not in _ignored_keys}
199
+
200
+ def __setstate__(self, state):
201
+ self.__dict__.update(state)
202
+ if self.scaler is not None:
203
+ self._accelerate_step_called = False
204
+ self._optimizer_original_step_method = self.optimizer.step
205
+ self._optimizer_patched_step_method = patch_optimizer_step(self, self.optimizer.step)
206
+
207
+
208
+ def patch_optimizer_step(accelerated_optimizer: AcceleratedOptimizer, method):
209
+ def patched_step(*args, **kwargs):
210
+ accelerated_optimizer._accelerate_step_called = True
211
+ return method(*args, **kwargs)
212
+
213
+ return patched_step