XinxuanLu's picture
Initial demo
becf13a verified
from torch.utils.data import Dataset
from PIL import Image
import os
import io
import json
import random
import torch
import numpy as np
from einops import rearrange
from xtuner.registry import BUILDER
from src.datasets.utils import crop2square
from glob import glob
class Text2ImageDataset(Dataset):
def __init__(self,
data_path,
local_folder,
image_size,
unconditional=0.1,
tokenizer=None,
prompt_template=None,
max_length=1024,
crop_image=True,
cap_source='caption',
front_bg_indicator=False,
):
super().__init__()
self.data_path = data_path
self._load_data(data_path)
self.unconditional = unconditional
self.local_folder = local_folder
self.cap_source = cap_source
self.image_size = image_size
self.tokenizer = BUILDER.build(tokenizer)
self.prompt_template = prompt_template
self.max_length = max_length
self.crop_image = crop_image
self.front_bg_indicator = front_bg_indicator
def _load_data(self, data_path):
with open(data_path, 'r') as f:
self.data_list = json.load(f)
print(f"Load {len(self.data_list)} data samples from {data_path}", flush=True)
def __len__(self):
return len(self.data_list)
def _read_image(self, image_file):
image = Image.open(os.path.join(self.local_folder, image_file))
assert image.width > 8 and image.height > 8, f"Image: {image.size}"
assert image.width / image.height > 0.1, f"Image: {image.size}"
assert image.width / image.height < 10, f"Image: {image.size}"
return image
def _process_text(self, text):
if random.uniform(0, 1) < self.unconditional:
prompt = "Generate an image."
else:
if self.front_bg_indicator:
prompt = f"Generate an image: real background, {text.strip()}"
else:
prompt = f"Generate an image: {text.strip()}"
prompt = self.prompt_template['INSTRUCTION'].format(input=prompt)
input_ids = self.tokenizer.encode(prompt, add_special_tokens=True, return_tensors='pt')[0]
return dict(input_ids=input_ids[:self.max_length])
def _process_image(self, image):
data = dict()
if self.crop_image:
image = crop2square(image)
else:
target_size = max(image.size)
image = image.resize(size=(target_size, target_size))
image = image.resize(size=(self.image_size, self.image_size))
pixel_values = torch.from_numpy(np.array(image)).float()
pixel_values = pixel_values / 255
pixel_values = 2 * pixel_values - 1
pixel_values = rearrange(pixel_values, 'h w c -> c h w')
data.update(pixel_values=pixel_values)
return data
def _retry(self):
return self.__getitem__(random.choice(range(self.__len__())))
def __getitem__(self, idx):
try:
data_sample = self.data_list[idx]
image = self._read_image(data_sample['image']).convert('RGB')
caption = data_sample[self.cap_source]
data = self._process_image(image)
data.update(self._process_text(caption))
data.update(type='text2image')
return data
except Exception as e:
print(f"Error when reading {self.data_path}:{self.data_list[idx]}: {e}", flush=True)
return self._retry()
class LargeText2ImageDataset(Text2ImageDataset):
# self.data_list only contains paths of images and captions
def __init__(self, cap_folder=None, *args, **kwargs):
super().__init__(*args, **kwargs)
self.cap_folder = self.local_folder if cap_folder is None else cap_folder
def _load_data(self, data_path): # image path and annotation path are saved in a json file
if data_path.endswith(".json"):
with open(data_path, 'r') as f:
self.data_list = json.load(f)
else:
self.data_list = []
json_files = glob(f'{data_path}/*.json')
for json_file in json_files:
with open(json_file, 'r') as f:
self.data_list += json.load(f)
print(f"Load {len(self.data_list)} data samples from {data_path}", flush=True)
def __getitem__(self, idx):
try:
data_sample = self.data_list[idx]
image = self._read_image(data_sample['image']).convert('RGB')
with open(f"{self.cap_folder}/{data_sample['annotation']}", 'r') as f:
caption = json.load(f)[self.cap_source]
data = self._process_image(image)
data.update(self._process_text(caption))
data.update(type='text2image')
return data
except Exception as e:
print(f"Error when reading {self.data_path}:{data_sample}: {e}", flush=True)
return self._retry()
class BlipO3Dataset(Text2ImageDataset):
def __init__(self,
data_path=None,
cache_dir=None,
*args, **kwargs):
self.data_path = data_path
self.cache_dir = cache_dir
super().__init__(data_path=data_path, *args, **kwargs)
def _load_data(self, data_path):
try:
from datasets import load_dataset
print(f"Loading dataset from {data_path} with cache_dir {self.cache_dir}")
data_files = glob(data_path)
self.dataset = load_dataset("webdataset", data_files=data_files, cache_dir=self.cache_dir, split="train", num_proc=64)
print(f"Loaded {len(self.dataset)} samples from {data_path}")
self.data_list = []
for idx in range(len(self.dataset)):
self.data_list.append({
'idx': idx,
})
except Exception as e:
print(f"Error loading dataset: {e}")
self.data_list = []
print(f"Load {len(self.data_list)} data samples from {data_path}", flush=True)
def __getitem__(self, idx):
try:
data_sample = self.data_list[idx]
original_idx = data_sample['idx']
sample = self.dataset[original_idx]
image_data = sample['jpg']
if isinstance(image_data, dict) and 'bytes' in image_data:
image = Image.open(io.BytesIO(image_data['bytes'])).convert('RGB')
elif hasattr(image_data, 'convert'):
image = image_data.convert('RGB')
elif isinstance(image_data, bytes):
image = Image.open(io.BytesIO(image_data)).convert('RGB')
else:
try:
image = Image.fromarray(np.array(image_data)).convert('RGB')
except Exception:
raise TypeError(f"Unknown type: {type(image_data)}")
caption = sample['txt']
data = self._process_image(image)
data.update(self._process_text(caption))
data.update(type='text2image')
return data
except Exception as e:
print(f"Error when processing index {idx}: {e}", flush=True)
import traceback
traceback.print_exc()
return self._retry()
class MidJourneyDataset(Text2ImageDataset):
def __init__(self,
data_path="brivangl/midjourney-v6-llava",
cache_dir=None,
use_llava=False,
front_bg_indicator=False,
*args, **kwargs):
self.data_path = data_path
self.cache_dir = cache_dir
self.use_llava = use_llava
super().__init__(data_path=data_path, front_bg_indicator=front_bg_indicator, *args, **kwargs)
def _load_data(self, data_path):
try:
from datasets import load_dataset
print(f"Loading dataset from {data_path} with cache_dir {self.cache_dir}")
self.dataset = load_dataset(data_path, cache_dir=self.cache_dir)['train']
print(f"Loaded {len(self.dataset)} samples from {data_path}")
self.data_list = []
for idx in range(len(self.dataset)):
self.data_list.append({
'idx': idx,
})
except Exception as e:
print(f"Error loading dataset: {e}")
self.data_list = []
print(f"Load {len(self.data_list)} data samples from {data_path}", flush=True)
def __getitem__(self, idx):
try:
data_sample = self.data_list[idx]
original_idx = data_sample['idx']
sample = self.dataset[original_idx]
image_data = sample['image']
if isinstance(image_data, dict) and 'bytes' in image_data:
image = Image.open(io.BytesIO(image_data['bytes'])).convert('RGB')
elif hasattr(image_data, 'convert'):
image = image_data.convert('RGB')
elif isinstance(image_data, bytes):
image = Image.open(io.BytesIO(image_data)).convert('RGB')
else:
try:
image = Image.fromarray(np.array(image_data)).convert('RGB')
except Exception:
raise TypeError(f"Unknown type: {type(image_data)}")
if self.use_llava:
caption = sample['llava']
else:
caption = sample['prompt']
data = self._process_image(image)
data.update(self._process_text(caption))
data.update(type='text2image')
return data
except Exception as e:
print(f"Error when processing index {idx}: {e}", flush=True)
import traceback
traceback.print_exc()
return self._retry()
class ReconstructionDataset(Text2ImageDataset):
def __init__(self,
data_path,
image_size,
unconditional=0.1,
tokenizer=None,
prompt_template=None,
max_length=1024,
crop_image=False,
cap_source='caption',
max_samples=None,
use_downscale=False,
cache_dir=None):
self.data_path = data_path
self.unconditional = unconditional
self.local_folder = None
self.cap_source = cap_source
self.image_size = image_size
self.tokenizer = BUILDER.build(tokenizer)
self.prompt_template = prompt_template
self.max_length = max_length
self.crop_image = crop_image
self.max_samples = max_samples
self.use_downscale = use_downscale
self.cache_dir = cache_dir
os.makedirs(self.cache_dir, exist_ok=True)
self._load_data(data_path)
from src.datasets.text2image.consts import get_recon_prompt_list
self.recon_prompts = get_recon_prompt_list()
print(f"Loaded ReconstructionDataset with {len(self.data_list)} samples, {len(self.recon_prompts)} prompts, cache_dir: {self.cache_dir}", flush=True)
def _extract_tar_if_needed(self, tar_path):
import tarfile
import hashlib
tar_hash = hashlib.md5(tar_path.encode()).hexdigest()
extract_dir = os.path.join(self.cache_dir, tar_hash)
lock_file = os.path.join(extract_dir, '.extraction_complete')
if os.path.exists(lock_file):
print(f"Using cached extraction for {tar_path} in {extract_dir}", flush=True)
return extract_dir
print(f"Extracting {tar_path} to {extract_dir}...", flush=True)
os.makedirs(extract_dir, exist_ok=True)
try:
with tarfile.open(tar_path, 'r') as tar:
tar.extractall(path=extract_dir)
with open(lock_file, 'w') as f:
f.write(f"Extracted from {tar_path} at {os.path.getmtime(tar_path)}")
print(f"Extraction complete: {tar_path} -> {extract_dir}", flush=True)
return extract_dir
except Exception as e:
print(f"Error extracting {tar_path}: {e}", flush=True)
raise
def _load_data(self, data_path):
import tarfile
import glob
self.tar_files = glob.glob(os.path.expanduser(data_path.replace('{', '[').replace('}', ']')))
self.data_list = []
self.image_cache_paths = {}
for tar_idx, tar_path in enumerate(self.tar_files):
try:
extract_dir = self._extract_tar_if_needed(tar_path)
with tarfile.open(tar_path, 'r') as tar:
for member in tar.getmembers():
if member.isfile() and member.name.lower().endswith(('.jpg', '.jpeg', '.png', '.webp')):
file_name = member.name
cache_path = os.path.join(extract_dir, file_name)
self.data_list.append({'image': file_name, 'tar_idx': tar_idx})
self.image_cache_paths[file_name] = cache_path
if self.max_samples and len(self.data_list) >= self.max_samples:
break
if self.max_samples and len(self.data_list) >= self.max_samples:
break
except Exception as e:
print(f"Error loading tar file {tar_path}: {e}", flush=True)
print(f"Loaded {len(self.data_list)} images from {len(self.tar_files)} tar files: {self.tar_files}", flush=True)
if len(self.data_list) == 0:
raise RuntimeError(f"No valid images found in tar archives: {data_path}")
def _read_image(self, image_file):
if image_file not in self.image_cache_paths:
raise ValueError(f"Image file {image_file} not found in cache")
cache_path = self.image_cache_paths[image_file]
try:
image = Image.open(cache_path)
assert image.width > 8 and image.height > 8, f"Image too small: {image.size}"
assert image.width / image.height > 0.1, f"Image aspect ratio too extreme: {image.size}"
assert image.width / image.height < 10, f"Image aspect ratio too extreme: {image.size}"
return image
except Exception as e:
raise RuntimeError(f"Error reading image from cache path {cache_path}: {e}")
def _process_text(self, text):
prompt = random.choice(self.recon_prompts)
if random.uniform(0, 1) < self.unconditional:
final_prompt = "Generate an image."
else:
final_prompt = f"\n{prompt}"
final_prompt = self.prompt_template['INSTRUCTION'].format(input=final_prompt)
input_ids = self.tokenizer.encode(final_prompt, add_special_tokens=True, return_tensors='pt')[0]
# print(f"Prompt: {final_prompt}", flush=True)
input_ids = torch.cat([
input_ids[:3],
torch.tensor([-200], dtype=torch.long),
input_ids[3:],
], dim=0)
return dict(input_ids=input_ids[:self.max_length])
def __getitem__(self, idx):
try:
data_sample = self.data_list[idx]
image = self._read_image(data_sample['image']).convert('RGB')
if self.use_downscale:
image = image.resize(size=(self.image_size // 2, self.image_size // 2))
image = image.resize(size=(self.image_size, self.image_size))
data = self._process_image(image)
data.update(self._process_text(""))
data.update(type='recon')
return data
except Exception as e:
print(f"Error when processing index {idx}: {e}", flush=True)
import traceback
traceback.print_exc()
return self._retry()
class MidjourneyReconstructionDataset(Text2ImageDataset):
def __init__(self,
image_size,
data_path="brivangl/midjourney-v6-llava",
cache_dir=None,
unconditional=0.1,
tokenizer=None,
prompt_template=None,
max_length=1024,
crop_image=False,
cap_source='caption',
max_samples=None,
use_downscale=False,
*args, **kwargs):
self.data_path = data_path
self.unconditional = unconditional
self.local_folder = None
self.cap_source = cap_source
self.image_size = image_size
self.tokenizer = BUILDER.build(tokenizer)
self.prompt_template = prompt_template
self.max_length = max_length
self.crop_image = crop_image
self.max_samples = max_samples
self.use_downscale = use_downscale
self.cache_dir = cache_dir
from src.datasets.text2image.consts import get_recon_prompt_list
self.recon_prompts = get_recon_prompt_list()
self._load_data(data_path)
def _load_data(self, data_path):
try:
from datasets import load_dataset
print(f"Loading dataset from {data_path} with cache_dir {self.cache_dir}")
self.dataset = load_dataset(data_path, cache_dir=self.cache_dir)['train']
print(f"Loaded {len(self.dataset)} samples from {data_path}")
self.data_list = []
for idx in range(len(self.dataset)):
self.data_list.append({
'idx': idx,
})
except Exception as e:
print(f"Error loading dataset: {e}")
self.data_list = []
print(f"Load {len(self.data_list)} data samples from {data_path} for reconstruction", flush=True)
def _process_text(self, text):
prompt = random.choice(self.recon_prompts)
if random.uniform(0, 1) < self.unconditional:
final_prompt = "Generate an image."
else:
final_prompt = f"\n{prompt}"
final_prompt = self.prompt_template['INSTRUCTION'].format(input=final_prompt)
input_ids = self.tokenizer.encode(final_prompt, add_special_tokens=True, return_tensors='pt')[0]
input_ids = torch.cat([
input_ids[:3],
torch.tensor([-200], dtype=torch.long),
input_ids[3:],
], dim=0)
return dict(input_ids=input_ids[:self.max_length])
def __getitem__(self, idx):
try:
data_sample = self.data_list[idx]
original_idx = data_sample['idx']
sample = self.dataset[original_idx]
image_data = sample['image']
if isinstance(image_data, dict) and 'bytes' in image_data:
image = Image.open(io.BytesIO(image_data['bytes'])).convert('RGB')
elif hasattr(image_data, 'convert'):
image = image_data.convert('RGB')
elif isinstance(image_data, bytes):
image = Image.open(io.BytesIO(image_data)).convert('RGB')
else:
try:
image = Image.fromarray(np.array(image_data)).convert('RGB')
except Exception:
raise TypeError(f"Unknown type: {type(image_data)}")
data = self._process_image(image)
data.update(self._process_text(""))
data.update(type='recon')
return data
except Exception as e:
print(f"Error when processing index {idx}: {e}", flush=True)
import traceback
traceback.print_exc()
return self._retry()
def __len__(self):
return len(self.data_list) if self.data_list else 0