litwell commited on
Commit
6f287f0
·
verified ·
1 Parent(s): 30ec4fd

Upload models/src/training/train.py with huggingface_hub

Browse files
Files changed (1) hide show
  1. models/src/training/train.py +228 -0
models/src/training/train.py ADDED
@@ -0,0 +1,228 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import torch
3
+ from peft import LoraConfig, get_peft_model
4
+ import ast
5
+ from transformers import AutoProcessor, BitsAndBytesConfig, Qwen2VLForConditionalGeneration, HfArgumentParser, Qwen2_5_VLForConditionalGeneration
6
+ from training.trainer import QwenTrainer
7
+ from training.data import make_supervised_data_module
8
+ from training.params import DataArguments, ModelArguments, TrainingArguments
9
+ from training.train_utils import get_peft_state_maybe_zero_3, get_peft_state_non_lora_maybe_zero_3, safe_save_model_for_hf_trainer
10
+ import pathlib
11
+ from liger_kernel.transformers import apply_liger_kernel_to_qwen2_vl, apply_liger_kernel_to_qwen2_5_vl
12
+ from monkey_patch_forward import replace_qwen2_5_with_mixed_modality_forward, replace_qwen_2_with_mixed_modality_forward
13
+
14
+ local_rank = None
15
+
16
+ def rank0_print(*args):
17
+ if local_rank == 0 or local_rank == '0' or local_rank is None:
18
+ print(*args)
19
+
20
+ def find_target_linear_names(model, num_lora_modules=-1, lora_namespan_exclude=[], verbose=True):
21
+ linear_cls = torch.nn.modules.Linear
22
+ embedding_cls = torch.nn.modules.Embedding
23
+ lora_module_names = []
24
+
25
+ for name, module in model.named_modules():
26
+ if any(ex_keyword in name for ex_keyword in lora_namespan_exclude):
27
+ continue
28
+ if isinstance(module, (linear_cls, embedding_cls)):
29
+ lora_module_names.append(name)
30
+
31
+ if num_lora_modules > 0:
32
+ lora_module_names = lora_module_names[-num_lora_modules:]
33
+ if verbose:
34
+ rank0_print(f"Found {len(lora_module_names)} lora modules: {lora_module_names}")
35
+ return lora_module_names
36
+
37
+ def set_requires_grad(parameters, requires_grad):
38
+ for p in parameters:
39
+ p.requires_grad = requires_grad
40
+
41
+ def configure_vision_tower(model, training_args, compute_dtype, device):
42
+ vision_tower = model.visual
43
+ vision_tower.to(dtype=compute_dtype, device=device)
44
+
45
+ vision_model_params = model.visual.parameters()
46
+ set_requires_grad(vision_model_params, not training_args.freeze_vision_tower)
47
+
48
+ # Handle merger specifically
49
+ merger_params = model.visual.merger.parameters()
50
+ set_requires_grad(merger_params, training_args.tune_merger)
51
+
52
+ def configure_llm(model, training_args):
53
+ lm_head = model.lm_head.parameters()
54
+ set_requires_grad(lm_head, not training_args.freeze_llm)
55
+
56
+ llm_params = model.model.parameters()
57
+ set_requires_grad(llm_params, not training_args.freeze_llm)
58
+
59
+
60
+ def train():
61
+ global local_rank
62
+
63
+ parser = HfArgumentParser(
64
+ (ModelArguments, DataArguments, TrainingArguments))
65
+
66
+ model_args, data_args, training_args = parser.parse_args_into_dataclasses()
67
+ use_liger = training_args.use_liger
68
+ if "Qwen2.5" in model_args.model_id:
69
+ # It monkey patches the forward to handle mixed modality inputs.
70
+ replace_qwen2_5_with_mixed_modality_forward(use_liger=use_liger)
71
+ # This is becuase mixed-modality training monkey-patches the model forward method.
72
+ if use_liger:
73
+ apply_liger_kernel_to_qwen2_5_vl(fused_linear_cross_entropy=False)
74
+ else:
75
+ # It monkey patches the forward to handle mixed modality inputs.
76
+ replace_qwen_2_with_mixed_modality_forward(use_liger=use_liger)
77
+ # This is becuase mixed-modality training monkey-patches the model forward method.
78
+ if use_liger:
79
+ apply_liger_kernel_to_qwen2_vl(fused_linear_cross_entropy=False)
80
+
81
+
82
+ if training_args.lora_enable and not training_args.freeze_llm:
83
+ raise ValueError("If `lora_enable` is True, `freeze_llm` must also be True.")
84
+
85
+ if not training_args.lora_enable:
86
+ assert not training_args.vision_lora, \
87
+ "Error: training_args.lora_enable is not enabled, but training_args.vision_lora is enabled."
88
+
89
+ if training_args.vision_lora and not training_args.freeze_vision_tower:
90
+ raise ValueError("If `vision_lora` is True, `freeze_vision_tower` must also be True.")
91
+
92
+ else:
93
+ if training_args.lora_namespan_exclude is not None:
94
+ training_args.lora_namespan_exclude = ast.literal_eval(training_args.lora_namespan_exclude)
95
+ else:
96
+ training_args.lora_namespan_exclude = []
97
+
98
+ if not training_args.vision_lora:
99
+ training_args.lora_namespan_exclude += ["visual"]
100
+
101
+ local_rank = training_args.local_rank
102
+ compute_dtype = (torch.float16 if training_args.fp16 else (torch.bfloat16 if training_args.bf16 else torch.float32))
103
+
104
+ bnb_model_from_pretrained_args = {}
105
+ if training_args.bits in [4,8]:
106
+ bnb_model_from_pretrained_args.update(dict(
107
+ device_map={"":training_args.device},
108
+ quantization_config = BitsAndBytesConfig(
109
+ load_in_4bit=training_args.bits==4,
110
+ load_in_8bit=training_args.bits==8,
111
+ llm_int8_skip_modules=["visual"],
112
+ llm_int8_threshold=6.0,
113
+ llm_int8_has_fp16_weight=False,
114
+ bnb_4bit_compute_dtype=compute_dtype,
115
+ bnb_4bit_use_double_quant=training_args.double_quant,
116
+ bnb_4bit_quant_type=training_args.quant_type,
117
+ )
118
+ ))
119
+
120
+ if "Qwen2.5" in model_args.model_id:
121
+ model = Qwen2_5_VLForConditionalGeneration.from_pretrained(
122
+ model_args.model_id,
123
+ torch_dtype=compute_dtype,
124
+ attn_implementation="flash_attention_2" if not training_args.disable_flash_attn2 else "sdpa",
125
+ **bnb_model_from_pretrained_args
126
+ )
127
+ else:
128
+ model = Qwen2VLForConditionalGeneration.from_pretrained(
129
+ model_args.model_id,
130
+ torch_dtype=compute_dtype,
131
+ attn_implementation="flash_attention_2" if not training_args.disable_flash_attn2 else "sdpa",
132
+ **bnb_model_from_pretrained_args
133
+ )
134
+
135
+ model.config.use_cache = False
136
+ model_to_configure = model
137
+ configure_llm(model_to_configure, training_args)
138
+ configure_vision_tower(model_to_configure, training_args, compute_dtype, training_args.device)
139
+
140
+ if training_args.bits in [4,8]:
141
+ model.config.torch_dtype = (torch.float32 if training_args.fp16 else (torch.bfloat16 if training_args.bf16 else torch.float32))
142
+ from peft import prepare_model_for_kbit_training
143
+ model = prepare_model_for_kbit_training(model, use_gradient_checkpointing=training_args.gradient_checkpointing, gradient_checkpointing_kwargs={"use_reentrant": True})
144
+
145
+ if training_args.gradient_checkpointing:
146
+ model.enable_input_require_grads()
147
+ training_args.gradient_checkpointing_kwargs = {"use_reentrant": True}
148
+
149
+ if training_args.lora_enable:
150
+ lora_namespan_exclude = training_args.lora_namespan_exclude
151
+ peft_config = LoraConfig(
152
+ r=training_args.lora_rank,
153
+ lora_alpha=training_args.lora_alpha,
154
+ target_modules=find_target_linear_names(model, lora_namespan_exclude=lora_namespan_exclude, num_lora_modules=training_args.num_lora_modules),
155
+ lora_dropout=training_args.lora_dropout,
156
+ bias=training_args.lora_bias
157
+ )
158
+ if training_args.bits == 16:
159
+ if training_args.bf16:
160
+ model.to(torch.bfloat16)
161
+ if training_args.fp16:
162
+ model.to(torch.float16)
163
+ rank0_print("Adding LoRA to the model...")
164
+ model = get_peft_model(model, peft_config)
165
+
166
+ processor = AutoProcessor.from_pretrained(model_args.model_id,
167
+ # The default setting is padding_side="left"
168
+ # When training using the right-side padding is more efficient.
169
+ padding_side="right")
170
+
171
+ # model.config.tokenizer_model_max_length = processor.tokenizer.model_max_length
172
+ model.config.tokenizer_padding_side = processor.tokenizer.padding_side
173
+ model.config.vision_lr = training_args.vision_lr
174
+
175
+ if training_args.bits in [4, 8]:
176
+ from peft.tuners.lora import LoraLayer
177
+ for name, module in model.named_modules():
178
+ if isinstance(module, LoraLayer):
179
+ if training_args.bf16:
180
+ module = module.to(torch.bfloat16)
181
+ if 'norm' in name:
182
+ module = module.to(torch.float32)
183
+
184
+ if 'lm_head' in name or 'embed_token' in name:
185
+ if hasattr(module, 'weight'):
186
+ if training_args.bf16 and module.weight.dtype == torch.float32:
187
+ module = module.to(torch.bfloat16)
188
+
189
+ data_module = make_supervised_data_module(model_id=model_args.model_id,
190
+ processor=processor,
191
+ data_args=data_args)
192
+
193
+ trainer = QwenTrainer(
194
+ model=model,
195
+ processor=processor,
196
+ args=training_args,
197
+ **data_module
198
+ )
199
+
200
+ if list(pathlib.Path(training_args.output_dir).glob("checkpoint-*")):
201
+ trainer.train(resume_from_checkpoint=True)
202
+ else:
203
+ trainer.train()
204
+
205
+ trainer.save_state()
206
+
207
+ model.config.use_cache = True
208
+
209
+ if training_args.lora_enable:
210
+ state_dict = get_peft_state_maybe_zero_3(
211
+ model.named_parameters(), training_args.lora_bias
212
+ )
213
+
214
+ non_lora_state_dict = get_peft_state_non_lora_maybe_zero_3(
215
+ model.named_parameters(), require_grad_only=False
216
+ )
217
+
218
+ if local_rank == 0 or local_rank == -1:
219
+ model.config.save_pretrained(training_args.output_dir)
220
+ model.save_pretrained(training_args.output_dir, state_dict=state_dict)
221
+ torch.save(non_lora_state_dict, os.path.join(training_args.output_dir, "non_lora_state_dict.bin"))
222
+ else:
223
+ safe_save_model_for_hf_trainer(trainer, output_dir=training_args.output_dir)
224
+
225
+
226
+
227
+ if __name__ == "__main__":
228
+ train()