| from torch.utils.data import Dataset |
| from PIL import Image |
| import os |
| import json |
| import random |
| import torch |
| import numpy as np |
| from einops import rearrange |
| from xtuner.registry import BUILDER |
| from xtuner.dataset.utils import expand2square |
| from src.datasets.utils import crop2square, encode_fn, load_jsonl |
| from xtuner.utils import DEFAULT_IMAGE_TOKEN |
| from transformers import AutoImageProcessor |
|
|
|
|
| class VLMDataset(Dataset): |
| def __init__( |
| self, |
| data_path, |
| image_size, |
| tokenizer=None, |
| template_map_fn=None, |
| max_length=2048, |
| min_image_size=80, |
| pad_image=True, |
| local_folder="", |
| key_value="conversations", |
| ): |
| super().__init__() |
| self.data_path = data_path |
| self._load_data(data_path) |
| self.image_size = image_size |
|
|
| self.tokenizer = BUILDER.build(tokenizer) |
| self.prompt_template = template_map_fn["template"] |
| self.template_map_fn = BUILDER.build(template_map_fn) |
| self.max_length = max_length |
| self.pad_image = pad_image |
| self.min_image_size = min_image_size |
| self.key_value = key_value |
| self.processor = AutoImageProcessor.from_pretrained( |
| "checkpoint/siglip2-so400m-patch16-512" |
| ) |
| self.metainfo = {'task' :'unified'} |
| self.DEFAULT_IMAGE_TOKEN = DEFAULT_IMAGE_TOKEN |
| m = n = self.image_size // 16 |
| self.image_token_repeat = m * n + 64 |
|
|
| self.tokenizer.add_tokens(["<image>"], special_tokens=True) |
| self.image_token_idx = self.tokenizer.convert_tokens_to_ids("<image>") |
| print(f"Registered <image> token at index {self.image_token_idx}") |
|
|
| def _load_data( |
| self, data_path: str |
| ): |
| self.data_list = load_jsonl(data_path) |
| print(f"Load {len(self.data_list)} data samples from {data_path}", flush=True) |
|
|
| def full_init(self): |
| """Dummy full_init to be compatible with MMEngine ConcatDataset.""" |
| return |
| def __len__(self): |
| return len(self.data_list) |
|
|
| def _read_image(self, image_file): |
| image = Image.open(image_file) |
| assert ( |
| image.width > self.min_image_size and image.height > self.min_image_size |
| ), f"Image: {image.size}" |
| assert image.width / image.height > 0.1, f"Image: {image.size}" |
| assert image.width / image.height < 10, f"Image: {image.size}" |
| return image.convert("RGB") |
|
|
| |
| |
| |
| |
| |
| |
|
|
| |
| |
| |
| |
| |
| |
| |
| |
|
|
| |
| |
|
|
|
|
| def _process_image(self, image: Image.Image): |
| |
| if self.pad_image: |
| image = crop2square(image) |
| |
| image = image.resize((self.image_size, self.image_size)) |
| |
| arr = np.array(image).astype(np.float32) / 255.0 |
| arr = 2 * arr - 1 |
| tensor = torch.from_numpy(arr) |
| tensor = rearrange(tensor, "h w c -> c h w") |
| return {"pixel_values": tensor} |
| def _process_text(self, question, answer): |
| data_dict = dict( |
| conversation=[ |
| { |
| "input": f"{self.DEFAULT_IMAGE_TOKEN}\n{question}", |
| "output": answer, |
| } |
| ] |
| ) |
| data_dict.update(self.template_map_fn(data_dict)) |
| data_dict.update( |
| encode_fn( |
| example=data_dict, |
| tokenizer=self.tokenizer, |
| max_length=self.max_length, |
| image_length=self.image_token_repeat, |
| input_ids_with_output=True, |
| with_image_token=True, |
| truncation='right', |
| image_token_idx=self.image_token_idx, |
| image_token_str=self.DEFAULT_IMAGE_TOKEN, |
| ) |
| ) |
|
|
| |
| |
| |
|
|
| data_dict["type"] = "image2text" |
| return data_dict |
|
|
| def _retry(self): |
| return self.__getitem__(random.choice(range(self.__len__()))) |
|
|
| def __getitem__(self, idx): |
| try: |
| data_sample = self.data_list[idx] |
| image = self._read_image(data_sample["image"]).convert("RGB") |
| data = self._process_image(image) |
| del image |
| question = ( |
| data_sample[self.key_value][0]["value"] |
| .replace("<image>", "") |
| .strip() |
| ) |
| answer = ( |
| data_sample[self.key_value][1]["value"] |
| .replace("<image>", "") |
| .strip() |
| ) |
|
|
| data.update(self._process_text(question, answer)) |
|
|
| data.update(image_file=data_sample["image"]) |
|
|
| return data |
|
|
| except Exception as e: |
| print( |
| f"Error when reading data_sample:{data_sample},{self.data_path}:{data_sample['image']}: {e}", |
| flush=True, |
| ) |
| return self._retry() |
|
|