| | import os |
| | import torch |
| | import random |
| | import json |
| | from pathlib import Path |
| | from llava.datasets.builder import DATASETS |
| |
|
| | from typing import Dict, Optional, Sequence, List |
| | from llava.datasets.data_cfgs import data_configs |
| | from llava.datasets.base_dataset import FramesTaskDataset |
| | from llava.datasets.prompts import tt_caption_prompt, ocr_prompt |
| | from llava.constants import DEFAULT_VIDEO_TOKEN |
| |
|
| |
|
| | class SyntheticOCRDataset(FramesTaskDataset): |
| | def __init__(self, anno_path, data_args=None, fps=2.0, name='synthetic_ocr'): |
| | super().__init__(anno_path=anno_path, |
| | data_args=data_args, |
| | fps=fps, |
| | name=name) |
| | self.default_fps = 0.1 |
| |
|
| |
|
| | def __getitem__(self, i) -> Dict[str, torch.Tensor]: |
| | item = self.annotation[i] |
| |
|
| | ret = { |
| | 'images': self.vis_preprocess(item['video_path']), |
| | 'conversations': self.text_preprocess(item) |
| | } |
| | if 'id' in item: |
| | ret['id'] = item['id'] |
| |
|
| | return ret |
| |
|
| | def text_preprocess(self, item) -> List[Dict[str, str]]: |
| | all_convs = [] |
| | if hasattr(self.data_args, 'caption_prompt'): |
| | cap_prompt = eval(self.data_args.caption_prompt) |
| | else: |
| | cap_prompt = tt_caption_prompt |
| | |
| | conversations = [] |
| | conversations.extend([ |
| | { |
| | 'from': 'human', |
| | 'value': DEFAULT_VIDEO_TOKEN + random.choice(cap_prompt) |
| | }, |
| | { |
| | 'from': 'model', |
| | 'value': item['gpt_caption'] + ' ' + random.choice(ocr_prompt) + ','.join(item['ocr_list']) |
| | } |
| | ]) |
| | return conversations |
| |
|
| |
|
| | @DATASETS.register_obj |
| | def synthetic_ocr(data_args): |
| | train_data_path = None |
| | if 'train_data_path' in data_args.external_args: |
| | train_data_path = data_args.external_args['train_data_path'] |
| | else: |
| | train_data_path = data_configs["synthetic_ocr"]['train_data_path'] |
| | return SyntheticOCRDataset(train_data_path, data_args, 2.0) |
| |
|
| | if __name__ == '__main__': |
| | with open('/mnt/bn/algo-masp-nas-2/xiangchen/dataset/masp/synthetic_ocr/train_filtered.json') as f: |
| | data = json.load(f) |
| | |
| | for sample in data: |
| | res = sample['gpt_caption'] + ' ' + random.choice(ocr_prompt) + ','.join(sample['ocr_list']) |
| | |