|
|
| from llava.datasets.builder import DATASETS |
|
|
| from typing import Dict, Optional, Sequence, List |
| from llava.datasets.data_cfgs import data_configs |
| from llava.datasets.base_dataset import FramesTaskDataset |
| from llava.datasets.data_cfgs import data_configs |
| import pickle |
| from pathlib import Path |
| import random |
| import numpy as np |
| from llava.datasets.prompts import tt_caption_prompt, internvid_prompt |
| from llava.constants import DEFAULT_VIDEO_TOKEN |
| from PIL import Image |
| import json |
| import torch |
| import os |
|
|
|
|
| class PromptV1Dataset(FramesTaskDataset): |
| def __init__(self, anno_path=None, data_args=None, name='promptv1_2_internal', task_types=None): |
| self.default_fps = 1.0 |
| self.task_types = task_types |
| self.annotation = self.get_dataset(anno_path) |
| super().__init__(anno_path=anno_path, |
| data_args=data_args, |
| name=name) |
| def __len__(self): |
| return len(self.annotation) |
|
|
|
|
| def get_dataset(self, anno_path): |
| dataset = [] |
| anno_path = Path(anno_path) |
| with anno_path.open('rb') as f: |
| data = json.load(f) |
| for info in data: |
| for task_type in self.task_types: |
| info_task = info.copy() |
| if task_type not in info or len(info_task[task_type]) == 0: |
| continue |
| if task_type == 'qas' and self.conv_type == 'single': |
| for qa_pair in info_task[task_type]: |
| one_info = info_task.copy() |
| one_info[task_type] = [qa_pair] |
| one_info.update({ |
| 'task_type': task_type |
| }) |
| dataset.append(one_info) |
| else: |
| info_task.update({ |
| 'task_type': task_type |
| }) |
| dataset.append(info_task) |
| return dataset |
|
|
|
|
| def text_preprocess(self, item) -> List[Dict[str, str]]: |
| all_convs = [] |
| if hasattr(self.data_args, 'caption_prompt'): |
| cap_prompt = eval(self.data_args.caption_prompt) |
| else: |
| cap_prompt = tt_caption_prompt |
| if item['task_type'] == 'refine_caption': |
| all_convs.append([ |
| { |
| 'from': 'human', |
| 'value': random.choice(cap_prompt) |
| }, |
| { |
| 'from': 'model', |
| 'value': item['refine_caption'] |
| } |
| ]) |
| else: |
| for idx, qa in enumerate(item['qas']): |
| all_convs.append([ |
| { |
| 'from': 'human', |
| 'value': qa['q'] |
| }, |
| { |
| 'from': 'model', |
| 'value': qa['a'] |
| } |
| ]) |
|
|
| conversations = [] |
| random.shuffle(all_convs) |
| for idx, conv in enumerate(all_convs): |
| if idx == 0: |
| conv[0]['value'] = DEFAULT_VIDEO_TOKEN + conv[0]['value'] |
| conversations.extend(conv) |
| return conversations |
|
|
|
|
|
|
| |
| |
|
|
| |
| |
| |
| |
| |
| |
|
|
| |
|
|
|
|
| |
| |
| |
|
|
| |
|
|
| |
|
|
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
|
|
|
|
| @DATASETS.register_obj |
| def promptv1_2_internal(data_args): |
| data_cfg = data_configs['promptv1_2_internal'] |
| task_types = data_args.external_args['task_types'] |
| return PromptV1Dataset(anno_path=data_cfg['train_data_path'], data_args=data_args, task_types=task_types) |
|
|
| |