| from llava.datasets.builder import DATASETS | |
| from typing import Dict, Optional, Sequence, List | |
| from llava.datasets.data_cfgs import data_configs | |
| from llava.datasets.base_dataset import ImageTaskDataset | |
| from llava.constants import DEFAULT_IMAGE_TOKEN | |
| class LLaVAPretrainDataset(ImageTaskDataset): | |
| def __init__(self, anno_path, data_args=None, name='llava_pretrain'): | |
| super().__init__(anno_path=anno_path, | |
| data_args=data_args, | |
| name=name) | |
| def text_preprocess(self, item) -> List[Dict[str, str]]: | |
| qas = item['qas'] | |
| conversations = [] | |
| for qa in qas: | |
| conv = [ | |
| { | |
| 'from': 'human', | |
| 'value': DEFAULT_IMAGE_TOKEN + qa['q'] | |
| }, | |
| { | |
| 'from': 'model', | |
| 'value': qa['a'] | |
| } | |
| ] | |
| conversations.extend(conv) | |
| return conversations | |
| def llava_pretrain(data_args): | |
| return LLaVAPretrainDataset(data_configs["llava_pretrain"]['train_data_path'], data_args) |