DeepGen_Test / src /datasets /image2image /edit_datasets.py
TienVu2204's picture
upload file
ed8f267
Raw
History Blame Contribute Delete
3.75 kB
import torch
import numpy as np
from einops import rearrange
from src.datasets.utils import crop2square
from src.datasets.text2image.caption_datasets import CaptionDataset
from PIL import Image
import os
class ImageEditDataset(CaptionDataset):
def _process_image(self, image):
assert self.image_process != 'crop2square'
return super()._process_image(image)['pixel_values']
# image = image.resize(size=(self.image_size, self.image_size))
# pixel_values = torch.from_numpy(np.array(image)).float()
# pixel_values = pixel_values / 255
# pixel_values = 2 * pixel_values - 1
# pixel_values = rearrange(pixel_values, 'h w c -> c h w')
# return pixel_values
def _process_text(self, text):
prompt_template = self.prompt_template
image_tokens = prompt_template['IMG_START_TOKEN'] + \
prompt_template['IMG_CONTEXT_TOKEN'] * self.image_length + \
prompt_template['IMG_END_TOKEN']
prompt = f'{image_tokens}\n{text}'
prompt = self.prompt_template['INSTRUCTION'].format(input=prompt)
if self.prompt_template.get('IMG_START_TOKEN_FOR_GENERATION', True):
prompt += prompt_template['IMG_START_TOKEN']
input_ids = self.tokenizer.encode(prompt, return_tensors='pt', **self.tokenizer_kwargs)[0]
return dict(input_ids=input_ids)
def __getitem__(self, idx):
if self.debug:
idx = 0
try:
data_sample = self.data_list[idx]
if self.image_folder is not None:
source_image = Image.open(os.path.join(self.image_folder,data_sample['input_image'][0])).convert('RGB')
target_image = Image.open(os.path.join(self.image_folder,data_sample['output_image'])).convert('RGB')
else:
source_image = Image.open(data_sample['input_image'][0]).convert('RGB')
target_image = Image.open(data_sample['output_image']).convert('RGB')
# prompt = self._read_json(data_sample['annotation'])[self.cap_source]
prompt = data_sample['instruction']
pixel_values_src = self._process_image(source_image)
pixel_values = self._process_image(target_image)
data = self._process_text(prompt) if self.tokenizer is not None else dict()
data.update(
pixel_values_src=pixel_values_src, pixel_values=pixel_values,
image_dir=self.image_folder,type='image2image', text=prompt)
return data
except Exception as e:
print(f"Error when reading {self.data_path}:{self.data_list[idx]}: {e}", flush=True)
return self._retry()
class ReconstructDataset(CaptionDataset):
def _process_image(self, image):
assert self.image_process != 'crop2square'
return super()._process_image(image)['pixel_values']
def __getitem__(self, idx):
if self.debug:
idx = 0
try:
data_sample = self.data_list[idx]
image = self._read_image(data_sample['image']).convert('RGB')
prompt = "Keep the image as it is."
pixel_values = pixel_values_src = self._process_image(image)
data = self._process_text(prompt) if self.tokenizer is not None else dict()
data.update(
pixel_values_src=pixel_values_src, pixel_values=pixel_values,
image_dir=self.image_folder, image_file=data_sample['image'],
type='image2image', text=prompt)
return data
except Exception as e:
print(f"Error when reading {self.data_path}:{self.data_list[idx]}: {e}", flush=True)
return self._retry()