asdjghh commited on
Commit
91748c3
·
verified ·
1 Parent(s): f3af360

Upload sft_gen.py with huggingface_hub

Browse files
Files changed (1) hide show
  1. sft_gen.py +125 -0
sft_gen.py ADDED
@@ -0,0 +1,125 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2025 PKU-Alignment Team. All Rights Reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ # ==============================================================================
15
+ """Trainer for supervised training."""
16
+
17
+
18
+ import argparse
19
+ import os
20
+ import sys
21
+
22
+ import deepspeed
23
+ import torch
24
+ import transformers
25
+ from janus.models import MultiModalityCausalLM, VLChatProcessor, VLMImageProcessor
26
+
27
+ from align_anything.datasets.janus import SupervisedBatch, SupervisedDataset, SupervisedTokenizedDataset
28
+ from align_anything.trainers.text_to_text.sft import SupervisedTrainer as SupervisedtextTrainer
29
+ from align_anything.utils.device_utils import torch_set_device
30
+ from align_anything.utils.multi_process import get_current_device
31
+ from align_anything.utils.tools import (
32
+ custom_cfgs_to_dict,
33
+ dict_to_namedtuple,
34
+ read_cfgs,
35
+ seed_everything,
36
+ update_dict,
37
+ )
38
+
39
+
40
+ transformers.logging.set_verbosity_info()
41
+
42
+
43
+ class SuperviseTrainer(SupervisedtextTrainer):
44
+
45
+ def init_datasets(self) -> None:
46
+ """Initialize training and evaluation datasets."""
47
+ self.train_dataloader, self.eval_dataloader = self.get_dataloaders(
48
+ SupervisedTokenizedDataset, SupervisedTokenizedDataset
49
+ ) # change to SupervisedTokenizedDataset, SupervisedTokenizedDataset in case of image input
50
+
51
+ def update_configs(self, model_config, args, fields):
52
+ cross_update = lambda a, b, field_name: (
53
+ setattr(b, field_name, getattr(a, field_name))
54
+ if getattr(b, field_name, None) is None
55
+ else setattr(a, field_name, getattr(b, field_name))
56
+ )
57
+
58
+ for f in fields:
59
+ cross_update(model_config, args, f)
60
+
61
+ def init_models(self) -> None:
62
+ """Initialize model and tokenizer."""
63
+ self.model = MultiModalityCausalLM.from_pretrained(
64
+ self.cfgs.model_cfgs.model_name_or_path,
65
+ ).to(get_current_device())
66
+ # for name, param in self.model.named_parameters():
67
+ # if '2' in name:
68
+ # param.requires_grad = True
69
+ # print(f"Trainable: {name}")
70
+ # else:
71
+ # param.requires_grad = False
72
+
73
+ # print(name)
74
+ # param.requires_grad = False
75
+ # print('#########################',self.model)
76
+ if self.cfgs.train_cfgs.bf16:
77
+ self.model = self.model.to(torch.bfloat16)
78
+
79
+ self.processor = VLChatProcessor.from_pretrained(
80
+ self.cfgs.model_cfgs.model_name_or_path,
81
+ )
82
+ self.tokenizer = self.processor.tokenizer
83
+
84
+ def loss(self, sft_batch: SupervisedBatch) -> dict[str, torch.Tensor]:
85
+ """Loss function for supervised finetuning."""
86
+ print("sft_batch", sft_batch.keys())
87
+ sft_batch['task'] = 'image_editing'
88
+ print('SSSS ',sft_batch['source_image'])
89
+ outputs = self.model.forward(vl_chat_processor=self.processor,**sft_batch)
90
+ return {
91
+ 'loss': outputs.loss,
92
+ }
93
+
94
+
95
+ def main():
96
+ # setup distribution training
97
+ deepspeed.init_distributed()
98
+ current_device = get_current_device()
99
+ torch_set_device(current_device)
100
+
101
+ # read default configs from the yaml file
102
+ task = os.path.join('janus', 'sft_gen')
103
+ dict_cfgs, ds_cfgs = read_cfgs(mode='train', task=task)
104
+
105
+ # get custom configs from command line
106
+ parser = argparse.ArgumentParser(formatter_class=argparse.ArgumentDefaultsHelpFormatter)
107
+ _, unparsed_args = parser.parse_known_args()
108
+ keys = [k[2:] for k in unparsed_args[1::2]]
109
+ values = list(unparsed_args[2::2])
110
+ unparsed_args = dict(zip(keys, values))
111
+ for k, v in unparsed_args.items():
112
+ dict_cfgs = update_dict(dict_cfgs, custom_cfgs_to_dict(k, v))
113
+
114
+ # setup training
115
+ cfgs = dict_to_namedtuple(dict_cfgs)
116
+ seed_everything(cfgs.train_cfgs.seed)
117
+
118
+ # finetune the model
119
+ trainer = SuperviseTrainer(cfgs=cfgs, ds_cfgs=ds_cfgs)
120
+ trainer.train()
121
+ trainer.save()
122
+
123
+
124
+ if __name__ == '__main__':
125
+ sys.exit(main())