event_retrieval / complex_image_search /utils /data_loading_utils.py
sanskar753's picture
Upload folder using huggingface_hub
02d3a85 verified
import torch
from PIL import Image
import os
from torch.utils.data import Dataset
import torch.distributed as dist
def ddp_setup(rank: int, world_size: int):
os.environ["MASTER_ADDR"] = "localhost"
os.environ["MASTER_PORT"] = "12355"
dist.init_process_group(backend="nccl", rank=rank, world_size=world_size)
torch.cuda.set_device(rank)
# class BlipCaptionedWrapper(nn.Module):
# def __init__(self, model, num_of_captions=1):
# super().__init__()
# self.model = model
# self.num_of_captions = num_of_captions
#
# def forward(self, image_inputs):
# return self.model.generate({"image": image_inputs}, use_nucleus_sampling=True, num_captions=self.num_of_captions)
class PrepareImageForBlip(object):
def __init__(self, processor):
self.processor = processor
def __call__(self, image):
return self.processor(image)
class ImageDataset(Dataset):
def __init__(self, root_dir, transform=None, path_list=None):
self.root_dir = root_dir
if path_list:
self.image_list = path_list
else:
self.image_list = os.listdir(root_dir)
self.transform = transform
def __len__(self):
return len(self.image_list)
def __getitem__(self, idx):
image_name = os.path.join(self.root_dir, self.image_list[idx])
img = Image.open(image_name)
img = img.convert("L")
img = img.convert("RGB")
filename = self.image_list[idx]
if self.transform:
img = self.transform(img)
return img, filename
class ImageDatasetFromImageList(Dataset):
def __init__(self, image_list):
self.image_list = image_list
def __len__(self):
return len(self.image_list)
def __getitem__(self, idx):
return self.image_list[idx]