|
|
| from llava.datasets.builder import DATASETS |
|
|
| from typing import Dict, Optional, Sequence, List |
| from llava.datasets.data_cfgs import data_configs |
| from llava.datasets.base_dataset import FramesTaskDataset |
| from llava.datasets.data_cfgs import data_configs |
| import pickle |
| from pathlib import Path |
| import random |
| import numpy as np |
| from llava.datasets.prompts import tt_caption_prompt, internvid_prompt |
| from llava.constants import DEFAULT_VIDEO_TOKEN |
| from PIL import Image |
| import json |
| import torch |
| import os |
|
|
|
|
| class LKVideoDataset(FramesTaskDataset): |
| def __init__(self, anno_path=None, data_args=None, fps=1.0, conv_type='multi', select_datasets=None, name='lk_video'): |
| self.default_fps = 1.0 |
| self.fps = fps |
| self.conv_type = conv_type |
| self.select_datasets = select_datasets |
| self.annotation = self.get_dataset(anno_path) |
| |
| assert self.conv_type in ('multi'), "lk_video conv type must be multi" |
| |
| |
| super().__init__(anno_path=anno_path, |
| data_args=data_args, |
| fps=fps, |
| name=name) |
| def __len__(self): |
| return len(self.annotation) |
|
|
|
|
| def get_dataset(self, anno_path): |
| anno_path = Path(anno_path) |
| with anno_path.open('rb') as f: |
| data = json.load(f) |
| |
| if self.select_datasets is not None: |
| filtered_data = [] |
| for sample in data: |
| video_path = Path(sample['video']) |
| dataset_name = video_path.parent.name |
| if dataset_name in self.select_datasets: |
| filtered_data.append(sample) |
| data = filtered_data |
|
|
| return data |
|
|
|
|
| def text_preprocess(self, item) -> List[Dict[str, str]]: |
| return item['conversations'] |
|
|
|
|
| def __getitem__(self, i) -> Dict[str, torch.Tensor]: |
| item = self.annotation[i] |
|
|
| ret = { |
| 'images': self.vis_preprocess(item['video']), |
| 'conversations': self.text_preprocess(item) |
| } |
| if 'id' in item: |
| ret['id'] = item['id'] |
|
|
| return ret |
|
|
|
|
| @staticmethod |
| def _sample_frames(frames, num_segments): |
| indices = np.linspace(start=0, stop=len(frames) - 1, num=num_segments).astype(int) |
|
|
| frames = [frames[ind] for ind in indices] |
|
|
| return frames |
|
|
| def vis_preprocess(self, vis_path): |
| image_files = [] |
| for img_path in os.listdir(vis_path): |
| if img_path.endswith('.jpeg'): |
| img_idx = int(img_path.split('_')[-1][:-5]) |
| image_files.append((img_idx, img_path)) |
| |
| image_files = sorted(image_files, key=lambda img: img[0]) |
| |
| if len(image_files) > 10: |
| image_files = self._sample_frames(image_files, 10) |
| if self.num_segments > 0 and len(image_files) > self.num_segments: |
| image_files = self._sample_frames(image_files, self.num_segments) |
| |
| images = [] |
| for image_file in image_files: |
| try: |
| images.append(Image.open(os.path.join(vis_path, image_file[1])).convert('RGB')) |
| except Exception as e: |
| continue |
| formatted_images = [] |
| for image in images: |
| im = self.preprocess_image(image) |
| if isinstance(im, list): |
| formatted_images.extend(im) |
| else: |
| formatted_images.append(im) |
| return formatted_images |
|
|
|
|
| @DATASETS.register_obj |
| def lk_video(data_args): |
| data_cfg = data_configs['lk_video'] |
| fps, conv_type = data_args.external_args['fps'], data_args.external_args['conv_type'] |
| select_datasets = data_args.external_args['select_datasets'] if 'select_datasets' in data_args.external_args else None |
| return LKVideoDataset(data_cfg['train_data_path'], data_args, fps, conv_type, select_datasets=select_datasets) |
|
|
|
|
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
|
|