| | import logging |
| |
|
| | import torch |
| | import datasets |
| | import cv2 |
| |
|
| | import numpy as np |
| | from base64 import b64decode |
| | from io import BytesIO |
| | from PIL import Image |
| | from torch.utils.data import ConcatDataset |
| | from llava.datasets.builder import DATASETS |
| | from typing import Dict, Optional, Sequence, List |
| | from llava.datasets.data_cfgs import data_configs |
| | from llava.datasets.base_dataset import ImageTaskDataset |
| | from llava.constants import DEFAULT_IMAGE_TOKEN, DEFAULT_VIDEO_TOKEN |
| | from llava.utils import master_print |
| |
|
| |
|
| | class M3ITDataset(ImageTaskDataset): |
| | def __init__(self, anno_path, data_args=None, name='m3it', selected_tasks=None): |
| | super().__init__(anno_path, data_args, name) |
| |
|
| | self.selected_tasks = selected_tasks |
| | dataset_list = [ |
| | datasets.load_dataset("MMInstruction/M3IT", i, num_proc=16) for i in selected_tasks |
| | ] |
| | |
| | target_dataset_list = [] |
| | master_print('#' * 50) |
| | for d in dataset_list: |
| | try: |
| | target_dataset_list.append(d['train']) |
| | master_print(f"TASK {d['train']._info.config_name}, SIZE {len(d['train'])}") |
| | except KeyError: |
| | print(f"{d['train']._info.config_name} has no train set.") |
| | self.dataset = ConcatDataset(target_dataset_list) |
| | master_print(f"Finished loading dataset {name} {len(self.dataset)} samples...") |
| |
|
| | def __len__(self): |
| | return len(self.dataset) |
| |
|
| | def text_preprocess(self, item, is_video=False) -> List[Dict[str, str]]: |
| | instruction = item['instruction'] |
| | question = item['inputs'] |
| | answer = item['outputs'] |
| |
|
| | query = f"{instruction} {DEFAULT_IMAGE_TOKEN if not is_video else DEFAULT_VIDEO_TOKEN}" |
| | if len(question) > 0: |
| | query += question |
| |
|
| | conversations = [ |
| | { |
| | 'from': 'human', |
| | 'value': query |
| | }, |
| | { |
| | 'from': 'model', |
| | 'value': answer |
| | } |
| | ] |
| |
|
| | return conversations |
| |
|
| | def bin2image(self, image_base64_str): |
| | img = Image.open(BytesIO(b64decode(image_base64_str))).convert("RGB") |
| | img = np.array(img) |
| |
|
| | if img.shape[2] != 3: |
| | img = cv2.cvtColor(img, cv2.COLOR_GRAY2RGB) |
| |
|
| | img = Image.fromarray(img).convert('RGB') |
| | img = self.preprocess_image(img) |
| |
|
| | return img |
| |
|
| | def vis_preprocess(self, image_base64_str_list) -> Image: |
| | try: |
| | images = list(map(self.bin2image, image_base64_str_list)) |
| | formatted_images = [] |
| | for image in images: |
| | if isinstance(image, list): |
| | formatted_images.extend(image) |
| | else: |
| | formatted_images.append(image) |
| | return formatted_images |
| | except Exception as e: |
| | |
| | return None |
| |
|
| | def __getitem__(self, i) -> Dict[str, torch.Tensor]: |
| | item = self.dataset[i] |
| |
|
| | img_data = item['image_base64_str'] |
| |
|
| | images = self.vis_preprocess(img_data) |
| | if images is None: |
| | return None |
| |
|
| | |
| | is_video = True if len(images) > 0 else False |
| |
|
| | ret = { |
| | 'images': images, |
| | 'conversations': self.text_preprocess(item, is_video) |
| | } |
| |
|
| | return ret |
| |
|
| |
|
| | @DATASETS.register_obj |
| | def m3it(data_args): |
| | tasks = data_configs['m3it']['default_tasks'] |
| | if 'tasks' in data_args.external_args: |
| | tasks = data_args.external_args['tasks'] |
| |
|
| | return M3ITDataset(anno_path=None, data_args=data_args, selected_tasks=tasks) |
| |
|