Spaces:
Runtime error
Runtime error
| import torch | |
| from PIL import Image | |
| import os | |
| from torch.utils.data import Dataset | |
| import torch.distributed as dist | |
| def ddp_setup(rank: int, world_size: int): | |
| os.environ["MASTER_ADDR"] = "localhost" | |
| os.environ["MASTER_PORT"] = "12355" | |
| dist.init_process_group(backend="nccl", rank=rank, world_size=world_size) | |
| torch.cuda.set_device(rank) | |
| # class BlipCaptionedWrapper(nn.Module): | |
| # def __init__(self, model, num_of_captions=1): | |
| # super().__init__() | |
| # self.model = model | |
| # self.num_of_captions = num_of_captions | |
| # | |
| # def forward(self, image_inputs): | |
| # return self.model.generate({"image": image_inputs}, use_nucleus_sampling=True, num_captions=self.num_of_captions) | |
| class PrepareImageForBlip(object): | |
| def __init__(self, processor): | |
| self.processor = processor | |
| def __call__(self, image): | |
| return self.processor(image) | |
| class ImageDataset(Dataset): | |
| def __init__(self, root_dir, transform=None, path_list=None): | |
| self.root_dir = root_dir | |
| if path_list: | |
| self.image_list = path_list | |
| else: | |
| self.image_list = os.listdir(root_dir) | |
| self.transform = transform | |
| def __len__(self): | |
| return len(self.image_list) | |
| def __getitem__(self, idx): | |
| image_name = os.path.join(self.root_dir, self.image_list[idx]) | |
| img = Image.open(image_name) | |
| img = img.convert("L") | |
| img = img.convert("RGB") | |
| filename = self.image_list[idx] | |
| if self.transform: | |
| img = self.transform(img) | |
| return img, filename | |
| class ImageDatasetFromImageList(Dataset): | |
| def __init__(self, image_list): | |
| self.image_list = image_list | |
| def __len__(self): | |
| return len(self.image_list) | |
| def __getitem__(self, idx): | |
| return self.image_list[idx] |