| import os |
| import random |
| import json |
| from pathlib import Path |
| from llava.datasets.builder import DATASETS |
|
|
| from typing import Dict, Optional, Sequence, List |
| from llava.datasets.data_cfgs import data_configs |
| from llava.datasets.base_dataset import FramesTaskDataset |
| from llava.datasets.prompts import tt_caption_prompt, tt_caption_prompt2 |
| from llava.constants import DEFAULT_VIDEO_TOKEN |
|
|
|
|
| class TTVqaDataset(FramesTaskDataset): |
| def __init__(self, anno_path, data_args=None, fps=2.0, data_cfgs=None, name='tt_vqa'): |
| super().__init__(anno_path=anno_path, |
| data_args=data_args, |
| fps=fps, |
| name=name) |
| self.default_fps = data_cfgs['fps'] |
|
|
|
|
| def text_preprocess(self, item) -> List[Dict[str, str]]: |
| all_convs = [] |
| if hasattr(self.data_args, 'caption_prompt'): |
| cap_prompt = eval(self.data_args.caption_prompt) |
| else: |
| cap_prompt = tt_caption_prompt |
| if 'caption' in item: |
| all_convs.append([ |
| { |
| 'from': 'human', |
| 'value': random.choice(cap_prompt) |
| }, |
| { |
| 'from': 'model', |
| 'value': item['caption'] |
| } |
| ]) |
| if 'qas' in item: |
| for idx, qa in enumerate(item['qas']): |
| all_convs.append([ |
| { |
| 'from': 'human', |
| 'value': qa['q'] |
| }, |
| { |
| 'from': 'model', |
| 'value': qa['a'] |
| } |
| ]) |
|
|
| conversations = [] |
| random.shuffle(all_convs) |
| for idx, conv in enumerate(all_convs): |
| if idx == 0: |
| conv[0]['value'] = DEFAULT_VIDEO_TOKEN + conv[0]['value'] |
| conversations.extend(conv) |
|
|
| return conversations |
|
|
|
|
| @DATASETS.register_obj |
| def tt_vqa(data_args): |
| train_data_path = None |
| if 'train_data_path' in data_args.external_args: |
| train_data_path = data_args.external_args['train_data_path'] |
| else: |
| train_data_path = data_configs["tt_vqa"]['train_data_path'] |
| return TTVqaDataset(train_data_path, data_args, 2.0, data_configs["tt_vqa"]) |
|
|
|
|