| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| """Trainer for supervised training.""" |
|
|
|
|
| import argparse |
| import os |
| import sys |
|
|
| import deepspeed |
| import torch |
| import transformers |
| from janus.models import MultiModalityCausalLM, VLChatProcessor, VLMImageProcessor |
|
|
| from align_anything.datasets.janus import SupervisedBatch, SupervisedDataset, SupervisedTokenizedDataset |
| from align_anything.trainers.text_to_text.sft import SupervisedTrainer as SupervisedtextTrainer |
| from align_anything.utils.device_utils import torch_set_device |
| from align_anything.utils.multi_process import get_current_device |
| from align_anything.utils.tools import ( |
| custom_cfgs_to_dict, |
| dict_to_namedtuple, |
| read_cfgs, |
| seed_everything, |
| update_dict, |
| ) |
|
|
|
|
| transformers.logging.set_verbosity_info() |
|
|
|
|
| class SuperviseTrainer(SupervisedtextTrainer): |
|
|
| def init_datasets(self) -> None: |
| """Initialize training and evaluation datasets.""" |
| self.train_dataloader, self.eval_dataloader = self.get_dataloaders( |
| SupervisedTokenizedDataset, SupervisedTokenizedDataset |
| ) |
|
|
| def update_configs(self, model_config, args, fields): |
| cross_update = lambda a, b, field_name: ( |
| setattr(b, field_name, getattr(a, field_name)) |
| if getattr(b, field_name, None) is None |
| else setattr(a, field_name, getattr(b, field_name)) |
| ) |
|
|
| for f in fields: |
| cross_update(model_config, args, f) |
|
|
| def init_models(self) -> None: |
| """Initialize model and tokenizer.""" |
| self.model = MultiModalityCausalLM.from_pretrained( |
| self.cfgs.model_cfgs.model_name_or_path, |
| ).to(get_current_device()) |
| |
| |
| |
| |
| |
| |
|
|
| |
| |
| |
| if self.cfgs.train_cfgs.bf16: |
| self.model = self.model.to(torch.bfloat16) |
|
|
| self.processor = VLChatProcessor.from_pretrained( |
| self.cfgs.model_cfgs.model_name_or_path, |
| ) |
| self.tokenizer = self.processor.tokenizer |
|
|
| def loss(self, sft_batch: SupervisedBatch) -> dict[str, torch.Tensor]: |
| """Loss function for supervised finetuning.""" |
| print("sft_batch", sft_batch.keys()) |
| sft_batch['task'] = 'image_editing' |
| print('SSSS ',sft_batch['source_image']) |
| outputs = self.model.forward(vl_chat_processor=self.processor,**sft_batch) |
| return { |
| 'loss': outputs.loss, |
| } |
|
|
|
|
| def main(): |
| |
| deepspeed.init_distributed() |
| current_device = get_current_device() |
| torch_set_device(current_device) |
|
|
| |
| task = os.path.join('janus', 'sft_gen') |
| dict_cfgs, ds_cfgs = read_cfgs(mode='train', task=task) |
|
|
| |
| parser = argparse.ArgumentParser(formatter_class=argparse.ArgumentDefaultsHelpFormatter) |
| _, unparsed_args = parser.parse_known_args() |
| keys = [k[2:] for k in unparsed_args[1::2]] |
| values = list(unparsed_args[2::2]) |
| unparsed_args = dict(zip(keys, values)) |
| for k, v in unparsed_args.items(): |
| dict_cfgs = update_dict(dict_cfgs, custom_cfgs_to_dict(k, v)) |
|
|
| |
| cfgs = dict_to_namedtuple(dict_cfgs) |
| seed_everything(cfgs.train_cfgs.seed) |
|
|
| |
| trainer = SuperviseTrainer(cfgs=cfgs, ds_cfgs=ds_cfgs) |
| trainer.train() |
| trainer.save() |
|
|
|
|
| if __name__ == '__main__': |
| sys.exit(main()) |
|
|