litwell commited on
Commit
f8e18e6
·
verified ·
1 Parent(s): 569465e

Upload models/src/training/trainer.py with huggingface_hub

Browse files
Files changed (1) hide show
  1. models/src/training/trainer.py +230 -0
models/src/training/trainer.py ADDED
@@ -0,0 +1,230 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import torch
3
+ import torch.nn as nn
4
+
5
+ from transformers import Trainer
6
+ from transformers.trainer import (
7
+ is_sagemaker_mp_enabled,
8
+ get_parameter_names,
9
+ ALL_LAYERNORM_LAYERS,
10
+ is_peft_available,
11
+ WEIGHTS_NAME,
12
+ TRAINING_ARGS_NAME,
13
+ SAFE_WEIGHTS_NAME,
14
+ TRAINER_STATE_NAME,
15
+ PREFIX_CHECKPOINT_DIR,
16
+ logger,
17
+ )
18
+ import safetensors
19
+ from peft import PeftModel
20
+ from typing import Optional
21
+ import numpy as np
22
+ from transformers.processing_utils import ProcessorMixin
23
+ from transformers.modeling_utils import PreTrainedModel
24
+ from peft import PeftModel
25
+ from training.train_utils import get_peft_state_maybe_zero_3, get_peft_state_non_lora_maybe_zero_3
26
+
27
+ def maybe_zero_3(param, ignore_status=False, name=None):
28
+ from deepspeed import zero
29
+ from deepspeed.runtime.zero.partition_parameters import ZeroParamStatus
30
+
31
+ if hasattr(param, "ds_id"):
32
+ if param.ds_status == ZeroParamStatus.NOT_AVAILABLE:
33
+ if not ignore_status:
34
+ print(name, "no ignore status")
35
+ with zero.GatheredParameters([param]):
36
+ param = param.data.detach().cpu().clone()
37
+ else:
38
+ param = param.detach().cpu().clone()
39
+ return param
40
+
41
+ class QwenTrainer(Trainer):
42
+
43
+ def __init__(self, processor, *args, **kwargs):
44
+ super(QwenTrainer, self).__init__(*args, **kwargs)
45
+ self.processor = processor
46
+
47
+ def create_optimizer(self):
48
+ """
49
+ Setup the optimizer.
50
+ We provide a reasonable default that works well. If you want to use something else, you can pass a tuple in the
51
+ Trainer's init through `optimizers`, or subclass and override this method in a subclass.
52
+ """
53
+ if is_sagemaker_mp_enabled():
54
+ return super().create_optimizer()
55
+
56
+ opt_model = self.model
57
+
58
+ if self.optimizer is None:
59
+ decay_parameters = get_parameter_names(opt_model, ALL_LAYERNORM_LAYERS)
60
+ decay_parameters = [name for name in decay_parameters if "bias" not in name]
61
+ lr_mapper = {}
62
+ visual_parameters = []
63
+ merger_parameters = []
64
+
65
+ if self.args.vision_lr is not None:
66
+ lr_mapper["visual"] = self.args.vision_lr
67
+ visual_parameters = [name for name, _ in opt_model.named_parameters() if "visual" in name and "merger" not in name]
68
+ if self.args.merger_lr is not None:
69
+ lr_mapper["merger"] = self.args.merger_lr
70
+ merger_parameters = [name for name, _ in opt_model.named_parameters() if "merger" in name]
71
+
72
+ if len(lr_mapper) > 0:
73
+ special_lr_parameters = merger_parameters + visual_parameters
74
+
75
+ optimizer_grouped_parameters = [
76
+ {
77
+ "params": [p for n, p in opt_model.named_parameters() if (n in decay_parameters and n not in special_lr_parameters and p.requires_grad)],
78
+ "weight_decay": self.args.weight_decay,
79
+ },
80
+ {
81
+ "params": [p for n, p in opt_model.named_parameters() if (n not in decay_parameters and n not in special_lr_parameters and p.requires_grad)],
82
+ "weight_decay": 0.0,
83
+ },
84
+ ]
85
+
86
+ if visual_parameters:
87
+ optimizer_grouped_parameters.extend(
88
+ [
89
+ {
90
+ "params": [p for n, p in opt_model.named_parameters() if (n in decay_parameters and n in visual_parameters and p.requires_grad)],
91
+ "weight_decay": self.args.weight_decay,
92
+ "lr": self.args.vision_lr,
93
+ },
94
+ {
95
+ "params": [p for n, p in opt_model.named_parameters() if (n not in decay_parameters and n in visual_parameters and p.requires_grad)],
96
+ "weight_decay": 0.0,
97
+ "lr": self.args.vision_lr,
98
+ },
99
+ ]
100
+ )
101
+
102
+ if merger_parameters:
103
+ optimizer_grouped_parameters.extend(
104
+ [
105
+ {
106
+ "params": [p for n, p in opt_model.named_parameters() if (n in decay_parameters and n in merger_parameters and p.requires_grad)],
107
+ "weight_decay": self.args.weight_decay,
108
+ "lr": self.args.merger_lr,
109
+ },
110
+ {
111
+ "params": [p for n, p in opt_model.named_parameters() if (n not in decay_parameters and n in merger_parameters and p.requires_grad)],
112
+ "weight_decay": 0.0,
113
+ "lr": self.args.merger_lr,
114
+ },
115
+ ]
116
+ )
117
+ else:
118
+ optimizer_grouped_parameters = [
119
+ {
120
+ "params": [p for n, p in opt_model.named_parameters() if (n in decay_parameters and p.requires_grad)],
121
+ "weight_decay": self.args.weight_decay,
122
+ },
123
+ {
124
+ "params": [p for n, p in opt_model.named_parameters() if (n not in decay_parameters and p.requires_grad)],
125
+ "weight_decay": 0.0,
126
+ },
127
+ ]
128
+ optimizer_cls, optimizer_kwargs = Trainer.get_optimizer_cls_and_kwargs(self.args)
129
+
130
+ self.optimizer = optimizer_cls(optimizer_grouped_parameters, **optimizer_kwargs)
131
+ if optimizer_cls.__name__ == "Adam8bit":
132
+ import bitsandbytes
133
+
134
+ manager = bitsandbytes.optim.GlobalOptimManager.get_instance()
135
+
136
+ skipped = 0
137
+ for module in opt_model.modules():
138
+ if isinstance(module, nn.Embedding):
139
+ skipped += sum({p.data_ptr(): p.numel() for p in module.parameters()}.values())
140
+ logger.info(f"skipped {module}: {skipped/2**20}M params")
141
+ manager.register_module_override(module, "weight", {"optim_bits": 32})
142
+ logger.debug(f"bitsandbytes: will optimize {module} in fp32")
143
+ logger.info(f"skipped: {skipped/2**20}M params")
144
+
145
+ return self.optimizer
146
+
147
+ def _save_checkpoint(self, model, trial):
148
+ if self.args.lora_enable:
149
+ checkpoint_folder = f"{PREFIX_CHECKPOINT_DIR}-{self.state.global_step}"
150
+
151
+ if self.hp_search_backend is None and trial is None:
152
+ self.store_flos()
153
+
154
+ run_dir = self._get_output_dir(trial=trial)
155
+ output_dir = os.path.join(run_dir, checkpoint_folder)
156
+
157
+ self.save_model(output_dir, _internal_call=True)
158
+
159
+ non_lora_weights = get_peft_state_non_lora_maybe_zero_3(self.model.named_parameters(), require_grad_only=False)
160
+ torch.save(non_lora_weights, os.path.join(output_dir, "non_lora_state_dict.bin"))
161
+
162
+ if not self.args.save_only_model:
163
+ # Save optimizer and scheduler
164
+ self._save_optimizer_and_scheduler(output_dir)
165
+ # Save RNG state
166
+ self._save_rng_state(output_dir)
167
+
168
+ # Save the Trainer state
169
+ if self.args.should_save:
170
+ # Update the `TrainerControl` state to where we are currently
171
+ self.state.stateful_callbacks["TrainerControl"] = self.control.state()
172
+ self.state.save_to_json(os.path.join(output_dir, TRAINER_STATE_NAME))
173
+
174
+ if self.args.push_to_hub:
175
+ self._push_from_checkpoint(output_dir)
176
+
177
+ # Maybe delete some older checkpoints.
178
+ if self.args.should_save:
179
+ # Solely rely on numerical checkpoint id for rotation.
180
+ # mtime is not reliable especially on some fuse fs in cloud environments.
181
+ self._rotate_checkpoints(use_mtime=False, output_dir=run_dir)
182
+
183
+ else:
184
+ super(QwenTrainer, self)._save_checkpoint(model, trial)
185
+
186
+ def _save(self, output_dir: Optional[str] = None, state_dict=None):
187
+ # If we are executing this function, we are the process zero, so we don't check for that.
188
+ output_dir = output_dir if output_dir is not None else self.args.output_dir
189
+ os.makedirs(output_dir, exist_ok=True)
190
+ logger.info(f"Saving model checkpoint to {output_dir}")
191
+
192
+ supported_classes = (PreTrainedModel,) if not is_peft_available() else (PreTrainedModel, PeftModel)
193
+ # Save a trained model and configuration using `save_pretrained()`.
194
+ # They can then be reloaded using `from_pretrained()`
195
+ if not isinstance(self.model, supported_classes):
196
+ if state_dict is None:
197
+ state_dict = self.model.state_dict()
198
+
199
+ if isinstance(self.accelerator.unwrap_model(self.model), supported_classes):
200
+ self.accelerator.unwrap_model(self.model).save_pretrained(
201
+ output_dir, state_dict=state_dict, safe_serialization=self.args.save_safetensors
202
+ )
203
+ else:
204
+ logger.info("Trainer.model is not a `PreTrainedModel`, only saving its state dict.")
205
+ if self.args.save_safetensors:
206
+ safetensors.torch.save_file(
207
+ state_dict, os.path.join(output_dir, SAFE_WEIGHTS_NAME), metadata={"format": "pt"}
208
+ )
209
+ else:
210
+ torch.save(state_dict, os.path.join(output_dir, WEIGHTS_NAME))
211
+ else:
212
+ self.model.save_pretrained(
213
+ output_dir, state_dict=state_dict, safe_serialization=self.args.save_safetensors
214
+ )
215
+
216
+ if self.tokenizer is not None:
217
+ self.tokenizer.save_pretrained(output_dir)
218
+
219
+ if self.processor is not None:
220
+ self.processor.save_pretrained(output_dir)
221
+
222
+ # Good practice: save your training arguments together with the trained model
223
+ torch.save(self.args, os.path.join(output_dir, TRAINING_ARGS_NAME))
224
+
225
+ # def training_step(self, model, inputs):
226
+ # for name, param in model.named_parameters():
227
+ # if 'visual' in name and param.requires_grad:
228
+ # print(f"Training parameter {name}")
229
+ #
230
+ # return super().training_step(model, inputs)