Unify_dataset / sft_gen.py
asdjghh's picture
Upload sft_gen.py with huggingface_hub
91748c3 verified
# Copyright 2025 PKU-Alignment Team. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Trainer for supervised training."""
import argparse
import os
import sys
import deepspeed
import torch
import transformers
from janus.models import MultiModalityCausalLM, VLChatProcessor, VLMImageProcessor
from align_anything.datasets.janus import SupervisedBatch, SupervisedDataset, SupervisedTokenizedDataset
from align_anything.trainers.text_to_text.sft import SupervisedTrainer as SupervisedtextTrainer
from align_anything.utils.device_utils import torch_set_device
from align_anything.utils.multi_process import get_current_device
from align_anything.utils.tools import (
custom_cfgs_to_dict,
dict_to_namedtuple,
read_cfgs,
seed_everything,
update_dict,
)
transformers.logging.set_verbosity_info()
class SuperviseTrainer(SupervisedtextTrainer):
def init_datasets(self) -> None:
"""Initialize training and evaluation datasets."""
self.train_dataloader, self.eval_dataloader = self.get_dataloaders(
SupervisedTokenizedDataset, SupervisedTokenizedDataset
) # change to SupervisedTokenizedDataset, SupervisedTokenizedDataset in case of image input
def update_configs(self, model_config, args, fields):
cross_update = lambda a, b, field_name: (
setattr(b, field_name, getattr(a, field_name))
if getattr(b, field_name, None) is None
else setattr(a, field_name, getattr(b, field_name))
)
for f in fields:
cross_update(model_config, args, f)
def init_models(self) -> None:
"""Initialize model and tokenizer."""
self.model = MultiModalityCausalLM.from_pretrained(
self.cfgs.model_cfgs.model_name_or_path,
).to(get_current_device())
# for name, param in self.model.named_parameters():
# if '2' in name:
# param.requires_grad = True
# print(f"Trainable: {name}")
# else:
# param.requires_grad = False
# print(name)
# param.requires_grad = False
# print('#########################',self.model)
if self.cfgs.train_cfgs.bf16:
self.model = self.model.to(torch.bfloat16)
self.processor = VLChatProcessor.from_pretrained(
self.cfgs.model_cfgs.model_name_or_path,
)
self.tokenizer = self.processor.tokenizer
def loss(self, sft_batch: SupervisedBatch) -> dict[str, torch.Tensor]:
"""Loss function for supervised finetuning."""
print("sft_batch", sft_batch.keys())
sft_batch['task'] = 'image_editing'
print('SSSS ',sft_batch['source_image'])
outputs = self.model.forward(vl_chat_processor=self.processor,**sft_batch)
return {
'loss': outputs.loss,
}
def main():
# setup distribution training
deepspeed.init_distributed()
current_device = get_current_device()
torch_set_device(current_device)
# read default configs from the yaml file
task = os.path.join('janus', 'sft_gen')
dict_cfgs, ds_cfgs = read_cfgs(mode='train', task=task)
# get custom configs from command line
parser = argparse.ArgumentParser(formatter_class=argparse.ArgumentDefaultsHelpFormatter)
_, unparsed_args = parser.parse_known_args()
keys = [k[2:] for k in unparsed_args[1::2]]
values = list(unparsed_args[2::2])
unparsed_args = dict(zip(keys, values))
for k, v in unparsed_args.items():
dict_cfgs = update_dict(dict_cfgs, custom_cfgs_to_dict(k, v))
# setup training
cfgs = dict_to_namedtuple(dict_cfgs)
seed_everything(cfgs.train_cfgs.seed)
# finetune the model
trainer = SuperviseTrainer(cfgs=cfgs, ds_cfgs=ds_cfgs)
trainer.train()
trainer.save()
if __name__ == '__main__':
sys.exit(main())