import torch from PIL import Image import os from torch.utils.data import Dataset import torch.distributed as dist def ddp_setup(rank: int, world_size: int): os.environ["MASTER_ADDR"] = "localhost" os.environ["MASTER_PORT"] = "12355" dist.init_process_group(backend="nccl", rank=rank, world_size=world_size) torch.cuda.set_device(rank) # class BlipCaptionedWrapper(nn.Module): # def __init__(self, model, num_of_captions=1): # super().__init__() # self.model = model # self.num_of_captions = num_of_captions # # def forward(self, image_inputs): # return self.model.generate({"image": image_inputs}, use_nucleus_sampling=True, num_captions=self.num_of_captions) class PrepareImageForBlip(object): def __init__(self, processor): self.processor = processor def __call__(self, image): return self.processor(image) class ImageDataset(Dataset): def __init__(self, root_dir, transform=None, path_list=None): self.root_dir = root_dir if path_list: self.image_list = path_list else: self.image_list = os.listdir(root_dir) self.transform = transform def __len__(self): return len(self.image_list) def __getitem__(self, idx): image_name = os.path.join(self.root_dir, self.image_list[idx]) img = Image.open(image_name) img = img.convert("L") img = img.convert("RGB") filename = self.image_list[idx] if self.transform: img = self.transform(img) return img, filename class ImageDatasetFromImageList(Dataset): def __init__(self, image_list): self.image_list = image_list def __len__(self): return len(self.image_list) def __getitem__(self, idx): return self.image_list[idx]