# Copyright 2025 PKU-Alignment Team. All Rights Reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. # ============================================================================== """Trainer for supervised training.""" import argparse import os import sys import deepspeed import torch import transformers from janus.models import MultiModalityCausalLM, VLChatProcessor, VLMImageProcessor from align_anything.datasets.janus import SupervisedBatch, SupervisedDataset, SupervisedTokenizedDataset from align_anything.trainers.text_to_text.sft import SupervisedTrainer as SupervisedtextTrainer from align_anything.utils.device_utils import torch_set_device from align_anything.utils.multi_process import get_current_device from align_anything.utils.tools import ( custom_cfgs_to_dict, dict_to_namedtuple, read_cfgs, seed_everything, update_dict, ) transformers.logging.set_verbosity_info() class SuperviseTrainer(SupervisedtextTrainer): def init_datasets(self) -> None: """Initialize training and evaluation datasets.""" self.train_dataloader, self.eval_dataloader = self.get_dataloaders( SupervisedTokenizedDataset, SupervisedTokenizedDataset ) # change to SupervisedTokenizedDataset, SupervisedTokenizedDataset in case of image input def update_configs(self, model_config, args, fields): cross_update = lambda a, b, field_name: ( setattr(b, field_name, getattr(a, field_name)) if getattr(b, field_name, None) is None else setattr(a, field_name, getattr(b, field_name)) ) for f in fields: cross_update(model_config, args, f) def init_models(self) -> None: """Initialize model and tokenizer.""" self.model = MultiModalityCausalLM.from_pretrained( self.cfgs.model_cfgs.model_name_or_path, ).to(get_current_device()) # for name, param in self.model.named_parameters(): # if '2' in name: # param.requires_grad = True # print(f"Trainable: {name}") # else: # param.requires_grad = False # print(name) # param.requires_grad = False # print('#########################',self.model) if self.cfgs.train_cfgs.bf16: self.model = self.model.to(torch.bfloat16) self.processor = VLChatProcessor.from_pretrained( self.cfgs.model_cfgs.model_name_or_path, ) self.tokenizer = self.processor.tokenizer def loss(self, sft_batch: SupervisedBatch) -> dict[str, torch.Tensor]: """Loss function for supervised finetuning.""" print("sft_batch", sft_batch.keys()) sft_batch['task'] = 'image_editing' print('SSSS ',sft_batch['source_image']) outputs = self.model.forward(vl_chat_processor=self.processor,**sft_batch) return { 'loss': outputs.loss, } def main(): # setup distribution training deepspeed.init_distributed() current_device = get_current_device() torch_set_device(current_device) # read default configs from the yaml file task = os.path.join('janus', 'sft_gen') dict_cfgs, ds_cfgs = read_cfgs(mode='train', task=task) # get custom configs from command line parser = argparse.ArgumentParser(formatter_class=argparse.ArgumentDefaultsHelpFormatter) _, unparsed_args = parser.parse_known_args() keys = [k[2:] for k in unparsed_args[1::2]] values = list(unparsed_args[2::2]) unparsed_args = dict(zip(keys, values)) for k, v in unparsed_args.items(): dict_cfgs = update_dict(dict_cfgs, custom_cfgs_to_dict(k, v)) # setup training cfgs = dict_to_namedtuple(dict_cfgs) seed_everything(cfgs.train_cfgs.seed) # finetune the model trainer = SuperviseTrainer(cfgs=cfgs, ds_cfgs=ds_cfgs) trainer.train() trainer.save() if __name__ == '__main__': sys.exit(main())