| import os | |
| import torch | |
| import random | |
| import json | |
| from pathlib import Path | |
| from llava.datasets.builder import DATASETS | |
| from typing import Dict, Optional, Sequence, List | |
| from llava.datasets.data_cfgs import data_configs | |
| from llava.datasets.base_dataset import FramesTaskDataset | |
| from llava.datasets.prompts import tt_caption_prompt, ocr_prompt | |
| from llava.constants import DEFAULT_VIDEO_TOKEN | |
| class SyntheticOCRDataset(FramesTaskDataset): | |
| def __init__(self, anno_path, data_args=None, fps=2.0, name='synthetic_ocr'): | |
| super().__init__(anno_path=anno_path, | |
| data_args=data_args, | |
| fps=fps, | |
| name=name) | |
| self.default_fps = 0.1 | |
| def __getitem__(self, i) -> Dict[str, torch.Tensor]: | |
| item = self.annotation[i] | |
| ret = { | |
| 'images': self.vis_preprocess(item['video_path']), | |
| 'conversations': self.text_preprocess(item) | |
| } | |
| if 'id' in item: | |
| ret['id'] = item['id'] | |
| return ret | |
| def text_preprocess(self, item) -> List[Dict[str, str]]: | |
| all_convs = [] | |
| if hasattr(self.data_args, 'caption_prompt'): | |
| cap_prompt = eval(self.data_args.caption_prompt) | |
| else: | |
| cap_prompt = tt_caption_prompt | |
| conversations = [] | |
| conversations.extend([ | |
| { | |
| 'from': 'human', | |
| 'value': DEFAULT_VIDEO_TOKEN + random.choice(cap_prompt) | |
| }, | |
| { | |
| 'from': 'model', | |
| 'value': item['gpt_caption'] + ' ' + random.choice(ocr_prompt) + ','.join(item['ocr_list']) | |
| } | |
| ]) | |
| return conversations | |
| def synthetic_ocr(data_args): | |
| train_data_path = None | |
| if 'train_data_path' in data_args.external_args: | |
| train_data_path = data_args.external_args['train_data_path'] | |
| else: | |
| train_data_path = data_configs["synthetic_ocr"]['train_data_path'] | |
| return SyntheticOCRDataset(train_data_path, data_args, 2.0) | |
| if __name__ == '__main__': | |
| with open('/mnt/bn/algo-masp-nas-2/xiangchen/dataset/masp/synthetic_ocr/train_filtered.json') as f: | |
| data = json.load(f) | |
| for sample in data: | |
| res = sample['gpt_caption'] + ' ' + random.choice(ocr_prompt) + ','.join(sample['ocr_list']) | |
| # print(res) |