Spaces:
Running on Zero
Running on Zero
| from torch.utils.data import Dataset | |
| from PIL import Image | |
| import os | |
| import io | |
| import json | |
| import random | |
| import torch | |
| import numpy as np | |
| from einops import rearrange | |
| from xtuner.registry import BUILDER | |
| from src.datasets.utils import crop2square | |
| from glob import glob | |
| class Text2ImageDataset(Dataset): | |
| def __init__(self, | |
| data_path, | |
| local_folder, | |
| image_size, | |
| unconditional=0.1, | |
| tokenizer=None, | |
| prompt_template=None, | |
| max_length=1024, | |
| crop_image=True, | |
| cap_source='caption', | |
| front_bg_indicator=False, | |
| ): | |
| super().__init__() | |
| self.data_path = data_path | |
| self._load_data(data_path) | |
| self.unconditional = unconditional | |
| self.local_folder = local_folder | |
| self.cap_source = cap_source | |
| self.image_size = image_size | |
| self.tokenizer = BUILDER.build(tokenizer) | |
| self.prompt_template = prompt_template | |
| self.max_length = max_length | |
| self.crop_image = crop_image | |
| self.front_bg_indicator = front_bg_indicator | |
| def _load_data(self, data_path): | |
| with open(data_path, 'r') as f: | |
| self.data_list = json.load(f) | |
| print(f"Load {len(self.data_list)} data samples from {data_path}", flush=True) | |
| def __len__(self): | |
| return len(self.data_list) | |
| def _read_image(self, image_file): | |
| image = Image.open(os.path.join(self.local_folder, image_file)) | |
| assert image.width > 8 and image.height > 8, f"Image: {image.size}" | |
| assert image.width / image.height > 0.1, f"Image: {image.size}" | |
| assert image.width / image.height < 10, f"Image: {image.size}" | |
| return image | |
| def _process_text(self, text): | |
| if random.uniform(0, 1) < self.unconditional: | |
| prompt = "Generate an image." | |
| else: | |
| if self.front_bg_indicator: | |
| prompt = f"Generate an image: real background, {text.strip()}" | |
| else: | |
| prompt = f"Generate an image: {text.strip()}" | |
| prompt = self.prompt_template['INSTRUCTION'].format(input=prompt) | |
| input_ids = self.tokenizer.encode(prompt, add_special_tokens=True, return_tensors='pt')[0] | |
| return dict(input_ids=input_ids[:self.max_length]) | |
| def _process_image(self, image): | |
| data = dict() | |
| if self.crop_image: | |
| image = crop2square(image) | |
| else: | |
| target_size = max(image.size) | |
| image = image.resize(size=(target_size, target_size)) | |
| image = image.resize(size=(self.image_size, self.image_size)) | |
| pixel_values = torch.from_numpy(np.array(image)).float() | |
| pixel_values = pixel_values / 255 | |
| pixel_values = 2 * pixel_values - 1 | |
| pixel_values = rearrange(pixel_values, 'h w c -> c h w') | |
| data.update(pixel_values=pixel_values) | |
| return data | |
| def _retry(self): | |
| return self.__getitem__(random.choice(range(self.__len__()))) | |
| def __getitem__(self, idx): | |
| try: | |
| data_sample = self.data_list[idx] | |
| image = self._read_image(data_sample['image']).convert('RGB') | |
| caption = data_sample[self.cap_source] | |
| data = self._process_image(image) | |
| data.update(self._process_text(caption)) | |
| data.update(type='text2image') | |
| return data | |
| except Exception as e: | |
| print(f"Error when reading {self.data_path}:{self.data_list[idx]}: {e}", flush=True) | |
| return self._retry() | |
| class LargeText2ImageDataset(Text2ImageDataset): | |
| # self.data_list only contains paths of images and captions | |
| def __init__(self, cap_folder=None, *args, **kwargs): | |
| super().__init__(*args, **kwargs) | |
| self.cap_folder = self.local_folder if cap_folder is None else cap_folder | |
| def _load_data(self, data_path): # image path and annotation path are saved in a json file | |
| if data_path.endswith(".json"): | |
| with open(data_path, 'r') as f: | |
| self.data_list = json.load(f) | |
| else: | |
| self.data_list = [] | |
| json_files = glob(f'{data_path}/*.json') | |
| for json_file in json_files: | |
| with open(json_file, 'r') as f: | |
| self.data_list += json.load(f) | |
| print(f"Load {len(self.data_list)} data samples from {data_path}", flush=True) | |
| def __getitem__(self, idx): | |
| try: | |
| data_sample = self.data_list[idx] | |
| image = self._read_image(data_sample['image']).convert('RGB') | |
| with open(f"{self.cap_folder}/{data_sample['annotation']}", 'r') as f: | |
| caption = json.load(f)[self.cap_source] | |
| data = self._process_image(image) | |
| data.update(self._process_text(caption)) | |
| data.update(type='text2image') | |
| return data | |
| except Exception as e: | |
| print(f"Error when reading {self.data_path}:{data_sample}: {e}", flush=True) | |
| return self._retry() | |
| class BlipO3Dataset(Text2ImageDataset): | |
| def __init__(self, | |
| data_path=None, | |
| cache_dir=None, | |
| *args, **kwargs): | |
| self.data_path = data_path | |
| self.cache_dir = cache_dir | |
| super().__init__(data_path=data_path, *args, **kwargs) | |
| def _load_data(self, data_path): | |
| try: | |
| from datasets import load_dataset | |
| print(f"Loading dataset from {data_path} with cache_dir {self.cache_dir}") | |
| data_files = glob(data_path) | |
| self.dataset = load_dataset("webdataset", data_files=data_files, cache_dir=self.cache_dir, split="train", num_proc=64) | |
| print(f"Loaded {len(self.dataset)} samples from {data_path}") | |
| self.data_list = [] | |
| for idx in range(len(self.dataset)): | |
| self.data_list.append({ | |
| 'idx': idx, | |
| }) | |
| except Exception as e: | |
| print(f"Error loading dataset: {e}") | |
| self.data_list = [] | |
| print(f"Load {len(self.data_list)} data samples from {data_path}", flush=True) | |
| def __getitem__(self, idx): | |
| try: | |
| data_sample = self.data_list[idx] | |
| original_idx = data_sample['idx'] | |
| sample = self.dataset[original_idx] | |
| image_data = sample['jpg'] | |
| if isinstance(image_data, dict) and 'bytes' in image_data: | |
| image = Image.open(io.BytesIO(image_data['bytes'])).convert('RGB') | |
| elif hasattr(image_data, 'convert'): | |
| image = image_data.convert('RGB') | |
| elif isinstance(image_data, bytes): | |
| image = Image.open(io.BytesIO(image_data)).convert('RGB') | |
| else: | |
| try: | |
| image = Image.fromarray(np.array(image_data)).convert('RGB') | |
| except Exception: | |
| raise TypeError(f"Unknown type: {type(image_data)}") | |
| caption = sample['txt'] | |
| data = self._process_image(image) | |
| data.update(self._process_text(caption)) | |
| data.update(type='text2image') | |
| return data | |
| except Exception as e: | |
| print(f"Error when processing index {idx}: {e}", flush=True) | |
| import traceback | |
| traceback.print_exc() | |
| return self._retry() | |
| class MidJourneyDataset(Text2ImageDataset): | |
| def __init__(self, | |
| data_path="brivangl/midjourney-v6-llava", | |
| cache_dir=None, | |
| use_llava=False, | |
| front_bg_indicator=False, | |
| *args, **kwargs): | |
| self.data_path = data_path | |
| self.cache_dir = cache_dir | |
| self.use_llava = use_llava | |
| super().__init__(data_path=data_path, front_bg_indicator=front_bg_indicator, *args, **kwargs) | |
| def _load_data(self, data_path): | |
| try: | |
| from datasets import load_dataset | |
| print(f"Loading dataset from {data_path} with cache_dir {self.cache_dir}") | |
| self.dataset = load_dataset(data_path, cache_dir=self.cache_dir)['train'] | |
| print(f"Loaded {len(self.dataset)} samples from {data_path}") | |
| self.data_list = [] | |
| for idx in range(len(self.dataset)): | |
| self.data_list.append({ | |
| 'idx': idx, | |
| }) | |
| except Exception as e: | |
| print(f"Error loading dataset: {e}") | |
| self.data_list = [] | |
| print(f"Load {len(self.data_list)} data samples from {data_path}", flush=True) | |
| def __getitem__(self, idx): | |
| try: | |
| data_sample = self.data_list[idx] | |
| original_idx = data_sample['idx'] | |
| sample = self.dataset[original_idx] | |
| image_data = sample['image'] | |
| if isinstance(image_data, dict) and 'bytes' in image_data: | |
| image = Image.open(io.BytesIO(image_data['bytes'])).convert('RGB') | |
| elif hasattr(image_data, 'convert'): | |
| image = image_data.convert('RGB') | |
| elif isinstance(image_data, bytes): | |
| image = Image.open(io.BytesIO(image_data)).convert('RGB') | |
| else: | |
| try: | |
| image = Image.fromarray(np.array(image_data)).convert('RGB') | |
| except Exception: | |
| raise TypeError(f"Unknown type: {type(image_data)}") | |
| if self.use_llava: | |
| caption = sample['llava'] | |
| else: | |
| caption = sample['prompt'] | |
| data = self._process_image(image) | |
| data.update(self._process_text(caption)) | |
| data.update(type='text2image') | |
| return data | |
| except Exception as e: | |
| print(f"Error when processing index {idx}: {e}", flush=True) | |
| import traceback | |
| traceback.print_exc() | |
| return self._retry() | |
| class ReconstructionDataset(Text2ImageDataset): | |
| def __init__(self, | |
| data_path, | |
| image_size, | |
| unconditional=0.1, | |
| tokenizer=None, | |
| prompt_template=None, | |
| max_length=1024, | |
| crop_image=False, | |
| cap_source='caption', | |
| max_samples=None, | |
| use_downscale=False, | |
| cache_dir=None): | |
| self.data_path = data_path | |
| self.unconditional = unconditional | |
| self.local_folder = None | |
| self.cap_source = cap_source | |
| self.image_size = image_size | |
| self.tokenizer = BUILDER.build(tokenizer) | |
| self.prompt_template = prompt_template | |
| self.max_length = max_length | |
| self.crop_image = crop_image | |
| self.max_samples = max_samples | |
| self.use_downscale = use_downscale | |
| self.cache_dir = cache_dir | |
| os.makedirs(self.cache_dir, exist_ok=True) | |
| self._load_data(data_path) | |
| from src.datasets.text2image.consts import get_recon_prompt_list | |
| self.recon_prompts = get_recon_prompt_list() | |
| print(f"Loaded ReconstructionDataset with {len(self.data_list)} samples, {len(self.recon_prompts)} prompts, cache_dir: {self.cache_dir}", flush=True) | |
| def _extract_tar_if_needed(self, tar_path): | |
| import tarfile | |
| import hashlib | |
| tar_hash = hashlib.md5(tar_path.encode()).hexdigest() | |
| extract_dir = os.path.join(self.cache_dir, tar_hash) | |
| lock_file = os.path.join(extract_dir, '.extraction_complete') | |
| if os.path.exists(lock_file): | |
| print(f"Using cached extraction for {tar_path} in {extract_dir}", flush=True) | |
| return extract_dir | |
| print(f"Extracting {tar_path} to {extract_dir}...", flush=True) | |
| os.makedirs(extract_dir, exist_ok=True) | |
| try: | |
| with tarfile.open(tar_path, 'r') as tar: | |
| tar.extractall(path=extract_dir) | |
| with open(lock_file, 'w') as f: | |
| f.write(f"Extracted from {tar_path} at {os.path.getmtime(tar_path)}") | |
| print(f"Extraction complete: {tar_path} -> {extract_dir}", flush=True) | |
| return extract_dir | |
| except Exception as e: | |
| print(f"Error extracting {tar_path}: {e}", flush=True) | |
| raise | |
| def _load_data(self, data_path): | |
| import tarfile | |
| import glob | |
| self.tar_files = glob.glob(os.path.expanduser(data_path.replace('{', '[').replace('}', ']'))) | |
| self.data_list = [] | |
| self.image_cache_paths = {} | |
| for tar_idx, tar_path in enumerate(self.tar_files): | |
| try: | |
| extract_dir = self._extract_tar_if_needed(tar_path) | |
| with tarfile.open(tar_path, 'r') as tar: | |
| for member in tar.getmembers(): | |
| if member.isfile() and member.name.lower().endswith(('.jpg', '.jpeg', '.png', '.webp')): | |
| file_name = member.name | |
| cache_path = os.path.join(extract_dir, file_name) | |
| self.data_list.append({'image': file_name, 'tar_idx': tar_idx}) | |
| self.image_cache_paths[file_name] = cache_path | |
| if self.max_samples and len(self.data_list) >= self.max_samples: | |
| break | |
| if self.max_samples and len(self.data_list) >= self.max_samples: | |
| break | |
| except Exception as e: | |
| print(f"Error loading tar file {tar_path}: {e}", flush=True) | |
| print(f"Loaded {len(self.data_list)} images from {len(self.tar_files)} tar files: {self.tar_files}", flush=True) | |
| if len(self.data_list) == 0: | |
| raise RuntimeError(f"No valid images found in tar archives: {data_path}") | |
| def _read_image(self, image_file): | |
| if image_file not in self.image_cache_paths: | |
| raise ValueError(f"Image file {image_file} not found in cache") | |
| cache_path = self.image_cache_paths[image_file] | |
| try: | |
| image = Image.open(cache_path) | |
| assert image.width > 8 and image.height > 8, f"Image too small: {image.size}" | |
| assert image.width / image.height > 0.1, f"Image aspect ratio too extreme: {image.size}" | |
| assert image.width / image.height < 10, f"Image aspect ratio too extreme: {image.size}" | |
| return image | |
| except Exception as e: | |
| raise RuntimeError(f"Error reading image from cache path {cache_path}: {e}") | |
| def _process_text(self, text): | |
| prompt = random.choice(self.recon_prompts) | |
| if random.uniform(0, 1) < self.unconditional: | |
| final_prompt = "Generate an image." | |
| else: | |
| final_prompt = f"\n{prompt}" | |
| final_prompt = self.prompt_template['INSTRUCTION'].format(input=final_prompt) | |
| input_ids = self.tokenizer.encode(final_prompt, add_special_tokens=True, return_tensors='pt')[0] | |
| # print(f"Prompt: {final_prompt}", flush=True) | |
| input_ids = torch.cat([ | |
| input_ids[:3], | |
| torch.tensor([-200], dtype=torch.long), | |
| input_ids[3:], | |
| ], dim=0) | |
| return dict(input_ids=input_ids[:self.max_length]) | |
| def __getitem__(self, idx): | |
| try: | |
| data_sample = self.data_list[idx] | |
| image = self._read_image(data_sample['image']).convert('RGB') | |
| if self.use_downscale: | |
| image = image.resize(size=(self.image_size // 2, self.image_size // 2)) | |
| image = image.resize(size=(self.image_size, self.image_size)) | |
| data = self._process_image(image) | |
| data.update(self._process_text("")) | |
| data.update(type='recon') | |
| return data | |
| except Exception as e: | |
| print(f"Error when processing index {idx}: {e}", flush=True) | |
| import traceback | |
| traceback.print_exc() | |
| return self._retry() | |
| class MidjourneyReconstructionDataset(Text2ImageDataset): | |
| def __init__(self, | |
| image_size, | |
| data_path="brivangl/midjourney-v6-llava", | |
| cache_dir=None, | |
| unconditional=0.1, | |
| tokenizer=None, | |
| prompt_template=None, | |
| max_length=1024, | |
| crop_image=False, | |
| cap_source='caption', | |
| max_samples=None, | |
| use_downscale=False, | |
| *args, **kwargs): | |
| self.data_path = data_path | |
| self.unconditional = unconditional | |
| self.local_folder = None | |
| self.cap_source = cap_source | |
| self.image_size = image_size | |
| self.tokenizer = BUILDER.build(tokenizer) | |
| self.prompt_template = prompt_template | |
| self.max_length = max_length | |
| self.crop_image = crop_image | |
| self.max_samples = max_samples | |
| self.use_downscale = use_downscale | |
| self.cache_dir = cache_dir | |
| from src.datasets.text2image.consts import get_recon_prompt_list | |
| self.recon_prompts = get_recon_prompt_list() | |
| self._load_data(data_path) | |
| def _load_data(self, data_path): | |
| try: | |
| from datasets import load_dataset | |
| print(f"Loading dataset from {data_path} with cache_dir {self.cache_dir}") | |
| self.dataset = load_dataset(data_path, cache_dir=self.cache_dir)['train'] | |
| print(f"Loaded {len(self.dataset)} samples from {data_path}") | |
| self.data_list = [] | |
| for idx in range(len(self.dataset)): | |
| self.data_list.append({ | |
| 'idx': idx, | |
| }) | |
| except Exception as e: | |
| print(f"Error loading dataset: {e}") | |
| self.data_list = [] | |
| print(f"Load {len(self.data_list)} data samples from {data_path} for reconstruction", flush=True) | |
| def _process_text(self, text): | |
| prompt = random.choice(self.recon_prompts) | |
| if random.uniform(0, 1) < self.unconditional: | |
| final_prompt = "Generate an image." | |
| else: | |
| final_prompt = f"\n{prompt}" | |
| final_prompt = self.prompt_template['INSTRUCTION'].format(input=final_prompt) | |
| input_ids = self.tokenizer.encode(final_prompt, add_special_tokens=True, return_tensors='pt')[0] | |
| input_ids = torch.cat([ | |
| input_ids[:3], | |
| torch.tensor([-200], dtype=torch.long), | |
| input_ids[3:], | |
| ], dim=0) | |
| return dict(input_ids=input_ids[:self.max_length]) | |
| def __getitem__(self, idx): | |
| try: | |
| data_sample = self.data_list[idx] | |
| original_idx = data_sample['idx'] | |
| sample = self.dataset[original_idx] | |
| image_data = sample['image'] | |
| if isinstance(image_data, dict) and 'bytes' in image_data: | |
| image = Image.open(io.BytesIO(image_data['bytes'])).convert('RGB') | |
| elif hasattr(image_data, 'convert'): | |
| image = image_data.convert('RGB') | |
| elif isinstance(image_data, bytes): | |
| image = Image.open(io.BytesIO(image_data)).convert('RGB') | |
| else: | |
| try: | |
| image = Image.fromarray(np.array(image_data)).convert('RGB') | |
| except Exception: | |
| raise TypeError(f"Unknown type: {type(image_data)}") | |
| data = self._process_image(image) | |
| data.update(self._process_text("")) | |
| data.update(type='recon') | |
| return data | |
| except Exception as e: | |
| print(f"Error when processing index {idx}: {e}", flush=True) | |
| import traceback | |
| traceback.print_exc() | |
| return self._retry() | |
| def __len__(self): | |
| return len(self.data_list) if self.data_list else 0 |