| from torch.utils.data import Dataset |
| from PIL import Image |
| import os |
| import json |
| import random |
| import torch |
| import numpy as np |
| from einops import rearrange |
| from xtuner.registry import BUILDER |
| from mmengine.registry import DATASETS |
| from src.datasets.utils import crop2square |
| from glob import glob |
| from typing import List, Dict, Any, Optional |
| import mmap |
| import struct |
| from src.datasets.utils import crop2square, encode_fn |
| from xtuner.utils import DEFAULT_IMAGE_TOKEN |
|
|
|
|
| @BUILDER.register_module() |
| class Text2ImageDataset(Dataset): |
| def __init__(self, |
| data_path, |
| local_folder, |
| image_size, |
| unconditional=0.1, |
| tokenizer=None, |
| prompt_template=None, |
| max_length=1024, |
| crop_image=True, |
| cap_source='caption', |
| ): |
| super().__init__() |
| self.data_path = data_path |
| self._load_data(data_path) |
| self.unconditional = unconditional |
| self.local_folder = local_folder |
| self.cap_source = cap_source |
| self.image_size = image_size |
| self.tokenizer = BUILDER.build(tokenizer) |
|
|
| self.prompt_template = prompt_template |
| self.max_length = max_length |
| self.crop_image = crop_image |
| self.metainfo = {'task': 'unified'} |
| self.tokenizer.add_tokens(["<image>"], special_tokens=True) |
|
|
|
|
|
|
| def _load_data(self, data_path): |
| with open(data_path, 'r') as f: |
| self.data_list = json.load(f) |
|
|
| print(f"Load {len(self.data_list)} data samples from {data_path}", flush=True) |
| |
| def full_init(self): |
| """Dummy full_init to be compatible with MMEngine ConcatDataset.""" |
| return |
| |
|
|
| def __len__(self): |
| return len(self.data_list) |
|
|
| def _read_image(self, image_file): |
| image = Image.open(os.path.join(self.local_folder, image_file)) |
| assert image.width > 8 and image.height > 8, f"Image: {image.size}" |
| assert image.width / image.height > 0.1, f"Image: {image.size}" |
| assert image.width / image.height < 10, f"Image: {image.size}" |
| return image |
|
|
| def _process_text(self, text): |
| if random.uniform(0, 1) < self.unconditional: |
| prompt = "Generate an image." |
| else: |
| prompt = f"Generate an image: {text.strip()}" |
| prompt = self.prompt_template['INSTRUCTION'].format(input=prompt) |
| input_ids = self.tokenizer.encode(prompt, add_special_tokens=True, return_tensors='pt')[0] |
|
|
| return dict(input_ids=input_ids[:self.max_length]) |
|
|
| def _process_image(self, image): |
| data = dict() |
|
|
| if self.crop_image: |
| image = crop2square(image) |
| else: |
| target_size = max(image.size) |
| image = image.resize(size=(target_size, target_size)) |
|
|
| image = image.resize(size=(self.image_size, self.image_size)) |
| pixel_values = torch.from_numpy(np.array(image)).float() |
| pixel_values = pixel_values / 255 |
| pixel_values = 2 * pixel_values - 1 |
| pixel_values = rearrange(pixel_values, 'h w c -> c h w') |
|
|
| data.update(pixel_values=pixel_values) |
|
|
| return data |
|
|
| def _retry(self): |
| return self.__getitem__(random.choice(range(self.__len__()))) |
|
|
| def __getitem__(self, idx): |
| try: |
| data_sample = self.data_list[idx] |
| image = self._read_image(data_sample['image']).convert('RGB') |
|
|
| caption = data_sample[self.cap_source] |
| data = self._process_image(image) |
| data.update(self._process_text(caption)) |
| data.update(type='text2image') |
|
|
| return data |
|
|
| except Exception as e: |
| print(f"Error when reading {self.data_path}:{self.data_list[idx]}: {e}", flush=True) |
| return self._retry() |
|
|
| @DATASETS.register_module() |
| @BUILDER.register_module() |
| class LargeText2ImageDataset(Text2ImageDataset): |
| |
|
|
| def __init__(self, cap_folder=None, *args, **kwargs): |
| super().__init__(*args, **kwargs) |
| self.cap_folder = self.local_folder if cap_folder is None else cap_folder |
|
|
| def _load_data(self, data_path): |
| if data_path.endswith(".json"): |
| with open(data_path, 'r') as f: |
| self.data_list = json.load(f) |
| else: |
| self.data_list = [] |
| json_files = glob(f'{data_path}/*.json') |
| for json_file in json_files: |
| with open(json_file, 'r') as f: |
| self.data_list += json.load(f) |
|
|
| print(f"Load {len(self.data_list)} data samples from {data_path}", flush=True) |
|
|
| def __getitem__(self, idx): |
| try: |
| data_sample = self.data_list[idx] |
| image = self._read_image(data_sample['image']).convert('RGB') |
| with open(f"{self.cap_folder}/{data_sample['annotation']}", 'r') as f: |
| caption = json.load(f)[self.cap_source] |
| data = self._process_image(image) |
| data.update(self._process_text(caption)) |
| data.update(type='text2image') |
| return data |
|
|
| except Exception as e: |
| print(f"Error when reading {self.data_path}:{data_sample}: {e}", flush=True) |
| return self._retry() |
|
|
|
|
| @DATASETS.register_module() |
| @BUILDER.register_module() |
| class MMapT2IDataset(Dataset): |
| """ |
| Map-style Text2Image Dataset with mmap-based random access. |
| 一次性在 __init__ 打开 mmap;__getitem__ O(1) 读取指定行。 |
| """ |
| def __init__( |
| self, |
| jsonl_path: str, |
| idx_path: str, |
| image_size: int, |
| tokenizer: Optional[Dict] = None, |
| template_map_fn: Optional[Dict] = None, |
| cap_source: str = "prompt", |
| max_length: int = 2048, |
| image_length: int = 512, |
| unconditional: float = 0.01, |
| crop_image: bool = False, |
| ): |
| super().__init__() |
|
|
| |
| self.jsonl_path = jsonl_path |
| self.image_size = image_size |
| self.cap_source = cap_source |
| self.max_length = max_length |
| self.unconditional = unconditional |
| self.crop_image = crop_image |
|
|
| |
| self.tokenizer = BUILDER.build(tokenizer) |
| self.template_map_fn = template_map_fn |
|
|
| |
| self._open_mmap(jsonl_path, idx_path) |
| self.metainfo = {'task' :'unified'} |
| |
| def _open_mmap(self, jsonl_path: str, idx_path: str): |
| |
| self._jsonl_fp = open(jsonl_path, "r+b") |
| self._mm = mmap.mmap(self._jsonl_fp.fileno(), 0, access=mmap.ACCESS_READ) |
|
|
| |
| with open(idx_path, "rb") as f: |
| nlines = struct.unpack("<Q", f.read(8))[0] |
| self._offsets = np.frombuffer(f.read(8 * nlines), dtype=np.uint64) |
| print(f"[MMapT2IDataset] {jsonl_path}: {nlines} lines indexed") |
|
|
| def __len__(self) -> int: |
| return self._offsets.size |
|
|
| def full_init(self): |
| """Dummy full_init to be compatible with MMEngine ConcatDataset.""" |
| return |
| def _read_line(self, idx: int) -> str: |
| off = int(self._offsets[idx]) |
| self._mm.seek(off) |
| return self._mm.readline().decode("utf-8") |
|
|
| |
| def _load_image(self, path: str) -> torch.Tensor: |
| img = Image.open(path).convert("RGB") |
|
|
| |
| if self.crop_image: |
| img = crop2square(img) |
| else: |
| target_size = max(img.size) |
| img = img.resize((target_size, target_size)) |
|
|
| img = img.resize((self.image_size, self.image_size)) |
| arr = np.asarray(img, dtype=np.uint8) |
| px = torch.as_tensor(arr).float() / 255.0 |
| px = 2 * px - 1 |
| return rearrange(px, "h w c -> c h w") |
|
|
| def _build_prompt(self, caption: str) -> torch.Tensor: |
| if random.random() < self.unconditional: |
| caption = "Generate an image." |
| else: |
| caption = f"Generate an image: {caption.strip()}" |
|
|
| instr = self.template_map_fn["INSTRUCTION"].format(input=caption) |
| ids = self.tokenizer.encode( |
| instr, add_special_tokens=True, return_tensors="pt" |
| )[0][: self.max_length] |
| return ids |
|
|
| def __getitem__(self, idx: int) -> Dict[str, Any]: |
| |
| sample = json.loads(self._read_line(idx)) |
|
|
| |
| pixel_values = self._load_image(sample["image"]) |
|
|
| |
| caption = sample.get(self.cap_source, "") |
| input_ids = self._build_prompt(caption) |
|
|
| |
| data = dict( |
| pixel_values=pixel_values, |
| input_ids=input_ids, |
| type="text2image", |
| image_file=sample["image"], |
| idx=idx, |
| ) |
| return data |
|
|
|
|
| @DATASETS.register_module() |
| @BUILDER.register_module() |
| class ReconstructDataset(Dataset): |
| def __init__(self, |
| data_path: str, |
| image_size: int, |
| tokenizer=None, |
| prompt_template=None, |
| cap_source: str = "prompt", |
| max_length: int = 8192, |
| crop_image: bool = True, |
| img_prefix: str = ""): |
| super().__init__() |
| self.image_size = image_size |
| self.tokenizer = BUILDER.build(tokenizer) |
| self.tokenizer.add_tokens(["<image>"], special_tokens=True) |
| self.prompt_template = prompt_template |
| self.cap_source = cap_source |
| self.max_length = max_length |
| self.crop_image = crop_image |
| self.img_prefix = img_prefix |
| self._load_data(data_path) |
|
|
| m = n = self.image_size // 16 |
| self.image_token_repeat = m * n + 64 |
| self.metainfo = {'task': 'unified'} |
|
|
| def full_init(self): |
| """Dummy full_init to be compatible with MMEngine ConcatDataset.""" |
| return |
|
|
| def _load_data(self, path): |
| with open(path) as f: |
| self.data_list = [json.loads(l) for l in f] |
| print(f"[I2ICaptionReconstructDataset] Loaded {len(self.data_list)} samples from {path}") |
|
|
| def _add_prefix(self, rel): |
| return os.path.join(self.img_prefix, rel.lstrip("/")) if self.img_prefix else rel |
|
|
| def _read_image(self, path): |
| img = Image.open(path).convert("RGB") |
| assert img.width > 8 and img.height > 8 and 0.1 < img.width / img.height < 10 |
| return img |
|
|
| |
| def _process_image(self, img): |
| img = crop2square(img) if self.crop_image else img.resize((max(img.size),)*2) |
| img = img.resize((self.image_size, self.image_size)) |
| px = torch.from_numpy(np.array(img)).float() / 255. |
| px = 2 * px - 1 |
| return rearrange(px, "h w c -> c h w") |
|
|
| def _encode_prompt(self, text): |
| |
| |
| text = "Repeat this image." |
| prompt_in = f"<image>\n{text.strip()}" |
| prompt = self.prompt_template["INSTRUCTION"].format(input=prompt_in) |
| prompt = prompt.replace("<image>", "<image>" * self.image_token_repeat) |
| input_ids = self.tokenizer.encode(prompt, add_special_tokens=True, return_tensors="pt")[0] |
| mask = (input_ids != self.tokenizer.pad_token_id).long() |
| return input_ids[:self.max_length], mask[:self.max_length] |
|
|
| def __len__(self): |
| return len(self.data_list) |
|
|
| def _retry(self): |
| return self.__getitem__(random.randrange(len(self))) |
|
|
| def __getitem__(self, idx): |
| try: |
| sample = self.data_list[idx] |
| src_img = self._read_image(self._add_prefix(sample["image"])) |
| tgt_img = src_img |
| caption = sample[self.cap_source] |
|
|
| px_src = self._process_image(src_img) |
| px_tgt = self._process_image(tgt_img) |
| input_ids, mask = self._encode_prompt(caption) |
|
|
| return { |
| "pixel_values_src": px_src, |
| "pixel_values": px_tgt, |
| "input_ids": input_ids, |
| "attention_mask": mask, |
| "type": "image_edit" |
| } |
| except Exception as e: |
| print(f"[I2ICaptionReconstructDataset] Error @ {idx}: {e}") |
| return self._retry() |
|
|
| @DATASETS.register_module() |
| @BUILDER.register_module() |
| class UncondReconstructDataset(Dataset): |
| def __init__(self, |
| data_path: str, |
| image_size: int, |
| tokenizer=None, |
| prompt_template=None, |
| cap_source: str = "prompt", |
| max_length: int = 8192, |
| crop_image: bool = True, |
| img_prefix: str = ""): |
| super().__init__() |
| self.image_size = image_size |
| self.tokenizer = BUILDER.build(tokenizer) |
| self.tokenizer.add_tokens(["<image>"], special_tokens=True) |
| self.prompt_template = prompt_template |
| self.max_length = max_length |
| self.crop_image = crop_image |
| self.img_prefix = img_prefix |
| self.cap_source = cap_source |
|
|
|
|
| self._load_data(data_path) |
|
|
| |
| m = n = self.image_size // 16 |
| self.image_token_repeat = m * n + 64 |
| self.metainfo = {'task': 'unified'} |
| |
| def _load_data(self, path): |
| with open(path) as f: |
| self.data_list = [json.loads(l) for l in f] |
| print(f"[I2IUncondReconstructDataset] Loaded {len(self.data_list)} samples from {path}") |
|
|
| def _add_prefix(self, rel_path): |
| return os.path.join(self.img_prefix, rel_path.lstrip("/")) if self.img_prefix else rel_path |
|
|
| def full_init(self): |
| """Dummy full_init to be compatible with MMEngine ConcatDataset.""" |
| return |
| def _read_image(self, path): |
| image = Image.open(path).convert("RGB") |
| assert image.width > 8 and image.height > 8 and 0.1 < image.width / image.height < 10 |
| return image |
|
|
|
|
| |
| def _process_image(self, img): |
| img = crop2square(img) if self.crop_image else img.resize((max(img.size),)*2) |
| img = img.resize((self.image_size, self.image_size)) |
| px = torch.from_numpy(np.array(img)).float() / 255. |
| px = 2 * px - 1 |
| return rearrange(px, "h w c -> c h w") |
|
|
| def __len__(self): |
| return len(self.data_list) |
|
|
| def _retry(self, max_tries=5): |
| for _ in range(max_tries): |
| try: |
| return self.__getitem__(random.randrange(len(self))) |
| except Exception: |
| continue |
| raise RuntimeError("Exceeded max retries in I2IUncondReconstructDataset") |
|
|
| def __getitem__(self, idx): |
| try: |
| sample = self.data_list[idx] |
| path = self._add_prefix(sample["image"]) |
| img = self._read_image(path) |
| px = self._process_image(img) |
|
|
| |
| input_ids = torch.zeros(0, dtype=torch.long) |
| attention_mask = torch.zeros(0, dtype=torch.long) |
|
|
| return { |
| "pixel_values_src": px, |
| "pixel_values": px.clone(), |
| "type": "image_edit", |
| "input_ids": input_ids, |
| "attention_mask": attention_mask, |
| |
| } |
| except Exception as e: |
| print(f"[I2IUncondReconstructDataset] Error @ {idx}: {e}") |
| return self._retry() |
|
|
|
|
|
|
| @DATASETS.register_module() |
| @BUILDER.register_module() |
| class Text2ImageJSONLDataset(Dataset): |
| def __init__(self, |
| data_path, |
| image_size, |
| tokenizer=None, |
| prompt_template=None, |
| cap_source='prompt', |
| max_length=1024, |
| unconditional=0.1, |
| crop_image=True, |
| ): |
| super().__init__() |
| self.data_path = data_path |
| self._load_data(data_path) |
| self.image_size = image_size |
| self.tokenizer = BUILDER.build(tokenizer) |
| self.tokenizer.add_tokens(["<image>"], special_tokens=True) |
| self.prompt_template = prompt_template |
| self.cap_source = cap_source |
| self.max_length = max_length |
| self.unconditional = unconditional |
| self.crop_image = crop_image |
| self.metainfo = {'task': 'unified'} |
|
|
| def _load_data(self, data_path): |
| self.data_list = [] |
| with open(data_path, 'r') as f: |
| for line in f: |
| self.data_list.append(json.loads(line.strip())) |
| print(f"Loaded {len(self.data_list)} samples from {data_path}") |
|
|
| def full_init(self): |
| """Dummy full_init for MMEngine ConcatDataset compatibility.""" |
| pass |
| def __len__(self): |
| return len(self.data_list) |
|
|
| def _read_image(self, image_file): |
| image = Image.open(image_file).convert('RGB') |
| assert image.width > 8 and image.height > 8 |
| assert 0.1 < image.width / image.height < 10 |
| return image |
|
|
| def _process_image(self, image): |
| if self.crop_image: |
| image = crop2square(image) |
| else: |
| target_size = max(image.size) |
| image = image.resize((target_size, target_size)) |
|
|
| image = image.resize((self.image_size, self.image_size)) |
| pixel_values = torch.from_numpy(np.array(image)).float() / 255.0 |
| pixel_values = 2 * pixel_values - 1 |
| pixel_values = rearrange(pixel_values, 'h w c -> c h w') |
| return dict(pixel_values=pixel_values) |
|
|
| def _process_text(self, text): |
| if random.uniform(0, 1) < self.unconditional: |
| text = "Generate an image." |
| else: |
| text = f"Generate an image: {text.strip()}" |
| prompt = self.prompt_template['INSTRUCTION'].format(input=text) |
| input_ids = self.tokenizer.encode(prompt, add_special_tokens=True, return_tensors='pt')[0] |
| return dict(input_ids=input_ids[:self.max_length]) |
|
|
| def _retry(self): |
| return self.__getitem__(random.randint(0, len(self.data_list) - 1)) |
|
|
| def __getitem__(self, idx): |
| try: |
| sample = self.data_list[idx] |
| image = self._read_image(sample['image']) |
| caption = sample[self.cap_source] |
| data = self._process_image(image) |
| data.update(self._process_text(caption)) |
| data.update(type='text2image') |
| return data |
| except Exception as e: |
| print(f"[JSONLDataset] Error reading sample #{idx}: {e}") |
| return self._retry() |
|
|
|
|
|
|
| |
| @DATASETS.register_module() |
| @BUILDER.register_module() |
| class ImageEditJSONLDataset(Dataset): |
| """ |
| Dataset for <src, tgt, prompt> image editing, now decoupled from tokenization logic. |
| """ |
| def __init__(self, |
| data_path: str, |
| image_size: int, |
| tokenizer=None, |
| prompt_template=None, |
| max_length: int = 8192, |
| cap_source: str = "prompt", |
| unconditional: float = 0, |
| crop_image: bool = False, |
| img_prefix: str = ""): |
| super().__init__() |
| self.data_path = data_path |
| self.image_size = image_size |
| self.tokenizer = BUILDER.build(tokenizer) |
| self.prompt_template = prompt_template |
| self.max_length = max_length |
| self.cap_source = cap_source |
| self.unconditional = unconditional |
| self.crop_image = crop_image |
| self.img_prefix = img_prefix |
| self._load_data(data_path) |
| |
| m = n = self.image_size // 16 |
| self.image_token_repeat = m * n + 64 |
| self.metainfo = {'task': 'unified'} |
|
|
| self.tokenizer.add_tokens(["<image>"], special_tokens=True) |
| self.image_token_idx = self.tokenizer.convert_tokens_to_ids("<image>") |
| print(f"Registered <image> token at index {self.image_token_idx}") |
|
|
| def _load_data(self, path): |
| with open(path) as f: |
| self.data_list = [json.loads(l) for l in f] |
| print(f"[ImageEditJSONLDataset] Loaded {len(self.data_list)} samples from {path}") |
|
|
| def full_init(self): |
| """Dummy full_init for MMEngine ConcatDataset compatibility.""" |
| pass |
|
|
| def _add_prefix(self, rel_path): |
| return os.path.join(self.img_prefix, rel_path.lstrip("/")) if self.img_prefix else rel_path |
|
|
| def _read_image(self, path): |
| path = path.replace("datasets_vlm02", "datasets_vlm") |
| img = Image.open(path).convert("RGB") |
| assert img.width > 8 and img.height > 8 and 0.1 < img.width / img.height < 10 |
| return img |
|
|
| def _process_image(self, img): |
| img = crop2square(img) if self.crop_image else img.resize((max(img.size),) * 2) |
| img = img.resize((self.image_size, self.image_size)) |
| px = torch.from_numpy(np.array(img)).float() / 255. |
| px = 2 * px - 1 |
| return rearrange(px, "h w c -> c h w") |
|
|
| |
| def _prepare_prompt_text(self, raw_text: str): |
| """Cleans text and handles unconditional generation.""" |
|
|
| for bad_token in ["[IMAGE]", "<image_placeholder>", "<image_plaeholder>", "<image>"]: |
| txt = raw_text.replace(bad_token, "") |
| txt = txt.strip() |
|
|
| if random.random() < self.unconditional: |
| txt = "Edit this image." |
| return txt |
|
|
| def _retry(self): |
| return self.__getitem__(random.randrange(len(self))) |
|
|
| def __len__(self): |
| return len(self.data_list) |
|
|
| def __getitem__(self, idx): |
| try: |
| sample = self.data_list[idx] |
| src_path, tgt_path = map(self._add_prefix, [sample["images"][0], sample["image"]]) |
| src_img, tgt_img = map(self._read_image, [src_path, tgt_path]) |
|
|
| px_src, px_tgt = map(self._process_image, [src_img, tgt_img]) |
| |
| |
| |
| prompt_text = self._prepare_prompt_text(sample[self.cap_source]) |
|
|
| |
| encoded_text = encode_fn( |
| example=prompt_text, |
| tokenizer=self.tokenizer, |
| prompt_template=self.prompt_template, |
| max_length=self.max_length, |
| image_length=self.image_token_repeat, |
| image_token_idx=self.image_token_idx |
| ) |
|
|
| return { |
| "pixel_values_src": px_src, |
| "pixel_values": px_tgt, |
| "input_ids": torch.tensor(encoded_text["input_ids"], dtype=torch.long), |
| "attention_mask": torch.tensor(encoded_text["attention_mask"], dtype=torch.long), |
| "type": "image_edit", |
| } |
| except Exception as e: |
| print(f"[ImageEditJSONLDataset] Error @ {idx}: {e} from {self.data_path}") |
| return self._retry() |
|
|
|
|
|
|
|
|