Prompt48 commited on
Commit
d1cd809
·
verified ·
1 Parent(s): 15045f9

Upload edit\Qwen3-TTS-test\.venv\Lib\site-packages\accelerate\scheduler.py with huggingface_hub

Browse files
edit//Qwen3-TTS-test//.venv//Lib//site-packages//accelerate//scheduler.py ADDED
@@ -0,0 +1,98 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2022 The HuggingFace Team. All rights reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ # We ignore warnings about stepping the scheduler since we step it ourselves during gradient accumulation
16
+
17
+ import warnings
18
+
19
+ from .state import AcceleratorState, GradientState
20
+
21
+
22
+ warnings.filterwarnings("ignore", category=UserWarning, module="torch.optim.lr_scheduler")
23
+
24
+
25
+ class AcceleratedScheduler:
26
+ """
27
+ A wrapper around a learning rate scheduler that will only step when the optimizer(s) have a training step. Useful
28
+ to avoid making a scheduler step too fast when gradients went overflow and there was no training step (in mixed
29
+ precision training)
30
+
31
+ When performing gradient accumulation scheduler lengths should not be changed accordingly, Accelerate will always
32
+ step the scheduler to account for it.
33
+
34
+ Args:
35
+ scheduler (`torch.optim.lr_scheduler._LRScheduler`):
36
+ The scheduler to wrap.
37
+ optimizers (one or a list of `torch.optim.Optimizer`):
38
+ The optimizers used.
39
+ step_with_optimizer (`bool`, *optional*, defaults to `True`):
40
+ Whether or not the scheduler should be stepped at each optimizer step.
41
+ split_batches (`bool`, *optional*, defaults to `False`):
42
+ Whether or not the dataloaders split one batch across the different processes (so batch size is the same
43
+ regardless of the number of processes) or create batches on each process (so batch size is the original
44
+ batch size multiplied by the number of processes).
45
+ """
46
+
47
+ def __init__(self, scheduler, optimizers, step_with_optimizer: bool = True, split_batches: bool = False):
48
+ self.scheduler = scheduler
49
+ self.optimizers = optimizers if isinstance(optimizers, (list, tuple)) else [optimizers]
50
+ self.split_batches = split_batches
51
+ self.step_with_optimizer = step_with_optimizer
52
+ self.gradient_state = GradientState()
53
+
54
+ def step(self, *args, **kwargs):
55
+ if not self.step_with_optimizer:
56
+ # No link between scheduler and optimizer -> just step
57
+ self.scheduler.step(*args, **kwargs)
58
+ return
59
+
60
+ # Otherwise, first make sure the optimizer was stepped.
61
+ if not self.gradient_state.sync_gradients:
62
+ if self.gradient_state.adjust_scheduler:
63
+ self.scheduler._step_count += 1
64
+ return
65
+
66
+ for opt in self.optimizers:
67
+ if opt.step_was_skipped:
68
+ return
69
+ if self.split_batches:
70
+ # Split batches -> the training dataloader batch size is not changed so one step per training step
71
+ self.scheduler.step(*args, **kwargs)
72
+ else:
73
+ # Otherwise the training dataloader batch size was multiplied by `num_processes`, so we need to do
74
+ # num_processes steps per training step
75
+ num_processes = AcceleratorState().num_processes
76
+ for _ in range(num_processes):
77
+ # Special case when using OneCycle and `drop_last` was not used
78
+ if hasattr(self.scheduler, "total_steps"):
79
+ if self.scheduler._step_count <= self.scheduler.total_steps:
80
+ self.scheduler.step(*args, **kwargs)
81
+ else:
82
+ self.scheduler.step(*args, **kwargs)
83
+
84
+ # Passthroughs
85
+ def get_last_lr(self):
86
+ return self.scheduler.get_last_lr()
87
+
88
+ def state_dict(self):
89
+ return self.scheduler.state_dict()
90
+
91
+ def load_state_dict(self, state_dict):
92
+ self.scheduler.load_state_dict(state_dict)
93
+
94
+ def get_lr(self):
95
+ return self.scheduler.get_lr()
96
+
97
+ def print_lr(self, *args, **kwargs):
98
+ return self.scheduler.print_lr(*args, **kwargs)