| from functools import partial |
| from typing import Tuple, Dict, Any, Type |
|
|
| from transformers.trainer import DataCollator |
|
|
| from .shikra import ShikraTrainer |
| from .base_engine import TrainerForMMLLM, Seq2Seq2DataCollatorWithImage |
|
|
| TYPE2TRAINER = { |
| 'shikra': ShikraTrainer, |
| } |
|
|
|
|
| def prepare_trainer_collator( |
| model_args, |
| preprocessor: Dict[str, Any], |
| collator_kwargs: Dict[str, Any] |
| ) -> Tuple[Type[TrainerForMMLLM], Dict[str, DataCollator]]: |
| type_ = model_args.type |
| trainer_cls = TYPE2TRAINER[type_] |
| data_collator_func = partial( |
| Seq2Seq2DataCollatorWithImage, |
| preprocessor=preprocessor, |
| **collator_kwargs, |
| ) |
| data_collator_dict = { |
| "train_collator": data_collator_func(inference_mode=False), |
| "eval_collator": data_collator_func(inference_mode=True), |
| } |
| return trainer_cls, data_collator_dict |
|
|