Spaces:
Paused
Paused
| import torch | |
| import numpy as np | |
| from einops import rearrange | |
| from src.datasets.utils import crop2square | |
| from src.datasets.text2image.caption_datasets import CaptionDataset | |
| from PIL import Image | |
| import os | |
| class ImageEditDataset(CaptionDataset): | |
| def _process_image(self, image): | |
| assert self.image_process != 'crop2square' | |
| return super()._process_image(image)['pixel_values'] | |
| # image = image.resize(size=(self.image_size, self.image_size)) | |
| # pixel_values = torch.from_numpy(np.array(image)).float() | |
| # pixel_values = pixel_values / 255 | |
| # pixel_values = 2 * pixel_values - 1 | |
| # pixel_values = rearrange(pixel_values, 'h w c -> c h w') | |
| # return pixel_values | |
| def _process_text(self, text): | |
| prompt_template = self.prompt_template | |
| image_tokens = prompt_template['IMG_START_TOKEN'] + \ | |
| prompt_template['IMG_CONTEXT_TOKEN'] * self.image_length + \ | |
| prompt_template['IMG_END_TOKEN'] | |
| prompt = f'{image_tokens}\n{text}' | |
| prompt = self.prompt_template['INSTRUCTION'].format(input=prompt) | |
| if self.prompt_template.get('IMG_START_TOKEN_FOR_GENERATION', True): | |
| prompt += prompt_template['IMG_START_TOKEN'] | |
| input_ids = self.tokenizer.encode(prompt, return_tensors='pt', **self.tokenizer_kwargs)[0] | |
| return dict(input_ids=input_ids) | |
| def __getitem__(self, idx): | |
| if self.debug: | |
| idx = 0 | |
| try: | |
| data_sample = self.data_list[idx] | |
| if self.image_folder is not None: | |
| source_image = Image.open(os.path.join(self.image_folder,data_sample['input_image'][0])).convert('RGB') | |
| target_image = Image.open(os.path.join(self.image_folder,data_sample['output_image'])).convert('RGB') | |
| else: | |
| source_image = Image.open(data_sample['input_image'][0]).convert('RGB') | |
| target_image = Image.open(data_sample['output_image']).convert('RGB') | |
| # prompt = self._read_json(data_sample['annotation'])[self.cap_source] | |
| prompt = data_sample['instruction'] | |
| pixel_values_src = self._process_image(source_image) | |
| pixel_values = self._process_image(target_image) | |
| data = self._process_text(prompt) if self.tokenizer is not None else dict() | |
| data.update( | |
| pixel_values_src=pixel_values_src, pixel_values=pixel_values, | |
| image_dir=self.image_folder,type='image2image', text=prompt) | |
| return data | |
| except Exception as e: | |
| print(f"Error when reading {self.data_path}:{self.data_list[idx]}: {e}", flush=True) | |
| return self._retry() | |
| class ReconstructDataset(CaptionDataset): | |
| def _process_image(self, image): | |
| assert self.image_process != 'crop2square' | |
| return super()._process_image(image)['pixel_values'] | |
| def __getitem__(self, idx): | |
| if self.debug: | |
| idx = 0 | |
| try: | |
| data_sample = self.data_list[idx] | |
| image = self._read_image(data_sample['image']).convert('RGB') | |
| prompt = "Keep the image as it is." | |
| pixel_values = pixel_values_src = self._process_image(image) | |
| data = self._process_text(prompt) if self.tokenizer is not None else dict() | |
| data.update( | |
| pixel_values_src=pixel_values_src, pixel_values=pixel_values, | |
| image_dir=self.image_folder, image_file=data_sample['image'], | |
| type='image2image', text=prompt) | |
| return data | |
| except Exception as e: | |
| print(f"Error when reading {self.data_path}:{self.data_list[idx]}: {e}", flush=True) | |
| return self._retry() | |