import os import random import math import numpy as np from PIL import Image from torch.utils.data import Dataset from torchvision import transforms class CustomCocoDataset(Dataset): def __init__(self, img_folder, img_size=512, hint_size=448): self.img_folder = img_folder self.img_size = img_size self.hint_size = hint_size self.ids = [os.path.splitext(f)[0] for f in os.listdir(img_folder) if f.endswith(('.jpg', '.jpeg', '.png'))] def __len__(self): return len(self.ids) def __getitem__(self, index): img_id = self.ids[index] img_path = os.path.join(self.img_folder, img_id + '.png') image = Image.open(img_path).convert('RGB') # Perform a random crop using the custom random_crop_arr function cropped_image = random_crop_arr(image, self.img_size, min_crop_frac=0.8, max_crop_frac=1.0) # Convert cropped image back to PIL for further processing cropped_image = Image.fromarray(cropped_image) # Resize to different resolutions jpg_image = transforms.functional.to_tensor(cropped_image) hint_image = transforms.functional.resize(cropped_image, (self.hint_size, self.hint_size), interpolation=transforms.InterpolationMode.BICUBIC) hint_image = transforms.functional.to_tensor(hint_image) # Set captions to an empty string prompt = "" return dict(jpg=jpg_image, txt=prompt, hint=hint_image) def random_crop_arr(pil_image, image_size, min_crop_frac=0.8, max_crop_frac=1.0): min_smaller_dim_size = math.ceil(image_size / max_crop_frac) max_smaller_dim_size = math.ceil(image_size / min_crop_frac) smaller_dim_size = random.randrange(min_smaller_dim_size, max_smaller_dim_size + 1) # We are not on a new enough PIL to support the reducing_gap # argument, which uses BOX downsampling at powers of two first. # Thus, we do it by hand to improve downsample quality. while min(*pil_image.size) >= 2 * smaller_dim_size: pil_image = pil_image.resize( tuple(x // 2 for x in pil_image.size), resample=Image.BOX ) scale = smaller_dim_size / min(*pil_image.size) pil_image = pil_image.resize( tuple(round(x * scale) for x in pil_image.size), resample=Image.BICUBIC ) arr = np.array(pil_image) crop_y = random.randrange(arr.shape[0] - image_size + 1) crop_x = random.randrange(arr.shape[1] - image_size + 1) return arr[crop_y : crop_y + image_size, crop_x : crop_x + image_size] if __name__ == "__main__": dataset = CustomCocoDataset("/home/t2vg-a100-G4-1/projects/dataset/LSDIR_raw/images/train") print(len(dataset)) print(dataset[0]) from torch.utils.data import DataLoader dataloader = DataLoader( dataset, batch_size=4, num_workers=2, pin_memory=True, drop_last=True) # 从 DataLoader 中取出一个批次 batch = next(iter(dataloader)) # 提取批次中的 jpg_image 和 hint_image jpg_images = batch['jpg'] hint_images = batch['hint'] prompts = batch['txt'] # 打印提示语 print(f"Prompt: {prompts}") # 可视化并保存第一个batch的图像 import matplotlib.pyplot as plt for i in range(len(jpg_images)): plt.figure(figsize=(10, 5)) plt.subplot(1, 2, 1) plt.title(f"JPG Image {i+1} (512x512)") plt.imshow(jpg_images[i].permute(1, 2, 0)) # 转换维度以便imshow使用 plt.subplot(1, 2, 2) plt.title(f"Hint Image {i+1} (448x448)") plt.imshow(hint_images[i].permute(1, 2, 0)) # 转换维度以便imshow使用 # 保存图像到文件 plt.savefig(f'output_image_{i+1}.png') # 关闭当前图像,释放内存 plt.close()