Spaces:
Runtime error
Runtime error
File size: 1,834 Bytes
02d3a85 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 | import torch
from PIL import Image
import os
from torch.utils.data import Dataset
import torch.distributed as dist
def ddp_setup(rank: int, world_size: int):
os.environ["MASTER_ADDR"] = "localhost"
os.environ["MASTER_PORT"] = "12355"
dist.init_process_group(backend="nccl", rank=rank, world_size=world_size)
torch.cuda.set_device(rank)
# class BlipCaptionedWrapper(nn.Module):
# def __init__(self, model, num_of_captions=1):
# super().__init__()
# self.model = model
# self.num_of_captions = num_of_captions
#
# def forward(self, image_inputs):
# return self.model.generate({"image": image_inputs}, use_nucleus_sampling=True, num_captions=self.num_of_captions)
class PrepareImageForBlip(object):
def __init__(self, processor):
self.processor = processor
def __call__(self, image):
return self.processor(image)
class ImageDataset(Dataset):
def __init__(self, root_dir, transform=None, path_list=None):
self.root_dir = root_dir
if path_list:
self.image_list = path_list
else:
self.image_list = os.listdir(root_dir)
self.transform = transform
def __len__(self):
return len(self.image_list)
def __getitem__(self, idx):
image_name = os.path.join(self.root_dir, self.image_list[idx])
img = Image.open(image_name)
img = img.convert("L")
img = img.convert("RGB")
filename = self.image_list[idx]
if self.transform:
img = self.transform(img)
return img, filename
class ImageDatasetFromImageList(Dataset):
def __init__(self, image_list):
self.image_list = image_list
def __len__(self):
return len(self.image_list)
def __getitem__(self, idx):
return self.image_list[idx] |