| | |
| | import json |
| | import logging |
| | import os |
| |
|
| | import torch |
| | from datasets import Dataset as HFDataset |
| | from datasets import DatasetDict, load_from_disk |
| | from mmengine import print_log |
| | from mmengine.config import Config, ConfigDict |
| | from PIL import Image |
| | from torch.utils.data import Dataset |
| |
|
| |
|
| | from xtuner.registry import BUILDER |
| | from xtuner.dataset.utils import expand2square |
| | from .encode_fns import encode_fn |
| | from xtuner.dataset.llava import load_jsonl |
| |
|
| | from xtuner.dataset.huggingface import build_origin_dataset |
| |
|
| | import copy |
| | import numpy as np |
| |
|
| | from projects.omg_llava.dataset.utils import expand2square_mask |
| | import torch.nn.functional as F |
| |
|
| |
|
| | class RegionLLaVALazyDataset(Dataset): |
| |
|
| | def __init__(self, |
| | image_folder, |
| | image_processor, |
| | data_path=None, |
| | tokenizer=None, |
| | offline_processed_text_folder=None, |
| | max_dataset_length=None, |
| | dataset_map_fn=None, |
| | template_map_fn=None, |
| | max_length=2048, |
| | pad_image_to_square=False, |
| | lazy=False, |
| | feat_down_ratio=14, |
| | repeats=1, |
| | ): |
| | super().__init__() |
| |
|
| | assert offline_processed_text_folder or (data_path and tokenizer) |
| | if offline_processed_text_folder and data_path: |
| | print_log( |
| | 'Both `offline_processed_text_folder` and ' |
| | '`data_path` are set, and we load dataset from' |
| | '`offline_processed_text_folder` ' |
| | f'({offline_processed_text_folder})', |
| | logger='current', |
| | level=logging.WARNING) |
| |
|
| | if offline_processed_text_folder is not None: |
| | self.text_data = load_from_disk(offline_processed_text_folder) |
| | else: |
| | print("Loading {}!!!".format(data_path)) |
| | if data_path.endswith('.json'): |
| | json_data = json.load(open(data_path)) |
| | elif data_path.endswith('.jsonl'): |
| | json_data = load_jsonl(data_path) |
| | else: |
| | raise NotImplementedError |
| | print("Loaded {}!!!".format(data_path)) |
| |
|
| | for idx in range(len(json_data)): |
| | if "id" in json_data[idx].keys() and isinstance(json_data[idx]['id'], int): |
| | json_data[idx]['id'] = str(json_data[idx]['id']) |
| | json_data = DatasetDict({'train': HFDataset.from_list(json_data)}) |
| |
|
| | assert max_dataset_length is None, "max_dataset_length is not supported in Lazy mode" |
| | self.text_data = build_origin_dataset(json_data, 'train') |
| |
|
| | self.image_folder = image_folder |
| | if isinstance(image_processor, dict) or isinstance( |
| | image_processor, Config) or isinstance(image_processor, |
| | ConfigDict): |
| | self.image_processor = BUILDER.build(image_processor) |
| | else: |
| | self.image_processor = image_processor |
| | self.image_size = self.image_processor.crop_size |
| | assert self.image_size['height'] == self.image_size['width'] |
| | self.image_size = self.image_size['height'] |
| |
|
| | self.feat_size = self.image_size // feat_down_ratio |
| | self.pad_image_to_square = pad_image_to_square |
| |
|
| | |
| | self.lazy = lazy |
| | if lazy: |
| | self.tokenizer = tokenizer |
| | if isinstance(self.tokenizer, dict) or isinstance(self.tokenizer, Config) or isinstance(self.tokenizer, ConfigDict): |
| | self.tokenizer = BUILDER.build(self.tokenizer) |
| | self.max_length = max_length |
| |
|
| | self.dataset_map_fn = dataset_map_fn |
| | if isinstance(template_map_fn, dict) or isinstance(template_map_fn, Config) or isinstance( |
| | template_map_fn, ConfigDict): |
| | template_map_fn = BUILDER.build(template_map_fn) |
| | self.template_map_fn = template_map_fn |
| | self.repeats = repeats |
| |
|
| | @property |
| | def modality_length(self): |
| | if self.lazy: |
| | length_list = [1000] * len(self.text_data) * self.repeats |
| | return length_list |
| |
|
| | length_list = [] |
| | for data_dict in self.text_data: |
| | if 'input_ids' in data_dict.keys(): |
| | cur_len = len(data_dict['input_ids']) |
| | else: |
| | cur_len = 1000 |
| | if data_dict.get('image', None) is None: |
| | cur_len = -cur_len |
| | length_list.append(cur_len) |
| | return length_list |
| |
|
| | def __len__(self): |
| | return len(self.text_data) * self.repeats |
| |
|
| | def __getitem__(self, index): |
| | index = index % len(self.text_data) |
| | data_dict = copy.deepcopy(self.text_data[index]) |
| | if 'image' not in data_dict.keys() and 'image_name' in data_dict.keys(): |
| | data_dict['image'] = data_dict['image_name'] |
| | if data_dict.get('image', None) is not None: |
| | image_file = data_dict['image'] |
| | image = Image.open(os.path.join(self.image_folder, |
| | image_file)).convert('RGB') |
| | if self.pad_image_to_square: |
| | image = expand2square( |
| | image, |
| | tuple( |
| | int(x * 255) for x in self.image_processor.image_mean)) |
| | image = self.image_processor.preprocess( |
| | image, return_tensors='pt')['pixel_values'][0] |
| | data_dict['pixel_values'] = image |
| | else: |
| | if hasattr(self.image_processor, 'crop_size'): |
| | crop_size = self.image_processor.crop_size |
| | else: |
| | crop_size = self.image_processor.size |
| | data_dict['pixel_values'] = torch.zeros(3, crop_size['height'], |
| | crop_size['width']) |
| | result = self.dataset_map_fn(data_dict) |
| | data_dict.update(result) |
| |
|
| | result = self.template_map_fn(data_dict) |
| | data_dict.update(result) |
| |
|
| | result = encode_fn(data_dict, tokenizer=self.tokenizer, max_length=self.max_length, with_image_token=True) |
| | data_dict.update(result) |
| |
|
| | if 'region_masks' not in data_dict.keys(): |
| | region_masks = np.ones((1, self.feat_size, self.feat_size), dtype=np.uint8) |
| | region_masks = torch.from_numpy(region_masks) |
| | else: |
| | region_masks = data_dict['region_masks'] |
| | if self.pad_image_to_square: |
| | region_masks = expand2square_mask(region_masks) |
| | region_masks = torch.from_numpy(region_masks) |
| | region_masks = F.interpolate(region_masks.unsqueeze(0), |
| | size=(self.feat_size, self.feat_size), mode='nearest').squeeze(0) |
| |
|
| | data_dict['region_masks'] = region_masks |
| | return data_dict |
| |
|
| | class CombineDataset(Dataset): |
| | def __init__(self, |
| | datasets_cfgs, |
| | ): |
| | super().__init__() |
| |
|
| | self.datasets = [] |
| | self.datasets_length = [] |
| |
|
| | self.tokenizer = datasets_cfgs[0].tokenizer |
| | tokenizer_type = self.tokenizer['type'] |
| | del self.tokenizer['type'] |
| | self.tokenizer = tokenizer_type(**self.tokenizer) |
| |
|
| | for i in range(len(datasets_cfgs)): |
| | datasets_cfgs[i].tokenizer = self.tokenizer |
| |
|
| | for dataset_cfg in datasets_cfgs: |
| | dataset = dataset_cfg['type'] |
| | del dataset_cfg['type'] |
| | dataset = dataset(**dataset_cfg) |
| | self.datasets.append(dataset) |
| | self.datasets_length.append(len(dataset)) |
| |
|
| | self.dataset_threthold = [] |
| | for i, length in enumerate(self.datasets_length): |
| | if i == 0: |
| | self.dataset_threthold.append(length) |
| | else: |
| | self.dataset_threthold.append(length + self.dataset_threthold[i - 1]) |
| |
|
| | np.random.seed(42) |
| | self.shuffled_index = np.arange(self.dataset_threthold[-1]) |
| | np.random.shuffle(self.shuffled_index) |
| |
|
| | @property |
| | def modality_length(self): |
| | length_list = [] |
| | for dataset in self.datasets: |
| | length_list += dataset.modality_length |
| | return length_list |
| |
|
| | def __len__(self): |
| | return self.dataset_threthold[-1] |
| |
|
| | def __getitem__(self, index): |
| | index = int(self.shuffled_index[index]) |
| | for i, thred in enumerate(self.dataset_threthold): |
| | if index < thred: |
| | break |
| | if i == 0: |
| | _index = index |
| | else: |
| | _index = index - self.dataset_threthold[i - 1] |
| |
|
| | return self.datasets[i][_index] |